mirror of
https://github.com/ollama/ollama
synced 2026-04-23 08:45:14 +00:00
Match the ollamarunner and OpenAI semantics: raw, full-vocab log-softmax with the top-K ranked by probability. Skipped on the GPU when the request doesn't ask for logprobs so decode doesn't pay for it otherwise.
258 lines
6.5 KiB
Go
258 lines
6.5 KiB
Go
package mlxrunner
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"sort"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/llm"
|
|
"github.com/ollama/ollama/logutil"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
|
|
"github.com/ollama/ollama/x/tokenizer"
|
|
)
|
|
|
|
func prefillChunkSize() int {
|
|
return 2 << 10
|
|
}
|
|
|
|
func (r *Runner) TextGenerationPipeline(request Request) error {
|
|
if r.Model == nil {
|
|
return errors.New("model not loaded")
|
|
}
|
|
|
|
mlx.ResetPeakMemory()
|
|
ctx := request.Ctx
|
|
var sample, nextSample sampler.Result
|
|
|
|
defer func() {
|
|
if request.Sampler != nil {
|
|
request.Sampler.Free()
|
|
}
|
|
mlx.Unpin(sample.Arrays()...)
|
|
mlx.Unpin(nextSample.Arrays()...)
|
|
mlx.Sweep()
|
|
mlx.ClearCache()
|
|
|
|
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
|
mlx.LogArrays()
|
|
r.cache.dumpTree()
|
|
}
|
|
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
|
}()
|
|
|
|
inputs := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
|
|
if len(inputs) == 0 {
|
|
return errors.New("empty prompt")
|
|
}
|
|
|
|
if len(inputs) >= r.contextLength {
|
|
return api.StatusError{
|
|
StatusCode: http.StatusBadRequest,
|
|
ErrorMessage: fmt.Sprintf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(inputs), r.contextLength),
|
|
}
|
|
}
|
|
|
|
// Cap generation to stay within the model's context length
|
|
maxGenerate := r.contextLength - len(inputs)
|
|
if request.Options.NumPredict <= 0 {
|
|
request.Options.NumPredict = maxGenerate
|
|
} else {
|
|
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
|
|
}
|
|
|
|
request.Sampler.ResetHistory(inputs)
|
|
|
|
session := r.cache.begin(r.Model, inputs)
|
|
defer session.close()
|
|
|
|
caches := session.caches
|
|
tokens := session.remaining
|
|
prefillChunk := prefillChunkSize()
|
|
|
|
// Request periodic snapshots during prefill and near the end of the
|
|
// prompt so that long prompts can be partially restored and
|
|
// thinking/generation can be retried without full reprocessing.
|
|
const snapshotInterval = 8192
|
|
for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval {
|
|
session.requestSnapshot(offset)
|
|
}
|
|
|
|
const preThinking = 4
|
|
if end := len(inputs) - preThinking; end > 0 {
|
|
session.requestSnapshot(end)
|
|
}
|
|
|
|
materializeCaches := func() {
|
|
state := make([]*mlx.Array, 0, 2*len(caches))
|
|
for _, c := range caches {
|
|
state = append(state, c.State()...)
|
|
}
|
|
if len(state) == 0 {
|
|
return
|
|
}
|
|
mlx.Eval(state...)
|
|
}
|
|
|
|
now := time.Now()
|
|
total, processed := len(tokens), 0
|
|
for total-processed > 1 {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
n := min(prefillChunk, total-processed-1)
|
|
|
|
// If there's a pending snapshot, split the batch so we can
|
|
// capture it at the exact offset.
|
|
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
|
|
baseOffset := len(session.inputs) - len(tokens)
|
|
tokensUntilSnapshot := snapOffset - (baseOffset + processed)
|
|
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
|
|
n = tokensUntilSnapshot
|
|
}
|
|
}
|
|
|
|
r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
|
|
mlx.Sweep()
|
|
materializeCaches()
|
|
processed += n
|
|
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
|
|
|
// Create snapshot if we've reached a pending offset.
|
|
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 {
|
|
baseOffset := len(session.inputs) - len(tokens)
|
|
if baseOffset+processed >= snapOffset {
|
|
session.snapshot()
|
|
}
|
|
}
|
|
|
|
mlx.ClearCache()
|
|
}
|
|
|
|
step := func(token *mlx.Array) sampler.Result {
|
|
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
|
logits := r.Model.Unembed(fwd)
|
|
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
|
|
|
sample := request.Sampler.Sample(logits)
|
|
mlx.Pin(sample.Arrays()...)
|
|
mlx.Sweep()
|
|
mlx.AsyncEval(sample.Arrays()...)
|
|
return sample
|
|
}
|
|
|
|
sample = step(mlx.FromValues(tokens[processed:], total-processed))
|
|
|
|
dec := decoder{tokenizer: r.Tokenizer}
|
|
|
|
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
|
|
for i := range request.Options.NumPredict {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
request.Sampler.AppendToken(sample.Token)
|
|
nextSample = step(sample.Token)
|
|
|
|
if i == 0 {
|
|
mlx.Eval(sample.Arrays()...)
|
|
final.PromptEvalDuration = time.Since(now)
|
|
now = time.Now()
|
|
}
|
|
|
|
output := int32(sample.Token.Int())
|
|
session.outputs = append(session.outputs, output)
|
|
|
|
if r.Tokenizer.IsEOS(output) {
|
|
final.DoneReason = 0
|
|
final.EvalCount = i
|
|
break
|
|
}
|
|
|
|
if resp, ok := dec.decode(sample); ok {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case request.Responses <- resp:
|
|
}
|
|
}
|
|
|
|
mlx.Unpin(sample.Arrays()...)
|
|
sample, nextSample = nextSample, sampler.Result{}
|
|
|
|
if i%256 == 0 {
|
|
mlx.ClearCache()
|
|
}
|
|
}
|
|
|
|
final.EvalDuration = time.Since(now)
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case request.Responses <- final:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// decoder serializes sampled tokens into response chunks, holding bytes
|
|
// whose UTF-8 sequence hasn't completed yet and the logprobs that belong
|
|
// with those bytes so Content and Logprobs stay aligned when a chunk does
|
|
// flush.
|
|
type decoder struct {
|
|
tokenizer *tokenizer.Tokenizer
|
|
buf bytes.Buffer
|
|
logprobs []llm.Logprob
|
|
}
|
|
|
|
func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
|
|
output := int32(res.Token.Int())
|
|
d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
|
|
d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...)
|
|
|
|
content := flushValidUTF8Prefix(&d.buf)
|
|
if content == "" {
|
|
return CompletionResponse{}, false
|
|
}
|
|
resp := CompletionResponse{Content: content, Logprobs: d.logprobs}
|
|
d.logprobs = nil
|
|
return resp, true
|
|
}
|
|
|
|
func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob {
|
|
if sample.Logprob == nil {
|
|
return nil
|
|
}
|
|
tok := func(id int32) string { return decode([]int32{id}) }
|
|
|
|
out := llm.Logprob{
|
|
TokenLogprob: llm.TokenLogprob{
|
|
Token: tok(int32(sample.Token.Int())),
|
|
Logprob: float64(sample.Logprob.Floats()[0]),
|
|
},
|
|
}
|
|
|
|
if sample.TopTokens != nil {
|
|
ids := sample.TopTokens.Ints()
|
|
vals := sample.TopLogprobs.Floats()
|
|
pairs := make([]llm.TokenLogprob, len(ids))
|
|
for i, id := range ids {
|
|
pairs[i] = llm.TokenLogprob{
|
|
Token: tok(int32(id)),
|
|
Logprob: float64(vals[i]),
|
|
}
|
|
}
|
|
sort.Slice(pairs, func(i, j int) bool {
|
|
return pairs[i].Logprob > pairs[j].Logprob
|
|
})
|
|
out.TopLogprobs = pairs
|
|
}
|
|
return []llm.Logprob{out}
|
|
}
|