From ce99f24731773635408d43fc82b3f83eb6f5a570 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 3 Apr 2026 16:25:33 -0700 Subject: [PATCH] mlxrunner: tokenize prompts in request handler goroutines Move tokenization out of the single GPU processing goroutine and into each request's HTTP handler goroutine. This allows the next request's prompt to be tokenized on the CPU while the current request is executing on the GPU. --- x/mlxrunner/client.go | 2 +- x/mlxrunner/pipeline.go | 51 ++++++++++++++++++++++------------------- x/mlxrunner/runner.go | 12 ++++++---- x/mlxrunner/server.go | 5 ++++ 4 files changed, 41 insertions(+), 29 deletions(-) diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index fa55fbd71..4730a10d0 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -226,7 +226,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("%s", strings.TrimSpace(string(respBody))) + return api.StatusError{StatusCode: resp.StatusCode, ErrorMessage: strings.TrimSpace(string(respBody))} } scanner := bufio.NewScanner(resp.Body) diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index b1eef58a5..6297c220b 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -6,11 +6,9 @@ import ( "errors" "fmt" "log/slog" - "net/http" "sort" "time" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/mlx" @@ -22,13 +20,37 @@ func prefillChunkSize() int { return 2 << 10 } -func (r *Runner) TextGenerationPipeline(request Request) error { +// Prepare tokenizes the prompt and validates it against the model's +// context length. It is safe to call from any goroutine. On success it +// populates request.Tokens and adjusts request.Options.NumPredict. +func (r *Runner) Prepare(request *Request) error { if r.Model == nil { return errors.New("model not loaded") } + tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS()) + if len(tokens) == 0 { + return errors.New("empty prompt") + } + + if len(tokens) >= r.contextLength { + return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength) + } + + // Cap generation to stay within the model's context length + maxGenerate := r.contextLength - len(tokens) + if request.Options.NumPredict <= 0 { + request.Options.NumPredict = maxGenerate + } else { + request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate) + } + + request.Tokens = tokens + return nil +} + +func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error { mlx.ResetPeakMemory() - ctx := request.Ctx var sample, nextSample sampler.Result defer func() { @@ -47,26 +69,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory())) }() - inputs := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS()) - if len(inputs) == 0 { - return errors.New("empty prompt") - } - - if len(inputs) >= r.contextLength { - return api.StatusError{ - StatusCode: http.StatusBadRequest, - ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength), - } - } - - // Cap generation to stay within the model's context length - maxGenerate := r.contextLength - len(inputs) - if request.Options.NumPredict <= 0 { - request.Options.NumPredict = maxGenerate - } else { - request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate) - } - + inputs := request.Tokens request.Sampler.ResetHistory(inputs) session := r.cache.begin(r.Model, inputs) diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 95e1f2b62..2ab5e323a 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -18,13 +18,17 @@ import ( "github.com/ollama/ollama/x/tokenizer" ) +// Request is a short-lived struct that carries a completion request through +// a channel from the HTTP handler to the runner goroutine. The ctx field +// must travel with the request so that cancellation propagates across the +// channel boundary. type Request struct { CompletionRequest Responses chan CompletionResponse - Pipeline func(Request) error - - Ctx context.Context + Pipeline func(context.Context, Request) error + Ctx context.Context //nolint:containedctx + Tokens []int32 Sampler *sample.Sampler } @@ -131,7 +135,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error { case <-ctx.Done(): return nil case request := <-r.Requests: - if err := request.Pipeline(request); err != nil { + if err := request.Pipeline(request.Ctx, request); err != nil { slog.Info("Request terminated", "error", err) var statusErr api.StatusError if !errors.As(err, &statusErr) { diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index 089e5c7f5..41d8c976f 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -106,6 +106,11 @@ func Execute(args []string) error { TopLogprobs: request.TopLogprobs, }) + if err := runner.Prepare(&request); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + var cancel context.CancelFunc request.Ctx, cancel = context.WithCancel(r.Context()) defer cancel()