diff --git a/integration/api_test.go b/integration/api_test.go index 5d06c7590..6e101ccdd 100644 --- a/integration/api_test.go +++ b/integration/api_test.go @@ -406,10 +406,6 @@ func TestAPIShowModel(t *testing.T) { } func TestAPIGenerateLogprobs(t *testing.T) { - if testModel != "" { - // Logprobs requires runner support (e.g. llama.cpp has it, MLX does not). - t.Skip("logprobs not supported by all runners") - } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() @@ -523,10 +519,6 @@ func TestAPIGenerateLogprobs(t *testing.T) { } func TestAPIChatLogprobs(t *testing.T) { - if testModel != "" { - // Logprobs requires runner support (e.g. llama.cpp has it, MLX does not). - t.Skip("logprobs not supported by all runners") - } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go index a4d639533..fa55fbd71 100644 --- a/x/mlxrunner/client.go +++ b/x/mlxrunner/client.go @@ -151,22 +151,11 @@ func (c *Client) WaitUntilRunning(ctx context.Context) error { } } -// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization. -type completionRequest struct { - Prompt string `json:"prompt"` - Options *completionOpts `json:"options,omitempty"` -} - -type completionOpts struct { - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - MinP float32 `json:"min_p,omitempty"` - TopK int `json:"top_k,omitempty"` - RepeatLastN int `json:"repeat_last_n,omitempty"` - RepeatPenalty float32 `json:"repeat_penalty,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - NumPredict int `json:"num_predict,omitempty"` +type CompletionRequest struct { + Prompt string + Options api.Options + Logprobs bool + TopLogprobs int } type CompletionResponse struct { @@ -179,6 +168,8 @@ type CompletionResponse struct { EvalCount int EvalDuration time.Duration + Logprobs []llm.Logprob + Error *api.StatusError } @@ -203,21 +194,13 @@ func (c *Client) Close() error { // Completion implements llm.LlamaServer. func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { - creq := completionRequest{ - Prompt: req.Prompt, + creq := CompletionRequest{ + Prompt: req.Prompt, + Logprobs: req.Logprobs, + TopLogprobs: req.TopLogprobs, } if req.Options != nil { - creq.Options = &completionOpts{ - Temperature: req.Options.Temperature, - TopP: req.Options.TopP, - MinP: req.Options.MinP, - TopK: req.Options.TopK, - RepeatLastN: req.Options.RepeatLastN, - RepeatPenalty: req.Options.RepeatPenalty, - PresencePenalty: req.Options.PresencePenalty, - FrequencyPenalty: req.Options.FrequencyPenalty, - NumPredict: req.Options.NumPredict, - } + creq.Options = *req.Options } body, err := json.Marshal(creq) @@ -266,6 +249,7 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f PromptEvalDuration: raw.PromptEvalDuration, EvalCount: raw.EvalCount, EvalDuration: raw.EvalDuration, + Logprobs: raw.Logprobs, } fn(cresp) diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go index a41aee9cb..c4c80dae4 100644 --- a/x/mlxrunner/mlx/array.go +++ b/x/mlxrunner/mlx/array.go @@ -238,6 +238,9 @@ func (t Array) Float() float64 { } func (t Array) Ints() []int { + if dt := t.DType(); dt != DTypeInt32 { + panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt)) + } ints := make([]int, t.Size()) for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) { ints[i] = int(f) @@ -246,6 +249,9 @@ func (t Array) Ints() []int { } func (t Array) Floats() []float32 { + if dt := t.DType(); dt != DTypeFloat32 { + panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt)) + } floats := make([]float32, t.Size()) for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) { floats[i] = float32(f) diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go index 8641d213b..b1eef58a5 100644 --- a/x/mlxrunner/pipeline.go +++ b/x/mlxrunner/pipeline.go @@ -7,11 +7,15 @@ import ( "fmt" "log/slog" "net/http" + "sort" "time" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/x/mlxrunner/mlx" + sampler "github.com/ollama/ollama/x/mlxrunner/sample" + "github.com/ollama/ollama/x/tokenizer" ) func prefillChunkSize() int { @@ -25,17 +29,14 @@ func (r *Runner) TextGenerationPipeline(request Request) error { mlx.ResetPeakMemory() ctx := request.Ctx - var ( - sample *mlx.Array - nextSample *mlx.Array - ) + var sample, nextSample sampler.Result defer func() { if request.Sampler != nil { request.Sampler.Free() } - mlx.Unpin(sample) - mlx.Unpin(nextSample) + mlx.Unpin(sample.Arrays()...) + mlx.Unpin(nextSample.Arrays()...) mlx.Sweep() mlx.ClearCache() @@ -60,10 +61,10 @@ func (r *Runner) TextGenerationPipeline(request Request) error { // Cap generation to stay within the model's context length maxGenerate := r.contextLength - len(inputs) - if request.Options.MaxTokens <= 0 { - request.Options.MaxTokens = maxGenerate + if request.Options.NumPredict <= 0 { + request.Options.NumPredict = maxGenerate } else { - request.Options.MaxTokens = min(request.Options.MaxTokens, maxGenerate) + request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate) } request.Sampler.ResetHistory(inputs) @@ -135,40 +136,38 @@ func (r *Runner) TextGenerationPipeline(request Request) error { mlx.ClearCache() } - step := func(token *mlx.Array) *mlx.Array { + step := func(token *mlx.Array) sampler.Result { fwd := r.Model.Forward(token.ExpandDims(0), caches) logits := r.Model.Unembed(fwd) logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1) sample := request.Sampler.Sample(logits) - - mlx.Pin(sample) + mlx.Pin(sample.Arrays()...) mlx.Sweep() - mlx.AsyncEval(sample) - + mlx.AsyncEval(sample.Arrays()...) return sample } sample = step(mlx.FromValues(tokens[processed:], total-processed)) - var b bytes.Buffer + dec := decoder{tokenizer: r.Tokenizer} - final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.MaxTokens, DoneReason: 1} - for i := range request.Options.MaxTokens { + final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1} + for i := range request.Options.NumPredict { if err := ctx.Err(); err != nil { return err } - request.Sampler.AppendToken(sample) - nextSample = step(sample) + request.Sampler.AppendToken(sample.Token) + nextSample = step(sample.Token) if i == 0 { - mlx.Eval(sample) + mlx.Eval(sample.Arrays()...) final.PromptEvalDuration = time.Since(now) now = time.Now() } - output := int32(sample.Int()) + output := int32(sample.Token.Int()) session.outputs = append(session.outputs, output) if r.Tokenizer.IsEOS(output) { @@ -177,17 +176,16 @@ func (r *Runner) TextGenerationPipeline(request Request) error { break } - select { - case <-ctx.Done(): - return ctx.Err() - case request.Responses <- CompletionResponse{ - Content: r.Decode(output, &b), - }: + if resp, ok := dec.decode(sample); ok { + select { + case <-ctx.Done(): + return ctx.Err() + case request.Responses <- resp: + } } - mlx.Unpin(sample) - sample = nextSample - nextSample = nil + mlx.Unpin(sample.Arrays()...) + sample, nextSample = nextSample, sampler.Result{} if i%256 == 0 { mlx.ClearCache() @@ -203,13 +201,57 @@ func (r *Runner) TextGenerationPipeline(request Request) error { } } -func (r Runner) Decode(sample int32, b *bytes.Buffer) string { - token := r.Tokenizer.Decode([]int32{sample}) +// decoder serializes sampled tokens into response chunks, holding bytes +// whose UTF-8 sequence hasn't completed yet and the logprobs that belong +// with those bytes so Content and Logprobs stay aligned when a chunk does +// flush. +type decoder struct { + tokenizer *tokenizer.Tokenizer + buf bytes.Buffer + logprobs []llm.Logprob +} - if _, err := b.WriteString(token); err != nil { - slog.Error("Failed to write token to buffer", "error", err) - return "" +func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) { + output := int32(res.Token.Int()) + d.buf.WriteString(d.tokenizer.Decode([]int32{output})) + d.logprobs = append(d.logprobs, buildLogprob(res, d.tokenizer.Decode)...) + + content := flushValidUTF8Prefix(&d.buf) + if content == "" { + return CompletionResponse{}, false + } + resp := CompletionResponse{Content: content, Logprobs: d.logprobs} + d.logprobs = nil + return resp, true +} + +func buildLogprob(sample sampler.Result, decode func([]int32) string) []llm.Logprob { + if sample.Logprob == nil { + return nil + } + tok := func(id int32) string { return decode([]int32{id}) } + + out := llm.Logprob{ + TokenLogprob: llm.TokenLogprob{ + Token: tok(int32(sample.Token.Int())), + Logprob: float64(sample.Logprob.Floats()[0]), + }, } - return flushValidUTF8Prefix(b) + if sample.TopTokens != nil { + ids := sample.TopTokens.Ints() + vals := sample.TopLogprobs.Floats() + pairs := make([]llm.TokenLogprob, len(ids)) + for i, id := range ids { + pairs[i] = llm.TokenLogprob{ + Token: tok(int32(id)), + Logprob: float64(vals[i]), + } + } + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].Logprob > pairs[j].Logprob + }) + out.TopLogprobs = pairs + } + return []llm.Logprob{out} } diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go index 3e9680304..95e1f2b62 100644 --- a/x/mlxrunner/runner.go +++ b/x/mlxrunner/runner.go @@ -19,7 +19,7 @@ import ( ) type Request struct { - TextCompletionsRequest + CompletionRequest Responses chan CompletionResponse Pipeline func(Request) error @@ -28,24 +28,6 @@ type Request struct { Sampler *sample.Sampler } -type TextCompletionsRequest struct { - Prompt string `json:"prompt"` - Options struct { - Temperature float32 `json:"temperature"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TopK int `json:"top_k"` - RepeatLastN int `json:"repeat_last_n"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - MaxTokens int `json:"max_tokens"` - - // Deprecated: use MaxTokens instead - NumPredict int `json:"num_predict"` - } `json:"options"` -} - type Runner struct { Model base.Model Tokenizer *tokenizer.Tokenizer diff --git a/x/mlxrunner/sample/logprob_test.go b/x/mlxrunner/sample/logprob_test.go new file mode 100644 index 000000000..fa46d6389 --- /dev/null +++ b/x/mlxrunner/sample/logprob_test.go @@ -0,0 +1,249 @@ +//go:build mlx + +package sample + +import ( + "math" + "sort" + "testing" + + "github.com/ollama/ollama/x/mlxrunner/mlx" +) + +// logprobEntry is the (token id, logprob) pair returned by the sampler's +// top-K extraction, used after the test-side descending sort. +type logprobEntry struct { + id int + logprob float64 +} + +// runSampleLogprobs drives Sample on a fresh Sampler configured for logprobs +// and returns the greedily-sampled token id, its logprob, and the top-K +// entries sorted descending by logprob. Logits must be a [vocab]-shaped +// slice; the helper reshapes it to [1, vocab] before calling the sampler. +func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) { + t.Helper() + + s := New(Options{Logprobs: true, TopLogprobs: topK}) + defer func() { + s.Free() + mlx.Sweep() + }() + + tensor := mlx.FromValues(logits, 1, len(logits)) + res := s.Sample(tensor) + + mlx.Pin(res.Arrays()...) + defer mlx.Unpin(res.Arrays()...) + mlx.Sweep() + mlx.Eval(res.Arrays()...) + + selected := res.Token.Int() + selLP := float64(res.Logprob.Floats()[0]) + + var top []logprobEntry + if topK > 0 && res.TopTokens != nil { + ids := res.TopTokens.Ints() + vals := res.TopLogprobs.Floats() + top = make([]logprobEntry, len(ids)) + for i, id := range ids { + top[i] = logprobEntry{id: id, logprob: float64(vals[i])} + } + sort.Slice(top, func(i, j int) bool { return top[i].logprob > top[j].logprob }) + } + return selected, selLP, top +} + +func TestSampleLogprobsBasic(t *testing.T) { + tests := []struct { + name string + logits []float32 + topK int + wantSelectedID int + wantTopLen int + }{ + { + name: "single token without top logprobs", + logits: []float32{1.0, 0.5, 0.3, 0.1}, + topK: 0, + wantSelectedID: 0, + wantTopLen: 0, + }, + { + name: "single token with top logprobs", + logits: []float32{1.0, 0.5, 0.3, 0.1}, + topK: 3, + wantSelectedID: 0, + wantTopLen: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selected, _, top := runSampleLogprobs(t, tt.logits, tt.topK) + if selected != tt.wantSelectedID { + t.Errorf("selected = %d, want %d", selected, tt.wantSelectedID) + } + if len(top) != tt.wantTopLen { + t.Errorf("top-K length = %d, want %d", len(top), tt.wantTopLen) + } + }) + } +} + +func TestSampleLogprobsNumericalStability(t *testing.T) { + logits := []float32{1000.0, 999.0, 998.0} + _, selLP, top := runSampleLogprobs(t, logits, 3) + + if math.IsInf(selLP, 0) || math.IsNaN(selLP) { + t.Errorf("selected logprob is not finite: %f", selLP) + } + for i, e := range top { + if math.IsInf(e.logprob, 0) || math.IsNaN(e.logprob) { + t.Errorf("top[%d] logprob is not finite: %f", i, e.logprob) + } + } + for i := 1; i < len(top); i++ { + if top[i].logprob > top[i-1].logprob { + t.Errorf("top logprobs not descending: %f > %f", top[i].logprob, top[i-1].logprob) + } + } +} + +func TestSampleLogprobsProbabilityCorrectness(t *testing.T) { + tests := []struct { + name string + logits []float32 + }{ + {"uniform", []float32{1.0, 1.0, 1.0, 1.0}}, + {"different", []float32{2.0, 1.0, 0.5, 0.1}}, + {"negative", []float32{-1.0, -2.0, -3.0, -4.0}}, + {"mixed", []float32{5.0, -5.0, 0.0, 2.5}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selected, selLP, top := runSampleLogprobs(t, tt.logits, len(tt.logits)) + + if selLP > 0 { + t.Errorf("selected logprob should be <= 0, got %f", selLP) + } + for i, e := range top { + if e.logprob > 0 { + t.Errorf("top[%d] logprob should be <= 0, got %f", i, e.logprob) + } + } + + if tt.name == "uniform" { + want := 1.0 / float64(len(tt.logits)) + got := math.Exp(selLP) + if math.Abs(got-want) > 1e-6 { + t.Errorf("uniform logits: selected prob = %f, want %f", got, want) + } + } + + for i := 1; i < len(top); i++ { + if top[i].logprob > top[i-1].logprob { + t.Errorf("top logprobs not descending at %d: %f > %f", + i, top[i].logprob, top[i-1].logprob) + } + } + + found := false + for _, e := range top { + if e.id == selected { + found = true + if math.Abs(e.logprob-selLP) > 1e-6 { + t.Errorf("selected logprob mismatch: selLP=%f top=%f", selLP, e.logprob) + } + break + } + } + if !found { + t.Errorf("selected token %d not present in top-K", selected) + } + }) + } +} + +func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) { + tests := []struct { + name string + logits []float32 + }{ + {"small vocabulary", []float32{1.0, 2.0, 3.0}}, + {"large differences", []float32{10.0, 0.0, -10.0}}, + {"all equal", []float32{5.0, 5.0, 5.0, 5.0, 5.0}}, + {"very large values", []float32{500.0, 499.0, 498.0}}, + {"very small values", []float32{-500.0, -499.0, -498.0}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, top := runSampleLogprobs(t, tt.logits, len(tt.logits)) + if len(top) != len(tt.logits) { + t.Fatalf("top-K length = %d, want %d", len(top), len(tt.logits)) + } + + var sum float64 + for _, e := range top { + p := math.Exp(e.logprob) + if p < 0 || p > 1 { + t.Errorf("token %d: probability %f out of [0,1]", e.id, p) + } + sum += p + } + + if math.Abs(sum-1.0) > 1e-5 { + t.Errorf("probabilities sum = %f, want 1.0", sum) + } + }) + } +} + +func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) { + logits := []float32{3.0, 1.0, 2.0, 0.5} + + maxIdx := 0 + for i, v := range logits[1:] { + if v > logits[maxIdx] { + maxIdx = i + 1 + } + } + + selected, selLP, top := runSampleLogprobs(t, logits, len(logits)) + + if selected != maxIdx { + t.Errorf("selected = %d, want argmax %d", selected, maxIdx) + } + + if top[0].id != maxIdx { + t.Errorf("top[0].id = %d, want argmax %d", top[0].id, maxIdx) + } + if math.Abs(top[0].logprob-selLP) > 1e-6 { + t.Errorf("top[0].logprob = %f, want selected %f", top[0].logprob, selLP) + } +} + +func TestSampleLogprobsTopKOrdering(t *testing.T) { + // Logits chosen so argmax order differs from index order. + logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0} + wantOrder := []int{1, 3, 4, 0, 2} + + _, _, top := runSampleLogprobs(t, logits, len(logits)) + + if len(top) != len(wantOrder) { + t.Fatalf("top-K length = %d, want %d", len(top), len(wantOrder)) + } + for i, e := range top { + if e.id != wantOrder[i] { + t.Errorf("top[%d].id = %d, want %d", i, e.id, wantOrder[i]) + } + } + for i := 1; i < len(top); i++ { + if top[i].logprob > top[i-1].logprob { + t.Errorf("top[%d].logprob (%f) > top[%d].logprob (%f)", + i, top[i].logprob, i-1, top[i-1].logprob) + } + } +} diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go index 2331d72f7..1bf5b82c7 100644 --- a/x/mlxrunner/sample/sample.go +++ b/x/mlxrunner/sample/sample.go @@ -8,7 +8,7 @@ import ( type Transform func(*Sampler, *mlx.Array) *mlx.Array -type Sampler struct { +type Options struct { Temperature float32 TopP float32 MinP float32 @@ -18,45 +18,61 @@ type Sampler struct { PresencePenalty float32 FrequencyPenalty float32 + // Logprobs causes Sample to populate Result.Logprob with the selected + // token's log-probability. TopLogprobs (when > 0) adds top-K pairs. + Logprobs bool + TopLogprobs int +} + +type Sampler struct { + Options + history *mlx.Array historyLen int transforms []Transform } -func New(temp, top_p, min_p float32, top_k, repeatLastN int, repeatPenalty, presencePenalty, frequencyPenalty float32) *Sampler { - if repeatPenalty <= 0 { - repeatPenalty = 1 +// Result bundles the outputs of one decode step. The logprob tensors are +// populated only when the sampler is configured to report them. +type Result struct { + Token *mlx.Array // sampled token id, shape [B] + Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs + TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0 + TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0 +} + +// Arrays returns the tensor fields as a slice so callers can drive the mlx +// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset +// fields stay nil; the mlx helpers skip them. +func (r Result) Arrays() []*mlx.Array { + return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs} +} + +func New(opts Options) *Sampler { + if opts.RepeatPenalty <= 0 { + opts.RepeatPenalty = 1 } - s := &Sampler{ - Temperature: temp, - TopP: top_p, - MinP: min_p, - TopK: top_k, - RepeatLastN: repeatLastN, - RepeatPenalty: repeatPenalty, - PresencePenalty: presencePenalty, - FrequencyPenalty: frequencyPenalty, - } + s := &Sampler{Options: opts} var transforms []Transform if s.usesHistory() { transforms = append(transforms, penalty) } - if top_p > 0 && top_p < 1 { + if opts.TopP > 0 && opts.TopP < 1 { transforms = append(transforms, topP) } - if min_p != 0 { + if opts.MinP != 0 { transforms = append(transforms, minP) } - if top_k > 0 { + if opts.TopK > 0 { transforms = append(transforms, topK) } - if temp == 0 { + if opts.Temperature == 0 { transforms = append(transforms, greedy) } else { transforms = append(transforms, temperature) @@ -123,76 +139,101 @@ func (s *Sampler) Free() { s.setHistory(nil, 0) } -func (s *Sampler) Sample(logits *mlx.Array) *mlx.Array { +// Sample runs the configured transform chain on the raw per-token logits +// and returns the sampled token id plus, when configured, the reported +// log-probability tensors for the selected token and the top-K tokens. +func (s *Sampler) Sample(logits *mlx.Array) Result { + scores := logits for _, transform := range s.transforms { - logits = transform(s, logits) + scores = transform(s, scores) } - return logits + res := Result{Token: scores} + + if s.Logprobs { + // Compute log_softmax in fp32 and subtract the max before + // logsumexp so the final subtraction stays on small values. + // Otherwise it cancels two large numbers and loses precision. + lp := logits.AsType(mlx.DTypeFloat32) + lp = lp.Subtract(lp.MaxAxis(-1, true)) + lp = lp.Subtract(lp.Logsumexp(true)) + res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1) + if k := s.TopLogprobs; k > 0 { + if vocab := lp.Dim(lp.NumDims() - 1); k > vocab { + k = vocab + } + // Argpartition on the negated values places the K largest + // (unsorted) in positions [0:K]. + idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k)) + res.TopTokens = idx.AsType(mlx.DTypeInt32) + res.TopLogprobs = lp.TakeAlongAxis(idx, -1) + } + } + return res } -func greedy(_ *Sampler, logits *mlx.Array) *mlx.Array { - return logits.Argmax(-1, false) +func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array { + return scores.Argmax(-1, false) } -func temperature(s *Sampler, logits *mlx.Array) *mlx.Array { - return mlx.DivScalar(logits, s.Temperature).Categorical(-1) +func temperature(s *Sampler, scores *mlx.Array) *mlx.Array { + return mlx.DivScalar(scores, s.Temperature).Categorical(-1) } -func topP(s *Sampler, logits *mlx.Array) *mlx.Array { +func topP(s *Sampler, scores *mlx.Array) *mlx.Array { if s.TopP <= 0 || s.TopP >= 1 { - return logits + return scores } - order := logits.Negative().ArgsortAxis(-1) - sortedLogits := logits.TakeAlongAxis(order, -1) - sortedProbs := mlx.SoftmaxAxis(sortedLogits, -1, true) + order := scores.Negative().ArgsortAxis(-1) + sortedScores := scores.TakeAlongAxis(order, -1) + sortedProbs := mlx.SoftmaxAxis(sortedScores, -1, true) prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs) keep := prevCumProbs.Less(mlx.FromValue(s.TopP)) - filtered := mlx.Where(keep, sortedLogits, mlx.FromValue(float32(math.Inf(-1)))) - return logits.PutAlongAxis(order, filtered, -1) + filtered := mlx.Where(keep, sortedScores, mlx.FromValue(float32(math.Inf(-1)))) + return scores.PutAlongAxis(order, filtered, -1) } -func minP(s *Sampler, logits *mlx.Array) *mlx.Array { +func minP(s *Sampler, scores *mlx.Array) *mlx.Array { if s.MinP <= 0 || s.MinP > 1 { - return logits + return scores } - maxLogits := logits.TakeAlongAxis(logits.Argmax(-1, true), -1) - minLogits := mlx.AddScalar(maxLogits, float32(math.Log(float64(s.MinP)))) + maxScore := scores.TakeAlongAxis(scores.Argmax(-1, true), -1) + threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP)))) return mlx.Where( - logits.Less(minLogits), + scores.Less(threshold), mlx.FromValue(float32(math.Inf(-1))), - logits, + scores, ) } -func topK(s *Sampler, logits *mlx.Array) *mlx.Array { +func topK(s *Sampler, scores *mlx.Array) *mlx.Array { if s.TopK <= 0 { - return logits + return scores } - vocab := logits.Dim(logits.NumDims() - 1) + vocab := scores.Dim(scores.NumDims() - 1) if s.TopK >= vocab { - return logits + return scores } - mask := logits.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End)) - return logits.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) + mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End)) + return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1) } -func penalty(s *Sampler, logits *mlx.Array) *mlx.Array { +func penalty(s *Sampler, scores *mlx.Array) *mlx.Array { if s.historyLen == 0 { - return logits + return scores } tokenIndices := s.history - if logits.NumDims() > 1 { + if scores.NumDims() > 1 { tokenIndices = tokenIndices.ExpandDims(0) } if s.RepeatPenalty != 1 || s.PresencePenalty != 0 { - adjusted := logits.TakeAlongAxis(tokenIndices, -1) + adjusted := scores.TakeAlongAxis(tokenIndices, -1) if s.RepeatPenalty != 1 { factor := mlx.Where( adjusted.Less(mlx.FromValue(float32(0))), @@ -204,12 +245,12 @@ func penalty(s *Sampler, logits *mlx.Array) *mlx.Array { if s.PresencePenalty != 0 { adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty) } - logits = logits.PutAlongAxis(tokenIndices, adjusted, -1) + scores = scores.PutAlongAxis(tokenIndices, adjusted, -1) } if s.FrequencyPenalty != 0 { - logits = logits.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1) + scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1) } - return logits + return scores } diff --git a/x/mlxrunner/sample/sample_test.go b/x/mlxrunner/sample/sample_test.go index a53b49e72..5372e1d38 100644 --- a/x/mlxrunner/sample/sample_test.go +++ b/x/mlxrunner/sample/sample_test.go @@ -10,8 +10,7 @@ import ( ) func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) { - // RepeatLastN = 1, PresencePenalty = 6 - s := New(0, 0, 0, 0, 1, 1, 6, 0) + s := New(Options{RepeatLastN: 1, PresencePenalty: 6}) defer func() { s.Free() mlx.Sweep() @@ -21,7 +20,7 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) { s.AppendToken(mlx.NewArrayInt32([]int32{1}, []int32{1})) logits := mlx.FromValues([]float32{0, 5, 4}, 3) - got := s.Sample(logits) + got := s.Sample(logits).Token mlx.Eval(got) // logits will be [0, -1, 4] after the penalty @@ -33,7 +32,7 @@ func TestPresencePenaltyUsesAppendedTokenImmediately(t *testing.T) { } func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) { - s := New(0, 0, 0, 0, 1, 2, 0, 0) + s := New(Options{RepeatLastN: 1, RepeatPenalty: 2}) defer func() { s.Free() mlx.Sweep() @@ -42,7 +41,7 @@ func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) { s.ResetHistory([]int32{1}) logits := mlx.FromValues([]float32{0, 5, 4}, 3) - got := s.Sample(logits) + got := s.Sample(logits).Token mlx.Eval(got) // token 1 is repeated and positive, so 5 / 2 falls below token 2. @@ -53,7 +52,7 @@ func TestRepeatPenaltyUsesHistoryWithoutPresencePenalty(t *testing.T) { } func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) { - s := New(0, 0, 0, 0, 4, 1, 0, 2) + s := New(Options{RepeatLastN: 4, FrequencyPenalty: 2}) defer func() { s.Free() mlx.Sweep() @@ -62,7 +61,7 @@ func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) { s.ResetHistory([]int32{1, 1}) logits := mlx.FromValues([]float32{0, 5, 4}, 3) - got := s.Sample(logits) + got := s.Sample(logits).Token mlx.Eval(got) // token 1 appears twice, so 5 - (2 * 2) falls below token 2. @@ -73,7 +72,7 @@ func TestFrequencyPenaltyUsesTokenCounts(t *testing.T) { } func TestMinPMasksTokensBelowThreshold(t *testing.T) { - s := New(0, 0, 0.5, 0, 0, 1, 0, 0) + s := New(Options{MinP: 0.5}) defer func() { s.Free() mlx.Sweep() diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go index d83a55744..089e5c7f5 100644 --- a/x/mlxrunner/server.go +++ b/x/mlxrunner/server.go @@ -2,7 +2,6 @@ package mlxrunner import ( "bytes" - "cmp" "context" "encoding/json" "flag" @@ -87,25 +86,25 @@ func Execute(args []string) error { mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) { request := Request{Responses: make(chan CompletionResponse)} - if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil { + if err := json.NewDecoder(r.Body).Decode(&request.CompletionRequest); err != nil { slog.Error("Failed to decode request", "error", err) http.Error(w, "Bad Request", http.StatusBadRequest) return } - request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict) - request.Pipeline = runner.TextGenerationPipeline - request.Sampler = sample.New( - request.Options.Temperature, - request.Options.TopP, - request.Options.MinP, - request.Options.TopK, - request.Options.RepeatLastN, - request.Options.RepeatPenalty, - request.Options.PresencePenalty, - request.Options.FrequencyPenalty, - ) + request.Sampler = sample.New(sample.Options{ + Temperature: request.Options.Temperature, + TopP: request.Options.TopP, + MinP: request.Options.MinP, + TopK: request.Options.TopK, + RepeatLastN: request.Options.RepeatLastN, + RepeatPenalty: request.Options.RepeatPenalty, + PresencePenalty: request.Options.PresencePenalty, + FrequencyPenalty: request.Options.FrequencyPenalty, + Logprobs: request.Logprobs, + TopLogprobs: request.TopLogprobs, + }) var cancel context.CancelFunc request.Ctx, cancel = context.WithCancel(r.Context()) diff --git a/x/models/gemma4/gemma4_moe_test.go b/x/models/gemma4/gemma4_moe_test.go index ab390ae59..ae6a09ed8 100644 --- a/x/models/gemma4/gemma4_moe_test.go +++ b/x/models/gemma4/gemma4_moe_test.go @@ -144,6 +144,8 @@ func TestRouterForwardMatchesLegacy(t *testing.T) { gotScores, gotInds := r.Forward(x, cfg) wantScores, wantInds := legacyRouterForward(r, x, cfg) + gotInds = gotInds.AsType(mlx.DTypeInt32) + wantInds = wantInds.AsType(mlx.DTypeInt32) mlx.Eval(gotScores, gotInds, wantScores, wantInds) if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) { diff --git a/x/models/nn/nn_test.go b/x/models/nn/nn_test.go index 72bb22245..8ec4b7575 100644 --- a/x/models/nn/nn_test.go +++ b/x/models/nn/nn_test.go @@ -169,8 +169,8 @@ func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) { dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4") mlx.Eval(dequantizedWeight) - qOut := ql.Forward(input) - dOut := NewLinear(dequantizedWeight, nil).Forward(input) + qOut := ql.Forward(input).AsType(mlx.DTypeFloat32) + dOut := NewLinear(dequantizedWeight, nil).Forward(input).AsType(mlx.DTypeFloat32) mlx.Eval(qOut, dOut) got := qOut.Floats()