ollama/x/mlxrunner/pipeline.go
Jesse Gross ce99f24731 mlxrunner: tokenize prompts in request handler goroutines
Move tokenization out of the single GPU processing goroutine and
into each request's HTTP handler goroutine. This allows the next
request's prompt to be tokenized on the CPU while the current
request is executing on the GPU.
2026-04-21 14:38:49 -07:00

261 lines
6.7 KiB
Go

package mlxrunner
import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"sort"
"time"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/mlx"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
)
func prefillChunkSize() int {
return 2 << 10
}
// Prepare tokenizes the prompt and validates it against the model's
// context length. It is safe to call from any goroutine. On success it
// populates request.Tokens and adjusts request.Options.NumPredict.
func (r *Runner) Prepare(request *Request) error {
if r.Model == nil {
return errors.New("model not loaded")
}
tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
if len(tokens) == 0 {
return errors.New("empty prompt")
}
if len(tokens) >= r.contextLength {
return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength)
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(tokens)
if request.Options.NumPredict <= 0 {
request.Options.NumPredict = maxGenerate
} else {
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
}
request.Tokens = tokens
return nil
}
func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error {
mlx.ResetPeakMemory()
var sample, nextSample sampler.Result
defer func() {
if request.Sampler != nil {
request.Sampler.Free()
}
mlx.Unpin(sample.Arrays()...)
mlx.Unpin(nextSample.Arrays()...)
mlx.Sweep()
mlx.ClearCache()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
r.cache.dumpTree()
}
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
}()
inputs := request.Tokens
request.Sampler.ResetHistory(inputs)
session := r.cache.begin(r.Model, inputs)
defer session.close()
caches := session.caches
tokens := session.remaining
prefillChunk := prefillChunkSize()
// Request periodic snapshots during prefill and near the end of the
// prompt so that long prompts can be partially restored and
// thinking/generation can be retried without full reprocessing.
const snapshotInterval = 8192
for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval {
session.requestSnapshot(offset)
}
const preThinking = 4
if end := len(inputs) - preThinking; end > 0 {
session.requestSnapshot(end)
}
materializeCaches := func() {
state := make([]*mlx.Array, 0, 2*len(caches))
for _, c := range caches {
state = append(state, c.State()...)
}
if len(state) == 0 {
return
}
mlx.Eval(state...)
}
now := time.Now()
total, processed := len(tokens), 0
for total-processed > 1 {
if err := ctx.Err(); err != nil {
return err
}
n := min(prefillChunk, total-processed-1)
// If there's a pending snapshot, split the batch so we can
// capture it at the exact offset.
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
baseOffset := len(session.inputs) - len(tokens)
tokensUntilSnapshot := snapOffset - (baseOffset + processed)
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
n = tokensUntilSnapshot
}
}
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
mlx.Sweep()
materializeCaches()
processed += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
// Create snapshot if we've reached a pending offset.
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
baseOffset := len(session.inputs) - len(tokens)
if baseOffset+processed >= snapOffset {
session.snapshot()
}
}
mlx.ClearCache()
}
step := func(token *mlx.Array) sampler.Result {
fwd := r.Model.Forward(token.ExpandDims(0), caches)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
sample := request.Sampler.Sample(logits)
mlx.Pin(sample.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(sample.Arrays()...)
return sample
}
sample = step(mlx.FromValues(tokens[processed:], total-processed))
dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
for i := range request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
request.Sampler.AppendToken(sample.Token)
nextSample = step(sample.Token)
if i == 0 {
mlx.Eval(sample.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Token.Int())
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
final.DoneReason = 0
final.EvalCount = i
break
}
if resp, ok := dec.decode(sample); ok {
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- resp:
}
}
mlx.Unpin(sample.Arrays()...)
sample, nextSample = nextSample, sampler.Result{}
if i%256 == 0 {
mlx.ClearCache()
}
}
final.EvalDuration = time.Since(now)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
// decoder serializes sampled tokens into response chunks, holding bytes
// whose UTF-8 sequence hasn't completed yet and the logprobs that belong
// with those bytes so Content and Logprobs stay aligned when a chunk does
// flush.
type decoder struct {
tokenizer *tokenizer.Tokenizer
buf bytes.Buffer
logprobs []llm.Logprob
}
func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
output := int32(res.Token.Int())
d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...)
content := flushValidUTF8Prefix(&d.buf)
if content == "" {
return CompletionResponse{}, false
}
resp := CompletionResponse{Content: content, Logprobs: d.logprobs}
d.logprobs = nil
return resp, true
}
func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob {
if sample.Logprob == nil {
return nil
}
tok := func(id int32) string { return decode([]int32{id}) }
out := llm.Logprob{
TokenLogprob: llm.TokenLogprob{
Token: tok(int32(sample.Token.Int())),
Logprob: float64(sample.Logprob.Floats()[0]),
},
}
if sample.TopTokens != nil {
ids := sample.TopTokens.Ints()
vals := sample.TopLogprobs.Floats()
pairs := make([]llm.TokenLogprob, len(ids))
for i, id := range ids {
pairs[i] = llm.TokenLogprob{
Token: tok(int32(id)),
Logprob: float64(vals[i]),
}
}
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Logprob > pairs[j].Logprob
})
out.TopLogprobs = pairs
}
return []llm.Logprob{out}
}