mlx: perf improvements (#14768)

* mlx: perf improvements

Fix nn.go to call mlx_fast_layer_norm instead of manually implementing (mean,
subtract, variance, rsqrt, multiply, add — 6 ops)

Fix llama.go, gemma3.go to remove RepeatKV to tile K/V tensors to match the Q
head count, since scaled_dot_product_attention natively handles GQA (it just
requires n_q_heads % n_kv_heads == 0)

* review comments
This commit is contained in:
Daniel Hiltgen 2026-03-12 12:01:28 -07:00 committed by GitHub
parent 8f45236d09
commit 539741199e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 185 additions and 37 deletions

View file

@ -303,7 +303,7 @@ func BenchmarkLinearSmall(b *testing.B) {
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
for range b.N {
out := linear.Forward(x)
mlx.Eval(out)
}
@ -320,7 +320,7 @@ func BenchmarkLinearLarge(b *testing.B) {
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
for range b.N {
out := linear.Forward(x)
mlx.Eval(out)
}
@ -337,7 +337,7 @@ func BenchmarkRMSNorm(b *testing.B) {
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
for range b.N {
out := norm.Forward(x, 0)
mlx.Eval(out)
}
@ -356,7 +356,7 @@ func BenchmarkEmbedding(b *testing.B) {
mlx.Eval(indices)
b.ResetTimer()
for i := 0; i < b.N; i++ {
for range b.N {
out := emb.Forward(indices)
mlx.Eval(out)
}
@ -369,7 +369,7 @@ func BenchmarkRepeatKV(b *testing.B) {
mlx.Eval(x)
b.ResetTimer()
for i := 0; i < b.N; i++ {
for range b.N {
out := RepeatKV(x, 4)
mlx.Eval(out)
}

View file

@ -131,6 +131,22 @@ func init() {
if cwd, err := os.Getwd(); err == nil {
searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama"))
// Walk up from cwd to find the repo root (containing go.mod) so that
// tests running from a package subdirectory can find the build output.
for dir := cwd; ; {
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
if dir != cwd {
searchDirs = append(searchDirs, filepath.Join(dir, "build", "lib", "ollama"))
}
break
}
parent := filepath.Dir(dir)
if parent == dir {
break
}
dir = parent
}
}
// Also scan mlx_* subdirectories within each search dir

View file

@ -331,6 +331,19 @@ func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask b
return out
}
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
out := New("FAST_LAYERNORM")
var w, b C.mlx_array
if weight != nil {
w = weight.ctx
}
if bias != nil {
b = bias.ctx
}
C.mlx_fast_layer_norm(&out.ctx, x.ctx, w, b, C.float(eps), DefaultStream().ctx)
return out
}
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
var w C.mlx_array

View file

@ -502,12 +502,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
k, v = c.Update(k, v)
}
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
// MLX SDPA supports grouped-query attention directly (Q heads can be a
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)

View file

@ -306,12 +306,8 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config
k, v = c.Update(k, v)
}
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
// MLX SDPA supports grouped-query attention directly (Q heads can be a
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)

View file

@ -152,15 +152,7 @@ func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
if eps == 0 {
eps = 1e-5
}
mean := mlx.Mean(x, -1, true)
centered := x.Subtract(mean)
variance := mlx.Mean(centered.Multiply(centered), -1, true)
normalized := centered.Multiply(mlx.RSqrt(mlx.AddScalar(variance, eps)))
out := normalized.Multiply(ln.Weight)
if ln.Bias != nil && ln.Bias.Valid() {
out = out.Add(ln.Bias)
}
return out
return mlx.LayerNormFn(x, ln.Weight, ln.Bias, eps)
}
// MultiLinearLayer is an interface for per-head linear layers.
@ -183,17 +175,6 @@ func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
return x.Matmul(wT)
}
// RepeatKV repeats K/V tensors for grouped query attention.
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
if repeatFactor == 1 {
return x
}
shape := x.Dims()
x = x.ExpandDims(2)
reps := []int32{1, 1, repeatFactor, 1, 1}
x = mlx.Tile(x, reps)
return mlx.Reshape(x, int32(shape[0]), int32(shape[1])*repeatFactor, int32(shape[2]), int32(shape[3]))
}
// ApplyCausalMask applies causal (lower triangular) mask to attention scores.
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {

146
x/models/nn/nn_test.go Normal file
View file

@ -0,0 +1,146 @@
package nn
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func approxEqual(a, b, tol float32) bool {
return float32(math.Abs(float64(a-b))) < tol
}
// TestLayerNormNoBias verifies LayerNorm without bias against manual computation.
func TestLayerNormNoBias(t *testing.T) {
skipIfNoMLX(t)
// Input: [1, 4] — single row, 4 features
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
mlx.Eval(x, weight)
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
out := ln.Forward(x)
mlx.Eval(out)
data := out.Floats()
if len(data) != 4 {
t.Fatalf("expected 4 values, got %d", len(data))
}
// Manual LayerNorm: mean=2.5, var=1.25, std=sqrt(1.25+1e-5)
// normalized = (x - mean) / std
mean := float32(2.5)
variance := float32(1.25)
std := float32(math.Sqrt(float64(variance + 1e-5)))
for i, v := range []float32{1, 2, 3, 4} {
expected := (v - mean) / std
if !approxEqual(data[i], expected, 1e-4) {
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
}
}
}
// TestLayerNormWithBias verifies LayerNorm with weight and bias.
func TestLayerNormWithBias(t *testing.T) {
skipIfNoMLX(t)
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
weight := mlx.FromValues([]float32{2, 2, 2, 2}, 4)
bias := mlx.FromValues([]float32{10, 20, 30, 40}, 4)
mlx.Eval(x, weight, bias)
ln := &LayerNorm{Weight: weight, Bias: bias, Eps: 1e-5}
out := ln.Forward(x)
mlx.Eval(out)
data := out.Floats()
if len(data) != 4 {
t.Fatalf("expected 4 values, got %d", len(data))
}
mean := float32(2.5)
variance := float32(1.25)
std := float32(math.Sqrt(float64(variance + 1e-5)))
biases := []float32{10, 20, 30, 40}
for i, v := range []float32{1, 2, 3, 4} {
expected := ((v-mean)/std)*2 + biases[i]
if !approxEqual(data[i], expected, 1e-4) {
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
}
}
}
// TestLayerNormBatched verifies LayerNorm normalizes each row independently.
func TestLayerNormBatched(t *testing.T) {
skipIfNoMLX(t)
// Input: [2, 3] — two rows
x := mlx.FromValues([]float32{
1, 2, 3,
10, 20, 30,
}, 2, 3)
weight := mlx.FromValues([]float32{1, 1, 1}, 3)
mlx.Eval(x, weight)
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
out := ln.Forward(x)
mlx.Eval(out)
data := out.Floats()
if len(data) != 6 {
t.Fatalf("expected 6 values, got %d", len(data))
}
// Each row should be independently normalized.
// Row 0: [1,2,3] mean=2, var=2/3
// Row 1: [10,20,30] mean=20, var=200/3
// After normalization both rows should have the same pattern
// since [10,20,30] = 10*[1,2,3], the normalized values are identical.
for i := range 3 {
if !approxEqual(data[i], data[i+3], 1e-4) {
t.Errorf("row 0 elem %d (%.6f) != row 1 elem %d (%.6f); expected identical normalized values",
i, data[i], i, data[i+3])
}
}
// Verify the normalized values sum to ~0 (mean-centered)
sum := data[0] + data[1] + data[2]
if !approxEqual(sum, 0, 1e-4) {
t.Errorf("normalized row sum should be ~0, got %.6f", sum)
}
}
// TestLayerNormDefaultEps verifies the default epsilon of 1e-5 is used when Eps is 0.
func TestLayerNormDefaultEps(t *testing.T) {
skipIfNoMLX(t)
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
mlx.Eval(x, weight)
// Eps=0 should use default 1e-5
ln0 := &LayerNorm{Weight: weight, Eps: 0}
out0 := ln0.Forward(x)
mlx.Eval(out0)
lnExplicit := &LayerNorm{Weight: weight, Eps: 1e-5}
outExplicit := lnExplicit.Forward(x)
mlx.Eval(outExplicit)
d0 := out0.Floats()
dE := outExplicit.Floats()
for i := range d0 {
if !approxEqual(d0[i], dE[i], 1e-6) {
t.Errorf("index %d: Eps=0 gave %.6f, Eps=1e-5 gave %.6f", i, d0[i], dE[i])
}
}
}