From 05e0f21bec6b92c3071175625d6021843cf811da Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 14 Apr 2026 23:45:28 -0700 Subject: [PATCH] mlx: fuse sigmoid router head in glm4_moe_lite DeepSeek-V2-style aux-loss-free routing computes sigmoid(gates) once but needs it twice: the raw sigmoid output is gathered after top-k, while the post-bias negation is the argpartition key. Fuse into a single multi-output Compiled kernel returning both, saving two launches on the routing path per token. Exposed as a general SigmoidRouter since the same pattern is shared across DeepSeek-V2 descendants. Improves glm4.7 generation performance by approximately 1%. --- x/mlxrunner/mlx/act.go | 22 ++++++++++++++++++++++ x/models/glm4_moe_lite/glm4_moe_lite.go | 12 ++++++------ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/x/mlxrunner/mlx/act.go b/x/mlxrunner/mlx/act.go index 1563de60a..3a67da279 100644 --- a/x/mlxrunner/mlx/act.go +++ b/x/mlxrunner/mlx/act.go @@ -62,3 +62,25 @@ var LogitSoftcap = Compile2( }, Shapeless(), ) + +// sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router +// head. Two outputs are returned so the pre-bias sigmoid (used to gather +// per-expert scores after top-k) and the post-bias negation (used as the +// argpartition key for top-k) share a single kernel. +var sigmoidRouterFused = Compile( + "SigmoidRouter", + func(in ...*Array) []*Array { + gates, bias := in[0], in[1] + orig := gates.Sigmoid() + neg := orig.Add(bias).Negative() + return []*Array{orig, neg} + }, + Shapeless(), +) + +// SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused +// kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head. +func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) { + out := sigmoidRouterFused(gates, bias) + return out[0], out[1] +} diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go index 8b40c9348..aac320806 100644 --- a/x/models/glm4_moe_lite/glm4_moe_lite.go +++ b/x/models/glm4_moe_lite/glm4_moe_lite.go @@ -161,21 +161,21 @@ type MoEGate struct { func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) { gates := g.Gate.Forward(x) - scores := mlx.Sigmoid(gates) - origScores := scores - + var origScores, negScores *mlx.Array if g.EScoreCorrectionBias != nil { - scores = mlx.Add(scores, g.EScoreCorrectionBias) + origScores, negScores = mlx.SigmoidRouter(gates, g.EScoreCorrectionBias) + } else { + origScores = mlx.Sigmoid(gates) + negScores = mlx.Neg(origScores) } topK := cfg.NumExpertsPerTok - negScores := mlx.Neg(scores) inds := mlx.Argpartition(negScores, int(topK)-1, -1) dims := inds.Dims() inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK}) - scores = mlx.TakeAlongAxis(origScores, inds, -1) + scores := mlx.TakeAlongAxis(origScores, inds, -1) if topK > 1 && cfg.NormTopKProb { sumScores := mlx.Sum(scores, -1, true)