mirror of
https://github.com/ollama/ollama
synced 2026-04-23 08:45:14 +00:00
mlx: perf improvements (#14768)
* mlx: perf improvements Fix nn.go to call mlx_fast_layer_norm instead of manually implementing (mean, subtract, variance, rsqrt, multiply, add — 6 ops) Fix llama.go, gemma3.go to remove RepeatKV to tile K/V tensors to match the Q head count, since scaled_dot_product_attention natively handles GQA (it just requires n_q_heads % n_kv_heads == 0) * review comments
This commit is contained in:
parent
8f45236d09
commit
539741199e
|
|
@ -303,7 +303,7 @@ func BenchmarkLinearSmall(b *testing.B) {
|
|||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for range b.N {
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
|
@ -320,7 +320,7 @@ func BenchmarkLinearLarge(b *testing.B) {
|
|||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for range b.N {
|
||||
out := linear.Forward(x)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
|
@ -337,7 +337,7 @@ func BenchmarkRMSNorm(b *testing.B) {
|
|||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for range b.N {
|
||||
out := norm.Forward(x, 0)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
|
@ -356,7 +356,7 @@ func BenchmarkEmbedding(b *testing.B) {
|
|||
mlx.Eval(indices)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for range b.N {
|
||||
out := emb.Forward(indices)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
|
@ -369,7 +369,7 @@ func BenchmarkRepeatKV(b *testing.B) {
|
|||
mlx.Eval(x)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for range b.N {
|
||||
out := RepeatKV(x, 4)
|
||||
mlx.Eval(out)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -131,6 +131,22 @@ func init() {
|
|||
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama"))
|
||||
|
||||
// Walk up from cwd to find the repo root (containing go.mod) so that
|
||||
// tests running from a package subdirectory can find the build output.
|
||||
for dir := cwd; ; {
|
||||
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
||||
if dir != cwd {
|
||||
searchDirs = append(searchDirs, filepath.Join(dir, "build", "lib", "ollama"))
|
||||
}
|
||||
break
|
||||
}
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
break
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
// Also scan mlx_* subdirectories within each search dir
|
||||
|
|
|
|||
|
|
@ -331,6 +331,19 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
|
|||
return out
|
||||
}
|
||||
|
||||
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM")
|
||||
var w, b C.mlx_array
|
||||
if weight != nil {
|
||||
w = weight.ctx
|
||||
}
|
||||
if bias != nil {
|
||||
b = bias.ctx
|
||||
}
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, w, b, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
var w C.mlx_array
|
||||
|
|
|
|||
|
|
@ -502,12 +502,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
|
|||
k, v = c.Update(k, v)
|
||||
}
|
||||
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = nn.RepeatKV(k, repeatFactor)
|
||||
v = nn.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// MLX SDPA supports grouped-query attention directly (Q heads can be a
|
||||
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
|
|
|
|||
|
|
@ -306,12 +306,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
|
|||
k, v = c.Update(k, v)
|
||||
}
|
||||
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = nn.RepeatKV(k, repeatFactor)
|
||||
v = nn.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// MLX SDPA supports grouped-query attention directly (Q heads can be a
|
||||
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
|
||||
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
|
|
|
|||
|
|
@ -152,15 +152,7 @@ func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
|
|||
if eps == 0 {
|
||||
eps = 1e-5
|
||||
}
|
||||
mean := mlx.Mean(x, -1, true)
|
||||
centered := x.Subtract(mean)
|
||||
variance := mlx.Mean(centered.Multiply(centered), -1, true)
|
||||
normalized := centered.Multiply(mlx.RSqrt(mlx.AddScalar(variance, eps)))
|
||||
out := normalized.Multiply(ln.Weight)
|
||||
if ln.Bias != nil && ln.Bias.Valid() {
|
||||
out = out.Add(ln.Bias)
|
||||
}
|
||||
return out
|
||||
return mlx.LayerNormFn(x, ln.Weight, ln.Bias, eps)
|
||||
}
|
||||
|
||||
// MultiLinearLayer is an interface for per-head linear layers.
|
||||
|
|
@ -183,17 +175,6 @@ func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
|
|||
return x.Matmul(wT)
|
||||
}
|
||||
|
||||
// RepeatKV repeats K/V tensors for grouped query attention.
|
||||
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
|
||||
if repeatFactor == 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Dims()
|
||||
x = x.ExpandDims(2)
|
||||
reps := []int32{1, 1, repeatFactor, 1, 1}
|
||||
x = mlx.Tile(x, reps)
|
||||
return mlx.Reshape(x, int32(shape[0]), int32(shape[1])*repeatFactor, int32(shape[2]), int32(shape[3]))
|
||||
}
|
||||
|
||||
// ApplyCausalMask applies causal (lower triangular) mask to attention scores.
|
||||
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
|
||||
|
|
|
|||
146
x/models/nn/nn_test.go
Normal file
146
x/models/nn/nn_test.go
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func approxEqual(a, b, tol float32) bool {
|
||||
return float32(math.Abs(float64(a-b))) < tol
|
||||
}
|
||||
|
||||
// TestLayerNormNoBias verifies LayerNorm without bias against manual computation.
|
||||
func TestLayerNormNoBias(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Input: [1, 4] — single row, 4 features
|
||||
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
||||
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
|
||||
mlx.Eval(x, weight)
|
||||
|
||||
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
|
||||
out := ln.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Floats()
|
||||
if len(data) != 4 {
|
||||
t.Fatalf("expected 4 values, got %d", len(data))
|
||||
}
|
||||
|
||||
// Manual LayerNorm: mean=2.5, var=1.25, std=sqrt(1.25+1e-5)
|
||||
// normalized = (x - mean) / std
|
||||
mean := float32(2.5)
|
||||
variance := float32(1.25)
|
||||
std := float32(math.Sqrt(float64(variance + 1e-5)))
|
||||
for i, v := range []float32{1, 2, 3, 4} {
|
||||
expected := (v - mean) / std
|
||||
if !approxEqual(data[i], expected, 1e-4) {
|
||||
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestLayerNormWithBias verifies LayerNorm with weight and bias.
|
||||
func TestLayerNormWithBias(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
||||
weight := mlx.FromValues([]float32{2, 2, 2, 2}, 4)
|
||||
bias := mlx.FromValues([]float32{10, 20, 30, 40}, 4)
|
||||
mlx.Eval(x, weight, bias)
|
||||
|
||||
ln := &LayerNorm{Weight: weight, Bias: bias, Eps: 1e-5}
|
||||
out := ln.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Floats()
|
||||
if len(data) != 4 {
|
||||
t.Fatalf("expected 4 values, got %d", len(data))
|
||||
}
|
||||
|
||||
mean := float32(2.5)
|
||||
variance := float32(1.25)
|
||||
std := float32(math.Sqrt(float64(variance + 1e-5)))
|
||||
biases := []float32{10, 20, 30, 40}
|
||||
for i, v := range []float32{1, 2, 3, 4} {
|
||||
expected := ((v-mean)/std)*2 + biases[i]
|
||||
if !approxEqual(data[i], expected, 1e-4) {
|
||||
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestLayerNormBatched verifies LayerNorm normalizes each row independently.
|
||||
func TestLayerNormBatched(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Input: [2, 3] — two rows
|
||||
x := mlx.FromValues([]float32{
|
||||
1, 2, 3,
|
||||
10, 20, 30,
|
||||
}, 2, 3)
|
||||
weight := mlx.FromValues([]float32{1, 1, 1}, 3)
|
||||
mlx.Eval(x, weight)
|
||||
|
||||
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
|
||||
out := ln.Forward(x)
|
||||
mlx.Eval(out)
|
||||
|
||||
data := out.Floats()
|
||||
if len(data) != 6 {
|
||||
t.Fatalf("expected 6 values, got %d", len(data))
|
||||
}
|
||||
|
||||
// Each row should be independently normalized.
|
||||
// Row 0: [1,2,3] mean=2, var=2/3
|
||||
// Row 1: [10,20,30] mean=20, var=200/3
|
||||
// After normalization both rows should have the same pattern
|
||||
// since [10,20,30] = 10*[1,2,3], the normalized values are identical.
|
||||
for i := range 3 {
|
||||
if !approxEqual(data[i], data[i+3], 1e-4) {
|
||||
t.Errorf("row 0 elem %d (%.6f) != row 1 elem %d (%.6f); expected identical normalized values",
|
||||
i, data[i], i, data[i+3])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the normalized values sum to ~0 (mean-centered)
|
||||
sum := data[0] + data[1] + data[2]
|
||||
if !approxEqual(sum, 0, 1e-4) {
|
||||
t.Errorf("normalized row sum should be ~0, got %.6f", sum)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLayerNormDefaultEps verifies the default epsilon of 1e-5 is used when Eps is 0.
|
||||
func TestLayerNormDefaultEps(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
|
||||
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
|
||||
mlx.Eval(x, weight)
|
||||
|
||||
// Eps=0 should use default 1e-5
|
||||
ln0 := &LayerNorm{Weight: weight, Eps: 0}
|
||||
out0 := ln0.Forward(x)
|
||||
mlx.Eval(out0)
|
||||
|
||||
lnExplicit := &LayerNorm{Weight: weight, Eps: 1e-5}
|
||||
outExplicit := lnExplicit.Forward(x)
|
||||
mlx.Eval(outExplicit)
|
||||
|
||||
d0 := out0.Floats()
|
||||
dE := outExplicit.Floats()
|
||||
for i := range d0 {
|
||||
if !approxEqual(d0[i], dE[i], 1e-6) {
|
||||
t.Errorf("index %d: Eps=0 gave %.6f, Eps=1e-5 gave %.6f", i, d0[i], dE[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue