mirror of
https://github.com/ollama/ollama
synced 2026-04-23 08:45:14 +00:00
Match the ollamarunner and OpenAI semantics: raw, full-vocab log-softmax with the top-K ranked by probability. Skipped on the GPU when the request doesn't ask for logprobs so decode doesn't pay for it otherwise.
257 lines
6.6 KiB
Go
257 lines
6.6 KiB
Go
package sample
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
type Transform func(*Sampler, *mlx.Array) *mlx.Array
|
|
|
|
type Options struct {
|
|
Temperature float32
|
|
TopP float32
|
|
MinP float32
|
|
TopK int
|
|
RepeatLastN int
|
|
RepeatPenalty float32
|
|
PresencePenalty float32
|
|
FrequencyPenalty float32
|
|
|
|
// Logprobs causes Sample to populate Result.Logprob with the selected
|
|
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
|
|
Logprobs bool
|
|
TopLogprobs int
|
|
}
|
|
|
|
type Sampler struct {
|
|
Options
|
|
|
|
history *mlx.Array
|
|
historyLen int
|
|
transforms []Transform
|
|
}
|
|
|
|
// Result bundles the outputs of one decode step. The logprob tensors are
|
|
// populated only when the sampler is configured to report them.
|
|
type Result struct {
|
|
Token *mlx.Array // sampled token id, shape [B]
|
|
Logprob *mlx.Array // sampled-token logprob, shape [B,1]; nil unless Logprobs
|
|
TopTokens *mlx.Array // top-K token ids, shape [B,K]; nil unless TopLogprobs > 0
|
|
TopLogprobs *mlx.Array // top-K logprobs, shape [B,K]; nil unless TopLogprobs > 0
|
|
}
|
|
|
|
// Arrays returns the tensor fields as a slice so callers can drive the mlx
|
|
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
|
|
// fields stay nil; the mlx helpers skip them.
|
|
func (r Result) Arrays() []*mlx.Array {
|
|
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
|
|
}
|
|
|
|
func New(opts Options) *Sampler {
|
|
if opts.RepeatPenalty <= 0 {
|
|
opts.RepeatPenalty = 1
|
|
}
|
|
|
|
s := &Sampler{Options: opts}
|
|
|
|
var transforms []Transform
|
|
if s.usesHistory() {
|
|
transforms = append(transforms, penalty)
|
|
}
|
|
|
|
if opts.TopP > 0 && opts.TopP < 1 {
|
|
transforms = append(transforms, topP)
|
|
}
|
|
|
|
if opts.MinP != 0 {
|
|
transforms = append(transforms, minP)
|
|
}
|
|
|
|
if opts.TopK > 0 {
|
|
transforms = append(transforms, topK)
|
|
}
|
|
|
|
if opts.Temperature == 0 {
|
|
transforms = append(transforms, greedy)
|
|
} else {
|
|
transforms = append(transforms, temperature)
|
|
}
|
|
|
|
s.transforms = transforms
|
|
return s
|
|
}
|
|
|
|
func (s *Sampler) usesHistory() bool {
|
|
return s.RepeatPenalty != 1 || s.PresencePenalty != 0 || s.FrequencyPenalty != 0
|
|
}
|
|
|
|
func (s *Sampler) setHistory(history *mlx.Array, historyLen int) {
|
|
if history != nil {
|
|
mlx.Pin(history)
|
|
}
|
|
if s.history != nil {
|
|
mlx.Unpin(s.history)
|
|
}
|
|
s.history = history
|
|
s.historyLen = historyLen
|
|
}
|
|
|
|
func (s *Sampler) ResetHistory(history []int32) {
|
|
if !s.usesHistory() {
|
|
return
|
|
}
|
|
if s.RepeatLastN > 0 && len(history) > s.RepeatLastN {
|
|
history = history[len(history)-s.RepeatLastN:]
|
|
}
|
|
if len(history) == 0 {
|
|
s.setHistory(nil, 0)
|
|
return
|
|
}
|
|
|
|
tokens := append([]int32(nil), history...)
|
|
s.setHistory(mlx.NewArrayInt32(tokens, []int32{int32(len(tokens))}), len(tokens))
|
|
}
|
|
|
|
func (s *Sampler) AppendToken(token *mlx.Array) {
|
|
if !s.usesHistory() || token == nil {
|
|
return
|
|
}
|
|
|
|
next := token.AsType(mlx.DTypeInt32)
|
|
nextLen := next.Size()
|
|
|
|
if s.history != nil && s.historyLen > 0 {
|
|
next = s.history.Concatenate(0, next)
|
|
nextLen += s.historyLen
|
|
}
|
|
|
|
if s.RepeatLastN > 0 && nextLen > s.RepeatLastN {
|
|
trim := nextLen - s.RepeatLastN
|
|
next = next.Slice(mlx.Slice(trim, nextLen))
|
|
nextLen = s.RepeatLastN
|
|
}
|
|
|
|
s.setHistory(next, nextLen)
|
|
}
|
|
|
|
func (s *Sampler) Free() {
|
|
s.setHistory(nil, 0)
|
|
}
|
|
|
|
// Sample runs the configured transform chain on the raw per-token logits
|
|
// and returns the sampled token id plus, when configured, the reported
|
|
// log-probability tensors for the selected token and the top-K tokens.
|
|
func (s *Sampler) Sample(logits *mlx.Array) Result {
|
|
scores := logits
|
|
for _, transform := range s.transforms {
|
|
scores = transform(s, scores)
|
|
}
|
|
res := Result{Token: scores}
|
|
|
|
if s.Logprobs {
|
|
// Compute log_softmax in fp32 and subtract the max before
|
|
// logsumexp so the final subtraction stays on small values.
|
|
// Otherwise it cancels two large numbers and loses precision.
|
|
lp := logits.AsType(mlx.DTypeFloat32)
|
|
lp = lp.Subtract(lp.MaxAxis(-1, true))
|
|
lp = lp.Subtract(lp.Logsumexp(true))
|
|
res.Logprob = lp.TakeAlongAxis(res.Token.ExpandDims(-1), -1)
|
|
if k := s.TopLogprobs; k > 0 {
|
|
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
|
|
k = vocab
|
|
}
|
|
// Argpartition on the negated values places the K largest
|
|
// (unsorted) in positions [0:K].
|
|
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
|
|
res.TopTokens = idx.AsType(mlx.DTypeInt32)
|
|
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
|
|
func greedy(_ *Sampler, scores *mlx.Array) *mlx.Array {
|
|
return scores.Argmax(-1, false)
|
|
}
|
|
|
|
func temperature(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|
return mlx.DivScalar(scores, s.Temperature).Categorical(-1)
|
|
}
|
|
|
|
func topP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|
if s.TopP <= 0 || s.TopP >= 1 {
|
|
return scores
|
|
}
|
|
|
|
order := scores.Negative().ArgsortAxis(-1)
|
|
sortedScores := scores.TakeAlongAxis(order, -1)
|
|
sortedProbs := mlx.SoftmaxAxis(sortedScores, -1, true)
|
|
prevCumProbs := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
|
|
keep := prevCumProbs.Less(mlx.FromValue(s.TopP))
|
|
filtered := mlx.Where(keep, sortedScores, mlx.FromValue(float32(math.Inf(-1))))
|
|
return scores.PutAlongAxis(order, filtered, -1)
|
|
}
|
|
|
|
func minP(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|
if s.MinP <= 0 || s.MinP > 1 {
|
|
return scores
|
|
}
|
|
|
|
maxScore := scores.TakeAlongAxis(scores.Argmax(-1, true), -1)
|
|
threshold := mlx.AddScalar(maxScore, float32(math.Log(float64(s.MinP))))
|
|
|
|
return mlx.Where(
|
|
scores.Less(threshold),
|
|
mlx.FromValue(float32(math.Inf(-1))),
|
|
scores,
|
|
)
|
|
}
|
|
|
|
func topK(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|
if s.TopK <= 0 {
|
|
return scores
|
|
}
|
|
|
|
vocab := scores.Dim(scores.NumDims() - 1)
|
|
if s.TopK >= vocab {
|
|
return scores
|
|
}
|
|
|
|
mask := scores.Negative().ArgpartitionAxis(s.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(s.TopK, mlx.End))
|
|
return scores.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
|
}
|
|
|
|
func penalty(s *Sampler, scores *mlx.Array) *mlx.Array {
|
|
if s.historyLen == 0 {
|
|
return scores
|
|
}
|
|
|
|
tokenIndices := s.history
|
|
if scores.NumDims() > 1 {
|
|
tokenIndices = tokenIndices.ExpandDims(0)
|
|
}
|
|
|
|
if s.RepeatPenalty != 1 || s.PresencePenalty != 0 {
|
|
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
|
|
if s.RepeatPenalty != 1 {
|
|
factor := mlx.Where(
|
|
adjusted.Less(mlx.FromValue(float32(0))),
|
|
mlx.FromValue(s.RepeatPenalty),
|
|
mlx.FromValue(1/s.RepeatPenalty),
|
|
)
|
|
adjusted = adjusted.Multiply(factor)
|
|
}
|
|
if s.PresencePenalty != 0 {
|
|
adjusted = mlx.AddScalar(adjusted, -s.PresencePenalty)
|
|
}
|
|
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
|
|
}
|
|
|
|
if s.FrequencyPenalty != 0 {
|
|
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-s.FrequencyPenalty), -1)
|
|
}
|
|
|
|
return scores
|
|
}
|