diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index fa55fbd71..4730a10d0 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -226,7 +226,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("%s", strings.TrimSpace(string(respBody))) + return api.StatusError{StatusCode: resp.StatusCode, ErrorMessage: strings.TrimSpace(string(respBody))} } scanner := bufio.NewScanner(resp.Body) diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index b1eef58a5..6297c220b 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -6,11 +6,9 @@ import ( "errors" "fmt" "log/slog" - "net/http" "sort" "time" - "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/mlx" @@ -22,13 +20,37 @@ func prefillChunkSize() int { return 2 << 10 } -func (r *Runner) TextGenerationPipeline(request Request) error { +// Prepare tokenizes the prompt and validates it against the model's +// context length. It is safe to call from any goroutine. On success it +// populates request.Tokens and adjusts request.Options.NumPredict. +func (r *Runner) Prepare(request *Request) error { if r.Model == nil { return errors.New("model not loaded") } + tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS()) + if len(tokens) == 0 { + return errors.New("empty prompt") + } + + if len(tokens) >= r.contextLength { + return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength) + } + + // Cap generation to stay within the model's context length + maxGenerate := r.contextLength - len(tokens) + if request.Options.NumPredict <= 0 { + request.Options.NumPredict = maxGenerate + } else { + request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate) + } + + request.Tokens = tokens + return nil +} + +func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error { mlx.ResetPeakMemory() - ctx := request.Ctx var sample, nextSample sampler.Result defer func() { @@ -47,26 +69,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error { slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory())) }() - inputs := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS()) - if len(inputs) == 0 { - return errors.New("empty prompt") - } - - if len(inputs) >= r.contextLength { - return api.StatusError{ - StatusCode: http.StatusBadRequest, - ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength), - } - } - - // Cap generation to stay within the model's context length - maxGenerate := r.contextLength - len(inputs) - if request.Options.NumPredict <= 0 { - request.Options.NumPredict = maxGenerate - } else { - request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate) - } - + inputs := request.Tokens request.Sampler.ResetHistory(inputs) session := r.cache.begin(r.Model, inputs) diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 95e1f2b62..2ab5e323a 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -18,13 +18,17 @@ import ( "github.com/ollama/ollama/x/tokenizer" ) +// Request is a short-lived struct that carries a completion request through +// a channel from the HTTP handler to the runner goroutine. The ctx field +// must travel with the request so that cancellation propagates across the +// channel boundary. type Request struct { CompletionRequest Responses chan CompletionResponse - Pipeline func(Request) error - - Ctx context.Context + Pipeline func(context.Context, Request) error + Ctx context.Context //nolint:containedctx + Tokens []int32 Sampler *sample.Sampler } @@ -131,7 +135,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error { case <-ctx.Done(): return nil case request := <-r.Requests: - if err := request.Pipeline(request); err != nil { + if err := request.Pipeline(request.Ctx, request); err != nil { slog.Info("Request terminated", "error", err) var statusErr api.StatusError if !errors.As(err, &statusErr) { diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index 089e5c7f5..41d8c976f 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -106,6 +106,11 @@ func Execute(args []string) error { TopLogprobs: request.TopLogprobs, }) + if err := runner.Prepare(&request); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + var cancel context.CancelFunc request.Ctx, cancel = context.WithCancel(r.Context()) defer cancel()