mirror of
https://github.com/ollama/ollama
synced 2026-04-23 08:45:14 +00:00
wip sampling
This commit is contained in:
parent
dd497534c4
commit
67ce53b9b5
|
|
@ -1453,11 +1453,12 @@ type ImageData struct {
|
|||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
Think *api.ThinkValue
|
||||
Prompt string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
Think *api.ThinkValue
|
||||
ExplicitOptions map[string]struct{}
|
||||
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
Shift bool
|
||||
|
|
|
|||
|
|
@ -130,6 +130,35 @@ func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Opt
|
|||
return opts, nil
|
||||
}
|
||||
|
||||
func explicitOptions(modelOpts, requestOpts map[string]any) map[string]struct{} {
|
||||
keys := []string{
|
||||
"temperature",
|
||||
"top_p",
|
||||
"min_p",
|
||||
"top_k",
|
||||
"repeat_last_n",
|
||||
"repeat_penalty",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
}
|
||||
|
||||
explicit := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
if optionSpecified(modelOpts, requestOpts, key) {
|
||||
explicit[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
return explicit
|
||||
}
|
||||
|
||||
func optionSpecified(modelOpts, requestOpts map[string]any, key string) bool {
|
||||
if _, ok := requestOpts[key]; ok {
|
||||
return true
|
||||
}
|
||||
_, ok := modelOpts[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
||||
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
||||
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
||||
|
|
@ -539,15 +568,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Think: req.Think,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: req.Truncate == nil || *req.Truncate,
|
||||
Logprobs: req.Logprobs,
|
||||
TopLogprobs: req.TopLogprobs,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Think: req.Think,
|
||||
ExplicitOptions: explicitOptions(m.Options, req.Options),
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: req.Truncate == nil || *req.Truncate,
|
||||
Logprobs: req.Logprobs,
|
||||
TopLogprobs: req.TopLogprobs,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
|
|
@ -2299,15 +2329,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
// sets up new context given parent context per request
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
err := r.Completion(ctx, llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: currentFormat,
|
||||
Options: opts,
|
||||
Think: req.Think,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: truncate,
|
||||
Logprobs: req.Logprobs,
|
||||
TopLogprobs: req.TopLogprobs,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: currentFormat,
|
||||
Options: opts,
|
||||
Think: req.Think,
|
||||
ExplicitOptions: explicitOptions(m.Options, req.Options),
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: truncate,
|
||||
Logprobs: req.Logprobs,
|
||||
TopLogprobs: req.TopLogprobs,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
|
|
|
|||
|
|
@ -187,11 +187,15 @@ type completionRequest struct {
|
|||
}
|
||||
|
||||
type completionOpts struct {
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
MinP float32 `json:"min_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
TopP *float32 `json:"top_p,omitempty"`
|
||||
MinP *float32 `json:"min_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
RepeatLastN *int `json:"repeat_last_n,omitempty"`
|
||||
RepeatPenalty *float32 `json:"repeat_penalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
|
|
@ -241,11 +245,15 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||
}
|
||||
if req.Options != nil {
|
||||
creq.Options = &completionOpts{
|
||||
Temperature: req.Options.Temperature,
|
||||
TopP: req.Options.TopP,
|
||||
MinP: req.Options.MinP,
|
||||
TopK: req.Options.TopK,
|
||||
NumPredict: req.Options.NumPredict,
|
||||
Temperature: float32Ptr(req.Options.Temperature, hasExplicitOption(req.ExplicitOptions, "temperature")),
|
||||
TopP: float32Ptr(req.Options.TopP, hasExplicitOption(req.ExplicitOptions, "top_p")),
|
||||
MinP: float32Ptr(req.Options.MinP, hasExplicitOption(req.ExplicitOptions, "min_p")),
|
||||
TopK: intPtr(req.Options.TopK, hasExplicitOption(req.ExplicitOptions, "top_k")),
|
||||
RepeatLastN: intPtr(req.Options.RepeatLastN, hasExplicitOption(req.ExplicitOptions, "repeat_last_n")),
|
||||
RepeatPenalty: float32Ptr(req.Options.RepeatPenalty, hasExplicitOption(req.ExplicitOptions, "repeat_penalty")),
|
||||
PresencePenalty: float32Ptr(req.Options.PresencePenalty, hasExplicitOption(req.ExplicitOptions, "presence_penalty")),
|
||||
FrequencyPenalty: float32Ptr(req.Options.FrequencyPenalty, hasExplicitOption(req.ExplicitOptions, "frequency_penalty")),
|
||||
NumPredict: req.Options.NumPredict,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -304,6 +312,25 @@ func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||
return scanner.Err()
|
||||
}
|
||||
|
||||
func hasExplicitOption(explicit map[string]struct{}, key string) bool {
|
||||
_, ok := explicit[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func float32Ptr(v float32, ok bool) *float32 {
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &v
|
||||
}
|
||||
|
||||
func intPtr(v int, ok bool) *int {
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &v
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
return int(c.contextLength.Load())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,6 +78,88 @@ func TestCompletionForwardsThink(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCompletionForwardsOnlySpecifiedSamplingOptions(t *testing.T) {
|
||||
var got completionRequest
|
||||
|
||||
rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if err := json.NewDecoder(r.Body).Decode(&got); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader("{\"done\":true}\n")),
|
||||
Request: r,
|
||||
}, nil
|
||||
})
|
||||
|
||||
c := &Client{
|
||||
port: 11434,
|
||||
client: &http.Client{
|
||||
Transport: rt,
|
||||
},
|
||||
}
|
||||
|
||||
opts := &api.Options{
|
||||
Temperature: 1.0,
|
||||
TopP: 0.95,
|
||||
MinP: 0.1,
|
||||
TopK: 20,
|
||||
RepeatLastN: 128,
|
||||
RepeatPenalty: 1.2,
|
||||
PresencePenalty: 1.5,
|
||||
FrequencyPenalty: 0.25,
|
||||
NumPredict: 64,
|
||||
}
|
||||
|
||||
err := c.Completion(context.Background(), llm.CompletionRequest{
|
||||
Prompt: "hello",
|
||||
Options: opts,
|
||||
ExplicitOptions: map[string]struct{}{
|
||||
"temperature": {},
|
||||
"top_k": {},
|
||||
"repeat_penalty": {},
|
||||
"presence_penalty": {},
|
||||
},
|
||||
}, func(llm.CompletionResponse) {})
|
||||
if err != nil {
|
||||
t.Fatalf("completion request failed: %v", err)
|
||||
}
|
||||
|
||||
if got.Options == nil {
|
||||
t.Fatal("options = nil, want serialized options")
|
||||
}
|
||||
|
||||
if got.Options.Temperature == nil || *got.Options.Temperature != opts.Temperature {
|
||||
t.Fatalf("temperature = %v, want %v", got.Options.Temperature, opts.Temperature)
|
||||
}
|
||||
if got.Options.TopK == nil || *got.Options.TopK != opts.TopK {
|
||||
t.Fatalf("top_k = %v, want %v", got.Options.TopK, opts.TopK)
|
||||
}
|
||||
if got.Options.RepeatPenalty == nil || *got.Options.RepeatPenalty != opts.RepeatPenalty {
|
||||
t.Fatalf("repeat_penalty = %v, want %v", got.Options.RepeatPenalty, opts.RepeatPenalty)
|
||||
}
|
||||
if got.Options.PresencePenalty == nil || *got.Options.PresencePenalty != opts.PresencePenalty {
|
||||
t.Fatalf("presence_penalty = %v, want %v", got.Options.PresencePenalty, opts.PresencePenalty)
|
||||
}
|
||||
if got.Options.TopP != nil {
|
||||
t.Fatalf("top_p = %v, want nil", *got.Options.TopP)
|
||||
}
|
||||
if got.Options.MinP != nil {
|
||||
t.Fatalf("min_p = %v, want nil", *got.Options.MinP)
|
||||
}
|
||||
if got.Options.RepeatLastN != nil {
|
||||
t.Fatalf("repeat_last_n = %v, want nil", *got.Options.RepeatLastN)
|
||||
}
|
||||
if got.Options.FrequencyPenalty != nil {
|
||||
t.Fatalf("frequency_penalty = %v, want nil", *got.Options.FrequencyPenalty)
|
||||
}
|
||||
if got.Options.NumPredict != opts.NumPredict {
|
||||
t.Fatalf("num_predict = %d, want %d", got.Options.NumPredict, opts.NumPredict)
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
|
|
|
|||
|
|
@ -93,6 +93,12 @@ func (t *Array) Divide(other *Array) *Array {
|
|||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
|
||||
out := New("CUMSUM")
|
||||
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS")
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
|
|
@ -123,12 +129,30 @@ func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
|||
return out
|
||||
}
|
||||
|
||||
func (t *Array) GreaterEqual(other *Array) *Array {
|
||||
out := New("GREATER_EQUAL")
|
||||
C.mlx_greater_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Logsumexp(keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP")
|
||||
C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Less(other *Array) *Array {
|
||||
out := New("LESS")
|
||||
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) LogicalOr(other *Array) *Array {
|
||||
out := New("LOGICAL_OR")
|
||||
C.mlx_logical_or(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL")
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||
defer session.close()
|
||||
caches := session.caches
|
||||
tokens := session.remaining
|
||||
history := append([]int32(nil), session.inputs...)
|
||||
prefillChunk := prefillChunkSize()
|
||||
|
||||
materializeCaches := func() {
|
||||
|
|
@ -114,13 +115,13 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
step := func(token *mlx.Array, history []int32) (*mlx.Array, *mlx.Array) {
|
||||
fwd := r.Model.Forward(token.ExpandDims(0), caches)
|
||||
logits := r.Model.Unembed(fwd)
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
logprobs := logits.Subtract(logits.Logsumexp(true))
|
||||
sample := request.Sample(logprobs)
|
||||
sample := request.Sample(logprobs, history)
|
||||
|
||||
mlx.Pin(sample, logprobs)
|
||||
mlx.Sweep()
|
||||
|
|
@ -129,7 +130,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||
return sample, logprobs
|
||||
}
|
||||
|
||||
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed))
|
||||
sample, logprobs = step(mlx.FromValues(tokens[processed:], total-processed), history)
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
|
|
@ -139,8 +140,6 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||
return err
|
||||
}
|
||||
|
||||
nextSample, nextLogprobs = step(sample)
|
||||
|
||||
if i == 0 {
|
||||
mlx.Eval(sample)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
|
|
@ -149,6 +148,7 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||
|
||||
output := int32(sample.Int())
|
||||
session.outputs = append(session.outputs, output)
|
||||
history = append(history, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
final.DoneReason = 0
|
||||
|
|
@ -164,6 +164,8 @@ func (r *Runner) TextGenerationPipeline(request Request) error {
|
|||
}:
|
||||
}
|
||||
|
||||
nextSample, nextLogprobs = step(sample, history)
|
||||
|
||||
mlx.Unpin(sample, logprobs)
|
||||
sample, logprobs = nextSample, nextLogprobs
|
||||
nextSample, nextLogprobs = nil, nil
|
||||
|
|
|
|||
|
|
@ -34,11 +34,15 @@ type TextCompletionsRequest struct {
|
|||
Prompt string `json:"prompt"`
|
||||
Think *bool `json:"think,omitempty"`
|
||||
Options struct {
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
MinP float32 `json:"min_p"`
|
||||
TopK int `json:"top_k"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP *float32 `json:"top_p"`
|
||||
MinP *float32 `json:"min_p"`
|
||||
TopK *int `json:"top_k"`
|
||||
RepeatLastN *int `json:"repeat_last_n"`
|
||||
RepeatPenalty *float32 `json:"repeat_penalty"`
|
||||
PresencePenalty *float32 `json:"presence_penalty"`
|
||||
FrequencyPenalty *float32 `json:"frequency_penalty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
|
||||
// Deprecated: use MaxTokens instead
|
||||
NumPredict int `json:"num_predict"`
|
||||
|
|
|
|||
|
|
@ -9,69 +9,204 @@ import (
|
|||
)
|
||||
|
||||
type Sampler interface {
|
||||
Sample(*mlx.Array) *mlx.Array
|
||||
Sample(*mlx.Array, []int32) *mlx.Array
|
||||
}
|
||||
|
||||
func New(temp, top_p, min_p float32, top_k int) Sampler {
|
||||
if temp == 0 {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
func New(temp, top_p, min_p float32, top_k, repeatLastN int, repeatPenalty, presencePenalty, frequencyPenalty float32) Sampler {
|
||||
var samplers []Sampler
|
||||
if top_p > 0 && top_p < 1 {
|
||||
samplers = append(samplers, TopP(top_p))
|
||||
if repeatLastN > 0 && (repeatPenalty != 1 || presencePenalty != 0 || frequencyPenalty != 0) {
|
||||
samplers = append(samplers, Penalty{
|
||||
RepeatLastN: repeatLastN,
|
||||
RepeatPenalty: repeatPenalty,
|
||||
PresencePenalty: presencePenalty,
|
||||
FrequencyPenalty: frequencyPenalty,
|
||||
})
|
||||
}
|
||||
|
||||
if min_p != 0 {
|
||||
samplers = append(samplers, MinP(min_p))
|
||||
if temp == 0 {
|
||||
samplers = append(samplers, greedy{})
|
||||
} else {
|
||||
samplers = append(samplers, Distribution{
|
||||
Temperature: temp,
|
||||
TopK: top_k,
|
||||
TopP: top_p,
|
||||
MinP: min_p,
|
||||
})
|
||||
}
|
||||
|
||||
if top_k > 0 {
|
||||
samplers = append(samplers, TopK(top_k))
|
||||
}
|
||||
|
||||
samplers = append(samplers, Temperature(temp))
|
||||
return chain(samplers)
|
||||
}
|
||||
|
||||
type greedy struct{}
|
||||
|
||||
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
||||
func (greedy) Sample(logits *mlx.Array, _ []int32) *mlx.Array {
|
||||
return logits.Argmax(-1, false)
|
||||
}
|
||||
|
||||
type chain []Sampler
|
||||
|
||||
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
||||
func (c chain) Sample(logits *mlx.Array, history []int32) *mlx.Array {
|
||||
for _, sampler := range c {
|
||||
logits = sampler.Sample(logits)
|
||||
logits = sampler.Sample(logits, history)
|
||||
}
|
||||
return logits
|
||||
}
|
||||
|
||||
type Temperature float32
|
||||
|
||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return mlx.DivScalar(logits, float32(t)).Categorical(-1)
|
||||
type Distribution struct {
|
||||
Temperature float32
|
||||
TopK int
|
||||
TopP float32
|
||||
MinP float32
|
||||
}
|
||||
|
||||
type TopP float32
|
||||
func (d Distribution) Sample(logits *mlx.Array, _ []int32) *mlx.Array {
|
||||
filtered, indices := d.filter(logits)
|
||||
sample := filtered.Categorical(-1)
|
||||
if indices == nil {
|
||||
return sample
|
||||
}
|
||||
|
||||
func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
// TODO: implement
|
||||
return logprobs
|
||||
positions := sample.ExpandDims(1)
|
||||
return indices.TakeAlongAxis(positions, -1).Squeeze(1)
|
||||
}
|
||||
|
||||
type MinP float32
|
||||
func (d Distribution) filter(logits *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
candidates := logits
|
||||
var candidateIndices *mlx.Array
|
||||
|
||||
func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
// TODO: implement
|
||||
return logprobs
|
||||
if d.TopK > 0 && d.TopK < logits.Dim(logits.NumDims()-1) {
|
||||
partitions := logits.Negative().ArgpartitionAxis(d.TopK-1, -1)
|
||||
switch logits.NumDims() {
|
||||
case 1:
|
||||
candidateIndices = partitions.Slice(mlx.Slice(0, d.TopK))
|
||||
default:
|
||||
candidateIndices = partitions.Slice(mlx.Slice(), mlx.Slice(0, d.TopK))
|
||||
}
|
||||
candidates = logits.TakeAlongAxis(candidateIndices, -1)
|
||||
}
|
||||
|
||||
if d.Temperature != 1 {
|
||||
candidates = mlx.DivScalar(candidates, d.Temperature)
|
||||
}
|
||||
|
||||
if !d.needsProbabilityFilters() {
|
||||
return candidates, candidateIndices
|
||||
}
|
||||
|
||||
order := candidates.Negative().ArgsortAxis(-1)
|
||||
sortedLogits := candidates.TakeAlongAxis(order, -1)
|
||||
sortedProbs := mlx.SoftmaxAxis(candidates, -1, true).TakeAlongAxis(order, -1)
|
||||
|
||||
remove := d.topPRemovalMask(sortedProbs)
|
||||
if d.MinP > 0 {
|
||||
minPRemove := d.minPRemovalMask(sortedProbs)
|
||||
if remove == nil {
|
||||
remove = minPRemove
|
||||
} else {
|
||||
remove = remove.LogicalOr(minPRemove)
|
||||
}
|
||||
}
|
||||
|
||||
if remove == nil {
|
||||
return candidates, candidateIndices
|
||||
}
|
||||
|
||||
negInf := mlx.FromValue(float32(math.Inf(-1)))
|
||||
filtered := mlx.Where(remove, negInf, sortedLogits)
|
||||
return candidates.PutAlongAxis(order, filtered, -1), candidateIndices
|
||||
}
|
||||
|
||||
type TopK int
|
||||
|
||||
func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
|
||||
mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
|
||||
return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
func (d Distribution) needsProbabilityFilters() bool {
|
||||
return (d.TopP > 0 && d.TopP < 1) || d.MinP > 0
|
||||
}
|
||||
|
||||
func (d Distribution) topPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
|
||||
if d.TopP <= 0 || d.TopP >= 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
threshold := mlx.NewScalarArray(d.TopP)
|
||||
prevCum := sortedProbs.Cumsum(-1, false, true).Subtract(sortedProbs)
|
||||
return prevCum.GreaterEqual(threshold)
|
||||
}
|
||||
|
||||
func (d Distribution) minPRemovalMask(sortedProbs *mlx.Array) *mlx.Array {
|
||||
if d.MinP <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var maxProb *mlx.Array
|
||||
switch sortedProbs.NumDims() {
|
||||
case 1:
|
||||
maxProb = sortedProbs.Slice(mlx.Slice(0, 1))
|
||||
default:
|
||||
maxProb = sortedProbs.Slice(mlx.Slice(), mlx.Slice(0, 1))
|
||||
}
|
||||
|
||||
threshold := mlx.MulScalar(maxProb, d.MinP)
|
||||
return sortedProbs.Less(threshold)
|
||||
}
|
||||
|
||||
type Penalty struct {
|
||||
RepeatLastN int
|
||||
RepeatPenalty float32
|
||||
PresencePenalty float32
|
||||
FrequencyPenalty float32
|
||||
}
|
||||
|
||||
func (p Penalty) Sample(logprobs *mlx.Array, history []int32) *mlx.Array {
|
||||
if len(history) == 0 {
|
||||
return logprobs
|
||||
}
|
||||
|
||||
window := p.RepeatLastN
|
||||
if window <= 0 || window > len(history) {
|
||||
window = len(history)
|
||||
}
|
||||
|
||||
counts := make(map[int32]int, window)
|
||||
order := make([]int32, 0, window)
|
||||
for _, token := range history[len(history)-window:] {
|
||||
if token < 0 {
|
||||
continue
|
||||
}
|
||||
if counts[token] == 0 {
|
||||
order = append(order, token)
|
||||
}
|
||||
counts[token]++
|
||||
}
|
||||
if len(order) == 0 {
|
||||
return logprobs
|
||||
}
|
||||
|
||||
indexShape := []int32{int32(len(order))}
|
||||
valueShape := []int{len(order)}
|
||||
if logprobs.NumDims() > 1 {
|
||||
indexShape = []int32{1, int32(len(order))}
|
||||
valueShape = []int{1, len(order)}
|
||||
}
|
||||
|
||||
indices := mlx.NewArrayInt32(order, indexShape)
|
||||
selected := logprobs.TakeAlongAxis(indices, -1)
|
||||
mlx.Eval(selected)
|
||||
|
||||
values := selected.Floats()
|
||||
for i, token := range order {
|
||||
v := values[i]
|
||||
if p.RepeatPenalty != 1 {
|
||||
if v < 0 {
|
||||
v *= p.RepeatPenalty
|
||||
} else {
|
||||
v /= p.RepeatPenalty
|
||||
}
|
||||
}
|
||||
if p.PresencePenalty != 0 {
|
||||
v -= p.PresencePenalty
|
||||
}
|
||||
if p.FrequencyPenalty != 0 {
|
||||
v -= p.FrequencyPenalty * float32(counts[token])
|
||||
}
|
||||
values[i] = v
|
||||
}
|
||||
|
||||
return logprobs.PutAlongAxis(indices, mlx.FromValues(values, valueShape...), -1)
|
||||
}
|
||||
|
|
|
|||
104
x/mlxrunner/sample/sample_test.go
Normal file
104
x/mlxrunner/sample/sample_test.go
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
//go:build mlx
|
||||
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func TestPenaltySample(t *testing.T) {
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
|
||||
logprobs := mlx.FromValues([]float32{
|
||||
1.0, -2.0, 3.0, 4.0,
|
||||
}, 1, 4)
|
||||
|
||||
got := Penalty{
|
||||
RepeatLastN: 3,
|
||||
RepeatPenalty: 2.0,
|
||||
PresencePenalty: 1.5,
|
||||
FrequencyPenalty: 0.25,
|
||||
}.Sample(logprobs, []int32{2, 1, 2})
|
||||
|
||||
mlx.Eval(got)
|
||||
|
||||
want := []float32{1.0, -5.75, -0.5, 4.0}
|
||||
values := got.Floats()
|
||||
if len(values) != len(want) {
|
||||
t.Fatalf("len(values) = %d, want %d", len(values), len(want))
|
||||
}
|
||||
|
||||
for i := range want {
|
||||
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
|
||||
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPenaltySampleHonorsRepeatWindow(t *testing.T) {
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
|
||||
logprobs := mlx.FromValues([]float32{
|
||||
1.0, 2.0, 3.0,
|
||||
}, 1, 3)
|
||||
|
||||
got := Penalty{
|
||||
RepeatLastN: 1,
|
||||
PresencePenalty: 1.0,
|
||||
}.Sample(logprobs, []int32{0, 1})
|
||||
|
||||
mlx.Eval(got)
|
||||
|
||||
want := []float32{1.0, 1.0, 3.0}
|
||||
values := got.Floats()
|
||||
for i := range want {
|
||||
if math.Abs(float64(values[i]-want[i])) > 1e-5 {
|
||||
t.Fatalf("values[%d] = %v, want %v", i, values[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistributionFilterTopP(t *testing.T) {
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
|
||||
logits := mlx.FromValues([]float32{
|
||||
10.0, 9.0, 1.0, 0.0,
|
||||
}, 1, 4)
|
||||
|
||||
filtered, indices := Distribution{
|
||||
Temperature: 1.0,
|
||||
TopK: 2,
|
||||
TopP: 0.55,
|
||||
}.filter(logits)
|
||||
|
||||
got := materializeFilteredLogits(filtered, indices, 4)
|
||||
mlx.Eval(got)
|
||||
|
||||
values := got.Floats()
|
||||
if values[0] != 10.0 {
|
||||
t.Fatalf("values[0] = %v, want 10", values[0])
|
||||
}
|
||||
for i := 1; i < len(values); i++ {
|
||||
if !math.IsInf(float64(values[i]), -1) {
|
||||
t.Fatalf("values[%d] = %v, want -Inf", i, values[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func materializeFilteredLogits(filtered, indices *mlx.Array, width int) *mlx.Array {
|
||||
if indices == nil {
|
||||
return filtered
|
||||
}
|
||||
|
||||
base := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, width), float32(math.Inf(-1)))
|
||||
return base.PutAlongAxis(indices, filtered, -1)
|
||||
}
|
||||
|
|
@ -16,12 +16,89 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||
)
|
||||
|
||||
type samplingConfig struct {
|
||||
temperature float32
|
||||
topP float32
|
||||
minP float32
|
||||
topK int
|
||||
repeatLastN int
|
||||
repeatPenalty float32
|
||||
presencePenalty float32
|
||||
frequencyPenalty float32
|
||||
}
|
||||
|
||||
func defaultSamplingConfig(m base.Model, think *bool) samplingConfig {
|
||||
if _, ok := m.(*qwen3_5.Model); ok {
|
||||
cfg := samplingConfig{
|
||||
temperature: 1.0,
|
||||
topP: 0.95,
|
||||
minP: 0.0,
|
||||
topK: 20,
|
||||
repeatLastN: 64,
|
||||
repeatPenalty: 1.0,
|
||||
presencePenalty: 1.5,
|
||||
frequencyPenalty: 0.0,
|
||||
}
|
||||
if think != nil && !*think {
|
||||
cfg.temperature = 0.7
|
||||
cfg.topP = 0.8
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
return samplingConfig{
|
||||
temperature: opts.Temperature,
|
||||
topP: opts.TopP,
|
||||
minP: opts.MinP,
|
||||
topK: opts.TopK,
|
||||
repeatLastN: opts.RepeatLastN,
|
||||
repeatPenalty: opts.RepeatPenalty,
|
||||
presencePenalty: opts.PresencePenalty,
|
||||
frequencyPenalty: opts.FrequencyPenalty,
|
||||
}
|
||||
}
|
||||
|
||||
func resolveSamplingConfig(m base.Model, req Request) samplingConfig {
|
||||
cfg := defaultSamplingConfig(m, req.Think)
|
||||
|
||||
if req.Options.Temperature != nil {
|
||||
cfg.temperature = *req.Options.Temperature
|
||||
}
|
||||
if req.Options.TopP != nil {
|
||||
cfg.topP = *req.Options.TopP
|
||||
}
|
||||
if req.Options.MinP != nil {
|
||||
cfg.minP = *req.Options.MinP
|
||||
}
|
||||
if req.Options.TopK != nil {
|
||||
cfg.topK = *req.Options.TopK
|
||||
}
|
||||
if req.Options.RepeatLastN != nil {
|
||||
cfg.repeatLastN = *req.Options.RepeatLastN
|
||||
}
|
||||
if req.Options.RepeatPenalty != nil {
|
||||
cfg.repeatPenalty = *req.Options.RepeatPenalty
|
||||
}
|
||||
if req.Options.PresencePenalty != nil {
|
||||
cfg.presencePenalty = *req.Options.PresencePenalty
|
||||
}
|
||||
if req.Options.FrequencyPenalty != nil {
|
||||
cfg.frequencyPenalty = *req.Options.FrequencyPenalty
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func Execute(args []string) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
|
||||
|
|
@ -90,12 +167,18 @@ func Execute(args []string) error {
|
|||
|
||||
request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
|
||||
|
||||
sampling := resolveSamplingConfig(runner.Model, request)
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.Sampler = sample.New(
|
||||
request.Options.Temperature,
|
||||
request.Options.TopP,
|
||||
request.Options.MinP,
|
||||
request.Options.TopK,
|
||||
sampling.temperature,
|
||||
sampling.topP,
|
||||
sampling.minP,
|
||||
sampling.topK,
|
||||
sampling.repeatLastN,
|
||||
sampling.repeatPenalty,
|
||||
sampling.presencePenalty,
|
||||
sampling.frequencyPenalty,
|
||||
)
|
||||
|
||||
var cancel context.CancelFunc
|
||||
|
|
|
|||
172
x/mlxrunner/server_test.go
Normal file
172
x/mlxrunner/server_test.go
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
//go:build mlx
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/models/qwen3_5"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
type stubModel struct{}
|
||||
|
||||
func (stubModel) Forward(*mlx.Array, []cache.Cache) *mlx.Array { return nil }
|
||||
func (stubModel) Unembed(*mlx.Array) *mlx.Array { return nil }
|
||||
func (stubModel) NumLayers() int { return 0 }
|
||||
func (stubModel) Tokenizer() *tokenizer.Tokenizer { return nil }
|
||||
func (stubModel) LoadWeights(map[string]*mlx.Array) error { return nil }
|
||||
|
||||
func TestResolveSamplingConfigDefaults(t *testing.T) {
|
||||
trueValue := true
|
||||
falseValue := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model base.Model
|
||||
req Request
|
||||
want samplingConfig
|
||||
}{
|
||||
{
|
||||
name: "generic model uses api defaults",
|
||||
model: stubModel{},
|
||||
req: Request{},
|
||||
want: samplingConfig{
|
||||
temperature: 0.8,
|
||||
topP: 0.9,
|
||||
minP: 0.0,
|
||||
topK: 40,
|
||||
repeatLastN: 64,
|
||||
repeatPenalty: 1.1,
|
||||
presencePenalty: 0.0,
|
||||
frequencyPenalty: 0.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "qwen3.5 defaults to thinking profile when think unset",
|
||||
model: &qwen3_5.Model{},
|
||||
req: Request{},
|
||||
want: samplingConfig{
|
||||
temperature: 1.0,
|
||||
topP: 0.95,
|
||||
minP: 0.0,
|
||||
topK: 20,
|
||||
repeatLastN: 64,
|
||||
repeatPenalty: 1.0,
|
||||
presencePenalty: 1.5,
|
||||
frequencyPenalty: 0.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "qwen3.5 thinking disabled defaults",
|
||||
model: &qwen3_5.Model{},
|
||||
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &falseValue}},
|
||||
want: samplingConfig{
|
||||
temperature: 0.7,
|
||||
topP: 0.8,
|
||||
minP: 0.0,
|
||||
topK: 20,
|
||||
repeatLastN: 64,
|
||||
repeatPenalty: 1.0,
|
||||
presencePenalty: 1.5,
|
||||
frequencyPenalty: 0.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "qwen3.5 thinking enabled defaults",
|
||||
model: &qwen3_5.Model{},
|
||||
req: Request{TextCompletionsRequest: TextCompletionsRequest{Think: &trueValue}},
|
||||
want: samplingConfig{
|
||||
temperature: 1.0,
|
||||
topP: 0.95,
|
||||
minP: 0.0,
|
||||
topK: 20,
|
||||
repeatLastN: 64,
|
||||
repeatPenalty: 1.0,
|
||||
presencePenalty: 1.5,
|
||||
frequencyPenalty: 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := resolveSamplingConfig(tt.model, tt.req); got != tt.want {
|
||||
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSamplingConfigOverridesSpecifiedValues(t *testing.T) {
|
||||
trueValue := true
|
||||
temperature := float32(0.4)
|
||||
topP := float32(0.6)
|
||||
minP := float32(0.05)
|
||||
topK := 12
|
||||
repeatLastN := 32
|
||||
repeatPenalty := float32(1.1)
|
||||
presencePenalty := float32(0.7)
|
||||
frequencyPenalty := float32(0.2)
|
||||
|
||||
got := resolveSamplingConfig(stubModel{}, Request{
|
||||
TextCompletionsRequest: TextCompletionsRequest{
|
||||
Think: &trueValue,
|
||||
Options: struct {
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP *float32 `json:"top_p"`
|
||||
MinP *float32 `json:"min_p"`
|
||||
TopK *int `json:"top_k"`
|
||||
RepeatLastN *int `json:"repeat_last_n"`
|
||||
RepeatPenalty *float32 `json:"repeat_penalty"`
|
||||
PresencePenalty *float32 `json:"presence_penalty"`
|
||||
FrequencyPenalty *float32 `json:"frequency_penalty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
NumPredict int `json:"num_predict"`
|
||||
}{
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
MinP: &minP,
|
||||
TopK: &topK,
|
||||
RepeatLastN: &repeatLastN,
|
||||
RepeatPenalty: &repeatPenalty,
|
||||
PresencePenalty: &presencePenalty,
|
||||
FrequencyPenalty: &frequencyPenalty,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
want := samplingConfig{
|
||||
temperature: temperature,
|
||||
topP: topP,
|
||||
minP: minP,
|
||||
topK: topK,
|
||||
repeatLastN: repeatLastN,
|
||||
repeatPenalty: repeatPenalty,
|
||||
presencePenalty: presencePenalty,
|
||||
frequencyPenalty: frequencyPenalty,
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("resolveSamplingConfig() = %+v, want %+v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSamplingConfigMatchesGenericDefaults(t *testing.T) {
|
||||
want := api.DefaultOptions()
|
||||
got := defaultSamplingConfig(stubModel{}, nil)
|
||||
|
||||
if got.temperature != want.Temperature ||
|
||||
got.topP != want.TopP ||
|
||||
got.minP != want.MinP ||
|
||||
got.topK != want.TopK ||
|
||||
got.repeatLastN != want.RepeatLastN ||
|
||||
got.repeatPenalty != want.RepeatPenalty ||
|
||||
got.presencePenalty != want.PresencePenalty ||
|
||||
got.frequencyPenalty != want.FrequencyPenalty {
|
||||
t.Fatalf("defaultSamplingConfig() = %+v, want api defaults %+v", got, want)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue