mlxrunner: tokenize prompts in request handler goroutines

Move tokenization out of the single GPU processing goroutine and
into each request's HTTP handler goroutine. This allows the next
request's prompt to be tokenized on the CPU while the current
request is executing on the GPU.
This commit is contained in:
Jesse Gross 2026-04-03 16:25:33 -07:00
parent 04f5f0cdb4
commit ce99f24731
4 changed files with 41 additions and 29 deletions

View file

@ -226,7 +226,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("%s", strings.TrimSpace(string(respBody)))
return api.StatusError{StatusCode: resp.StatusCode, ErrorMessage: strings.TrimSpace(string(respBody))}
}
scanner := bufio.NewScanner(resp.Body)

View file

@ -6,11 +6,9 @@ import (
"errors"
"fmt"
"log/slog"
"net/http"
"sort"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
@ -22,13 +20,37 @@ func prefillChunkSize() int {
return 2 << 10
}
func (r *Runner) TextGenerationPipeline(request Request) error {
// Prepare tokenizes the prompt and validates it against the model's
// context length. It is safe to call from any goroutine. On success it
// populates request.Tokens and adjusts request.Options.NumPredict.
func (r *Runner) Prepare(request *Request) error {
if r.Model == nil {
return errors.New("model not loaded")
}
tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
if len(tokens) == 0 {
return errors.New("empty prompt")
}
if len(tokens) >= r.contextLength {
return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength)
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(tokens)
if request.Options.NumPredict <= 0 {
request.Options.NumPredict = maxGenerate
} else {
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
}
request.Tokens = tokens
return nil
}
func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error {
mlx.ResetPeakMemory()
ctx := request.Ctx
var sample, nextSample sampler.Result
defer func() {
@ -47,26 +69,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
}()
inputs := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
if len(inputs) == 0 {
return errors.New("empty prompt")
}
if len(inputs) >= r.contextLength {
return api.StatusError{
StatusCode: http.StatusBadRequest,
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
}
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(inputs)
if request.Options.NumPredict <= 0 {
request.Options.NumPredict = maxGenerate
} else {
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
}
inputs := request.Tokens
request.Sampler.ResetHistory(inputs)
session := r.cache.begin(r.Model, inputs)

View file

@ -18,13 +18,17 @@ import (
"github.com/ollama/ollama/x/tokenizer"
)
// Request is a short-lived struct that carries a completion request through
// a channel from the HTTP handler to the runner goroutine. The ctx field
// must travel with the request so that cancellation propagates across the
// channel boundary.
type Request struct {
CompletionRequest
Responses chan CompletionResponse
Pipeline func(Request) error
Ctx context.Context
Pipeline func(context.Context, Request) error
Ctx context.Context //nolint:containedctx
Tokens []int32
Sampler *sample.Sampler
}
@ -131,7 +135,7 @@ func (r *Runner) Run(host, port string, mux http.Handler) error {
case <-ctx.Done():
return nil
case request := <-r.Requests:
if err := request.Pipeline(request); err != nil {
if err := request.Pipeline(request.Ctx, request); err != nil {
slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {

View file

@ -106,6 +106,11 @@ func Execute(args []string) error {
TopLogprobs: request.TopLogprobs,
})
if err := runner.Prepare(&request); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context())
defer cancel()