This commit is contained in:
Michael Verrilli 2026-04-21 19:44:32 -04:00 committed by GitHub
commit a31703aa73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
66 changed files with 17881 additions and 75 deletions

View file

@ -41,6 +41,7 @@ set(GGML_LLAMAFILE ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_CUDA_GRAPHS ON)
set(GGML_CUDA_FA ON)
set(GGML_CUDA_FA_ALL_QUANTS OFF)
set(GGML_CUDA_COMPRESSION_MODE default)
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")

View file

@ -363,6 +363,16 @@ The currently available K/V cache quantization types are:
- `f16` - high precision and memory usage (default).
- `q8_0` - 8-bit quantization, uses approximately 1/2 the memory of `f16` with a very small loss in precision, this usually has no noticeable impact on the model's quality (recommended if not using f16).
- `q4_0` - 4-bit quantization, uses approximately 1/4 the memory of `f16` with a small-medium loss in precision that may be more noticeable at higher context sizes.
- `tq3k` - TurboQuant 3-bit K-only quantization, uses approximately 3/5 the memory of `f16` with a negligible quality impact, and unlike `q4_0`/`q8_0` does not require Flash Attention to be enabled.
- `tq2k` - TurboQuant 2-bit K-only quantization, uses approximately 3/5 the memory of `f16` with a small quality impact, and unlike `q4_0`/`q8_0` does not require Flash Attention to be enabled.
- `tq3` - TurboQuant 3-bit K and V quantization, uses approximately 1/5 the memory of `f16` with a negligible quality impact. Requires Flash Attention.
- `tq2` - TurboQuant 2-bit K and V quantization, uses approximately 1/7 the memory of `f16` with a small quality impact. Requires Flash Attention.
<Note>
The `tq*` (TurboQuant) cache types require an NVIDIA GPU with compute capability 6.0 or newer (Pascal or later) and run only through Ollama's native engine. TurboQuant applies to full-context attention layers; sliding-window layers (e.g. in Gemma) continue to use `f16` storage.
TurboQuant has been validated on Llama and Gemma 3 architectures. Models in the Qwen 2 family (including Qwen 2.5) are a known weak spot: their learned K bias gives them a per-channel asymmetric distribution that the paper-grounded rotation does not model well, so output quality under `tq3k`/`tq2k` is currently noticeably worse than `q8_0` or `q4_0`. For stable long-context use on Qwen 2, prefer `q8_0` or `q4_0`.
</Note>
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.

View file

@ -852,7 +852,7 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
return true
}
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
return slices.Contains([]string{"q8_0", "q4_0", "tq2", "tq3", "tq3k", "tq2k"}, cacheType)
}
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
@ -863,6 +863,20 @@ func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
return true
}
// KVCacheTypeRequiresFlashAttention reports whether a KV cache type can only be
// used when flash attention is enabled. Ggml-native quantized types (q8_0,
// q4_0) store K/V as quantized tensors that the non-FA softmax+matmul attention
// path cannot consume directly. TurboQuant K-only presets (tq2k, tq3k) dequant
// the packed K buffer to f16 before the attention op runs, so they work with
// either FA or the standard attention path.
func (f GGML) KVCacheTypeRequiresFlashAttention(cacheType string) bool {
switch cacheType {
case "tq2k", "tq3k":
return false
}
return f.KVCacheTypeIsQuantized(cacheType)
}
// SupportsFlashAttention checks if the model supports flash attention
func (f GGML) SupportsFlashAttention() bool {
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
@ -913,6 +927,18 @@ func kvCacheBytesPerElement(cacheType string) float64 {
return 1 // 1/2 of fp16
case "q4_0":
return 0.5 // 1/4 of fp16
case "tq3":
// 3-bit TQ K (~0.41 B/elem) + 3-bit TQ V (~0.41 B/elem), averaged over K+V
return 0.41
case "tq3k":
// 3-bit TQ K (~0.41 B/elem) + f16 V (2 B/elem), averaged over K+V
return 1.205
case "tq2":
// 2-bit TQ K (~0.28 B/elem) + 2-bit TQ V (~0.28 B/elem), averaged over K+V
return 0.28
case "tq2k":
// 2-bit TQ K (~0.28 B/elem) + f16 V (2 B/elem), averaged over K+V
return 1.14
case "f32":
return 4 // f32 (default for recurrent)
default:

View file

@ -82,3 +82,12 @@ type Cache interface {
type CheckpointCache interface {
PrepareRestore(seq int, targetPos int32) (int32, bool)
}
// CausalConfigurable is implemented by caches that accept CausalOptions
// (the concrete *Causal type and any wrappers that delegate to it, e.g.,
// *TurboQuantCache). Model code that needs to set per-layer masking options
// should assert against this interface rather than the concrete *Causal so
// that TurboQuant-wrapped global sub-caches still receive SetCausal calls.
type CausalConfigurable interface {
SetCausal(ctx ml.Context, opts CausalOptions)
}

View file

@ -18,7 +18,21 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
// The tensors are of shape embed dim, kv heads, batch size
// The mask is of shape history size, batch size
type Causal struct {
DType ml.DType
// DTypeK and DTypeV are the storage dtypes for K and V respectively.
// Init takes a single dtype and assigns it to both unless a caller has
// explicitly set one or both before Init runs (e.g., TurboQuantCache
// which forces its inner Causal to f16 on both sides while routing
// compressed K through a separate manager via SkipK).
DTypeK ml.DType
DTypeV ml.DType
// SkipK suppresses K tensor allocation and writes. When true, an external
// manager (e.g., TurboQuantCache) handles K storage and returns K from Get.
SkipK bool
// SkipV suppresses V tensor allocation and writes. When true, an external
// manager (e.g., TurboQuantCache) handles V storage and returns V from Get.
SkipV bool
// swaWindowSize is the number of tokens that will be included in the mask
// during attention operations. swaMemorySize is the number of tokens that
@ -44,6 +58,10 @@ type Causal struct {
// locations for data storage for this batch
curLoc ml.Tensor
// curLocs mirrors curLoc as a raw int slice for TurboQuant's per-cell
// byte-map indexing. Only populated when SkipK or SkipV is set — non-TQ
// users never pay the allocation cost.
curLocs []int
// mask of the cache as used by this batch
curMask ml.Tensor
@ -175,7 +193,15 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
c.DType = dtype
// Resolve effective K/V dtypes. If a caller (e.g. TurboQuantCache) hasn't
// already set DTypeK or DTypeV before Init runs, fall back to the single
// dtype parameter for both — the historical single-dtype behaviour.
if c.DTypeK == ml.DTypeOther {
c.DTypeK = dtype
}
if c.DTypeV == ml.DTypeOther {
c.DTypeV = dtype
}
c.cellRanges = make(map[int]cellRange)
c.backend = backend
c.maxBatch = maxBatch
@ -242,6 +268,17 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
}
c.curLoc = ctx.Input().FromInts(locs, len(locs))
// curLocs is only needed by TurboQuantCache's per-cell byte-map indexing.
// Skip the allocation and conversion loop on the non-TQ path so vanilla
// users don't pay per-batch overhead for an unused slice.
if c.SkipK || c.SkipV {
c.curLocs = make([]int, len(locs))
for i, l := range locs {
c.curLocs[i] = int(l)
}
} else {
c.curLocs = nil
}
c.curMask = c.buildMask(ctx)
return nil
@ -254,8 +291,39 @@ func newRange() cellRange {
}
}
// Returns a slice of locations where each token in the batch should be stored
// Returns a slice of locations where each token in the batch should be stored.
//
// When SkipK/SkipV is set (TurboQuant's compressed path), the backing TQ
// encode kernels write a contiguous run of cells starting at loc[0]; a
// fragmented allocation would desynchronize the compressed K/V buffers
// from the per-cell metadata. Require a contiguous empty run of the
// required length and surface ErrKvCacheFull otherwise — the runner treats
// that the same as an out-of-space condition, which lets a fragmented
// cache recover as other sequences complete rather than crashing the
// process on a gapped write.
func (c *Causal) findLocs() ([]int32, error) {
if c.SkipK || c.SkipV {
runStart, runLen := -1, 0
for i := range c.cells {
if len(c.cells[i].sequences) == 0 {
if runLen == 0 {
runStart = i
}
runLen++
if runLen >= c.curBatchSize {
loc := make([]int32, c.curBatchSize)
for j := range loc {
loc[j] = int32(runStart + j)
}
return loc, nil
}
} else {
runLen = 0
}
}
return nil, fmt.Errorf("%w (cache: %v batch: %v, no contiguous run)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
}
loc := make([]int32, 0, c.curBatchSize)
for i := range c.cells {
@ -409,40 +477,47 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
cachedSize := c.curMask.Dim(0)
key = key.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
cachedSize,
)
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
value = value.View(ctx, elemSize*c.curCellRange.min,
cachedSize, value.Stride(1),
vHeadDim, value.Stride(2),
numKVHeads,
)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
value = value.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, value.Stride(1),
numKVHeads, value.Stride(2),
var key ml.Tensor
if !c.SkipK {
k := c.keys[c.curLayer]
kHeadDim := k.Dim(0)
numKVHeads := k.Dim(1)
rowSize := k.Stride(2)
key = k.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, k.Stride(1),
numKVHeads, k.Stride(2),
cachedSize,
)
}
var value ml.Tensor
if !c.SkipV {
v := c.values[c.curLayer]
if c.config.PermutedV {
vHeadDim := v.Dim(1)
vKVHeads := v.Dim(2)
elemSize := v.Stride(0)
value = v.View(ctx, elemSize*c.curCellRange.min,
cachedSize, v.Stride(1),
vHeadDim, v.Stride(2),
vKVHeads,
)
} else {
vHeadDim := v.Dim(0)
vKVHeads := v.Dim(1)
rowSize := v.Stride(2)
value = v.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, v.Stride(1),
vKVHeads, v.Stride(2),
cachedSize,
)
}
}
return key, value, c.curMask
}
@ -460,37 +535,45 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
if !c.SkipK {
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DTypeK, kHeadDim, numKVHeads, len(c.cells))
}
}
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
keyCache := c.keys[c.curLayer]
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
if !c.SkipV {
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DTypeV, len(c.cells), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DTypeV, vHeadDim, numKVHeads, len(c.cells))
}
}
}
if c.config.PermutedV {
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
value = value.Permute(ctx, 2, 0, 1, 3)
if !c.SkipK {
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
keyCache := c.keys[c.curLayer]
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
}
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
if !c.SkipV {
if c.config.PermutedV {
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
value = value.Permute(ctx, 2, 0, 1, 3)
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
} else {
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
} else {
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
}
}
}

View file

@ -40,7 +40,7 @@ var (
// Conv state shape (per layer, per sequence): [convDim, convChannels]
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
type Recurrent struct {
kv *Causal
kv Cache
backend ml.Backend
dtype ml.DType
@ -95,6 +95,19 @@ type Recurrent struct {
writableError error
}
// AttentionKV returns the inner attention KV cache when it is still the
// original *Causal. Returns nil after the cache has been replaced (e.g. by
// *TurboQuantCache), making WrapWithTurboQuant idempotently skip a second wrap.
func (c *Recurrent) AttentionKV() *Causal {
causal, _ := c.kv.(*Causal)
return causal
}
// SetAttentionKV replaces the inner attention KV cache. Used by
// WrapWithTurboQuant to inject compression into the attention path
// without disturbing the embedded conv/recurrent state tensors.
func (c *Recurrent) SetAttentionKV(kv Cache) { c.kv = kv }
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
return &Recurrent{
kv: NewCausalCache(config.Shift),

619
kvcache/turboquant.go Normal file
View file

@ -0,0 +1,619 @@
package kvcache
import (
"fmt"
"log/slog"
"math"
"sync"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/turboquant"
)
type TurboQuantCache struct {
meta *Causal
preset turboquant.Preset
isReserve bool
compressedK ml.TQCompressedKManager
// phase2Checked ensures GPU encode is activated at most once.
phase2Checked bool
headDim int
numKVHeads int
// encodeResults stores per-layer EncodeK result tensors for the current
// forward pass. DequantK uses them as src[0] to establish the graph
// dependency (encode before dequant in the ggml scheduler).
encodeResults map[int]ml.Tensor
// vEncodeResults stores per-layer EncodeV result tensors for the current
// forward pass. DequantV uses them to establish the encode→dequant ordering.
vEncodeResults map[int]ml.Tensor
// logPathOnce ensures each active Get() path is logged at most once per
// cache instance (avoids log spam: Get() is called every layer every step).
logPathOnce [5]sync.Once
// fusedFallbackEligible gates the inline-decode fused-FA fallback paths
// (Get paths 2 and 4). The CUDA fused kernel is template-instantiated only
// at D=128, so models with a larger head dim (gemma4 D=512) must skip it
// to avoid a kernel-side GGML_ASSERT. The Metal fused kernel has both
// D=128 and D=256 variants (kernel_tq_fattn_vec_*{,_d256}), so gemma3
// D=256 is eligible on Metal but not on CUDA until the CUDA kernel gains
// a D=256 instantiation. The DequantK + stock FA path (Get paths 0/1/5)
// works at any head dim — this gate is specific to the inline-decode
// variants.
fusedFallbackEligible bool
// preferFusedAttn is true on Metal. At long context, DequantKV + stock FA
// writes a full f16 intermediate buffer (doubling KV bandwidth) before
// attention reads it. The fused kernel (kernel_tq_fattn_vec_packed) reads
// packed K+V directly and skips the intermediate write — dramatically
// faster on Metal for *decode* (Q=1). On CUDA the DequantKV path is
// preferred because the intermediate buffer stays in L2 and stock flash
// attention is highly tuned.
//
// For prefill (Q>1, see curQueryLen) DequantKV wins on every backend
// because it decodes each packed cell once then runs stock FA with full
// batch amortisation, whereas the fused kernel re-decodes each cell per
// Q-token. Get() uses the combination of preferFusedAttn and curQueryLen
// to route prefill vs decode independently.
preferFusedAttn bool
// curQueryLen is the number of query tokens in the current forward pass,
// captured at StartForward from len(batch.Positions). curQueryLen==1
// means decode; curQueryLen>1 means prefill (or a batched prompt chunk).
curQueryLen int
// rotMatrix is the R^T rotation matrix sized for this cache's headDim.
// Set in activateGPUEncode. Get() sets the backend's tqRotationMatrix to
// this value per-call (consume-once) right before returning rotated K.
rotMatrix ml.Tensor
// vRotMatrix is the R (inverse) matrix used to undo V rotation after
// attention in K+V presets (tq3/tq2). Nil for K-only presets. Get() sets
// the backend's tqVRotationMatrix to this value per-call along with
// rotMatrix so SDPA applies R @ attn_out after flash attention.
vRotMatrix ml.Tensor
// rotSetter is the cached type assertion of c.meta.backend onto the TQ
// rotation setter interface, populated once in activateGPUEncode. nil
// when the backend doesn't support TQ rotation hooks (the fallback case).
rotSetter tqRotSetter
}
// tqRotSetter is the backend hook TurboQuantCache uses to arm the per-call
// rotation matrices SDPA consumes. Implemented by ml/backend/ggml.Backend.
type tqRotSetter interface {
SetTQRotationMatrix(ml.Tensor)
SetTQVRotationMatrix(ml.Tensor)
}
// isSWACausal reports whether a *Causal has sliding-window attention
// active. Plain Causal caches have swaWindowSize either 0 (before Init
// normalizes the default) or math.MaxInt32 (after); SWA constructors set
// it to the actual window size.
func isSWACausal(c *Causal) bool {
return c.swaWindowSize > 0 && c.swaWindowSize != math.MaxInt32
}
// AttentionKVWrapper is implemented by caches that embed *Recurrent and
// expose the attention half of a hybrid (SSM/recurrent + attention) cache.
// WrapWithTurboQuant uses it to inject TurboQuant compression into the
// attention KV path without disturbing conv/recurrent state buffers.
// *kvcache.Recurrent implements this interface, so any model HybridCache that
// embeds *Recurrent satisfies it automatically via Go method promotion.
type AttentionKVWrapper interface {
AttentionKV() *Causal
SetAttentionKV(Cache)
}
// WrapWithTurboQuant returns a cache that applies TurboQuant compression to
// global-attention Causal layers and a bool reporting whether any wrapping
// took effect. For a top-level *Causal (non-SWA), it returns a new
// *TurboQuantCache. For a *WrapperCache, it mutates the caches slice in
// place, replacing every non-SWA *Causal sub-cache with a *TurboQuantCache,
// and returns the same *WrapperCache pointer. This enables TQ on SWA models
// like gemma3/gemma4 where the global attention layers dominate KV memory
// at long context. Returns (cache, false) if no eligible sub-caches were
// found.
func WrapWithTurboQuant(cache Cache, preset turboquant.Preset) (Cache, bool) {
switch c := cache.(type) {
case *Causal:
// Reject SWA caches. Plain NewCausalCache leaves swaWindowSize=0
// until Init() normalizes it to math.MaxInt32; SWA constructors set
// it to the actual window size. "Plain causal" means the field is
// either 0 (uninitialized default) or math.MaxInt32 (post-Init).
if isSWACausal(c) {
slog.Warn("turboquant: top-level Causal is sliding-window, cannot wrap")
return cache, false
}
return &TurboQuantCache{
meta: c,
preset: preset,
encodeResults: make(map[int]ml.Tensor),
vEncodeResults: make(map[int]ml.Tensor),
}, true
case *WrapperCache:
// Mutate sub-caches in place: replace every non-SWA *Causal with a
// *TurboQuantCache wrapping it. SWA sub-caches (SWACache, SWAMemCache)
// are left untouched — they still allocate f16 K/V as before.
wrapped := 0
for i, sub := range c.caches {
inner, ok := sub.(*Causal)
if !ok || isSWACausal(inner) {
continue
}
c.caches[i] = &TurboQuantCache{
meta: inner,
preset: preset,
encodeResults: make(map[int]ml.Tensor),
vEncodeResults: make(map[int]ml.Tensor),
}
wrapped++
}
if wrapped == 0 {
slog.Warn("turboquant: no eligible Causal sub-caches in WrapperCache, falling back to unwrapped cache")
return cache, false
}
slog.Info("turboquant: wrapped Causal sub-caches inside WrapperCache",
"count", wrapped, "preset", preset.Name)
return cache, true
case AttentionKVWrapper:
inner := c.AttentionKV()
if inner == nil {
slog.Warn("turboquant: hybrid cache inner kv is not *Causal (already wrapped?), leaving as-is")
return cache, false
}
if isSWACausal(inner) {
slog.Warn("turboquant: hybrid cache inner *Causal is sliding-window, cannot wrap")
return cache, false
}
c.SetAttentionKV(&TurboQuantCache{
meta: inner,
preset: preset,
encodeResults: make(map[int]ml.Tensor),
vEncodeResults: make(map[int]ml.Tensor),
})
slog.Info("turboquant: wrapped attention KV in hybrid recurrent cache",
"preset", preset.Name)
return cache, true
default:
slog.Warn("turboquant: underlying cache is not *Causal or *WrapperCache, falling back to unwrapped cache")
return cache, false
}
}
func (c *TurboQuantCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
// K is always compressed; suppress inner Causal from allocating it.
c.meta.SkipK = true
// V is compressed only when ValueBits > 0 (tq2/tq3). K-only presets
// (tq2k/tq3k) have ValueBits=0 and keep V in the f16 Causal cache.
if c.preset.ValueBits > 0 {
c.meta.SkipV = true
}
c.meta.Init(backend, ml.DTypeF16, maxSequences, capacity, maxBatch)
slog.Info("turboquant cache initialized", "preset", c.preset.Name,
"K_bits", c.preset.KeyPrimaryBits, "V_bits", c.preset.ValueBits)
}
func (c *TurboQuantCache) Close() {
if c.compressedK != nil {
c.compressedK.Close()
c.compressedK = nil
}
c.meta.Close()
}
func (c *TurboQuantCache) SetLayer(layer int) { c.meta.SetLayer(layer) }
func (c *TurboQuantCache) SetConfig(config ml.CacheConfig) { c.meta.SetConfig(config) }
func (c *TurboQuantCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.isReserve = reserve
c.curQueryLen = len(batch.Positions)
clear(c.encodeResults)
clear(c.vEncodeResults)
return c.meta.StartForward(ctx, batch, reserve)
}
func (c *TurboQuantCache) Put(ctx ml.Context, key, value ml.Tensor) {
// Capture headDim early — needed even during reserve to size placeholders.
if c.headDim == 0 && key != nil {
c.headDim = key.Dim(0)
c.numKVHeads = key.Dim(1)
}
// Activate the GPU encode path on the first Put (reserve or not) once we
// know headDim/numKVHeads. Doing this during reserve lets EnsureLayer run
// on the reserve pass, which books TQ's per-layer persistent K/V buffers
// into btDeviceMemory.Cache[layer] (via newTensor's ctx.layer accounting
// in EnsureLayer). Without this, the scheduler's fit/alloc probe sees a
// 0-byte Cache for tq2/tq3 and only the f16 V half for tq2k/tq3k, which
// is wrong for both headline-footprint reporting and long-context fit math.
if !c.phase2Checked && c.headDim > 0 {
c.phase2Checked = true
c.activateGPUEncode()
}
if c.isReserve {
c.meta.Put(ctx, key, value)
// Eagerly allocate TQ persistent buffers during reserve so the scheduler's
// per-layer Cache totals reflect the real post-compression footprint. No
// encode kernels run on reserve — we only need the tensors in the layer
// context so newTensor's c.b.btDeviceMemory[...].Cache[layer] bump fires.
if c.compressedK != nil {
layer := c.meta.curLayer
capacity := len(c.meta.cells)
c.compressedK.EnsureLayer(layer, capacity)
if c.preset.ValueBits > 0 {
c.compressedK.EnsureVLayer(layer, capacity)
}
}
return
}
if c.compressedK != nil {
layer := c.meta.curLayer
capacity := len(c.meta.cells)
c.compressedK.EnsureLayer(layer, capacity)
if c.preset.ValueBits > 0 {
c.compressedK.EnsureVLayer(layer, capacity)
}
// The TQ encode kernels (CUDA + Metal) write a contiguous run of
// cells starting at firstCell. Causal.findLocs guarantees a
// contiguous run when SkipK/SkipV is set (otherwise it returns
// ErrKvCacheFull before we reach here); this loop is a defensive
// invariant check that should never fire in practice.
firstCell := 0
if len(c.meta.curLocs) > 0 {
firstCell = c.meta.curLocs[0]
for i := 1; i < len(c.meta.curLocs); i++ {
if c.meta.curLocs[i] != firstCell+i {
panic(fmt.Sprintf("turboquant: non-contiguous cache slots %v — findLocs invariant violated", c.meta.curLocs))
}
}
}
if c.preset.ValueBits > 0 {
// Combined K+V encode: single GGML op, two back-to-back kernels.
kResult, vResult := c.compressedK.EncodeKV(ctx, layer, key, value, firstCell)
if kResult != nil {
ctx.Forward(kResult)
c.encodeResults[layer] = kResult
c.vEncodeResults[layer] = vResult
}
} else {
// K-only presets (tq2k/tq3k): V stays as f16 in the Causal cache.
encodeResult := c.compressedK.EncodeK(ctx, layer, key, firstCell)
if encodeResult != nil {
ctx.Forward(encodeResult)
c.encodeResults[layer] = encodeResult
}
}
// Inner Causal.Put() tracks cell metadata (positions, masks).
// SkipK and SkipV suppress the actual K/V tensor writes.
c.meta.Put(ctx, key, value)
return
}
c.meta.Put(ctx, key, value)
}
func (c *TurboQuantCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
if c.isReserve {
key, value, mask := c.meta.Get(ctx)
// SkipK: synthesize zero f16 K placeholder for graph sizing.
if key == nil && c.headDim > 0 {
nCells := 1
if c.meta.curMask != nil {
nCells = c.meta.curMask.Dim(0)
}
key = ctx.Input().Zeros(ml.DTypeF16, c.headDim, c.numKVHeads, nCells)
}
// SkipV: synthesize zero f16 V placeholder for graph sizing.
if value == nil && c.headDim > 0 {
nCells := 1
if c.meta.curMask != nil {
nCells = c.meta.curMask.Dim(0)
}
value = ctx.Input().Zeros(ml.DTypeF16, c.headDim, c.numKVHeads, nCells)
}
return key, value, mask
}
if c.compressedK != nil {
layer := c.meta.curLayer
firstCell := c.meta.curCellRange.min
nCells := c.meta.curMask.Dim(0)
encodeResult := c.encodeResults[layer]
vEncodeResult := c.vEncodeResults[layer]
// 0. K-only presets (tq2k/tq3k): DequantK + f16 V from Causal → stock FA.
// Skips the fused FA kernel (slower than stock FA on Pascal).
if c.preset.ValueBits == 0 && encodeResult != nil {
key := c.compressedK.DequantK(ctx, layer, encodeResult, firstCell, nCells)
if key != nil {
c.logPathOnce[0].Do(func() {
slog.Info("turboquant: using K-only DequantK + f16 V path")
})
_, value, mask := c.meta.Get(ctx)
c.armRotationForNextSDPA()
return key, value, mask
}
}
// Prefill vs decode routing on Metal: the fused inline-decode kernel
// re-decodes each packed K/V cell per Q-token, so its cost scales
// O(nCells × nTokensQ). For prefill (nTokensQ ≫ 1) DequantKV + stock
// FA wins by decoding each cell once and then letting stock FA
// amortise the read across all Q-tokens. For decode (nTokensQ=1) the
// fused path wins by avoiding the f16 intermediate write. This flag
// routes independently of preferFusedAttn so CUDA (preferFusedAttn=
// false) continues to take path 1 for everything.
useDequantKVForPrefill := c.preferFusedAttn && c.curQueryLen > 1
// 1. Combined K+V dequant → stock FA.
// * CUDA/ROCm: always (preferFusedAttn=false).
// * Metal: prefill only (batched Q). Metal decode drops to path 2.
if vEncodeResult != nil && (!c.preferFusedAttn || useDequantKVForPrefill) {
key, value := c.compressedK.DequantKV(ctx, layer, encodeResult, vEncodeResult, firstCell, nCells)
if key != nil && value != nil {
c.logPathOnce[1].Do(func() {
if useDequantKVForPrefill {
slog.Info("turboquant: using combined DequantKV + stock FA path (Metal prefill)")
} else {
slog.Info("turboquant: using combined DequantKV + stock FA path")
}
})
_, _, mask := c.meta.Get(ctx)
c.armRotationForNextSDPA()
return key, value, mask
}
}
// 2. K+V fused inline-decode: reads packed K+V directly, no f16 intermediate.
// Primary path on Metal decode (Q=1); fallback on CUDA/ROCm.
// Instantiated at D=128 always; D=256 on Metal only.
if vEncodeResult != nil && c.fusedFallbackEligible {
if tqkv, ok := c.compressedK.GetAsTQTensorKV(ctx, layer, encodeResult, vEncodeResult, firstCell, nCells); ok {
c.logPathOnce[2].Do(func() {
if c.preferFusedAttn {
slog.Info("turboquant: using K+V fused inline-decode path (Metal decode: avoids f16 intermediate)")
} else {
slog.Warn("turboquant: falling back to K+V inline-decode fused kernel (slower)")
}
})
_, _, mask := c.meta.Get(ctx)
c.armRotationForNextSDPA()
return tqkv, nil, mask
}
}
// 1b. DequantKV fallback when Metal decode tried path 2 and the fused
// path was unavailable (e.g. headDim outside 128/256).
if vEncodeResult != nil && c.preferFusedAttn {
key, value := c.compressedK.DequantKV(ctx, layer, encodeResult, vEncodeResult, firstCell, nCells)
if key != nil && value != nil {
c.logPathOnce[1].Do(func() {
slog.Info("turboquant: using combined DequantKV + stock FA path (fused unavailable)")
})
_, _, mask := c.meta.Get(ctx)
c.armRotationForNextSDPA()
return key, value, mask
}
}
// 3. V-only dequant for K-only fused or separate K dequant fallback.
var value ml.Tensor
if vEncodeResult != nil {
value = c.compressedK.DequantV(ctx, layer, vEncodeResult, firstCell, nCells)
}
// 4. Try K-only fused: K decoded inline, V is dequanted f16. Gated on
// fusedFallbackEligible for the same D=128 reason as path 2.
if c.fusedFallbackEligible {
if tqk, ok := c.compressedK.GetAsTQTensor(ctx, layer, encodeResult, firstCell, nCells); ok {
c.logPathOnce[3].Do(func() {
slog.Warn("turboquant: falling back to K-only inline-decode fused kernel")
})
_, metaValue, mask := c.meta.Get(ctx)
if value == nil {
value = metaValue
}
c.armRotationForNextSDPA()
return tqk, value, mask
}
}
// 5. Separate K + V dequant fallback (last resort).
c.logPathOnce[4].Do(func() {
slog.Warn("turboquant: falling back to separate K + V dequant path")
})
key := c.compressedK.DequantK(ctx, layer, encodeResult, firstCell, nCells)
_, metaValue, mask := c.meta.Get(ctx)
if value == nil {
value = metaValue
}
c.armRotationForNextSDPA()
return key, value, mask
}
return c.meta.Get(ctx)
}
// armRotationForNextSDPA sets the backend's tqRotationMatrix (and
// tqVRotationMatrix for K+V presets) so the next SDPA call — which happens
// immediately after this Get returns — rotates Q to match the TQ-rotated K
// and applies the V rotation undo after attention. SDPA reads-and-clears the
// flags, so non-TQ sub-cache layers in a WrapperCache (e.g. gemma3 SWA
// layers) are unaffected.
func (c *TurboQuantCache) armRotationForNextSDPA() {
if c.rotMatrix == nil || c.rotSetter == nil {
return
}
c.rotSetter.SetTQRotationMatrix(c.rotMatrix)
// For K+V presets (tq3/tq2), also arm the V rotation undo on the next
// SDPA call. For K-only presets (tq3k/tq2k), c.vRotMatrix is nil so the
// V rotation is not armed and SDPA's consumed vRot stays nil.
if c.vRotMatrix != nil {
c.rotSetter.SetTQVRotationMatrix(c.vRotMatrix)
}
}
// activateGPUEncode initialises the TQ compressed-K manager if the backend
// supports it and re-enables Q rotation (stored K is in rotated space).
func (c *TurboQuantCache) activateGPUEncode() {
// fallbackToF16 un-skips K/V on the inner Causal so subsequent Put/Get
// on this cache store and read f16 tensors like an ordinary non-TQ cache.
// Init() sets SkipK/SkipV unconditionally, so any failure to activate GPU
// encode must reverse those flags before the current Put() continues into
// c.meta.Put — otherwise SDPA receives a nil K/V from Get and segfaults.
// K allocation in Causal.Put is lazy (keyed on presence of c.keys[layer])
// so flipping the flag on the first Put is sufficient.
fallbackToF16 := func() {
c.meta.SkipK = false
c.meta.SkipV = false
}
tqb, ok := c.meta.backend.(ml.TQCompressedKBackend)
if !ok {
fallbackToF16()
return
}
// Pass the preset's outlier config so the manager can enable post-rotation
// outlier split on the GPU encode path. This is required for correct
// output on models with learned K bias (e.g. qwen2 family) and matches
// the TurboQuant paper's validated experimental setup.
mgr := tqb.NewTQCompressedKManager(
c.headDim, c.numKVHeads, c.preset.KeyPrimaryBits, c.preset.RotationSeed,
c.preset.ValueBits, c.preset.OutlierBits, c.preset.OutlierCount,
)
if mgr == nil {
slog.Info("turboquant: GPU encode not available, using f16 K fallback")
fallbackToF16()
return
}
c.compressedK = mgr
type fusedAttnPreferrer interface {
PreferFusedAttention() bool
}
if ff, ok := mgr.(fusedAttnPreferrer); ok {
c.preferFusedAttn = ff.PreferFusedAttention()
if c.preferFusedAttn {
slog.Info("turboquant: preferring fused flash-attention path (Metal: avoids f16 intermediate buffer)")
}
}
// The inline-decode fused-FA fallback paths (Get paths 2 and 4) dispatch
// to a kernel that is D-specialised. CUDA has only D=128 today; Metal has
// D=128 and D=256 (kernel_tq_fattn_vec_*{,_d256}). Models with an
// unsupported head dim (e.g. gemma4 D=512, or gemma3 D=256 on CUDA) must
// skip these fallbacks to avoid a kernel-side GGML_ASSERT; path 5
// (separate K+V dequant) handles them correctly.
c.fusedFallbackEligible = c.headDim == 128 ||
(c.headDim == 256 && c.preferFusedAttn)
if !c.fusedFallbackEligible {
reason := "headDim != 128"
if c.headDim == 256 {
reason = "headDim == 256 but backend lacks D=256 fused kernel"
}
slog.Info("turboquant: inline-decode fused-FA fallback paths disabled",
"reason", reason, "headDim", c.headDim)
}
// Cache the rotation matrices and the backend's rotation-setter hook on
// TurboQuantCache so Get() can arm them per-call without re-running a
// type assertion every layer. We do NOT set them at activate time — a
// sticky backend-global rotation would corrupt attention on unwrapped
// SWA layers in mixed-head-dim models like gemma3.
c.rotMatrix = mgr.RotationMatrix(nil, 0)
if rs, ok := c.meta.backend.(tqRotSetter); ok {
c.rotSetter = rs
}
type vRotFusedSetter interface {
SetTQVRotFusedInDequant(bool)
}
type vRotProvider interface {
RotationMatrixR() ml.Tensor
}
if c.preset.ValueBits > 0 {
if vp, ok := mgr.(vRotProvider); ok {
c.vRotMatrix = vp.RotationMatrixR()
}
// DequantKV outputs R^T @ v (still rotated). SDPA applies the rotation
// undo as R @ attn_out via mulmat, which is dramatically faster than
// the per-cell matmul of the fused dequant kernel.
if rs, ok := c.meta.backend.(vRotFusedSetter); ok {
rs.SetTQVRotFusedInDequant(false)
}
}
slog.Info("turboquant: GPU-native encode active",
"headDim", c.headDim, "numKVHeads", c.numKVHeads,
"K_bits", c.preset.KeyPrimaryBits, "V_bits", c.preset.ValueBits)
}
func (c *TurboQuantCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
c.meta.CopyPrefix(srcSeq, dstSeq, prefixLen)
}
func (c *TurboQuantCache) CanResume(seq int, pos int32) bool {
return c.meta.CanResume(seq, pos)
}
// Remove returns ErrNotSupported for any partial eviction when GPU-compressed
// K is active: the compressed buffer cannot be RoPE-shifted in-place, and the
// inner Causal.Remove path silently skips its shiftFn loop because SkipK
// leaves c.keys empty — survivors would keep their old RoPE embeddings while
// their positions are decremented, producing wrong attention scores. Only a
// full-sequence eviction (Remove(seq, 0, MaxInt32)) avoids the shift path;
// other callers must fall back to full reprocessing via ErrReprocessInputs.
func (c *TurboQuantCache) Remove(seq int, beginIndex, endIndex int32) error {
if c.compressedK != nil && !(beginIndex == 0 && endIndex == math.MaxInt32) {
return ErrNotSupported
}
return c.meta.Remove(seq, beginIndex, endIndex)
}
func (c *TurboQuantCache) SetCausal(ctx ml.Context, opts CausalOptions) {
c.meta.SetCausal(ctx, opts)
}
func PresetFromDType(dtype ml.DType) (turboquant.Preset, bool) {
switch dtype {
case ml.DTypeTQ2:
return turboquant.PresetTQ2, true
case ml.DTypeTQ3:
return turboquant.PresetTQ3, true
case ml.DTypeTQ3K:
return turboquant.PresetTQ3K, true
case ml.DTypeTQ2K:
return turboquant.PresetTQ2K, true
default:
return turboquant.Preset{}, false
}
}
var _ Cache = (*TurboQuantCache)(nil)
// Note: TurboQuantCache intentionally does NOT implement CheckpointCache.
// CheckpointCache is for recurrent caches that need special per-sequence
// state restoration. The inner *Causal cache supports plain CanResume, which
// is the right semantics for TQ-wrapped caches too — the runner will fall
// through to the CanResume branch when TurboQuantCache is in use (see
// runner/ollamarunner/cache.go:163-170). An earlier implementation had a
// stub PrepareRestore that always returned (0, false), which forced a full
// prompt reprocess on every resume for long-context gemma/llama runs.

144
kvcache/turboquant_test.go Normal file
View file

@ -0,0 +1,144 @@
package kvcache
import (
"testing"
"github.com/ollama/ollama/turboquant"
)
// TestWrapWithTurboQuantWrapperCache verifies Path C: when wrapping a
// WrapperCache that contains a SWA Causal + a plain Causal, TurboQuant
// replaces only the plain Causal sub-cache and leaves the SWA sub-cache
// untouched. This is the gemma3/gemma4 pattern.
func TestWrapWithTurboQuantWrapperCache(t *testing.T) {
swa := NewSWAMemCache(1024, 4096, nil)
global := NewCausalCache(nil)
wc := NewWrapperCache(swa, global)
wrapped, active := WrapWithTurboQuant(wc, turboquant.PresetTQ3K)
if !active {
t.Fatalf("expected active=true when wrapping a WrapperCache with a plain Causal sub-cache")
}
// The same WrapperCache pointer should be returned (mutated in place).
if wrapped != wc {
t.Fatalf("expected the same WrapperCache pointer to be returned")
}
if len(wc.caches) != 2 {
t.Fatalf("expected 2 sub-caches, got %d", len(wc.caches))
}
// SWA sub-cache untouched.
if swaSub, ok := wc.caches[0].(*Causal); !ok || swaSub != swa {
t.Fatalf("expected wc.caches[0] to remain the original SWA Causal, got %T", wc.caches[0])
}
// Global sub-cache replaced with *TurboQuantCache wrapping the original.
tqc, ok := wc.caches[1].(*TurboQuantCache)
if !ok {
t.Fatalf("expected wc.caches[1] to be *TurboQuantCache, got %T", wc.caches[1])
}
if tqc.meta != global {
t.Fatalf("expected TurboQuantCache.meta to point at the original inner Causal")
}
if tqc.preset.Name != turboquant.PresetTQ3K.Name {
t.Fatalf("expected preset %q, got %q", turboquant.PresetTQ3K.Name, tqc.preset.Name)
}
}
// TestWrapWithTurboQuantSWAOnly verifies that wrapping a WrapperCache
// containing only SWA sub-caches yields active=false and leaves the
// cache unchanged.
func TestWrapWithTurboQuantSWAOnly(t *testing.T) {
swa := NewSWAMemCache(1024, 4096, nil)
wc := NewWrapperCache(swa)
wrapped, active := WrapWithTurboQuant(wc, turboquant.PresetTQ3K)
if active {
t.Fatalf("expected active=false for a SWA-only WrapperCache")
}
if wrapped != wc {
t.Fatalf("expected the same WrapperCache pointer to be returned")
}
// Sub-cache must remain the original SWA Causal.
if sub, ok := wc.caches[0].(*Causal); !ok || sub != swa {
t.Fatalf("expected wc.caches[0] unchanged, got %T", wc.caches[0])
}
}
// TestWrapWithTurboQuantTopLevelSWA verifies that a top-level SWA Causal
// cannot be wrapped — TurboQuant needs full-context Causal semantics.
func TestWrapWithTurboQuantTopLevelSWA(t *testing.T) {
swa := NewSWAMemCache(1024, 4096, nil)
wrapped, active := WrapWithTurboQuant(swa, turboquant.PresetTQ3K)
if active {
t.Fatalf("expected active=false for a top-level SWA Causal")
}
if wrapped != swa {
t.Fatalf("expected the SWA Causal to be returned unchanged")
}
}
// TestWrapWithTurboQuantTopLevelCausal verifies the existing top-level
// non-SWA Causal case still works: a new *TurboQuantCache is returned
// wrapping the input.
func TestWrapWithTurboQuantTopLevelCausal(t *testing.T) {
c := NewCausalCache(nil)
wrapped, active := WrapWithTurboQuant(c, turboquant.PresetTQ3K)
if !active {
t.Fatalf("expected active=true for a top-level plain Causal")
}
tqc, ok := wrapped.(*TurboQuantCache)
if !ok {
t.Fatalf("expected *TurboQuantCache, got %T", wrapped)
}
if tqc.meta != c {
t.Fatalf("expected TurboQuantCache.meta to point at the input Causal")
}
}
// TestWrapWithTurboQuantHybridCache verifies that a cache embedding *Recurrent
// (matching the model/models/*/HybridCache shape) has its inner *Causal swapped
// for a *TurboQuantCache in place, and the outer pointer is returned unchanged.
func TestWrapWithTurboQuantHybridCache(t *testing.T) {
type HybridCache struct{ *Recurrent }
r := NewRecurrentCache(RecurrentConfig{ConvDim: 4, ConvChannels: 2, RecurrentStateSize: 4})
hc := &HybridCache{Recurrent: r}
wrapped, active := WrapWithTurboQuant(hc, turboquant.PresetTQ3K)
if !active {
t.Fatal("expected active=true for AttentionKVWrapper")
}
if wrapped != hc {
t.Fatalf("expected same pointer returned, got %T", wrapped)
}
tqc, ok := r.kv.(*TurboQuantCache)
if !ok {
t.Fatalf("expected inner kv to be *TurboQuantCache, got %T", r.kv)
}
if tqc.preset.Name != turboquant.PresetTQ3K.Name {
t.Fatalf("preset mismatch: got %q", tqc.preset.Name)
}
}
// TestWrapWithTurboQuantHybridCacheIdempotent verifies that a second wrap attempt
// on an already-wrapped hybrid cache returns (cache, false) rather than double-wrapping.
func TestWrapWithTurboQuantHybridCacheIdempotent(t *testing.T) {
type HybridCache struct{ *Recurrent }
r := NewRecurrentCache(RecurrentConfig{ConvDim: 4, ConvChannels: 2, RecurrentStateSize: 4})
hc := &HybridCache{Recurrent: r}
_, _ = WrapWithTurboQuant(hc, turboquant.PresetTQ3K)
wrapped, active := WrapWithTurboQuant(hc, turboquant.PresetTQ3K)
if active {
t.Fatal("expected active=false on second wrap (already wrapped)")
}
if wrapped != hc {
t.Fatalf("expected same pointer returned, got %T", wrapped)
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,98 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Verrilli <msv@pobox.com>
Date: Sun, 19 Apr 2026 23:28:06 -0400
Subject: [PATCH] =?UTF-8?q?turboquant:=20fix=20HIP=20compile=20=E2=80=94?=
=?UTF-8?q?=20drop=20cuda=5Ffp16.h,=20pass=20=5F=5Fshfl=5Fsync=20width?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Two minimal, behavior-neutral edits so the TQ kernels compile under
hipcc on the ROCm 7 preset:
1. Remove direct `#include <cuda_fp16.h>` from tq-encode.cu,
tq-encode-v.cu, and tq-dequant.cu. Header is transitively pulled
in via common.cuh → vendors/{cuda,hip,musa}.h; direct include
is redundant on CUDA and fatal on HIP.
2. Pass width=32 explicitly on four __shfl_sync sites in tq-dequant.cu.
HIP's shim is a 4-arg macro with no defaults; CUDA's 3-arg overload
defaults to width=warpSize=32, so passing 32 is a compile fix on
HIP and a no-op on CUDA.
---
ggml/src/ggml-cuda/tq-dequant.cu | 14 ++++++++------
ggml/src/ggml-cuda/tq-encode-v.cu | 1 -
ggml/src/ggml-cuda/tq-encode.cu | 1 -
3 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/ggml/src/ggml-cuda/tq-dequant.cu b/ggml/src/ggml-cuda/tq-dequant.cu
index 03ed64543..4a4cc4d64 100644
--- a/ggml/src/ggml-cuda/tq-dequant.cu
+++ b/ggml/src/ggml-cuda/tq-dequant.cu
@@ -1,5 +1,4 @@
#include "tq-dequant.cuh"
-#include <cuda_fp16.h>
// Optimized TQ dequant kernel: warp-shuffle codebook + hardcoded bit extraction.
//
@@ -52,8 +51,11 @@ __global__ void tq_dequant_multihead_kernel(
}
// Codebook lookup via warp shuffle: zero global memory latency.
- // Width = warpSize (32) works because cb_lane is periodic with period (1<<bits).
- float val = __shfl_sync(0xFFFFFFFF, cb_lane, idx) * scale;
+ // Width = 32 works because cb_lane is periodic with period (1<<bits).
+ // Pass width explicitly — the HIP __shfl_sync shim is a 4-arg macro
+ // that doesn't default, and CUDA's 3-arg overload is width=warpSize=32
+ // anyway, so this is behavior-neutral on NVIDIA.
+ float val = __shfl_sync(0xFFFFFFFF, cb_lane, idx, 32) * scale;
cell_out[elem] = __float2half_rn(val);
}
}
@@ -180,7 +182,7 @@ __global__ void tq_dequant_multihead_kernel_outlier(
if (reg_shift + bits > 8) {
reg_idx |= (cell_reg[reg_byte_idx + 1] << (8 - reg_shift)) & cb_mask;
}
- float reg_val = __shfl_sync(0xFFFFFFFF, cb_lane_reg, reg_idx) * regScale;
+ float reg_val = __shfl_sync(0xFFFFFFFF, cb_lane_reg, reg_idx, 32) * regScale;
int out_slot_safe = (outlier_slot >= 0) ? outlier_slot : 0;
int out_bit_offset = out_slot_safe * outlier_bits;
@@ -190,7 +192,7 @@ __global__ void tq_dequant_multihead_kernel_outlier(
if (out_shift + outlier_bits > 8) {
out_idx |= (cell_outl[out_byte_idx + 1] << (8 - out_shift)) & ocb_mask;
}
- float out_val = __shfl_sync(0xFFFFFFFF, cb_lane_out, out_idx) * outScale;
+ float out_val = __shfl_sync(0xFFFFFFFF, cb_lane_out, out_idx, 32) * outScale;
float val = (outlier_slot >= 0) ? out_val : reg_val;
cell_out[elem] = __float2half_rn(val);
@@ -308,7 +310,7 @@ __global__ void tq_dequant_v_rotated_kernel(
if (shift + bits > 8) {
idx |= (cell_packed[byte_idx + 1] << (8 - shift)) & cb_mask;
}
- s_rotV[elem] = __shfl_sync(0xFFFFFFFF, cb_lane, idx) * scale;
+ s_rotV[elem] = __shfl_sync(0xFFFFFFFF, cb_lane, idx, 32) * scale;
__syncthreads();
// Phase 2: each thread computes one output element = dot(R[elem,:], s_rotV).
diff --git a/ggml/src/ggml-cuda/tq-encode-v.cu b/ggml/src/ggml-cuda/tq-encode-v.cu
index f9673861d..09518dccd 100644
--- a/ggml/src/ggml-cuda/tq-encode-v.cu
+++ b/ggml/src/ggml-cuda/tq-encode-v.cu
@@ -1,5 +1,4 @@
#include "tq-encode-v.cuh"
-#include <cuda_fp16.h>
#include <math.h>
#define TQ_ENCODE_V_BLOCK_SIZE 128
diff --git a/ggml/src/ggml-cuda/tq-encode.cu b/ggml/src/ggml-cuda/tq-encode.cu
index 12e7f7ba7..9c9eafa39 100644
--- a/ggml/src/ggml-cuda/tq-encode.cu
+++ b/ggml/src/ggml-cuda/tq-encode.cu
@@ -1,5 +1,4 @@
#include "tq-encode.cuh"
-#include <cuda_fp16.h>
#include <math.h>
#define TQ_ENCODE_BLOCK_SIZE 128

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,667 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Verrilli <msv@pobox.com>
Date: Tue, 21 Apr 2026 21:02:21 +0000
Subject: [PATCH] ml/backend/ggml: add D=256 TurboQuant flash-attention kernels
on Metal
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The fused inline-decode kernels were only instantiated at D=128. gemma3
runs at headDim=256 and was therefore forced onto the DequantKV + stock
FA path — which materialises a ~3.5 GB f16 intermediate K+V tensor at
32k context and OOMs on Apple Silicon with tq3.
Adds D=256 variants of the two fused kernels:
* kernel_tq_fattn_vec_f16_d256
* kernel_tq_fattn_vec_packed_d256
Thread layout is unchanged (32x4 = 128 threads); each thread now covers
32 D-positions in the Q and K loops, accumulates four V-passes, and
writes two output elements in the final phase. Shared memory grows
from 2048 to 4096 floats (16 KiB, well under the 32 KiB per-threadgroup
limit).
The dispatch picks D=128 or D=256 based on Q->ne[0]. Go-side eligibility
extends to headDim==256 only when preferFusedAttention is true (Metal);
CUDA continues to gate on headDim==128 until the CUDA kernel gains the
same instantiation.
---
ggml/src/ggml-metal/ggml-metal-device.cpp | 6 +-
ggml/src/ggml-metal/ggml-metal-device.h | 6 +-
ggml/src/ggml-metal/ggml-metal-ops.cpp | 11 +-
ggml/src/ggml-metal/ggml-metal.metal | 562 ++++++++++++++++++++++
4 files changed, 579 insertions(+), 6 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index 2686d2d30..c99fafe00 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -1720,5 +1720,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_v"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_outlier"); }
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16_d256"); }
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed_d256"); }
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
index aadd82659..fb45cbbfd 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.h
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
@@ -194,8 +194,10 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib);
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib);
// MTLResidencySet wrapper
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index 7abc0eaac..b5ab1c14e 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -4602,9 +4602,16 @@ int ggml_metal_op_tq_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb31 =*/ mask ? mask->nb[1] : 0,
};
+ // Select D=128 vs D=256 pipeline. Gemma3 runs at headDim=256; everything
+ // else supported so far is D=128.
+ GGML_ASSERT(D == 128 || D == 256);
auto pipeline = v_packed
- ? ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib)
- : ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib);
+ ? (D == 256
+ ? ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(lib)
+ : ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib))
+ : (D == 256
+ ? ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256(lib)
+ : ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib));
ggml_metal_buffer_id bid_mask = hasMask ? ggml_metal_get_buffer_id(mask) : ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_v_scales = v_packed ? ggml_metal_get_buffer_id(op->src[6]) : ggml_metal_get_buffer_id(op);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index d0bc6bf9e..2718e8bb1 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -11184,6 +11184,293 @@ kernel void kernel_tq_fattn_vec_f16(
}
}
+// ─────────────────────────────────────────────────────────────────────────────
+// kernel_tq_fattn_vec_f16_d256
+// TQ fused flash-attention at head dim 256: K packed i8, V f16.
+// Thread layout identical to the D=128 variant (32×4 = 128 threads) but each
+// thread now produces 2 output elements (D/nthreads = 2) and covers twice as
+// many D positions in the Q/K/V loops.
+// Grid: (ntiles_x, 1, nHeadsQ*nSeq)
+// ─────────────────────────────────────────────────────────────────────────────
+kernel void kernel_tq_fattn_vec_f16_d256(
+ constant ggml_metal_kargs_tq_fattn_vec & args,
+ device const char * Q_data [[buffer(1)]],
+ device const uint8_t * K_packed [[buffer(2)]],
+ device const half * V_data [[buffer(3)]],
+ device const half * mask_data [[buffer(4)]],
+ device const float * K_scales [[buffer(5)]],
+ device const float * K_cb [[buffer(6)]],
+ device const float * dummy_vs [[buffer(7)]],
+ device const float * dummy_vc [[buffer(8)]],
+ device float * dst [[buffer(9)]],
+ uint3 tgpig [[threadgroup_position_in_grid]],
+ uint tiisg [[thread_index_in_simdgroup]],
+ uint sgitg [[simdgroup_index_in_threadgroup]])
+{
+ constexpr int D = 256;
+ constexpr int nthreads = 128;
+ constexpr int nthreads_KQ = 8;
+ constexpr int nthreads_V = 8;
+ constexpr int V_cols_per_iter = 4;
+ constexpr int nwarps = 4;
+
+ const int ic0 = (int)tgpig.x * args.ncols;
+ const int blk_z = (int)tgpig.z;
+ const int sequence = blk_z / args.nHeadsQ;
+ const int head = blk_z % args.nHeadsQ;
+ const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
+ const int head_kv = head / gqa_ratio;
+
+ const int tid = (int)sgitg * 32 + (int)tiisg;
+
+ device const float * Q = (device const float *)Q_data
+ + (long)sequence * (args.nb03 / sizeof(float))
+ + (long)head * (args.nb02 / sizeof(float))
+ + (long)ic0 * (args.nb01 / sizeof(float));
+
+ device const uint8_t * K_p = K_packed
+ + (long)args.firstCell * args.nKVHeads * args.packedBytes
+ + (long)head_kv * args.packedBytes;
+ device const float * K_sc = K_scales
+ + (long)args.firstCell * args.nKVHeads + head_kv;
+
+ device const half * V = V_data
+ + (long)sequence * (args.nb23 / sizeof(half))
+ + (long)head_kv * (args.nb22 / sizeof(half));
+
+ device const half * maskh = args.hasMask
+ ? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
+ : nullptr;
+
+ const int k_cb_mask = (1 << args.bits) - 1;
+ const float k_cb_lane = K_cb[tiisg & k_cb_mask];
+
+ const int tid_kq = (int)tiisg % nthreads_KQ;
+
+ // D=256: Q_reg holds 16 float2 per thread per query slot (D/(2*nthreads_KQ)).
+ float2 Q_reg[2][16];
+ for (int j = 0; j < args.ncols; j++) {
+ device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
+ for (int i = 0; i < 16; i++) {
+ const int elem = tid_kq * 16 + i; // float2 index within [0, 127]
+ Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
+ }
+ for (int i = 0; i < 16; i++) {
+ Q_reg[j][i].x *= args.scale;
+ Q_reg[j][i].y *= args.scale;
+ }
+ }
+
+ // D=256: VKQ holds 4 passes × 4 float2 = 16 float2 per query slot.
+ float2 VKQ[2][16];
+ for (int j = 0; j < 2; j++)
+ for (int i = 0; i < 16; i++)
+ VKQ[j][i] = float2(0.0f, 0.0f);
+
+ float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
+ float KQ_sum[2] = { 0.0f, 0.0f };
+
+ // D=256: KQ_tg sized nwarps*V_cols_per_iter*D = 4*4*256 = 4096 floats (16 KiB).
+ threadgroup float KQ_tg[4096];
+ threadgroup float KQ_max_tg[2][32];
+ threadgroup float KQ_sum_tg[2][32];
+
+ for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
+
+ float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
+
+ for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
+ const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
+ const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
+ const int cell_rel = k_VKQ_0 + i_KQ;
+ const bool in_range = (cell_rel < args.nCells);
+
+ for (int j = 0; j < args.ncols; j++) {
+ device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
+ const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
+
+ // D=256: 16 k-iterations × 2 elements each = 32 D-positions per thread.
+ float sum = 0.0f;
+ for (int k = 0; k < 16; k++) {
+ const int start_elem = tid_kq * 32 + k * 2; // float index [0..254]
+ float k_dec[2];
+ if (args.bits == 3) {
+ const int bit_pos0 = start_elem * 3;
+ const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
+ const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
+ int idx0 = (int)((w0 >> sh0) & 7);
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)idx0) * rms_scale;
+ const int bit_pos1 = (start_elem + 1) * 3;
+ const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
+ const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
+ int idx1 = (int)((w1 >> sh1) & 7);
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)idx1) * rms_scale;
+ } else {
+ const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
+ const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
+ }
+ sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
+ }
+ sum += simd_shuffle_xor(sum, 4);
+ sum += simd_shuffle_xor(sum, 2);
+ sum += simd_shuffle_xor(sum, 1);
+
+ if (args.logit_softcap != 0.0f) {
+ sum = args.logit_softcap * tanh(sum);
+ }
+
+ if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
+ sum += float(maskh[(long)j * args.ne31 + i_KQ]);
+ }
+
+ if (!in_range) sum = -FLT_MAX/2.0f;
+
+ KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
+
+ if (tid_kq == (uint)i_KQ_0) {
+ KQ_tg[j * nthreads + tid] = sum;
+ }
+ }
+ }
+
+ for (int j = 0; j < args.ncols; j++) {
+ KQ_max_new[j] = simd_max(KQ_max_new[j]);
+
+ const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
+ KQ_max[j] = KQ_max_new[j];
+
+ const float kq_val = KQ_tg[j * nthreads + tid];
+ const float kq_exp = exp(kq_val - KQ_max[j]);
+ KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
+ KQ_tg[j * nthreads + tid] = kq_exp;
+
+ for (int i = 0; i < 16; i++) {
+ VKQ[j][i].x *= KQ_max_scale;
+ VKQ[j][i].y *= KQ_max_scale;
+ }
+ }
+
+ // D=256: 4 passes × 64 V-elements each = full 256-element coverage.
+ // pass 0 → V[ 0.. 63] → VKQ[ 0.. 3]
+ // pass 1 → V[ 64..127] → VKQ[ 4.. 7]
+ // pass 2 → V[128..191] → VKQ[ 8..11]
+ // pass 3 → V[192..255] → VKQ[12..15]
+ for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
+ const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
+ const int cell_rel = k_VKQ_0 + k;
+
+ float KQ_k[2];
+ for (int j = 0; j < args.ncols; j++) {
+ KQ_k[j] = KQ_tg[j * nthreads + k];
+ }
+
+ device const half * V_cell = (cell_rel < args.nCells)
+ ? V + (long)cell_rel * (args.nb21 / sizeof(half))
+ : nullptr;
+
+ const int v_tid = (int)tiisg % nthreads_V;
+ for (int pass = 0; pass < 4; pass++) {
+ for (int i = 0; i < 8; i++) {
+ const int elem = pass * 64 + v_tid * 8 + i;
+ float v_val = (V_cell && elem < D) ? float(V_cell[elem]) : 0.0f;
+ const int vkq_idx = pass * 4 + i / 2;
+ for (int j = 0; j < args.ncols; j++) {
+ if (i % 2 == 0) VKQ[j][vkq_idx].x += v_val * KQ_k[j];
+ else VKQ[j][vkq_idx].y += v_val * KQ_k[j];
+ }
+ }
+ }
+ }
+ } // end KV loop
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (sgitg == 0) {
+ KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
+ KQ_sum_tg[j][tiisg] = 0.0f;
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (tiisg == 0) {
+ KQ_max_tg[j][sgitg] = KQ_max[j];
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
+
+ float kqmax_new = KQ_max_tg[j][tiisg];
+ kqmax_new = simd_max(kqmax_new);
+ const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
+ KQ_max[j] = kqmax_new;
+
+ for (int i = 0; i < 16; i++) {
+ VKQ[j][i].x *= kqmax_scale;
+ VKQ[j][i].y *= kqmax_scale;
+ }
+
+ // D=256: VKQ_tg layout per (sgitg,v-group) region of D/2 = 128 float2:
+ // VKQ[ 0.. 3] → VKQ_tg[v_tid*4 + 0..3] (pass 0, float2 slots 0..31)
+ // VKQ[ 4.. 7] → VKQ_tg[32 + v_tid*4 + 0..3] (pass 1, float2 slots 32..63)
+ // VKQ[ 8..11] → VKQ_tg[64 + v_tid*4 + 0..3] (pass 2, float2 slots 64..95)
+ // VKQ[12..15] → VKQ_tg[96 + v_tid*4 + 0..3] (pass 3, float2 slots 96..127)
+ const int v_tid = (int)tiisg % nthreads_V;
+ threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
+ + (long)sgitg * (V_cols_per_iter * D/2)
+ + (long)((int)tiisg / nthreads_V) * (D/2);
+ VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
+ VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
+ VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
+ VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
+ VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
+ VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
+ VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
+ VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
+ VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
+ VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
+ VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
+ VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
+ VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
+ VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
+ VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
+ VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
+
+ KQ_sum[j] *= kqmax_scale;
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
+ if (tiisg == 0) {
+ KQ_sum_tg[j][sgitg] = KQ_sum[j];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // D=256: each thread writes 2 output positions (tid, tid+128). Compute
+ // KQ_sum with ALL lanes participating in the simd_sum (required for
+ // convergence), then two writes per thread.
+ KQ_sum[j] = KQ_sum_tg[j][tiisg];
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
+
+ const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
+ for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
+ const int out_elem = out_offset + tid;
+ float dst_val = 0.0f;
+ for (int w = 0; w < nwarps; w++) {
+ for (int v = 0; v < V_cols_per_iter; v++) {
+ dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
+ }
+ }
+ dst_val /= KQ_sum[j];
+ dst[out_idx * D + out_elem] = dst_val;
+ }
+
+ if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+}
+
// ─────────────────────────────────────────────────────────────────────────────
// kernel_tq_fattn_vec_packed
// TQ fused flash-attention: K packed i8, V packed i8.
@@ -11451,3 +11738,278 @@ kernel void kernel_tq_fattn_vec_packed(
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
+
+// ─────────────────────────────────────────────────────────────────────────────
+// kernel_tq_fattn_vec_packed_d256
+// TQ fused flash-attention at head dim 256: K packed i8, V packed i8.
+// Mirrors kernel_tq_fattn_vec_packed with 4 V-passes and 2 outputs per thread.
+// ─────────────────────────────────────────────────────────────────────────────
+kernel void kernel_tq_fattn_vec_packed_d256(
+ constant ggml_metal_kargs_tq_fattn_vec & args,
+ device const char * Q_data [[buffer(1)]],
+ device const uint8_t * K_packed [[buffer(2)]],
+ device const uint8_t * V_packed [[buffer(3)]],
+ device const half * mask_data [[buffer(4)]],
+ device const float * K_scales [[buffer(5)]],
+ device const float * K_cb [[buffer(6)]],
+ device const float * V_scales [[buffer(7)]],
+ device const float * V_cb [[buffer(8)]],
+ device float * dst [[buffer(9)]],
+ uint3 tgpig [[threadgroup_position_in_grid]],
+ uint tiisg [[thread_index_in_simdgroup]],
+ uint sgitg [[simdgroup_index_in_threadgroup]])
+{
+ constexpr int D = 256;
+ constexpr int nthreads = 128;
+ constexpr int nthreads_KQ = 8;
+ constexpr int nthreads_V = 8;
+ constexpr int V_cols_per_iter = 4;
+ constexpr int nwarps = 4;
+
+ const int ic0 = (int)tgpig.x * args.ncols;
+ const int blk_z = (int)tgpig.z;
+ const int sequence = blk_z / args.nHeadsQ;
+ const int head = blk_z % args.nHeadsQ;
+ const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
+ const int head_kv = head / gqa_ratio;
+
+ const int tid = (int)sgitg * 32 + (int)tiisg;
+
+ device const float * Q = (device const float *)Q_data
+ + (long)sequence * (args.nb03 / sizeof(float))
+ + (long)head * (args.nb02 / sizeof(float))
+ + (long)ic0 * (args.nb01 / sizeof(float));
+
+ device const uint8_t * K_p = K_packed
+ + (long)args.firstCell * args.nKVHeads * args.packedBytes
+ + (long)head_kv * args.packedBytes;
+ device const float * K_sc = K_scales
+ + (long)args.firstCell * args.nKVHeads + head_kv;
+
+ device const uint8_t * V_p = V_packed
+ + (long)args.firstCell * args.nKVHeads * args.v_packedBytes
+ + (long)head_kv * args.v_packedBytes;
+ device const float * V_sc = V_scales
+ + (long)args.firstCell * args.nKVHeads + head_kv;
+
+ device const half * maskh = args.hasMask
+ ? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
+ : nullptr;
+
+ const int k_cb_mask = (1 << args.bits) - 1;
+ const float k_cb_lane = K_cb[tiisg & k_cb_mask];
+ const int v_cb_mask = (1 << args.v_bits) - 1;
+ const float v_cb_lane = V_cb[tiisg & v_cb_mask];
+
+ const int tid_kq = (int)tiisg % nthreads_KQ;
+
+ float2 Q_reg[2][16];
+ for (int j = 0; j < args.ncols; j++) {
+ device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
+ for (int i = 0; i < 16; i++) {
+ const int elem = tid_kq * 16 + i;
+ Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
+ }
+ for (int i = 0; i < 16; i++) {
+ Q_reg[j][i].x *= args.scale;
+ Q_reg[j][i].y *= args.scale;
+ }
+ }
+
+ float2 VKQ[2][16];
+ for (int j = 0; j < 2; j++)
+ for (int i = 0; i < 16; i++)
+ VKQ[j][i] = float2(0.0f, 0.0f);
+
+ float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
+ float KQ_sum[2] = { 0.0f, 0.0f };
+
+ threadgroup float KQ_tg[4096];
+ threadgroup float KQ_max_tg[2][32];
+ threadgroup float KQ_sum_tg[2][32];
+
+ for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
+
+ float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
+
+ for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
+ const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
+ const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
+ const int cell_rel = k_VKQ_0 + i_KQ;
+ const bool in_range = (cell_rel < args.nCells);
+
+ for (int j = 0; j < args.ncols; j++) {
+ device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
+ const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
+
+ float sum = 0.0f;
+ for (int k = 0; k < 16; k++) {
+ const int start_elem = tid_kq * 32 + k * 2;
+ float k_dec[2];
+ if (args.bits == 3) {
+ const int bit_pos0 = start_elem * 3;
+ const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
+ const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((w0 >> sh0) & 7)) * rms_scale;
+ const int bit_pos1 = (start_elem + 1) * 3;
+ const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
+ const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((w1 >> sh1) & 7)) * rms_scale;
+ } else {
+ const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
+ const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
+ }
+ sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
+ }
+ sum += simd_shuffle_xor(sum, 4);
+ sum += simd_shuffle_xor(sum, 2);
+ sum += simd_shuffle_xor(sum, 1);
+
+ if (args.logit_softcap != 0.0f) {
+ sum = args.logit_softcap * tanh(sum);
+ }
+ if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
+ sum += float(maskh[(long)j * args.ne31 + i_KQ]);
+ }
+ if (!in_range) sum = -FLT_MAX/2.0f;
+
+ KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
+
+ if (tid_kq == (uint)i_KQ_0) {
+ KQ_tg[j * nthreads + tid] = sum;
+ }
+ }
+ }
+
+ for (int j = 0; j < args.ncols; j++) {
+ KQ_max_new[j] = simd_max(KQ_max_new[j]);
+
+ const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
+ KQ_max[j] = KQ_max_new[j];
+
+ const float kq_val = KQ_tg[j * nthreads + tid];
+ const float kq_exp = exp(kq_val - KQ_max[j]);
+ KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
+ KQ_tg[j * nthreads + tid] = kq_exp;
+
+ for (int i = 0; i < 16; i++) {
+ VKQ[j][i].x *= KQ_max_scale;
+ VKQ[j][i].y *= KQ_max_scale;
+ }
+ }
+
+ // D=256: 4 passes cover V[0..63], V[64..127], V[128..191], V[192..255].
+ for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
+ const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
+ const int cell_rel = k_VKQ_0 + k;
+ const bool in_range_v = (cell_rel < args.nCells);
+
+ float KQ_k[2];
+ for (int j = 0; j < args.ncols; j++) {
+ KQ_k[j] = KQ_tg[j * nthreads + k];
+ }
+
+ device const uint8_t * v_row = in_range_v
+ ? V_p + (long)cell_rel * args.nKVHeads * args.v_packedBytes
+ : nullptr;
+ const float v_rms = in_range_v ? V_sc[cell_rel * args.nKVHeads] : 0.0f;
+
+ const int v_tid = (int)tiisg % nthreads_V;
+ for (int pass = 0; pass < 4; pass++) {
+ float v_dec[8];
+ const int start_elem = pass * 64 + v_tid * 8;
+ if (v_row && start_elem < D) {
+ tq_decode_8_shfl(v_row, v_cb_lane, v_rms, start_elem, args.v_bits, v_dec);
+ } else {
+ tq_decode_8_shfl(K_p, v_cb_lane, 0.0f, 0, args.v_bits, v_dec);
+ }
+ for (int i = 0; i < 8; i++) {
+ const int vkq_idx = pass * 4 + i / 2;
+ for (int j = 0; j < args.ncols; j++) {
+ if (i % 2 == 0) VKQ[j][vkq_idx].x += v_dec[i] * KQ_k[j];
+ else VKQ[j][vkq_idx].y += v_dec[i] * KQ_k[j];
+ }
+ }
+ }
+ }
+ } // end KV loop
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (sgitg == 0) {
+ KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
+ KQ_sum_tg[j][tiisg] = 0.0f;
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (tiisg == 0) {
+ KQ_max_tg[j][sgitg] = KQ_max[j];
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
+
+ float kqmax_new = KQ_max_tg[j][tiisg];
+ kqmax_new = simd_max(kqmax_new);
+ const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
+ KQ_max[j] = kqmax_new;
+
+ for (int i = 0; i < 16; i++) {
+ VKQ[j][i].x *= kqmax_scale;
+ VKQ[j][i].y *= kqmax_scale;
+ }
+
+ const int v_tid = (int)tiisg % nthreads_V;
+ threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
+ + (long)sgitg * (V_cols_per_iter * D/2)
+ + (long)((int)tiisg / nthreads_V) * (D/2);
+ VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
+ VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
+ VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
+ VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
+ VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
+ VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
+ VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
+ VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
+ VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
+ VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
+ VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
+ VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
+ VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
+ VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
+ VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
+ VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
+
+ KQ_sum[j] *= kqmax_scale;
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
+ if (tiisg == 0) {
+ KQ_sum_tg[j][sgitg] = KQ_sum[j];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ KQ_sum[j] = KQ_sum_tg[j][tiisg];
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
+
+ const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
+ for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
+ const int out_elem = out_offset + tid;
+ float dst_val = 0.0f;
+ for (int w = 0; w < nwarps; w++) {
+ for (int v = 0; v < V_cols_per_iter; v++) {
+ dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
+ }
+ }
+ dst_val /= KQ_sum[j];
+ dst[out_idx * D + out_elem] = dst_val;
+ }
+
+ if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+}

View file

@ -0,0 +1,135 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Verrilli <msv@pobox.com>
Date: Tue, 21 Apr 2026 21:02:22 +0000
Subject: [PATCH] ml/backend/ggml: optimize the Metal TurboQuant dequant kernel
Two independent improvements to kernel_tq_dequant.
Dispatch: the kernel uses only [[thread_index_in_simdgroup]] (tiisg,
0..31 per SIMDgroup), has no sgitg stride, no threadgroup barriers, and
no atomics. It was nonetheless dispatched with 128-thread threadgroups
(four SIMDgroups x 32), so all four SIMDgroups ran the outer loop
identically and wrote the same f16 bytes four times. Drop non-outlier
dispatches to 32-thread threadgroups. The outlier kernel still
dispatches at 128 - it uses s_mask atomics and a popcount reduction
that legitimately need the full threadgroup.
Inner loop: replace the 1-element-per-iteration scalar path with a
4-elements-per-iteration vectorised path that issues a single half4
store per iteration. For bits=2 the 4 elements fit in one byte; for
bits=3 they fit in a 16-bit window (shift0 in {0,4}). The per-cell
scale is pre-multiplied into the codebook lane at kernel entry so the
decode path drops one fmul per element. The scalar fallback is
preserved for head dims that aren't a multiple of 128.
Decode throughput on llama3.2:3b tq3 at 32k context improves ~9% on
Apple Silicon (~42 -> ~46 tok/s); the K-only DequantK path used by
tq2k/tq3k benefits from the same kernel.
---
ggml/src/ggml-metal/ggml-metal-ops.cpp | 15 ++++++++---
ggml/src/ggml-metal/ggml-metal.metal | 37 +++++++++++++++++++++++---
2 files changed, 45 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index b5ab1c14e..ea3580b0b 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -4197,7 +4197,12 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
- const int block_size = std::min(128, headDim);
+ // Outlier kernel uses a 128-thread TG: threadgroup barriers + atomics on
+ // s_mask require all threads. Non-outlier kernel uses a single simdgroup
+ // (32 threads): it only reads tiisg and has no barriers, so a larger TG
+ // just replicates work across idle simdgroups.
+ const int outlier_block_size = 128;
+ const int nonoutlier_block_size = 32;
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
const int regular_count = headDim - outlierCount;
@@ -4234,7 +4239,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
- ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
+ ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, outlier_block_size, 1, 1);
return 1;
}
@@ -4259,7 +4264,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); // codebook
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
- ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
+ ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, nonoutlier_block_size, 1, 1);
return 1;
}
@@ -4283,7 +4288,9 @@ int ggml_metal_op_tq_dequant_kv(ggml_metal_op_t ctx, int idx) {
const int k_codebook_len = (int)op->src[2]->ne[0];
const int v_codebook_len = (int)op->src[5]->ne[0];
- const int block_size = std::min(128, headDim);
+ // kernel_tq_dequant is single-simdgroup (uses only tiisg, no barriers,
+ // no atomics) — 32-thread TGs eliminate 4× redundant work vs 128-thread.
+ const int block_size = 32;
const size_t plane_size = (size_t)headDim * numKVHeads * nCells;
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 2718e8bb1..b61f755aa 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -10246,8 +10246,40 @@ kernel void kernel_tq_dequant(
const int cb_mask = (1 << args.bits) - 1;
// Load one codebook entry per lane; period ≤ 8 divides 32, so simd_shuffle is exact.
- const float cb_lane = codebook[tiisg & cb_mask];
+ // Pre-multiply by the per-cell scale so the decode path drops 1 fmul per element.
+ const float scaled_cb_lane = codebook[tiisg & cb_mask] * scale;
+
+ // Fast path: when headDim is a multiple of 128 (=32 lanes × 4 elements per thread),
+ // each thread decodes 4 consecutive D-positions per iter and writes a single half4.
+ // bits=2: 4 elems = 8 bits, always byte-aligned (shift0=0 since elem_base mod 4 == 0).
+ // bits=3: 4 elems = 12 bits, shift0 ∈ {0,4}, always fits in a 16-bit window.
+ // A 16-bit window (2 packed bytes) suffices for both. The scalar fallback
+ // covers non-multiple-of-128 head dims.
+ if ((args.headDim & 127) == 0) {
+ const int iters = args.headDim >> 7;
+ for (int iter = 0; iter < iters; iter++) {
+ const int elem_base = iter * 128 + (int)tiisg * 4;
+ const int bit_offset = elem_base * args.bits;
+ const int byte_base = bit_offset >> 3;
+ const int shift0 = bit_offset & 7;
+
+ uint w = (uint)cell_packed[byte_base];
+ if (args.bits == 3) {
+ w |= ((uint)cell_packed[byte_base + 1] << 8);
+ }
+
+ half4 v4;
+ v4[0] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> shift0 ) & cb_mask)));
+ v4[1] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + args.bits)) & cb_mask)));
+ v4[2] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 2 * args.bits)) & cb_mask)));
+ v4[3] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 3 * args.bits)) & cb_mask)));
+
+ *((device half4 *)(cell_out + elem_base)) = v4;
+ }
+ return;
+ }
+ // Scalar fallback for head dims that aren't a multiple of 128.
for (uint elem = tiisg; elem < (uint)args.headDim; elem += 32) {
const int bit_offset = (int)elem * args.bits;
const int byte_idx = bit_offset >> 3;
@@ -10256,8 +10288,7 @@ kernel void kernel_tq_dequant(
if (shift + args.bits > 8) {
idx |= ((int)(cell_packed[byte_idx + 1] << (8 - shift))) & cb_mask;
}
- const float val = simd_shuffle(cb_lane, (ushort)idx) * scale;
- cell_out[elem] = half(val);
+ cell_out[elem] = half(simd_shuffle(scaled_cb_lane, (ushort)idx));
}
}

View file

@ -258,16 +258,20 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
if fa {
slog.Info("enabling flash attention")
loadRequest.FlashAttention = ml.FlashAttentionEnabled
}
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
// Most quantized KV cache types require flash attention, but TurboQuant
// K-only presets (tq2k/tq3k) dequant to f16 before attention and work
// with either FA or the standard softmax+matmul attention path.
if kvct != "" {
switch {
case f.KVCacheTypeRequiresFlashAttention(kvct) && !fa:
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
case !f.SupportsKVCacheType(kvct):
slog.Warn("kv cache type not supported by model", "type", kvct)
default:
loadRequest.KvCacheType = kvct
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
}
}

View file

@ -365,7 +365,7 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
sb.WriteString("[")
defer func() { sb.WriteString("]") }()
for i := 0; i < dims[0]; i++ {
for i := range dims[0] {
if i >= items && i < dims[0]-items {
sb.WriteString("..., ")
// skip to next printable element
@ -408,9 +408,82 @@ const (
DTypeQ80
DTypeQ40
DTypeI32
DTypeI8
DTypeMXFP4
DTypeTQ2
DTypeTQ3
DTypeTQ3K
DTypeTQ2K
)
// TQCompressedKManager manages GPU-resident packed N-bit key indices for
// TurboQuant VRAM compression. Implemented by the ggml backend.
type TQCompressedKManager interface {
// EnsureLayer allocates per-layer GPU tensors (packed + scales) on first use.
// capacity = total cache cell count for this layer.
EnsureLayer(layer, capacity int)
// EncodeK creates a GGML_OP_TQ_ENCODE graph node that encodes key vectors
// into the persistent compressed buffer. Returns a view of the packed buffer;
// pass this result to DequantK to establish the graph ordering dependency.
// key: [headDim, numKVHeads, batchSize] f16
// firstCell: first cache slot index; cells are sequential (firstCell+0, +1, ...)
EncodeK(ctx Context, layer int, key Tensor, firstCell int) Tensor
// DequantK creates a GGML_OP_TQ_DEQUANT graph node returning
// [headDim, numKVHeads, nCells] f16 ready for flash attention.
// encodeResult is the tensor returned by EncodeK for this layer+step,
// establishing encode→dequant ordering in the ggml graph.
DequantK(ctx Context, layer int, encodeResult Tensor, firstCell, nCells int) Tensor
// GetAsTQTensor returns a tqTensor wrapping the packed K buffer for the
// fused TQ flash-attention path. Returns (nil, false) when the config is
// not supported by the fused kernel — callers must fall back to DequantK in that case.
GetAsTQTensor(ctx Context, layer int, encodeResult Tensor, firstCell, nCells int) (Tensor, bool)
// GetAsTQTensorKV returns a tqTensor wrapping both packed K and packed V
// buffers for the fully fused K+V TQ flash-attention path. Returns (nil, false)
// when fused is not supported or V compression is not active.
GetAsTQTensorKV(ctx Context, layer int, kEncodeResult, vEncodeResult Tensor, firstCell, nCells int) (Tensor, bool)
// RotationMatrix returns the persistent [headDim, headDim] f32 R^T tensor.
RotationMatrix(ctx Context, layer int) Tensor
// EnsureVLayer allocates per-layer V packed and scales tensors on first use.
EnsureVLayer(layer, capacity int)
// EncodeV creates a GGML_OP_TQ_ENCODE_V graph node encoding value vectors
// into the persistent compressed V buffer. Returns a view of the V packed buffer.
EncodeV(ctx Context, layer int, value Tensor, firstCell int) Tensor
// DequantV creates a GGML_OP_TQ_DEQUANT graph node returning
// [headDim, numKVHeads, nCells] f16 for V, ready for flash attention.
DequantV(ctx Context, layer int, encodeResult Tensor, firstCell, nCells int) Tensor
// DequantKV creates a single GGML_OP_TQ_DEQUANT_KV graph node that dequants
// both K and V in one op, halving GGML scheduler overhead.
// Returns (key, value) as [headDim, numKVHeads, nCells] f16 views.
DequantKV(ctx Context, layer int, kEncodeResult, vEncodeResult Tensor, firstCell, nCells int) (Tensor, Tensor)
// EncodeKV creates a single GGML_OP_TQ_ENCODE_KV graph node encoding both
// K and V, halving scheduler overhead vs separate EncodeK + EncodeV.
EncodeKV(ctx Context, layer int, key, value Tensor, firstCell int) (Tensor, Tensor)
// Close frees all GPU buffers.
Close()
}
// TQCompressedKBackend is implemented by backends that support TQ compressed K.
//
// outlierBits/outlierCount: optional post-rotation outlier split. When
// outlierCount > 0, the manager allocates additional per-layer tensors for an
// outlier sub-block at the specified bit width and the encode/dequant kernels
// route through the outlier-aware path. Set outlierCount = 0 for pure uniform
// per-channel Lloyd-Max at `bits` (historical behavior).
type TQCompressedKBackend interface {
NewTQCompressedKManager(headDim, numKVHeads, bits int, rotationSeed uint64, vBits, outlierBits, outlierCount int) TQCompressedKManager
}
type SamplingMode int
const (

View file

@ -111,6 +111,30 @@ type Backend struct {
flashAttention ml.FlashAttentionType
// tqRotationMatrix is a per-call flag set by TurboQuantCache.Get() right
// before it returns rotated K. SDPA reads and clears it; if non-nil, SDPA
// applies R^T @ Q to match the K rotation. This per-call (not sticky)
// semantics is required for mixed-head-dim models like gemma3, where only
// the global sub-cache of a WrapperCache is TQ-wrapped and the SWA sub-
// cache passes through plain f16 K. A sticky rotation would be applied to
// every SDPA call and corrupt attention on the unwrapped SWA layers.
tqRotationMatrix ml.Tensor
// tqVRotationMatrix is a per-call flag set by TurboQuantCache.Get() right
// before returning K+V (tq3/tq2) when V was encoded with Hadamard rotation
// (R^T @ v). SDPA reads and clears it; if non-nil, SDPA applies R @
// attn_out after the attention op to undo the V rotation. Per-call (not
// sticky) semantics is required for mixed-head-dim models like gemma3,
// where only the global sub-cache of a WrapperCache is TQ-wrapped and the
// SWA sub-cache's V is plain f16. A sticky rotation would corrupt
// attention on unwrapped SWA layers.
tqVRotationMatrix ml.Tensor
// tqVRotFusedInDequant is true when DequantKV applies the V rotation undo
// internally (via the rotated V kernel). When true, the stock FA path in
// SDPA skips the mulmat undo — V is already in the unrotated domain.
tqVRotFusedInDequant bool
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
maxGraphNodes int
@ -664,6 +688,97 @@ func (b *Backend) NewContext() ml.Context {
return b.NewContextSize(b.maxGraphNodes)
}
// TQDeviceScan describes the GPU devices discovered in the scheduler from the
// perspective of TurboQuant: which one TQ will use, plus the names of any GPUs
// that were skipped because they're not wave32-capable (NVIDIA < Pascal, or
// AMD wave64 Vega/GCN/CDNA). Used to emit actionable warnings and to avoid
// dispatching TQ kernels to an unsupported card — either one that would hit
// the compute-capability assert in tq-dequant.cu or one whose HIP __shfl_sync
// shim would silently produce garbage on 64-lane warps.
type TQDeviceScan struct {
// selected is the buffer type TQ will place its tensors on. Zero-valued
// if no TQ-capable GPU is present.
selected C.ggml_backend_buffer_type_t
selectedOK bool
SelectedName string // e.g. "NVIDIA Tesla P40"
SelectedCC string // e.g. "6.1"
SelectedLibrary string // e.g. "Metal", "CUDA", "ROCm"
// Accepted lists "<name> (cc X.Y)" for every TQ-capable GPU in schedBufts.
Accepted []string
// Skipped lists "<name> (cc X.Y, <library>): <reason>" for every non-host GPU
// in schedBufts that fails the wave32 gate (CUDA < Pascal, ROCm wave64, or a
// non-CUDA/ROCm backend). The reason is included so operators can diagnose
// without reading the source tree.
Skipped []string
}
// scanTQDevices walks the scheduler buffer types and classifies each GPU via
// tqDeviceAccepted: accepted GPUs (NVIDIA Pascal+, AMD RDNA1+) are eligible to
// host TQ tensors; others are skipped with a diagnosable reason. The first
// accepted buffer type is marked as selected; TQ tensors will be placed there
// regardless of which scheduler index it occupies.
func (b *Backend) scanTQDevices() TQDeviceScan {
var scan TQDeviceScan
for _, buft := range b.schedBufts {
if C.ggml_backend_buft_is_host(buft) {
continue
}
dev := C.ggml_backend_buft_get_device(buft)
if dev == nil {
continue
}
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(dev, &props)
name := C.GoString(props.name)
var library string
if props.library != nil {
library = C.GoString(props.library)
}
cc := fmt.Sprintf("%d.%d", int(props.compute_major), int(props.compute_minor))
accepted, skipReason := tqDeviceAccepted(library, int(props.compute_major))
if !accepted {
scan.Skipped = append(scan.Skipped,
fmt.Sprintf("%s (cc %s, %s): %s", name, cc, library, skipReason))
continue
}
scan.Accepted = append(scan.Accepted, fmt.Sprintf("%s (cc %s)", name, cc))
if !scan.selectedOK {
scan.selected = buft
scan.selectedOK = true
scan.SelectedName = name
scan.SelectedCC = cc
scan.SelectedLibrary = library
}
}
return scan
}
// newTQContext creates a GGML context whose tensors are allocated in GPU
// memory (CUDA, HIP, or Metal). Used by the TQ compressed KV cache manager:
// TQ encode/decode ops require their tensors (packed buffers, scales,
// codebook, rotation matrix) to reside on the GPU regardless of which model
// layers are on CPU vs GPU. TQ tensors always land on the first TQ-capable
// GPU — NVIDIA Pascal (cc 6.0)+, AMD RDNA1 (gfx1010)+, or Apple Silicon
// (Metal, always wave32) — in the scheduler. In a mixed rig, unsupported
// cards are skipped: older NVIDIA would hit the compute-capability assert in
// tq-dequant.cu, and wave64 AMD (Vega/CDNA) would silently corrupt through
// the HIP __shfl_sync shim.
func (b *Backend) newTQContext(n int) *Context {
var allocatedBuffers []C.ggml_backend_buffer_t
scan := b.scanTQDevices()
return &Context{
b: b,
ctx: C.ggml_init(C.struct_ggml_init_params{
mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
no_alloc: true,
}),
buft: scan.selected,
allocatedBuffers: &allocatedBuffers,
maxGraphNodes: n,
layer: -1,
}
}
func (b *Backend) NewContextSize(n int) ml.Context {
if n > b.maxGraphNodes {
panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
@ -683,6 +798,23 @@ func (b *Backend) NewContextSize(n int) ml.Context {
}
}
// SetTQRotationMatrix registers the TQ rotation matrix for Q rotation in SDPA.
// Called by TurboQuantCache when Phase 2 CUDA dequant activates.
func (b *Backend) SetTQRotationMatrix(m ml.Tensor) {
b.tqRotationMatrix = m
}
// SetTQVRotationMatrix registers the rotation matrix used for V encoding.
// When non-nil, SDPA applies R @ attn_out after the TQ fused flash attention
// to undo the V rotation (V was stored as R^T @ v).
func (b *Backend) SetTQVRotationMatrix(m ml.Tensor) {
b.tqVRotationMatrix = m
}
func (b *Backend) SetTQVRotFusedInDequant(fused bool) {
b.tqVRotFusedInDequant = fused
}
func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention == ml.FlashAttentionEnabled {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16}
@ -1118,6 +1250,8 @@ func (t *Tensor) DType() ml.DType {
return ml.DTypeQ40
case C.GGML_TYPE_I32:
return ml.DTypeI32
case C.GGML_TYPE_I8:
return ml.DTypeI8
case C.GGML_TYPE_MXFP4:
return ml.DTypeMXFP4
default:
@ -1137,6 +1271,8 @@ func ggmlDType(dtype ml.DType) uint32 {
return C.GGML_TYPE_Q4_0
case ml.DTypeI32:
return C.GGML_TYPE_I32
case ml.DTypeI8:
return C.GGML_TYPE_I8
case ml.DTypeMXFP4:
return C.GGML_TYPE_MXFP4
default:
@ -1367,6 +1503,191 @@ func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tenso
}
}
// TQEncode creates a GGML_OP_TQ_ENCODE graph node.
// t = packed buffer [packedBytes*numKVHeads, capacity] i8 (dst, view returned)
// scales = [numKVHeads, capacity] f32 (written as side output via src[3])
// k = [headDim, numKVHeads, batchSize] f16
// rot = [headDim, headDim] f32 (R^T row-major)
// cidx = [batchSize] i32
// bounds = [(1<<bits)-1] f32
func (t *Tensor) TQEncode(ctx ml.Context, scales, k, rot ml.Tensor, firstCell int, bounds ml.Tensor, bits int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tq_encode(
ctx.(*Context).ctx,
t.t,
scales.(*Tensor).t,
k.(*Tensor).t,
rot.(*Tensor).t,
C.int32_t(firstCell),
bounds.(*Tensor).t,
C.int32_t(bits),
),
}
}
// TQEncodeV creates a GGML_OP_TQ_ENCODE_V graph node.
// t = packed buffer [packedBytes*numKVHeads, capacity] i8 (dst, view returned)
// scales = [numKVHeads, capacity] f32 (written as side output via src[3])
// v = [headDim, numKVHeads, batchSize] f16 or f32
// rot = [headDim, headDim] f32 R^T row-major, or nil (no rotation)
// bounds = [(1<<bits)-1] f32
func (t *Tensor) TQEncodeV(ctx ml.Context, scales, v ml.Tensor, rot ml.Tensor, firstCell int, bounds ml.Tensor, bits int) ml.Tensor {
var rotT *C.struct_ggml_tensor
if rot != nil {
rotT = rot.(*Tensor).t
}
return &Tensor{
b: t.b,
t: C.ggml_tq_encode_v(
ctx.(*Context).ctx,
t.t,
scales.(*Tensor).t,
v.(*Tensor).t,
rotT,
C.int32_t(firstCell),
bounds.(*Tensor).t,
C.int32_t(bits),
),
}
}
// TQEncodeKV creates a GGML_OP_TQ_ENCODE_KV graph node encoding both K and V
// in a single GGML op. t = K packed buffer (view returned). V packed buffer
// is written as a side effect via src[5].
func (t *Tensor) TQEncodeKV(ctx ml.Context,
kScales, k, rot, kBounds ml.Tensor,
vPacked, vScales, v, vBounds ml.Tensor,
firstCell, kBits, vBits int,
) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tq_encode_kv(
ctx.(*Context).ctx,
t.t,
kScales.(*Tensor).t,
k.(*Tensor).t,
rot.(*Tensor).t,
kBounds.(*Tensor).t,
vPacked.(*Tensor).t,
vScales.(*Tensor).t,
v.(*Tensor).t,
vBounds.(*Tensor).t,
C.int32_t(firstCell),
C.int32_t(kBits),
C.int32_t(vBits),
),
}
}
// TQDequant creates a GGML_OP_TQ_DEQUANT graph node.
// t = encode result (view of packed, establishes graph dependency)
// scales = [numKVHeads, capacity] f32
// codebook = [1<<bits] f32
// Returns [headDim, numKVHeads, nCells] f16.
func (t *Tensor) TQDequant(ctx ml.Context, scales, codebook ml.Tensor, headDim, numKVHeads, nCells, firstCell, bits int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tq_dequant(
ctx.(*Context).ctx,
t.t,
scales.(*Tensor).t,
codebook.(*Tensor).t,
C.int(headDim),
C.int(numKVHeads),
C.int(nCells),
C.int(firstCell),
C.int(bits),
),
}
}
// TQEncodeOutlier creates a GGML_OP_TQ_ENCODE graph node with the outlier
// sub-block extension. op_params[3] (outlier_count) > 0 signals to the CUDA
// dispatcher that the outlier kernel should run.
func (t *Tensor) TQEncodeOutlier(ctx ml.Context, scales, k, rot ml.Tensor, firstCell int, bounds ml.Tensor, bits int,
outlierPacked, outlierScales, outlierIndices, outlierBounds ml.Tensor, outlierBits, outlierCount int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tq_encode_outlier(
ctx.(*Context).ctx,
t.t,
scales.(*Tensor).t,
k.(*Tensor).t,
rot.(*Tensor).t,
C.int32_t(firstCell),
bounds.(*Tensor).t,
C.int32_t(bits),
outlierPacked.(*Tensor).t,
outlierScales.(*Tensor).t,
outlierIndices.(*Tensor).t,
outlierBounds.(*Tensor).t,
C.int32_t(outlierBits),
C.int32_t(outlierCount),
),
}
}
// TQDequantOutlier creates a GGML_OP_TQ_DEQUANT graph node with the outlier
// overwrite pass. op_params[3] (outlier_count) > 0 signals the dispatcher.
func (t *Tensor) TQDequantOutlier(ctx ml.Context, scales, codebook ml.Tensor, headDim, numKVHeads, nCells, firstCell, bits int,
outlierPacked, outlierScales, outlierIndices, outlierCodebook ml.Tensor, outlierBits, outlierCount int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tq_dequant_outlier(
ctx.(*Context).ctx,
t.t,
scales.(*Tensor).t,
codebook.(*Tensor).t,
C.int(headDim),
C.int(numKVHeads),
C.int(nCells),
C.int(firstCell),
C.int(bits),
outlierPacked.(*Tensor).t,
outlierScales.(*Tensor).t,
outlierIndices.(*Tensor).t,
outlierCodebook.(*Tensor).t,
C.int32_t(outlierBits),
C.int32_t(outlierCount),
),
}
}
// TQDequantKV creates a GGML_OP_TQ_DEQUANT_KV graph node that dequants both
// K and V in a single GGML op. Returns a [headDim, numKVHeads, nCells, 2] f16
// tensor; the caller splits it into K (ne[3]=0) and V (ne[3]=1) views.
func TQDequantKV(ctx ml.Context, b *Backend,
kEncode, kScales, kCodebook *Tensor,
vEncode, vScales, vCodebook *Tensor,
vRotation *Tensor,
headDim, numKVHeads, nCells, firstCell, kBits, vBits int,
) *Tensor {
var vRotT *C.struct_ggml_tensor
if vRotation != nil {
vRotT = vRotation.t
}
return &Tensor{
b: b,
t: C.ggml_tq_dequant_kv(
ctx.(*Context).ctx,
kEncode.t,
kScales.t,
kCodebook.t,
vEncode.t,
vScales.t,
vCodebook.t,
vRotT,
C.int(headDim),
C.int(numKVHeads),
C.int(nCells),
C.int(firstCell),
C.int(kBits),
C.int(vBits),
),
}
}
func (t *Tensor) SetInplace(ctx ml.Context, src ml.Tensor, nb1, nb2, nb3, offset int) ml.Tensor {
return &Tensor{
b: t.b,
@ -1744,9 +2065,60 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
}
query := t.Permute(ctx, 0, 2, 1, 3)
// TQ consume-once state: TurboQuantCache.Get() sets tqRotationMatrix and
// tqVRotationMatrix right before returning rotated K/V for a TQ-wrapped
// layer. Read and clear here so the NEXT SDPA call (potentially for a
// non-TQ sub-cache of a WrapperCache, e.g. the SWA side of gemma3) doesn't
// see a stale rotation and corrupt attention on unrotated tensors. The
// nil pre-check keeps non-TQ workloads from dirtying the backend cache
// line on every attention op.
var rot, vRot ml.Tensor
if t.b.tqRotationMatrix != nil || t.b.tqVRotationMatrix != nil {
rot = t.b.tqRotationMatrix
vRot = t.b.tqVRotationMatrix
t.b.tqRotationMatrix = nil
t.b.tqVRotationMatrix = nil
}
// TQ: K is stored in rotated space (R^T @ k). Rotate Q to match so
// attention = (R^T q)^T (R^T k) = q^T k.
// rotTensor stores R^T row-major; ggml_mul_mat(rotTensor, Q) = R^T @ Q.
if rot != nil && query.Dim(0) == rot.Dim(0) {
// Make query contiguous before mul_mat; permuted (non-contiguous) tensors
// may cause incorrect results with cuBLAS batched matmul.
query = query.Contiguous(ctx)
query = rot.Mulmat(ctx, query)
}
key = key.Permute(ctx, 0, 2, 1, 3)
if t.b.flashAttention == ml.FlashAttentionEnabled {
// TQ fused flash attention: check for tqTensor BEFORE permuting value,
// because the K+V fused path passes packed V directly (no permute needed).
if tqk, ok := key.(*tqTensor); ok {
if sinks != nil || vmla != nil {
panic("ggml: TQ compressed K does not support sinks or vmla attention")
}
var attnOut ml.Tensor
if tqk.vPacked != nil {
// K+V fused: V is packed i8 inside tqTensor; pass it directly.
attnOut = t.b.tqFlashAttention(ctx, query.(*Tensor), tqk, tqk.vPacked, mask, scale, 0)
} else {
// K-only fused: V is f16, permute normally.
value = value.Permute(ctx, 0, 2, 1, 3)
attnOut = t.b.tqFlashAttention(ctx, query.(*Tensor), tqk, value.(*Tensor), mask, scale, 0)
}
// If V was encoded with Hadamard rotation (R^T @ v), the FA output is
// R^T @ output_orig. Recover output_orig by applying R.
// tqVRotationMatrix stores R (not R^T); mul_mat(R, x) = R @ x = output_orig.
// Uses the consumed vRot local from the top of SDPA.
if vRot != nil && attnOut.Dim(0) == vRot.Dim(0) {
attnOut = vRot.(*Tensor).Mulmat(ctx, attnOut)
}
return attnOut
}
value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
@ -1764,7 +2136,17 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
kqv = cur.(*Tensor).t
}
return &Tensor{b: t.b, t: kqv}
attnOut := ml.Tensor(&Tensor{b: t.b, t: kqv})
// Two-pass TQ path: if DequantKV fused the V rotation undo into the
// dequant kernel, V is already unrotated — skip the mulmat.
// Otherwise (no rotation fusion), apply R @ attn_out to undo rotation.
// Uses the consumed vRot local from the top of SDPA.
if vRot != nil && !t.b.tqVRotFusedInDequant && attnOut.Dim(0) == vRot.Dim(0) {
attnOut = vRot.(*Tensor).Mulmat(ctx, attnOut)
}
return attnOut
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{

View file

@ -567,6 +567,13 @@ extern "C" {
GGML_OP_GLU,
GGML_OP_TQ_ENCODE,
GGML_OP_TQ_DEQUANT,
GGML_OP_TQ_DEQUANT_KV,
GGML_OP_TQ_FLASH_ATTN_EXT,
GGML_OP_TQ_ENCODE_V,
GGML_OP_TQ_ENCODE_KV,
GGML_OP_COUNT,
};
@ -2714,6 +2721,154 @@ extern "C" {
GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
// TurboQuant GPU-native key encoding: f16 K + rotation → packed N-bit + scales.
// packed: [packedBytes*numKVHeads, capacity] GGML_TYPE_I8
// scales: [numKVHeads, capacity] f32 (written as side output via src[3])
// k: [headDim, numKVHeads, batchSize] f16
// rotation: [headDim, headDim] f32 (R^T row-major)
// cell_idx: [batchSize] i32
// boundaries: [(1<<bits)-1] f32
GGML_API struct ggml_tensor * ggml_tq_encode(
struct ggml_context * ctx,
struct ggml_tensor * packed,
struct ggml_tensor * scales,
struct ggml_tensor * k,
struct ggml_tensor * rotation,
int32_t firstCell,
struct ggml_tensor * boundaries,
int32_t bits);
// TurboQuant GPU-native key dequant: packed N-bit → f16 [headDim, numKVHeads, nCells].
// encode_result: view of packed buffer returned by ggml_tq_encode (establishes graph
// dependency: encode must run before dequant)
// scales: [numKVHeads, capacity] f32
// codebook: [1<<bits] f32
GGML_API struct ggml_tensor * ggml_tq_dequant(
struct ggml_context * ctx,
struct ggml_tensor * encode_result,
struct ggml_tensor * scales,
struct ggml_tensor * codebook,
int headDim,
int numKVHeads,
int nCells,
int firstCell,
int bits);
// TurboQuant GPU-native key encode with outlier split: same op family as
// ggml_tq_encode but adds top-K outlier sub-block. After rotation, the top
// outlier_count channels by absolute magnitude go into a higher-bit sub-block;
// the remaining channels go into a lower-bit regular sub-block. Matches the
// TurboQuant paper (arXiv 2504.19874, Sec 4.3) experimental setup. outlier_count
// is stored in op_params[3] so the CUDA backend dispatcher can branch on it.
// packed, scales, outlier_packed, outlier_scales, outlier_indices are written
// as side effects (persistent GPU buffers, not graph intermediates).
GGML_API struct ggml_tensor * ggml_tq_encode_outlier(
struct ggml_context * ctx,
struct ggml_tensor * packed,
struct ggml_tensor * scales,
struct ggml_tensor * k,
struct ggml_tensor * rotation,
int32_t firstCell,
struct ggml_tensor * boundaries,
int32_t bits,
struct ggml_tensor * outlier_packed,
struct ggml_tensor * outlier_scales,
struct ggml_tensor * outlier_indices,
struct ggml_tensor * outlier_boundaries,
int32_t outlier_bits,
int32_t outlier_count);
// TurboQuant GPU-native key dequant with outlier overwrite pass: reconstructs
// [headDim, numKVHeads, nCells] f16 by decoding the regular sub-block for all
// positions and then overwriting the outlier channels from the outlier
// sub-block. Paired with ggml_tq_encode_outlier.
GGML_API struct ggml_tensor * ggml_tq_dequant_outlier(
struct ggml_context * ctx,
struct ggml_tensor * encode_result,
struct ggml_tensor * scales,
struct ggml_tensor * codebook,
int headDim,
int numKVHeads,
int nCells,
int firstCell,
int bits,
struct ggml_tensor * outlier_packed,
struct ggml_tensor * outlier_scales,
struct ggml_tensor * outlier_indices,
struct ggml_tensor * outlier_codebook,
int32_t outlier_bits,
int32_t outlier_count);
// Combined K+V dequant: single op dequants both K and V packed buffers.
// Output: [headDim, numKVHeads, nCells, 2] f16 where ne[3]=0 is K, ne[3]=1 is V.
// Halves scheduler overhead vs separate DequantK + DequantV.
GGML_API struct ggml_tensor * ggml_tq_dequant_kv(
struct ggml_context * ctx,
struct ggml_tensor * k_encode_result,
struct ggml_tensor * k_scales,
struct ggml_tensor * k_codebook,
struct ggml_tensor * v_encode_result,
struct ggml_tensor * v_scales,
struct ggml_tensor * v_codebook,
struct ggml_tensor * v_rotation, // R [headDim, headDim] f32 — undo V rotation during dequant; NULL = no rotation
int headDim,
int numKVHeads,
int nCells,
int firstCell,
int k_bits,
int v_bits);
GGML_API struct ggml_tensor * ggml_tq_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k_packed,
struct ggml_tensor * v, // f16 when v_scales==NULL; packed i8 when v_scales!=NULL
struct ggml_tensor * mask,
struct ggml_tensor * scales, // K scales [nKVHeads, capacity] f32
struct ggml_tensor * codebook, // K codebook [1<<bits] f32
float scale, float logit_softcap,
int32_t bits, int32_t firstCell,
struct ggml_tensor * v_scales, // NULL → V is f16; non-NULL → V is packed i8
struct ggml_tensor * v_codebook, // V codebook (required when v_scales != NULL)
int32_t v_bits); // V bit width (ignored when v_scales == NULL)
// TurboQuant GPU-native value encode: V → RMS scale → Lloyd-Max quantize → packed bits.
// Like ggml_tq_encode but for V. rotation may be NULL (no rotation) or a [headDim, headDim]
// R^T matrix to apply before quantization. Using the same rotation as K spreads outlier
// energy evenly so the Lloyd-Max codebook applies correctly.
// packed: [packedBytes*numKVHeads, capacity] i8 (destination buffer, view returned)
// scales: [numKVHeads, capacity] f32 (scale per head per cell)
// v: [headDim, numKVHeads, batchSize] f16 or f32
// rotation: [headDim, headDim] f32 R^T row-major, or NULL
// boundaries: [(1<<bits)-1] f32
GGML_API struct ggml_tensor * ggml_tq_encode_v(
struct ggml_context * ctx,
struct ggml_tensor * packed,
struct ggml_tensor * scales,
struct ggml_tensor * v,
struct ggml_tensor * rotation,
int32_t firstCell,
struct ggml_tensor * boundaries,
int32_t bits);
// Combined K+V encode: single GGML op encoding both K and V, halving scheduler overhead.
// k_packed: [packedBytes*numKVHeads, capacity] i8 (K destination, view returned)
// v_packed: [packedBytes*numKVHeads, capacity] i8 (V destination, side-effect write)
GGML_API struct ggml_tensor * ggml_tq_encode_kv(
struct ggml_context * ctx,
struct ggml_tensor * k_packed,
struct ggml_tensor * k_scales,
struct ggml_tensor * k,
struct ggml_tensor * rotation,
struct ggml_tensor * k_boundaries,
struct ggml_tensor * v_packed,
struct ggml_tensor * v_scales,
struct ggml_tensor * v,
struct ggml_tensor * v_boundaries,
int32_t firstCell,
int32_t k_bits,
int32_t v_bits);
#ifdef __cplusplus
}
#endif

View file

@ -692,7 +692,12 @@ static bool ggml_is_view_op(enum ggml_op op) {
#endif
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
#define GGML_SCHED_MAX_SPLIT_INPUTS 30
// Increased from 30 to 128 to support TurboQuant K+V compression on large MoE
// models (e.g. qwen3-coder:30b/48-layer) where per-layer TQ encode ops and
// MoE expert routing create more cross-backend split inputs than the original
// limit allows. Upstream GGML has a FIXME here: the check only fires when the
// split is exactly full, so multi-input ops can overshoot the limit.
#define GGML_SCHED_MAX_SPLIT_INPUTS 128
#endif
#ifndef GGML_SCHED_MAX_COPIES

View file

@ -2080,6 +2080,20 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
// nop
} break;
case GGML_OP_TQ_ENCODE:
case GGML_OP_TQ_DEQUANT:
case GGML_OP_TQ_DEQUANT_KV:
case GGML_OP_TQ_FLASH_ATTN_EXT:
case GGML_OP_TQ_ENCODE_V:
case GGML_OP_TQ_ENCODE_KV:
{
// CUDA-only ops. If these reach CPU, it means the scheduler
// incorrectly assigned them to CPU — abort to diagnose.
fprintf(stderr, "[FATAL] TQ op %s reached CPU backend! This is a scheduler bug.\n",
ggml_op_name(tensor->op));
fflush(stderr);
GGML_ABORT("TQ op reached CPU backend");
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error");
@ -2401,6 +2415,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{
n_tasks = 1;
} break;
case GGML_OP_TQ_ENCODE:
case GGML_OP_TQ_DEQUANT:
case GGML_OP_TQ_DEQUANT_KV:
case GGML_OP_TQ_FLASH_ATTN_EXT:
case GGML_OP_TQ_ENCODE_V:
case GGML_OP_TQ_ENCODE_KV:
{
n_tasks = 1; // CUDA-only; handled as no-op in compute_forward
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error");

View file

@ -426,7 +426,16 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
}
}
// TQ ops are CUDA-only; CPU has no implementation.
// Returning false here prevents the scheduler from routing them to CPU in pass 3.
switch (op->op) {
case GGML_OP_TQ_ENCODE:
case GGML_OP_TQ_DEQUANT:
case GGML_OP_TQ_DEQUANT_KV:
case GGML_OP_TQ_FLASH_ATTN_EXT:
case GGML_OP_TQ_ENCODE_V:
case GGML_OP_TQ_ENCODE_KV:
return false;
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
return

View file

@ -58,6 +58,10 @@
#include "ggml-cuda/tri.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/fill.cuh"
#include "ggml-cuda/tq-encode.cuh"
#include "ggml-cuda/tq-dequant.cuh"
#include "ggml-cuda/tq-fattn.cuh"
#include "ggml-cuda/tq-encode-v.cuh"
#include "ggml.h"
#include <algorithm>
@ -2872,6 +2876,24 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_FILL:
ggml_cuda_op_fill(ctx, dst);
break;
case GGML_OP_TQ_ENCODE:
ggml_cuda_tq_encode(ctx, dst);
break;
case GGML_OP_TQ_DEQUANT:
ggml_cuda_tq_dequant(ctx, dst);
break;
case GGML_OP_TQ_DEQUANT_KV:
ggml_cuda_tq_dequant_kv(ctx, dst);
break;
case GGML_OP_TQ_FLASH_ATTN_EXT:
ggml_cuda_tq_flash_attn_ext(ctx, dst);
break;
case GGML_OP_TQ_ENCODE_V:
ggml_cuda_tq_encode_v(ctx, dst);
break;
case GGML_OP_TQ_ENCODE_KV:
ggml_cuda_tq_encode_kv(ctx, dst);
break;
default:
return false;
}
@ -4910,6 +4932,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_SOLVE_TRI:
case GGML_OP_TQ_ENCODE:
case GGML_OP_TQ_DEQUANT:
case GGML_OP_TQ_DEQUANT_KV:
case GGML_OP_TQ_FLASH_ATTN_EXT:
case GGML_OP_TQ_ENCODE_V:
case GGML_OP_TQ_ENCODE_KV:
return true;
default:

View file

@ -0,0 +1,394 @@
#include "tq-dequant.cuh"
// Optimized TQ dequant kernel: warp-shuffle codebook + hardcoded bit extraction.
//
// Grid: (nCells, numKVHeads). Block: 128 threads.
// For D=128 and 128 threads: each thread decodes exactly 1 element.
// Codebook lookup via __shfl_sync eliminates global memory reads.
// Output is written as f16.
//
// This kernel is the "separate dequant" path — paired with the stock f16
// flash attention kernel, it avoids injecting decode ALU into the
// bandwidth-bound FA loop.
#if __CUDA_ARCH__ >= 600 || !defined(__CUDA_ARCH__)
__global__ void tq_dequant_multihead_kernel(
const uint8_t *packed, // [(firstCell+c)*numKVHeads+h]*packed_bytes
const float *scales, // [(firstCell+c)*numKVHeads+h]
const float *codebook, // [codebook_len]
uint16_t *output, // [nCells * numKVHeads * headDim] f16
int headDim,
int numKVHeads,
int bits,
int packed_bytes,
int codebook_len,
int firstCell
) {
int c = blockIdx.x; // cell index within [0, nCells)
int h = blockIdx.y; // head index within [0, numKVHeads)
int cell = firstCell + c;
int slot = cell * numKVHeads + h;
float scale = scales[slot];
const uint8_t *cell_packed = packed + (size_t)slot * packed_bytes;
__half *cell_out = (__half *)(output + ((size_t)c * numKVHeads + h) * headDim);
// Load one codebook entry per lane for warp-shuffle lookup.
// For 3-bit (8 entries): lanes 0-7 hold codebook[0-7], repeated every 8 lanes.
// For 2-bit (4 entries): lanes 0-3 hold codebook[0-3], repeated.
const int cb_mask = (1 << bits) - 1;
const float cb_lane = codebook[threadIdx.x & cb_mask];
for (int elem = threadIdx.x; elem < headDim; elem += blockDim.x) {
// Generic bit extraction (handles any alignment).
const int bit_offset = elem * bits;
const int byte_idx = bit_offset >> 3;
const int shift = bit_offset & 7;
int idx = (cell_packed[byte_idx] >> shift) & cb_mask;
if (shift + bits > 8) {
idx |= (cell_packed[byte_idx + 1] << (8 - shift)) & cb_mask;
}
// Codebook lookup via warp shuffle: zero global memory latency.
// Width = 32 works because cb_lane is periodic with period (1<<bits).
// Pass width explicitly — the HIP __shfl_sync shim is a 4-arg macro
// that doesn't default, and CUDA's 3-arg overload is width=warpSize=32
// anyway, so this is behavior-neutral on NVIDIA.
float val = __shfl_sync(0xFFFFFFFF, cb_lane, idx, 32) * scale;
cell_out[elem] = __float2half_rn(val);
}
}
#else
// Stub for sm < 600 (no __shfl_sync). Never executed — the launcher
// asserts compute capability >= 6.0. Only exists so the kernel launch
// in ggml_cuda_tq_dequant compiles for Maxwell targets.
__global__ void tq_dequant_multihead_kernel(
const uint8_t *, const float *, const float *, uint16_t *,
int, int, int, int, int, int) {}
#endif
// ── outlier-split dequant ──────────────────────────────────────────────────
// Paired with tq_encode_kernel_outlier. Reads the regular packed sub-block and
// the outlier packed sub-block from two separate per-layer tensors, decodes
// each with its own codebook/scale, and writes a single [headDim] f16 K
// vector per (cell, head) to the output tensor.
//
// The outlier_indices tensor holds the top-K channel positions the encoder
// selected. For each output position elem, the kernel scans outlier_indices
// to determine whether elem is an outlier (and if so, which outlier slot k)
// or a regular channel (and if so, its contiguous regular slot r = elem minus
// the number of outlier indices less than elem). This scan is O(outlier_count)
// per thread, which is cheap: at outlier_count=32 it's 32 compares per thread,
// all in registers.
#if __CUDA_ARCH__ >= 600 || !defined(__CUDA_ARCH__)
__global__ void tq_dequant_multihead_kernel_outlier(
const uint8_t *reg_packed,
const float *reg_scales,
const float *reg_codebook,
const uint8_t *out_packed,
const float *out_scales,
const uint8_t *out_indices,
const float *out_codebook,
uint16_t *output,
int headDim,
int numKVHeads,
int bits,
int reg_packed_bytes,
int outlier_bits,
int outlier_count,
int out_packed_bytes,
int firstCell
) {
int c = blockIdx.x;
int h = blockIdx.y;
int cell = firstCell + c;
int slot = cell * numKVHeads + h;
float regScale = reg_scales[slot];
float outScale = out_scales[slot];
const uint8_t *cell_reg = reg_packed + (size_t)slot * reg_packed_bytes;
const uint8_t *cell_outl = out_packed + (size_t)slot * out_packed_bytes;
const uint8_t *cell_idx = out_indices + (size_t)slot * outlier_count;
__half *cell_out = (__half *)(output + ((size_t)c * numKVHeads + h) * headDim);
// Shared memory layout:
// s_outl_slot[headDim] int8_t — outlier slot k or -1 if not outlier
// s_mask[headDim / 32 (min 4)] uint32 — outlier bitmap (bit i = channel i outlier)
//
// Per-element `regular_slot` is computed at decode time via popcount over the
// mask bits below `elem` — O(1) for headDim <= 256 (up to 8 popc ops). This
// replaces the per-element O(outlier_count) classification scan.
// Setup is fully parallel: no serial prefix sum, no idle threads.
const int mask_words = (headDim + 31) >> 5; // e.g. 4 for headDim=128, 8 for 256
extern __shared__ char s_mem_dq[];
int8_t *s_outl_slot = (int8_t *)s_mem_dq;
uint32_t *s_mask = (uint32_t *)(s_outl_slot + headDim);
// Step A: init s_outl_slot to -1 and s_mask to 0 (parallel over all threads).
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
s_outl_slot[i] = -1;
}
for (int w = threadIdx.x; w < mask_words; w += blockDim.x) {
s_mask[w] = 0u;
}
__syncthreads();
// Step B: threads 0..outlier_count-1 each register one outlier — write its
// slot into s_outl_slot AND set its bit in s_mask. atomicOr needed because
// two outliers can land in the same word.
if ((int)threadIdx.x < outlier_count) {
int pos = (int)cell_idx[threadIdx.x];
s_outl_slot[pos] = (int8_t)threadIdx.x;
atomicOr(&s_mask[pos >> 5], 1u << (pos & 31));
}
__syncthreads();
// Warp-shuffle register-resident codebooks.
const int cb_mask = (1 << bits) - 1;
const int ocb_mask = (1 << outlier_bits) - 1;
const float cb_lane_reg = reg_codebook[threadIdx.x & cb_mask];
const float cb_lane_out = out_codebook[threadIdx.x & ocb_mask];
for (int elem = threadIdx.x; elem < headDim; elem += blockDim.x) {
int outlier_slot = (int)s_outl_slot[elem];
// regular_slot = elem - popcount(outlier_mask bits below elem).
// Sum popcount of fully-covered 32-bit chunks, then partial chunk.
int outliers_below = 0;
const int full_words = elem >> 5;
#pragma unroll
for (int w = 0; w < 8; w++) { // hard-unrolled; loop body is a no-op past mask_words
if (w < full_words && w < mask_words) {
outliers_below += __popc(s_mask[w]);
}
}
if (full_words < mask_words) {
uint32_t partial_bits = (1u << (elem & 31)) - 1u;
outliers_below += __popc(s_mask[full_words] & partial_bits);
}
int regular_slot = elem - outliers_below;
// Both sub-block decodes must execute unconditionally because
// __shfl_sync with mask 0xFFFFFFFF requires every lane in the warp
// to be at the same instruction. Putting a shuffle inside a divergent
// if-branch is undefined behavior (observed as all-zero decodes for
// some (cell, head) blocks on multi-head models). Compute both values
// up front, then select.
int reg_bit_offset = regular_slot * bits;
int reg_byte_idx = reg_bit_offset >> 3;
int reg_shift = reg_bit_offset & 7;
int reg_idx = (cell_reg[reg_byte_idx] >> reg_shift) & cb_mask;
if (reg_shift + bits > 8) {
reg_idx |= (cell_reg[reg_byte_idx + 1] << (8 - reg_shift)) & cb_mask;
}
float reg_val = __shfl_sync(0xFFFFFFFF, cb_lane_reg, reg_idx, 32) * regScale;
int out_slot_safe = (outlier_slot >= 0) ? outlier_slot : 0;
int out_bit_offset = out_slot_safe * outlier_bits;
int out_byte_idx = out_bit_offset >> 3;
int out_shift = out_bit_offset & 7;
int out_idx = (cell_outl[out_byte_idx] >> out_shift) & ocb_mask;
if (out_shift + outlier_bits > 8) {
out_idx |= (cell_outl[out_byte_idx + 1] << (8 - out_shift)) & ocb_mask;
}
float out_val = __shfl_sync(0xFFFFFFFF, cb_lane_out, out_idx, 32) * outScale;
float val = (outlier_slot >= 0) ? out_val : reg_val;
cell_out[elem] = __float2half_rn(val);
}
}
#else
__global__ void tq_dequant_multihead_kernel_outlier(
const uint8_t *, const float *, const float *,
const uint8_t *, const float *, const uint8_t *, const float *,
uint16_t *, int, int, int, int, int, int, int, int) {}
#endif
void ggml_cuda_tq_dequant(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
GGML_ASSERT(ggml_cuda_info().devices[ctx.device].cc >= 600 &&
"TurboQuant dequant requires compute capability 6.0+ (Pascal or newer)");
const struct ggml_tensor * encode_result = dst->src[0]; // view of packed
const struct ggml_tensor * scales = dst->src[1];
const struct ggml_tensor * codebook = dst->src[2];
const int headDim = (int)dst->ne[0];
const int numKVHeads = (int)dst->ne[1];
const int nCells = (int)dst->ne[2];
const int bits = (int)((const int32_t *)dst->op_params)[0];
const int firstCell = (int)((const int32_t *)dst->op_params)[1];
const int outlierBits = (int)((const int32_t *)dst->op_params)[2];
const int outlierCount = (int)((const int32_t *)dst->op_params)[3];
dim3 grid(nCells, numKVHeads);
int block_size = 128;
if (headDim < block_size) block_size = headDim;
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
// Outlier-split dequant: read both sub-blocks.
const struct ggml_tensor * outlier_packed = dst->src[3];
const struct ggml_tensor * outlier_scales = dst->src[4];
const struct ggml_tensor * outlier_indices = dst->src[5];
const struct ggml_tensor * outlier_codebook = dst->src[6];
const int regular_count = headDim - outlierCount;
// Per-head stride for both packed tensors is padded up to a 4-byte
// multiple so atomicOr-on-word stays aligned in the encode kernel.
// The Go-side ggmlTQCompressedK.regularPackedBytes() applies the
// same padding, so the kernel-visible layout matches the allocator.
const int reg_packed_raw = (regular_count * bits + 7) / 8;
const int reg_packed_bytes = (reg_packed_raw + 3) & ~3;
const int out_packed_raw = (outlierCount * outlierBits + 7) / 8;
const int out_packed_bytes = (out_packed_raw + 3) & ~3;
// Shared memory: s_outl_slot (headDim * int8) + s_mask (ceil(headDim/32) * u32).
const int mask_words = (headDim + 31) >> 5;
size_t smem = (size_t)headDim * sizeof(int8_t)
+ (size_t)mask_words * sizeof(uint32_t);
tq_dequant_multihead_kernel_outlier<<<grid, block_size, smem, ctx.stream()>>>(
(const uint8_t *)encode_result->data,
(const float *)scales->data,
(const float *)codebook->data,
(const uint8_t *)outlier_packed->data,
(const float *)outlier_scales->data,
(const uint8_t *)outlier_indices->data,
(const float *)outlier_codebook->data,
(uint16_t *)dst->data,
headDim, numKVHeads, bits, reg_packed_bytes,
outlierBits, outlierCount, out_packed_bytes, firstCell
);
return;
}
const int packed_bytes = (headDim * bits + 7) / 8;
const int codebook_len = (int)codebook->ne[0];
tq_dequant_multihead_kernel<<<grid, block_size, 0, ctx.stream()>>>(
(const uint8_t *)encode_result->data,
(const float *)scales->data,
(const float *)codebook->data,
(uint16_t *)dst->data,
headDim, numKVHeads, bits, packed_bytes, codebook_len, firstCell
);
}
// V dequant with fused rotation undo: decode packed V to shared memory
// (rotated domain), then multiply by R [headDim × headDim] to produce
// unrotated f16 output. Eliminates the per-layer mulmat op in SDPA.
//
// Grid: (nCells, numKVHeads). Block: 128 (= headDim for D=128).
// Shared memory: headDim floats = 512 bytes.
#if __CUDA_ARCH__ >= 600 || !defined(__CUDA_ARCH__)
__global__ void tq_dequant_v_rotated_kernel(
const uint8_t * __restrict__ packed,
const float * __restrict__ scales,
const float * __restrict__ codebook,
const float * __restrict__ rotation, // R [headDim, headDim] row-major
uint16_t * __restrict__ output,
int headDim, int numKVHeads, int bits, int packed_bytes,
int codebook_len, int firstCell)
{
extern __shared__ float s_rotV[]; // headDim floats
int c = blockIdx.x;
int h = blockIdx.y;
int cell = firstCell + c;
int slot = cell * numKVHeads + h;
float scale = scales[slot];
const uint8_t *cell_packed = packed + (size_t)slot * packed_bytes;
// Phase 1: decode one element per thread into shared memory (rotated domain).
const int cb_mask = (1 << bits) - 1;
const float cb_lane = codebook[threadIdx.x & cb_mask];
int elem = threadIdx.x;
int bit_offset = elem * bits;
int byte_idx = bit_offset >> 3;
int shift = bit_offset & 7;
int idx = (cell_packed[byte_idx] >> shift) & cb_mask;
if (shift + bits > 8) {
idx |= (cell_packed[byte_idx + 1] << (8 - shift)) & cb_mask;
}
s_rotV[elem] = __shfl_sync(0xFFFFFFFF, cb_lane, idx, 32) * scale;
__syncthreads();
// Phase 2: each thread computes one output element = dot(R[elem,:], s_rotV).
// R is in L2 (64KB, fits in P40's 3MB L2; read-only, broadcast across blocks).
const float *R_row = rotation + elem * headDim;
float sum = 0.0f;
for (int j = 0; j < headDim; j++) {
sum += R_row[j] * s_rotV[j];
}
__half *cell_out = (__half *)(output + ((size_t)c * numKVHeads + h) * headDim);
cell_out[elem] = __float2half_rn(sum);
}
#else
// Stub for sm < 600 (no __shfl_sync). Not currently launched, but kept
// so future call sites compile for Maxwell targets without special casing.
__global__ void tq_dequant_v_rotated_kernel(
const uint8_t * __restrict__, const float * __restrict__,
const float * __restrict__, const float * __restrict__,
uint16_t * __restrict__,
int, int, int, int, int, int) {}
#endif
// Combined K+V dequant: two back-to-back kernel launches in a single GGML op.
// Output: [headDim, numKVHeads, nCells, 2] f16 — K at ne[3]=0, V at ne[3]=1.
// When src[6] (v_rotation) is non-NULL, V is dequanted with rotation fused.
void ggml_cuda_tq_dequant_kv(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
GGML_ASSERT(ggml_cuda_info().devices[ctx.device].cc >= 600 &&
"TurboQuant dequant requires compute capability 6.0+ (Pascal or newer)");
const struct ggml_tensor * k_encode = dst->src[0];
const struct ggml_tensor * k_scales = dst->src[1];
const struct ggml_tensor * k_cb = dst->src[2];
const struct ggml_tensor * v_encode = dst->src[3];
const struct ggml_tensor * v_scales = dst->src[4];
const struct ggml_tensor * v_cb = dst->src[5];
const struct ggml_tensor * v_rotation = dst->src[6]; // NULL = no rotation fusion
const int headDim = (int)dst->ne[0];
const int numKVHeads = (int)dst->ne[1];
const int nCells = (int)dst->ne[2];
int32_t k_bits, v_bits, firstCell;
memcpy(&k_bits, (const int32_t *)dst->op_params + 0, sizeof(int32_t));
memcpy(&v_bits, (const int32_t *)dst->op_params + 1, sizeof(int32_t));
memcpy(&firstCell, (const int32_t *)dst->op_params + 2, sizeof(int32_t));
const int k_packed_bytes = (headDim * k_bits + 7) / 8;
const int v_packed_bytes = (headDim * v_bits + 7) / 8;
dim3 grid(nCells, numKVHeads);
int block_size = 128;
if (headDim < block_size) {
block_size = headDim;
}
const size_t plane_size = (size_t)headDim * numKVHeads * nCells;
uint16_t * out_base = (uint16_t *)dst->data;
cudaStream_t stream = ctx.stream();
// K dequant → first plane (offset 0) — always unrotated
tq_dequant_multihead_kernel<<<grid, block_size, 0, stream>>>(
(const uint8_t *)k_encode->data,
(const float *)k_scales->data,
(const float *)k_cb->data,
out_base,
headDim, numKVHeads, k_bits, k_packed_bytes, (int)k_cb->ne[0], firstCell
);
// V dequant → second plane (offset plane_size). Plain dequant only — the
// rotation undo (R @ attn_out) is handled by SDPA via mulmat, which is
// dramatically faster than the per-cell matmul the fused kernel did.
(void)v_rotation;
tq_dequant_multihead_kernel<<<grid, block_size, 0, stream>>>(
(const uint8_t *)v_encode->data,
(const float *)v_scales->data,
(const float *)v_cb->data,
out_base + plane_size,
headDim, numKVHeads, v_bits, v_packed_bytes, (int)v_cb->ne[0], firstCell
);
}

View file

@ -0,0 +1,6 @@
#pragma once
#include "common.cuh"
void ggml_cuda_tq_dequant(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst);
void ggml_cuda_tq_dequant_kv(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst);

View file

@ -0,0 +1,154 @@
#include "tq-encode-v.cuh"
#include <math.h>
#define TQ_ENCODE_V_BLOCK_SIZE 128
__global__ void tq_encode_v_kernel(
const void *v,
const float *rotation, // [headDim, headDim] R^T row-major, or NULL for no rotation
uint8_t *packed_out,
float *scales_out,
int firstCell,
const float *boundaries,
int headDim,
int numKVHeads,
int bits,
int numBoundaries,
int vIsF32
) {
int batch = blockIdx.x;
int head = blockIdx.y;
int cell = firstCell + batch;
extern __shared__ char s_mem[];
float *s_v = (float *)s_mem;
float *s_rot = s_v + headDim;
float *s_reduce = s_rot + headDim;
uint8_t *s_idx = (uint8_t *)(s_reduce + blockDim.x);
// Step 1: Load V[batch, head] into shared memory as f32
int base_v = batch * numKVHeads * headDim + head * headDim;
for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
if (vIsF32) {
s_v[d] = ((const float *)v)[base_v + d];
} else {
s_v[d] = __half2float(__ushort_as_half(((const uint16_t *)v)[base_v + d]));
}
}
__syncthreads();
// Step 2: Apply Hadamard rotation if provided (R^T @ v spreads outlier energy evenly)
float *s_input = s_v;
if (rotation != NULL) {
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
float sum = 0.0f;
for (int j = 0; j < headDim; j++) {
sum += rotation[i * headDim + j] * s_v[j];
}
s_rot[i] = sum;
}
__syncthreads();
s_input = s_rot;
}
// Step 3: RMS scale = sqrt(mean(v^2))
float local_sq = 0.0f;
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
local_sq += s_input[i] * s_input[i];
}
s_reduce[threadIdx.x] = local_sq;
__syncthreads();
for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
if (threadIdx.x < stride)
s_reduce[threadIdx.x] += s_reduce[threadIdx.x + stride];
__syncthreads();
}
float scale = 0.0f;
if (threadIdx.x == 0) {
float sum_sq = s_reduce[0];
if (sum_sq > 1e-12f)
scale = sqrtf(sum_sq / (float)headDim);
scales_out[cell * numKVHeads + head] = scale;
s_reduce[0] = scale;
}
__syncthreads();
scale = s_reduce[0];
// Step 4: Quantize via boundary binary search
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
float val = (scale > 0.0f) ? (s_input[i] / scale) : 0.0f;
int idx = 0;
for (int b = 0; b < numBoundaries; b++) {
if (val >= boundaries[b]) idx++;
}
s_idx[i] = (uint8_t)idx;
}
__syncthreads();
// Step 5: Pack bits LSB-first
{
int packed_bytes = (headDim * bits + 7) / 8;
uint8_t *out = packed_out + (cell * numKVHeads + head) * packed_bytes;
uint8_t bitmask = (uint8_t)((1 << bits) - 1);
for (int p = threadIdx.x; p < packed_bytes; p += blockDim.x) {
out[p] = 0;
}
__syncthreads();
for (int elem = threadIdx.x; elem < headDim; elem += blockDim.x) {
int bit_offset = elem * bits;
int byte_idx = bit_offset >> 3;
int shift = bit_offset & 7;
uint8_t val = s_idx[elem] & bitmask;
atomicOr((unsigned int *)(out + (byte_idx & ~3)),
(unsigned int)(val << shift) << ((byte_idx & 3) * 8));
if (shift + bits > 8) {
int byte_idx2 = byte_idx + 1;
atomicOr((unsigned int *)(out + (byte_idx2 & ~3)),
(unsigned int)(val >> (8 - shift)) << ((byte_idx2 & 3) * 8));
}
}
}
}
void ggml_cuda_tq_encode_v(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
const struct ggml_tensor * v = dst->src[0];
const struct ggml_tensor * rotation = dst->src[1]; // NULL when no rotation
const struct ggml_tensor * scales = dst->src[3];
const struct ggml_tensor * boundaries = dst->src[4];
const int headDim = (int)v->ne[0];
const int numKVHeads = (int)v->ne[1];
const int batchSize = (int)v->ne[2];
const int bits = (int)((const int32_t *)dst->op_params)[0];
const int firstCell = (int)((const int32_t *)dst->op_params)[1];
const int numBoundaries = (1 << bits) - 1;
const int vIsF32 = (v->type == GGML_TYPE_F32) ? 1 : 0;
const float * rotation_ptr = rotation ? (const float *)rotation->data : nullptr;
dim3 grid(batchSize, numKVHeads);
int block_size = (headDim < TQ_ENCODE_V_BLOCK_SIZE) ? headDim : TQ_ENCODE_V_BLOCK_SIZE;
int bs = 1;
while (bs < block_size) bs <<= 1;
block_size = bs;
size_t smem = (size_t)headDim * 2 * sizeof(float) // s_v + s_rot
+ (size_t)block_size * sizeof(float)
+ (size_t)headDim * sizeof(uint8_t);
cudaStream_t stream = ctx.stream();
tq_encode_v_kernel<<<grid, block_size, smem, stream>>>(
v->data,
rotation_ptr,
(uint8_t *)dst->data,
(float *)scales->data,
firstCell,
(const float *)boundaries->data,
headDim, numKVHeads, bits, numBoundaries, vIsF32
);
}

View file

@ -0,0 +1,5 @@
#pragma once
#include "common.cuh"
void ggml_cuda_tq_encode_v(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst);

View file

@ -0,0 +1,531 @@
#include "tq-encode.cuh"
#include <math.h>
#define TQ_ENCODE_BLOCK_SIZE 128
// ── kernel ──────────────────────────────────────────────────────────────────
__global__ void tq_encode_kernel(
const void *k, // f16 or f32, ggml layout [headDim, numKVHeads, batchSize]
const float *rotation, // [headDim, headDim] R^T stored row-major
uint8_t *packed_out, // [(c*numKVHeads+h)*packedBytes] interleaved
float *scales_out, // [c*numKVHeads+h] interleaved
int firstCell, // first cache cell index (cell = firstCell + batch)
const float *boundaries, // [numLevels-1]
int headDim,
int numKVHeads,
int bits,
int numBoundaries, // = (1<<bits) - 1
int kIsF32 // non-zero when k is float32 (vs float16)
) {
int batch = blockIdx.x;
int head = blockIdx.y;
int cell = firstCell + batch;
// Shared memory layout:
// s_k[headDim] K values as f32
// s_rot[headDim] rotated values
// s_reduce[BLOCK] warp reduction scratch
// s_idx[headDim] quantized indices (uint8)
extern __shared__ char s_mem[];
float *s_k = (float *)s_mem;
float *s_rot = s_k + headDim;
float *s_reduce = s_rot + headDim;
uint8_t *s_idx = (uint8_t *)(s_reduce + blockDim.x);
// ── Step 1: Load K[batch, head] into shared memory as f32 ────────────────
// K layout: element (dim=d, head=h, batch=b) at b*numKVHeads*headDim + h*headDim + d
int base_k = batch * numKVHeads * headDim + head * headDim;
for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
if (kIsF32) {
s_k[d] = ((const float *)k)[base_k + d];
} else {
s_k[d] = __half2float(__ushort_as_half(((const uint16_t *)k)[base_k + d]));
}
}
__syncthreads();
// ── Step 2: Rotation matmul: rotated[i] = Σ_j rotation[i*headDim+j] * s_k[j] ──
// rotation stores Q^T row-major (rotTensor[i][j] = Q^T[i][j]).
// This computes rotated = Q^T @ k.
// Q is also rotated as Q^T @ q (via ggml_mul_mat(rotTensor, q)),
// so attention = (Q^T q)^T (Q^T k) = q^T Q Q^T k = q^T k.
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
float sum = 0.0f;
for (int j = 0; j < headDim; j++) {
sum += rotation[i * headDim + j] * s_k[j];
}
s_rot[i] = sum;
}
__syncthreads();
// ── Step 3: RMS scale = sqrt(mean(rotated²)) ─────────────────────────────
float local_sq = 0.0f;
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
local_sq += s_rot[i] * s_rot[i];
}
s_reduce[threadIdx.x] = local_sq;
__syncthreads();
for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
if (threadIdx.x < stride)
s_reduce[threadIdx.x] += s_reduce[threadIdx.x + stride];
__syncthreads();
}
float scale = 0.0f;
if (threadIdx.x == 0) {
float sum_sq = s_reduce[0];
if (sum_sq > 1e-12f)
scale = sqrtf(sum_sq / (float)headDim);
scales_out[cell * numKVHeads + head] = scale;
s_reduce[0] = scale; // broadcast via shared mem
}
__syncthreads();
scale = s_reduce[0];
// ── Step 4: Quantize each element via boundary binary search ─────────────
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
float v = (scale > 0.0f) ? (s_rot[i] / scale) : 0.0f;
int idx = 0;
for (int b = 0; b < numBoundaries; b++) {
if (v >= boundaries[b]) idx++;
}
s_idx[i] = (uint8_t)idx;
}
__syncthreads();
// ── Step 5: Pack bits LSB-first into output (parallel via atomicOr) ─────
{
int packed_bytes = (headDim * bits + 7) / 8;
uint8_t *out = packed_out + (cell * numKVHeads + head) * packed_bytes;
uint8_t bitmask = (uint8_t)((1 << bits) - 1);
// Zero output buffer in parallel.
for (int p = threadIdx.x; p < packed_bytes; p += blockDim.x) {
out[p] = 0;
}
__syncthreads();
// Each thread packs its own elements using atomicOr on 4-byte aligned words.
for (int elem = threadIdx.x; elem < headDim; elem += blockDim.x) {
int bit_offset = elem * bits;
int byte_idx = bit_offset >> 3;
int shift = bit_offset & 7;
uint8_t v = s_idx[elem] & bitmask;
// Pack into 4-byte aligned word: position byte within 32-bit word.
atomicOr((unsigned int *)(out + (byte_idx & ~3)),
(unsigned int)(v << shift) << ((byte_idx & 3) * 8));
if (shift + bits > 8) {
int byte_idx2 = byte_idx + 1;
atomicOr((unsigned int *)(out + (byte_idx2 & ~3)),
(unsigned int)(v >> (8 - shift)) << ((byte_idx2 & 3) * 8));
}
}
}
}
// ── outlier-split kernel ─────────────────────────────────────────────────────
//
// Implements the TurboQuant paper's actual experimental configuration
// (arXiv 2504.19874 Sec 4.3): split channels into a top-K outlier set and a
// regular set, each encoded with its own RMS scale and codebook at different
// bit widths. Outlier selection happens in ROTATED space (single rotation
// matmul shared between both sub-blocks), not in the original space like the
// CPU reference — this is the only paper-realistic algorithm we can run
// cheaply on the GPU without a second rotation pass. For the near-orthogonal
// rotations produced by QR on a Gaussian matrix, the top-K in rotated space
// is a close proxy for the top-K in original space.
__global__ void tq_encode_kernel_outlier(
const void *k, // f16 or f32, [headDim, numKVHeads, batchSize]
const float *rotation, // [headDim, headDim] R^T row-major
uint8_t *packed_out, // regular packed [regularPackedBytes*numKVHeads*cells]
float *scales_out, // regular scales [numKVHeads*cells]
uint8_t *outlier_packed, // outlier packed [outlierPackedBytes*numKVHeads*cells]
float *outlier_scales, // outlier scales [numKVHeads*cells]
uint8_t *outlier_indices, // outlier channel idx [outlierCount*numKVHeads*cells] (interpreted as uint8 so positions 0..255 fit)
int firstCell,
const float *boundaries, // regular boundaries [(1<<bits)-1]
const float *outlier_boundaries, // outlier boundaries [(1<<outlierBits)-1]
int headDim,
int numKVHeads,
int bits,
int numBoundaries,
int outlierBits,
int outlierCount,
int numOutlierBoundaries,
int kIsF32
) {
int batch = blockIdx.x;
int head = blockIdx.y;
int cell = firstCell + batch;
// Shared memory layout (laid out contiguously — launch passes total size):
// s_k[headDim] - f32 K values
// s_rot[headDim] - f32 rotated values
// s_reduce[blockDim.x] - reduction scratch (float)
// s_idx[headDim] - regular quantized indices (uint8)
// s_is_outlier[headDim] - 1 = outlier, 0 = regular (uint8)
// s_reg_pos[headDim] - regular position index map (int, up to headDim entries)
// s_outl_pos[outlierCount] - outlier channel positions (int)
// s_outl_val[outlierCount] - outlier rotated values (float)
// s_outl_idx[outlierCount] - outlier quantized indices (uint8)
extern __shared__ char s_mem[];
float *s_k = (float *)s_mem;
float *s_rot = s_k + headDim;
float *s_reduce = s_rot + headDim;
uint8_t *s_idx = (uint8_t *)(s_reduce + blockDim.x);
uint8_t *s_is_outlier = s_idx + headDim;
int *s_reg_pos = (int *)(s_is_outlier + headDim);
int *s_outl_pos = s_reg_pos + headDim;
float *s_outl_val = (float *)(s_outl_pos + outlierCount);
uint8_t *s_outl_idx = (uint8_t *)(s_outl_val + outlierCount);
// Step 1: Load K into s_k as f32.
int base_k = batch * numKVHeads * headDim + head * headDim;
for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
if (kIsF32) {
s_k[d] = ((const float *)k)[base_k + d];
} else {
s_k[d] = __half2float(__ushort_as_half(((const uint16_t *)k)[base_k + d]));
}
}
__syncthreads();
// Step 2: Rotate: s_rot[i] = sum_j rotation[i*headDim+j] * s_k[j] = (R^T @ k)[i].
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
float sum = 0.0f;
for (int j = 0; j < headDim; j++) {
sum += rotation[i * headDim + j] * s_k[j];
}
s_rot[i] = sum;
}
__syncthreads();
// Step 3: Clear outlier mask.
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
s_is_outlier[i] = 0;
}
__syncthreads();
// Step 4: Top-K outlier selection (serial on thread 0). O(K * headDim)
// comparisons. At outlierCount=32, headDim=128 that's ~4k ops on one
// thread — negligible next to the rotation matmul (16k ops) that ran
// in parallel above. Also builds s_reg_pos as the ordered list of
// regular (non-outlier) channel positions, and s_outl_pos for outliers.
if (threadIdx.x == 0) {
for (int r = 0; r < outlierCount; r++) {
float best_val = -1.0f;
int best_idx = 0;
for (int i = 0; i < headDim; i++) {
if (s_is_outlier[i]) continue;
float a = fabsf(s_rot[i]);
if (a > best_val) {
best_val = a;
best_idx = i;
}
}
s_is_outlier[best_idx] = 1;
s_outl_pos[r] = best_idx;
s_outl_val[r] = s_rot[best_idx];
}
// Build regular position map after outlier selection is complete.
int pos = 0;
for (int i = 0; i < headDim; i++) {
if (!s_is_outlier[i]) {
s_reg_pos[pos++] = i;
}
}
}
__syncthreads();
int regularCount = headDim - outlierCount;
// Step 5: Per-sub-block RMS scales via parallel reduction.
// Each thread accumulates its dims' squared contributions into two
// locals; we reduce regular and outlier sums back to back using
// s_reduce as scratch.
float local_sq_reg = 0.0f;
float local_sq_out = 0.0f;
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
float v = s_rot[i];
float sq = v * v;
if (s_is_outlier[i]) {
local_sq_out += sq;
} else {
local_sq_reg += sq;
}
}
s_reduce[threadIdx.x] = local_sq_reg;
__syncthreads();
for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) s_reduce[threadIdx.x] += s_reduce[threadIdx.x + stride];
__syncthreads();
}
float regScale = 0.0f;
if (threadIdx.x == 0) {
float sum_sq = s_reduce[0];
if (sum_sq > 1e-12f && regularCount > 0) {
regScale = sqrtf(sum_sq / (float)regularCount);
}
scales_out[cell * numKVHeads + head] = regScale;
s_reduce[0] = regScale;
}
__syncthreads();
regScale = s_reduce[0];
__syncthreads();
s_reduce[threadIdx.x] = local_sq_out;
__syncthreads();
for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
if (threadIdx.x < stride) s_reduce[threadIdx.x] += s_reduce[threadIdx.x + stride];
__syncthreads();
}
float outScale = 0.0f;
if (threadIdx.x == 0) {
float sum_sq = s_reduce[0];
if (sum_sq > 1e-12f && outlierCount > 0) {
outScale = sqrtf(sum_sq / (float)outlierCount);
}
outlier_scales[cell * numKVHeads + head] = outScale;
s_reduce[0] = outScale;
}
__syncthreads();
outScale = s_reduce[0];
// Step 6: Quantize regular channels. s_idx[r] stores the code at the
// CONTIGUOUS regular position r, not the original channel index, so the
// dequant kernel can read them directly into the packed bit stream.
for (int r = threadIdx.x; r < regularCount; r += blockDim.x) {
int orig = s_reg_pos[r];
float v = (regScale > 0.0f) ? (s_rot[orig] / regScale) : 0.0f;
int idx = 0;
for (int b = 0; b < numBoundaries; b++) {
if (v >= boundaries[b]) idx++;
}
s_idx[r] = (uint8_t)idx;
}
__syncthreads();
// Step 7: Pack regular bits into packed_out.
// Per-head stride is padded up to a 4-byte multiple so atomicOr on
// unsigned-int words stays aligned for every head, regardless of
// how many regular channels (bit count) are in play. The Go-side
// allocator uses the same padded value (regularPackedBytes()).
const int regular_packed_bytes_raw = (regularCount * bits + 7) / 8;
const int regular_packed_bytes = (regular_packed_bytes_raw + 3) & ~3;
uint8_t *reg_out = packed_out + (cell * numKVHeads + head) * regular_packed_bytes;
for (int p = threadIdx.x; p < regular_packed_bytes; p += blockDim.x) {
reg_out[p] = 0;
}
__syncthreads();
{
uint8_t bitmask = (uint8_t)((1 << bits) - 1);
for (int r = threadIdx.x; r < regularCount; r += blockDim.x) {
int bit_offset = r * bits;
int byte_idx = bit_offset >> 3;
int shift = bit_offset & 7;
uint8_t v = s_idx[r] & bitmask;
atomicOr((unsigned int *)(reg_out + (byte_idx & ~3)),
(unsigned int)(v << shift) << ((byte_idx & 3) * 8));
if (shift + bits > 8) {
int byte_idx2 = byte_idx + 1;
atomicOr((unsigned int *)(reg_out + (byte_idx2 & ~3)),
(unsigned int)(v >> (8 - shift)) << ((byte_idx2 & 3) * 8));
}
}
}
__syncthreads();
// Step 8: Quantize outlier channels with their own codebook.
for (int r = threadIdx.x; r < outlierCount; r += blockDim.x) {
float v = (outScale > 0.0f) ? (s_outl_val[r] / outScale) : 0.0f;
int idx = 0;
for (int b = 0; b < numOutlierBoundaries; b++) {
if (v >= outlier_boundaries[b]) idx++;
}
s_outl_idx[r] = (uint8_t)idx;
}
__syncthreads();
// Step 9: Pack outlier bits. Same 4-byte alignment as regular packing.
const int outlier_packed_bytes_raw = (outlierCount * outlierBits + 7) / 8;
const int outlier_packed_bytes = (outlier_packed_bytes_raw + 3) & ~3;
uint8_t *out_out = outlier_packed + (cell * numKVHeads + head) * outlier_packed_bytes;
for (int p = threadIdx.x; p < outlier_packed_bytes; p += blockDim.x) {
out_out[p] = 0;
}
__syncthreads();
{
uint8_t obmask = (uint8_t)((1 << outlierBits) - 1);
for (int r = threadIdx.x; r < outlierCount; r += blockDim.x) {
int bit_offset = r * outlierBits;
int byte_idx = bit_offset >> 3;
int shift = bit_offset & 7;
uint8_t v = s_outl_idx[r] & obmask;
atomicOr((unsigned int *)(out_out + (byte_idx & ~3)),
(unsigned int)(v << shift) << ((byte_idx & 3) * 8));
if (shift + outlierBits > 8) {
int byte_idx2 = byte_idx + 1;
atomicOr((unsigned int *)(out_out + (byte_idx2 & ~3)),
(unsigned int)(v >> (8 - shift)) << ((byte_idx2 & 3) * 8));
}
}
}
// Step 10: Write outlier channel indices. uint8_t covers 0..255 safely.
uint8_t *idx_out = outlier_indices + (cell * numKVHeads + head) * outlierCount;
for (int r = threadIdx.x; r < outlierCount; r += blockDim.x) {
idx_out[r] = (uint8_t)s_outl_pos[r];
}
}
// ── ggml dispatch ─────────────────────────────────────────────────────────────
// Extern declaration for the V encode kernel (defined in tq-encode-v.cu).
extern __global__ void tq_encode_v_kernel(
const void *v, const float *rotation, uint8_t *packed_out, float *scales_out,
int firstCell, const float *boundaries,
int headDim, int numKVHeads, int bits, int numBoundaries, int vIsF32);
void ggml_cuda_tq_encode(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
const struct ggml_tensor * k = dst->src[0];
const struct ggml_tensor * rotation = dst->src[1];
// src[2] unused (was cell_idx; now firstCell is in op_params[1])
const struct ggml_tensor * scales = dst->src[3];
const struct ggml_tensor * boundaries = dst->src[4];
const int headDim = (int)k->ne[0];
const int numKVHeads = (int)k->ne[1];
const int batchSize = (int)k->ne[2];
const int bits = (int)((const int32_t *)dst->op_params)[0];
const int firstCell = (int)((const int32_t *)dst->op_params)[1];
const int outlierBits = (int)((const int32_t *)dst->op_params)[2];
const int outlierCount = (int)((const int32_t *)dst->op_params)[3];
const int numBoundaries = (1 << bits) - 1;
const int kIsF32 = (k->type == GGML_TYPE_F32) ? 1 : 0;
dim3 grid(batchSize, numKVHeads);
int block_size = (headDim < TQ_ENCODE_BLOCK_SIZE) ? headDim : TQ_ENCODE_BLOCK_SIZE;
int bs = 1;
while (bs < block_size) bs <<= 1;
block_size = bs;
cudaStream_t stream = ctx.stream();
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
const struct ggml_tensor * outlier_packed = dst->src[5];
const struct ggml_tensor * outlier_scales = dst->src[6];
const struct ggml_tensor * outlier_indices = dst->src[7];
const struct ggml_tensor * outlier_boundaries = dst->src[8];
const int numOutlierBoundaries = (1 << outlierBits) - 1;
// Shared memory layout for outlier kernel:
// s_k[headDim] f32
// s_rot[headDim] f32
// s_reduce[block_size] f32
// s_idx[headDim] u8
// s_is_outlier[headDim] u8
// s_reg_pos[headDim] i32 (over-sized to headDim; only regularCount used)
// s_outl_pos[outlierCount] i32
// s_outl_val[outlierCount] f32
// s_outl_idx[outlierCount] u8
size_t smem = (size_t)headDim * 2 * sizeof(float) // s_k + s_rot
+ (size_t)block_size * sizeof(float) // s_reduce
+ (size_t)headDim * 2 * sizeof(uint8_t) // s_idx + s_is_outlier
+ (size_t)headDim * sizeof(int) // s_reg_pos
+ (size_t)outlierCount * (sizeof(int) + sizeof(float) + sizeof(uint8_t));
tq_encode_kernel_outlier<<<grid, block_size, smem, stream>>>(
k->data,
(const float *)rotation->data,
(uint8_t *)dst->data,
(float *)scales->data,
(uint8_t *)outlier_packed->data,
(float *)outlier_scales->data,
(uint8_t *)outlier_indices->data,
firstCell,
(const float *)boundaries->data,
(const float *)outlier_boundaries->data,
headDim, numKVHeads, bits, numBoundaries,
outlierBits, outlierCount, numOutlierBoundaries, kIsF32
);
return;
}
size_t smem = (size_t)headDim * 2 * sizeof(float)
+ (size_t)block_size * sizeof(float)
+ (size_t)headDim * sizeof(uint8_t);
tq_encode_kernel<<<grid, block_size, smem, stream>>>(
k->data,
(const float *)rotation->data,
(uint8_t *)dst->data,
(float *)scales->data,
firstCell,
(const float *)boundaries->data,
headDim, numKVHeads, bits, numBoundaries, kIsF32
);
}
// Combined K+V encode: two back-to-back kernel launches in a single GGML op.
void ggml_cuda_tq_encode_kv(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
// src layout: [0]=K, [1]=rotation, [2]=V, [3]=K_scales, [4]=K_bounds,
// [5]=V_packed, [6]=V_scales, [7]=V_bounds
const struct ggml_tensor * k = dst->src[0];
const struct ggml_tensor * rotation = dst->src[1];
const struct ggml_tensor * v = dst->src[2];
const struct ggml_tensor * k_scales = dst->src[3];
const struct ggml_tensor * k_bounds = dst->src[4];
const struct ggml_tensor * v_packed = dst->src[5];
const struct ggml_tensor * v_scales = dst->src[6];
const struct ggml_tensor * v_bounds = dst->src[7];
int32_t k_bits, v_bits, firstCell;
memcpy(&k_bits, (const int32_t *)dst->op_params + 0, sizeof(int32_t));
memcpy(&v_bits, (const int32_t *)dst->op_params + 1, sizeof(int32_t));
memcpy(&firstCell, (const int32_t *)dst->op_params + 2, sizeof(int32_t));
const int headDim = (int)k->ne[0];
const int numKVHeads = (int)k->ne[1];
const int batchSize = (int)k->ne[2];
const int kIsF32 = (k->type == GGML_TYPE_F32) ? 1 : 0;
const int vIsF32 = (v->type == GGML_TYPE_F32) ? 1 : 0;
dim3 grid(batchSize, numKVHeads);
int block_size = (headDim < TQ_ENCODE_BLOCK_SIZE) ? headDim : TQ_ENCODE_BLOCK_SIZE;
int bs = 1;
while (bs < block_size) bs <<= 1;
block_size = bs;
size_t smem = (size_t)headDim * 2 * sizeof(float)
+ (size_t)block_size * sizeof(float)
+ (size_t)headDim * sizeof(uint8_t);
cudaStream_t stream = ctx.stream();
// K encode
tq_encode_kernel<<<grid, block_size, smem, stream>>>(
k->data,
(const float *)rotation->data,
(uint8_t *)dst->data,
(float *)k_scales->data,
firstCell,
(const float *)k_bounds->data,
headDim, numKVHeads, k_bits, (1 << k_bits) - 1, kIsF32
);
// V encode
const float * rotation_ptr = rotation ? (const float *)rotation->data : nullptr;
tq_encode_v_kernel<<<grid, block_size, smem, stream>>>(
v->data,
rotation_ptr,
(uint8_t *)v_packed->data,
(float *)v_scales->data,
firstCell,
(const float *)v_bounds->data,
headDim, numKVHeads, v_bits, (1 << v_bits) - 1, vIsF32
);
}

View file

@ -0,0 +1,6 @@
#pragma once
#include "common.cuh"
void ggml_cuda_tq_encode(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst);
void ggml_cuda_tq_encode_kv(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst);

View file

@ -0,0 +1,552 @@
#pragma once
#include "common.cuh"
#include "fattn-common.cuh"
// TurboQuant inline decode: extract a N-bit Lloyd-Max index from a packed byte row
// and return codebook[idx] * rms_scale. Handles cross-byte boundaries for bits=2,3.
static __device__ __forceinline__ float tq_decode_elem(
const uint8_t * packed_row, const float * codebook, float rms_scale, int elem, int bits)
{
const int bit_pos = elem * bits;
const int byte_idx = bit_pos >> 3;
const int shift = bit_pos & 7;
const int mask_val = (1 << bits) - 1;
int idx = ((int)(packed_row[byte_idx] >> shift)) & mask_val;
if (shift + bits > 8) {
idx |= ((int)(packed_row[byte_idx + 1] << (8 - shift))) & mask_val;
}
return codebook[idx] * rms_scale;
}
// --------------------------------------------------------------------
// Hardcoded warp-shuffle TQ decode: combine packed bytes into a single
// integer, shift to align, then extract N indices with compile-time
// offsets. Eliminates per-element multiply, byte_idx computation, and
// boundary-crossing branches.
//
// cb_lane: codebook[threadIdx.x % (1<<bits)], loaded once per kernel.
// shfl_w: shuffle width (power of 2, >= 1<<bits).
// --------------------------------------------------------------------
#if __CUDA_ARCH__ >= 600 || !defined(__CUDA_ARCH__)
// 3-bit, 4 elements. start_elem is a multiple of 4.
// 12 bits needed from 2 bytes; bit offset within the first byte is 0 or 4.
static __device__ __forceinline__ void tq_decode_4_3bit(
const uint8_t * __restrict__ packed_row,
float cb_lane, float rms, int start_elem, int shfl_w,
float * __restrict__ out)
{
const int byte_off = (start_elem * 3) >> 3;
const int bit_off = (start_elem * 3) & 7; // 0 or 4
const uint32_t w = (uint32_t)packed_row[byte_off] | ((uint32_t)packed_row[byte_off + 1] << 8);
const uint32_t s = w >> bit_off; // align element 0 to bit 0
out[0] = __shfl_sync(0xFFFFFFFF, cb_lane, (s >> 0) & 7, shfl_w) * rms;
out[1] = __shfl_sync(0xFFFFFFFF, cb_lane, (s >> 3) & 7, shfl_w) * rms;
out[2] = __shfl_sync(0xFFFFFFFF, cb_lane, (s >> 6) & 7, shfl_w) * rms;
out[3] = __shfl_sync(0xFFFFFFFF, cb_lane, (s >> 9) & 7, shfl_w) * rms;
}
// 3-bit, 8 elements. start_elem is a multiple of 8 (Volta+ path).
// 24 bits from 3 bytes; bit offset is always 0.
static __device__ __forceinline__ void tq_decode_8_3bit(
const uint8_t * __restrict__ packed_row,
float cb_lane, float rms, int start_elem, int shfl_w,
float * __restrict__ out)
{
const int byte_off = (start_elem * 3) >> 3;
const uint32_t w = (uint32_t)packed_row[byte_off]
| ((uint32_t)packed_row[byte_off + 1] << 8)
| ((uint32_t)packed_row[byte_off + 2] << 16);
out[0] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 0) & 7, shfl_w) * rms;
out[1] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 3) & 7, shfl_w) * rms;
out[2] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 6) & 7, shfl_w) * rms;
out[3] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 9) & 7, shfl_w) * rms;
out[4] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 12) & 7, shfl_w) * rms;
out[5] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 15) & 7, shfl_w) * rms;
out[6] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 18) & 7, shfl_w) * rms;
out[7] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 21) & 7, shfl_w) * rms;
}
// 2-bit, 4 elements. start_elem is a multiple of 4.
// 8 bits = 1 byte, always byte-aligned.
static __device__ __forceinline__ void tq_decode_4_2bit(
const uint8_t * __restrict__ packed_row,
float cb_lane, float rms, int start_elem, int shfl_w,
float * __restrict__ out)
{
const uint32_t b = packed_row[start_elem >> 2];
out[0] = __shfl_sync(0xFFFFFFFF, cb_lane, (b >> 0) & 3, shfl_w) * rms;
out[1] = __shfl_sync(0xFFFFFFFF, cb_lane, (b >> 2) & 3, shfl_w) * rms;
out[2] = __shfl_sync(0xFFFFFFFF, cb_lane, (b >> 4) & 3, shfl_w) * rms;
out[3] = __shfl_sync(0xFFFFFFFF, cb_lane, (b >> 6) & 3, shfl_w) * rms;
}
// 2-bit, 8 elements. start_elem is a multiple of 8.
// 16 bits = 2 bytes.
static __device__ __forceinline__ void tq_decode_8_2bit(
const uint8_t * __restrict__ packed_row,
float cb_lane, float rms, int start_elem, int shfl_w,
float * __restrict__ out)
{
const int byte_off = start_elem >> 2;
const uint32_t w = (uint32_t)packed_row[byte_off] | ((uint32_t)packed_row[byte_off + 1] << 8);
out[0] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 0) & 3, shfl_w) * rms;
out[1] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 2) & 3, shfl_w) * rms;
out[2] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 4) & 3, shfl_w) * rms;
out[3] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 6) & 3, shfl_w) * rms;
out[4] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 8) & 3, shfl_w) * rms;
out[5] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 10) & 3, shfl_w) * rms;
out[6] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 12) & 3, shfl_w) * rms;
out[7] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 14) & 3, shfl_w) * rms;
}
#else
// Stubs for sm < 600 (no __shfl_sync). Never executed — the launcher
// asserts compute capability >= 6.0. These only exist so that the
// __device__ dispatch template below can compile for Maxwell targets.
static __device__ __forceinline__ void tq_decode_4_3bit(
const uint8_t * __restrict__, float, float, int, int,
float * __restrict__ out)
{ out[0] = out[1] = out[2] = out[3] = 0.0f; }
static __device__ __forceinline__ void tq_decode_8_3bit(
const uint8_t * __restrict__, float, float, int, int,
float * __restrict__ out)
{ out[0] = out[1] = out[2] = out[3] = out[4] = out[5] = out[6] = out[7] = 0.0f; }
static __device__ __forceinline__ void tq_decode_4_2bit(
const uint8_t * __restrict__, float, float, int, int,
float * __restrict__ out)
{ out[0] = out[1] = out[2] = out[3] = 0.0f; }
static __device__ __forceinline__ void tq_decode_8_2bit(
const uint8_t * __restrict__, float, float, int, int,
float * __restrict__ out)
{ out[0] = out[1] = out[2] = out[3] = out[4] = out[5] = out[6] = out[7] = 0.0f; }
#endif
// Dispatch: decode N elements (N=4 on Pascal, N=8 on Volta+).
template<int N>
static __device__ __forceinline__ void tq_decode_N_shfl(
const uint8_t * __restrict__ packed_row,
float cb_lane, float rms,
int start_elem, int bits, int shfl_w,
float * __restrict__ out)
{
if constexpr (N == 4) {
if (bits == 3) { tq_decode_4_3bit(packed_row, cb_lane, rms, start_elem, shfl_w, out); }
else { tq_decode_4_2bit(packed_row, cb_lane, rms, start_elem, shfl_w, out); }
} else {
if (bits == 3) { tq_decode_8_3bit(packed_row, cb_lane, rms, start_elem, shfl_w, out); }
else { tq_decode_8_2bit(packed_row, cb_lane, rms, start_elem, shfl_w, out); }
}
}
// Compute Q·K dot product where K is TQ-compressed.
// Uses warp-shuffle for codebook lookup (1-cycle register-to-register).
template<int D, int nthreads, int cpy_ne>
static __device__ __forceinline__ float tq_vec_dot_KQ(
const uint8_t * packed_row,
float cb_lane,
float rms_scale,
const float2 * Q_f,
int bits)
{
float sum = 0.0f;
const int tid_kq = threadIdx.x % nthreads;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
// Decode 2*cpy_ne K elements at once.
float k_dec[2*cpy_ne];
tq_decode_N_shfl<2*cpy_ne>(packed_row, cb_lane, rms_scale,
2*(k_KQ_0 + tid_kq*cpy_ne), bits, nthreads, k_dec);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
const float2 q_pair = Q_f[k_KQ_0/nthreads + k_KQ_1];
sum += q_pair.x * k_dec[2*k_KQ_1] + q_pair.y * k_dec[2*k_KQ_1 + 1];
}
}
return sum;
}
// TurboQuant fused flash-attention kernel.
//
// Q: [D, nTokensQ, nHeadsQ, nSeq] f32 (after SDPA Permute(0,2,1,3))
// K_packed: [packedBytes*nKVHeads, capacity] i8 (encode result; base = full buffer)
// V: [D, nCells, nKVHeads, nSeq] f16 when !V_PACKED (after FA-branch Permute(0,2,1,3))
// [v_packedBytes*nKVHeads, capacity] i8 when V_PACKED (packed buffer, no permute)
// mask: [nCells, nTokensQ] f16 or NULL
// scales: [nKVHeads, capacity] f32 (K scales; base = full buffer)
// codebook: [1<<bits] f32 (K codebook)
//
// V_PACKED == false: existing K-only fused path (V is f16, src[6] == NULL)
// V_PACKED == true: new K+V fused path (V is packed i8, src[6] = v_scales, src[7] = v_codebook)
//
// Phase 1: D == 128, bits runtime param, gridDim.y == 1 (no multi-block).
template<int D, int ncols, bool use_logit_softcap, bool V_PACKED>
__launch_bounds__(128, 2)
static __global__ void tq_flash_attn_ext_vec(
const char * __restrict__ Q,
const uint8_t * __restrict__ K_packed,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
const float * __restrict__ scales,
const float * __restrict__ codebook,
float scale,
float logit_softcap,
int bits,
int firstCell,
int nCells,
int nKVHeads,
int packedBytes,
// Q geometry
int32_t ne00,
uint3 ne01, // init_fastdiv_values(nTokensQ); .z == nTokensQ
int32_t ne02, // nHeadsQ
int32_t ne03, // nSeq
int32_t nb01, // Q stride: bytes between consecutive tokens
int32_t nb02, // Q stride: bytes between consecutive heads
int64_t nb03, // Q stride: bytes between consecutive sequences
// V geometry (after permute: [D, nCells, nKVHeads]) — only used when !V_PACKED
int32_t nb21, // bytes between V cells (= D*nKVHeads*sizeof(half))
int32_t nb22, // bytes between V heads (= D*sizeof(half))
int64_t nb23, // bytes between V seqs
// mask geometry
int32_t ne31, // nCells (mask row width)
int32_t nb31, // mask stride: bytes between token rows
// V packed params (only used when V_PACKED == true)
const float * __restrict__ v_scales, // [nKVHeads, capacity] f32
const float * __restrict__ v_codebook, // [1<<v_bits] f32
int v_bits,
int v_packedBytes
)
{
#ifdef FLASH_ATTN_AVAILABLE
// Skip logit_softcap variants for unsupported D values (mirrors original kernel guard).
if (use_logit_softcap && D != 128 && D != 256) {
GGML_UNUSED_VARS(Q, K_packed, V, mask, dst, scales, codebook,
scale, logit_softcap, bits, firstCell, nCells, nKVHeads, packedBytes,
ne00, ne01, ne02, ne03, nb01, nb02, nb03,
nb21, nb22, nb23, ne31, nb31,
v_scales, v_codebook, v_bits, v_packedBytes);
NO_DEVICE_CODE;
return;
}
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); // 16 on Volta+
constexpr int cpy_ne = cpy_nb / 4; // 4
constexpr int nthreads = 128;
constexpr int nthreads_KQ = nthreads / cpy_nb; // 8
constexpr int nthreads_V = nthreads / cpy_nb; // 8
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_KQ");
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 64");
constexpr int V_rows_per_thread = 2 * cpy_ne; // 8
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; // 4
// dequantize_V is only needed for the f16 V path
[[maybe_unused]] constexpr dequantize_V_t dequantize_V =
V_PACKED ? (dequantize_V_t)nullptr
: get_dequantize_V<GGML_TYPE_F16, float, V_rows_per_thread>();
const int ic0 = blockIdx.x * ncols;
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z % ne02;
const int gqa_ratio = ne02 / nKVHeads;
const int head_kv = head / gqa_ratio;
// Advance base pointers.
Q += (int64_t)nb03*sequence + nb02*head + nb01*ic0;
K_packed += (int64_t)firstCell * nKVHeads * packedBytes + (int64_t)head_kv * packedBytes;
scales += (int64_t)firstCell * nKVHeads + head_kv;
// V pointer setup: f16 path uses stride-based addressing; packed path uses cell-index addressing.
const uint8_t * V_packed_base = nullptr;
const float * v_scales_base = nullptr;
if constexpr (V_PACKED) {
V_packed_base = (const uint8_t *)V
+ (int64_t)firstCell * nKVHeads * v_packedBytes
+ (int64_t)head_kv * v_packedBytes;
v_scales_base = v_scales
+ (int64_t)firstCell * nKVHeads + head_kv;
} else {
V += (int64_t)nb23*sequence + (int64_t)nb22*head_kv;
}
const half * maskh = mask ? (const half *)(mask + (int64_t)nb31*ic0) : nullptr;
// Load one codebook entry per lane for warp-shuffle lookups.
// For 3-bit (8 entries): lanes 0-7 hold codebook[0-7], lanes 8-15 repeat, etc.
// __shfl_sync with width=8 handles the wrap-around.
const float k_cb_lane = codebook[threadIdx.x & ((1 << bits) - 1)];
[[maybe_unused]] const float v_cb_lane = V_PACKED
? v_codebook[threadIdx.x & ((1 << v_bits) - 1)]
: 0.0f;
constexpr int nwarps = nthreads / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
__builtin_assume(tid < nthreads);
constexpr int ne_KQ = ncols * D;
constexpr int ne_combine = nwarps * V_cols_per_iter * D;
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
float KQ_max[ncols];
float KQ_sum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_max[j] = -FLT_MAX/2.0f;
KQ_sum[j] = 0.0f;
}
// Load Q into registers (float2 per half-pair, scale applied).
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}};
#pragma unroll
for (int j = 0; j < ncols; ++j) {
const float2 * Q_j = (const float2 *)(Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
if (ncols == 1 || ic0 + j < (int)ne01.z) {
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
}
}
// Apply attention scale.
#pragma unroll
for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
Q_reg[j][k].x *= scale;
Q_reg[j][k].y *= scale;
}
}
// Main KV loop — single block (gridDim.y == 1).
for (int k_VKQ_0 = 0; k_VKQ_0 < nCells; k_VKQ_0 += nthreads,
V += (V_PACKED ? 0 : (int64_t)nthreads * nb21),
maskh += (maskh ? nthreads : 0)) {
float KQ_reg[ncols];
float KQ_max_new[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_max_new[j] = KQ_max[j];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
const int i_KQ = threadIdx.y*WARP_SIZE
+ (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1)))
+ i_KQ_0;
const int cell_rel = k_VKQ_0 + i_KQ;
const bool in_range = (cell_rel < nCells);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
// Always execute the dot product to keep all lanes convergent
// for warp shuffles. rms_scale=0 for out-of-range cells.
const uint8_t * packed_row = K_packed + (int64_t)cell_rel * nKVHeads * packedBytes;
const float rms_scale = in_range ? scales[cell_rel * nKVHeads] : 0.0f;
float sum = tq_vec_dot_KQ<D, nthreads_KQ, cpy_ne>(
packed_row, k_cb_lane, rms_scale, Q_reg[j], bits);
sum = warp_reduce_sum<nthreads_KQ>(sum);
if (use_logit_softcap) {
sum = logit_softcap * tanhf(sum);
}
if (maskh && (ncols == 1 || ic0 + j < (int)ne01.z)) {
sum += __half2float(maskh[j*ne31 + i_KQ]);
}
if (!in_range) {
sum = -FLT_MAX/2.0f;
}
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == (uint32_t)i_KQ_0) {
KQ_reg[j] = sum;
}
}
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
#pragma unroll
for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
}
const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
KQ_max[j] = KQ_max_new[j];
KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
KQ[j*nthreads + tid] = KQ_reg[j];
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
}
}
#ifndef GGML_USE_HIP
__syncwarp();
#endif
#pragma unroll
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
const int k = threadIdx.y*WARP_SIZE + k0
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
float KQ_k[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
KQ_k[j] = KQ[j*nthreads + k];
}
if constexpr (V_PACKED) {
// Decode V from packed buffer inline — warp-shuffle batch decode.
// Always execute decodes to keep all lanes convergent for shuffles.
// v_rms=0 for out-of-range cells produces zero contributions.
const int cell_rel = k_VKQ_0 + k;
const uint8_t * v_row = V_packed_base
+ (int64_t)cell_rel * nKVHeads * v_packedBytes;
const float v_rms = (cell_rel < nCells)
? v_scales_base[cell_rel * nKVHeads] : 0.0f;
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
const int base_elem = 2*i_VKQ_0
+ (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)
* V_rows_per_thread;
// Batch-decode V elements using shuffle codebook.
float v_dec[V_rows_per_thread];
tq_decode_N_shfl<V_rows_per_thread>(v_row, v_cb_lane, v_rms,
base_elem, v_bits, nthreads_V, v_dec);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += v_dec[2*i_VKQ_1] * KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += v_dec[2*i_VKQ_1+1] * KQ_k[j];
}
}
}
} else {
// Original f16 V path (unchanged).
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
float2 tmp[V_rows_per_thread/2];
dequantize_V(V + k*nb21, tmp,
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
#pragma unroll
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
#pragma unroll
for (int j = 0; j < ncols; ++j) {
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x * KQ_k[j];
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y * KQ_k[j];
}
}
}
}
}
} // end KV loop
// --- Reduce across warps and write output ---
__shared__ float KQ_max_shared[ncols][WARP_SIZE];
__shared__ float KQ_sum_shared[ncols][WARP_SIZE];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.y == 0) {
KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
KQ_sum_shared[j][threadIdx.x] = 0.0f;
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
KQ_max_shared[j][threadIdx.y] = KQ_max[j];
}
}
__syncthreads();
#pragma unroll
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
if (ncols > 1 && ic0 + j_VKQ >= (int)ne01.z) {
break;
}
float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];
kqmax_new = warp_reduce_max(kqmax_new);
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
KQ_max[j_VKQ] = kqmax_new;
float2 * VKQ_tmp = (float2 *)KQ + threadIdx.y*(V_cols_per_iter*D/2)
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale;
VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale;
}
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
const int i_VKQ = i_VKQ_0
+ (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ,
&VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4,
&VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
}
KQ_sum[j_VKQ] *= kqmax_scale;
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
if (threadIdx.x == 0) {
KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ];
}
__syncthreads();
if (nthreads <= D || tid < D) {
KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x];
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
#pragma unroll
for (int i0 = 0; i0 < D; i0 += nthreads) {
float dst_val = 0;
#pragma unroll
for (int w = 0; w < nwarps; ++w) {
#pragma unroll
for (int v = 0; v < V_cols_per_iter; ++v) {
dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
}
}
dst_val /= KQ_sum[j_VKQ];
// Output layout: [D, nHeadsQ, nTokensQ, nSeq] — matches ggml_flash_attn_ext layout.
dst[(((int64_t)sequence*(int)ne01.z + ic0 + j_VKQ)*ne02 + head)*D + i0 + tid] = dst_val;
}
}
if (j_VKQ < ncols-1) {
__syncthreads();
}
}
#else
GGML_UNUSED_VARS(Q, K_packed, V, mask, dst, scales, codebook,
scale, logit_softcap, bits, firstCell, nCells, nKVHeads, packedBytes,
ne00, ne01, ne02, ne03, nb01, nb02, nb03,
nb21, nb22, nb23, ne31, nb31,
v_scales, v_codebook, v_bits, v_packedBytes);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}

View file

@ -0,0 +1,139 @@
#include "tq-fattn.cuh"
#include "tq-fattn-vec.cuh"
// Launch the TQ fused flash-attention kernel for a given (D, ncols, use_logit_softcap, V_PACKED).
template<int D, int ncols, bool use_logit_softcap, bool V_PACKED>
static void tq_fattn_vec_launch(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
float scale, float logit_softcap,
int bits, int firstCell, int nCells, int nKVHeads, int packedBytes,
int v_bits, int v_packedBytes,
const float * v_scales_ptr, const float * v_codebook_ptr)
{
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K_p = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * scales = dst->src[4];
const ggml_tensor * codebook = dst->src[5];
GGML_ASSERT(Q->ne[0] == D);
GGML_ASSERT(Q->type == GGML_TYPE_F32);
if constexpr (!V_PACKED) {
GGML_ASSERT(V->type == GGML_TYPE_F16);
} else {
GGML_ASSERT(V->type == GGML_TYPE_I8);
}
const int nTokensQ = (int)Q->ne[1];
const int nHeadsQ = (int)Q->ne[2];
const int nSeq = (int)Q->ne[3];
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
const uint3 ne01 = init_fastdiv_values((uint64_t)nTokensQ);
const int ntiles_x = (nTokensQ + ncols - 1) / ncols;
dim3 blocks(ntiles_x, 1, nHeadsQ * nSeq);
dim3 threads(WARP_SIZE, 4);
// V strides: only used by !V_PACKED path; pass V strides for the f16 case.
// For the V_PACKED case these are passed but ignored by the kernel.
tq_flash_attn_ext_vec<D, ncols, use_logit_softcap, V_PACKED><<<blocks, threads, 0, ctx.stream()>>>(
(const char *)Q->data,
(const uint8_t *)K_p->data,
(const char *)V->data,
mask ? (const char *)mask->data : nullptr,
(float *)dst->data,
(const float *)scales->data,
(const float *)codebook->data,
scale, logit_softcap, bits, firstCell, nCells, nKVHeads, packedBytes,
(int32_t)Q->ne[0],
ne01,
(int32_t)Q->ne[2],
(int32_t)Q->ne[3],
(int32_t)Q->nb[1],
(int32_t)Q->nb[2],
(int64_t)Q->nb[3],
(int32_t)V->nb[1],
(int32_t)V->nb[2],
(int64_t)V->nb[3],
mask ? (int32_t)mask->ne[0] : 0, // nCells (mask row width)
mask ? (int32_t)mask->nb[1] : 0, // bytes between token rows
v_scales_ptr, v_codebook_ptr, v_bits, v_packedBytes
);
}
void ggml_cuda_tq_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(ggml_cuda_info().devices[ctx.device].cc >= 600 &&
"TurboQuant fused flash attention requires compute capability 6.0+");
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * v_scales_t = dst->src[6]; // NULL for K-only fused
const ggml_tensor * v_codebook_t = dst->src[7];
float scale;
float logit_softcap;
int32_t bits;
int32_t firstCell;
int32_t v_bits;
memcpy(&scale, (const float *)dst->op_params + 0, sizeof(float));
memcpy(&logit_softcap, (const float *)dst->op_params + 1, sizeof(float));
memcpy(&bits, (const int32_t *)dst->op_params + 2, sizeof(int32_t));
memcpy(&firstCell, (const int32_t *)dst->op_params + 3, sizeof(int32_t));
memcpy(&v_bits, (const int32_t *)dst->op_params + 4, sizeof(int32_t));
const int D = (int)Q->ne[0];
const int nTokensQ = (int)Q->ne[1];
// nCells: when V_PACKED, V is the raw buffer so ne[1] is capacity not nCells.
// Compute nCells from the mask if available, else from V geometry.
const bool v_packed = (v_scales_t != nullptr);
int nCells;
if (v_packed) {
// V is packed i8 buffer: nCells is not directly in V->ne[].
// The mask always carries nCells; require mask when V is packed.
GGML_ASSERT(dst->src[3] != nullptr && "TQ K+V fused requires mask to determine nCells");
nCells = (int)dst->src[3]->ne[0];
} else {
nCells = (int)V->ne[1]; // after SDPA permute: [D, nCells, nKVHeads]
}
// packedBytes for K:
const int packedBytes = (D * bits + 7) / 8;
// packedBytes for V (computed before nKVHeads so we can derive nKVHeads from packed V dims):
const int v_packedBytes = v_packed ? ((D * v_bits + 7) / 8) : 0;
// nKVHeads: for K-only fused, V is a permuted f16 tensor with ne[2]=nKVHeads.
// For K+V fused, V is the raw packed i8 buffer with ne[0]=v_packedBytes*nKVHeads.
const int nKVHeads = v_packed ? (int)V->ne[0] / v_packedBytes : (int)V->ne[2];
const float * v_scales_ptr = v_scales_t ? (const float *)v_scales_t->data : nullptr;
const float * v_codebook_ptr = v_codebook_t ? (const float *)v_codebook_t->data : nullptr;
GGML_ASSERT(D == 128); // Phase 1: head_dim=128 only
if (logit_softcap != 0.0f) { scale /= logit_softcap; }
const int ncols = (nTokensQ == 1) ? 1 : 2;
#define DISPATCH(NCOLS, SOFTCAP) \
if (v_packed) { \
tq_fattn_vec_launch<128, NCOLS, SOFTCAP, true>(ctx, dst, scale, logit_softcap, \
bits, firstCell, nCells, nKVHeads, packedBytes, \
v_bits, v_packedBytes, v_scales_ptr, v_codebook_ptr); \
} else { \
tq_fattn_vec_launch<128, NCOLS, SOFTCAP, false>(ctx, dst, scale, logit_softcap, \
bits, firstCell, nCells, nKVHeads, packedBytes, \
0, 0, nullptr, nullptr); \
}
if (ncols == 1) {
if (logit_softcap == 0.0f) { DISPATCH(1, false); }
else { DISPATCH(1, true); }
} else {
if (logit_softcap == 0.0f) { DISPATCH(2, false); }
else { DISPATCH(2, true); }
}
#undef DISPATCH
CUDA_CHECK(cudaGetLastError());
}

View file

@ -0,0 +1,4 @@
#pragma once
#include "ggml-cuda/common.cuh"
void ggml_cuda_tq_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -1704,3 +1704,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggm
return res;
}
// TurboQuant pipeline getters — fixed-name kernels, no variants
static struct ggml_metal_pipeline_with_params tq_get_pipeline(ggml_metal_library_t lib, const char * name) {
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, name, name, nullptr);
}
return res;
}
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequant (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_dequant"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequant_outlier(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_dequant_outlier"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_v"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_outlier"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16_d256"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed_d256"); }

View file

@ -188,6 +188,17 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_att
int32_t dv,
int32_t nwg);
// TurboQuant pipeline getters
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequant (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequant_outlier(ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib);
// MTLResidencySet wrapper
typedef void * ggml_metal_rset_t;

View file

@ -1181,6 +1181,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return has_simdgroup_reduction;
case GGML_OP_TQ_ENCODE:
case GGML_OP_TQ_ENCODE_V:
case GGML_OP_TQ_ENCODE_KV:
case GGML_OP_TQ_DEQUANT:
case GGML_OP_TQ_DEQUANT_KV:
case GGML_OP_TQ_FLASH_ATTN_EXT:
return has_simdgroup_reduction;
default:
return false;
}

File diff suppressed because it is too large Load diff

View file

@ -942,4 +942,82 @@ typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_sgd;
// TurboQuant (TQ) — dynamic KV-cache compression at inference time
typedef struct {
int32_t headDim;
int32_t numKVHeads;
int32_t bits;
int32_t firstCell;
int32_t packed_bytes; // (headDim * bits + 7) / 8
int32_t codebook_len;
} ggml_metal_kargs_tq_dequant;
typedef struct {
int32_t headDim;
int32_t numKVHeads;
int32_t bits;
int32_t firstCell;
int32_t reg_packed_bytes; // padded: (reg_count * bits + 7) / 8, aligned to 4
int32_t outlier_bits;
int32_t outlier_count;
int32_t out_packed_bytes; // padded: (outlierCount * outlierBits + 7) / 8, aligned to 4
} ggml_metal_kargs_tq_dequant_outlier;
typedef struct {
int32_t headDim;
int32_t numKVHeads;
int32_t k_bits;
int32_t v_bits;
int32_t firstCell;
int32_t k_packed_bytes;
int32_t v_packed_bytes;
int32_t k_codebook_len;
int32_t v_codebook_len;
} ggml_metal_kargs_tq_dequant_kv;
typedef struct {
int32_t headDim;
int32_t numKVHeads;
int32_t bits;
int32_t firstCell;
int32_t kIsF32; // 1 = f32 input, 0 = f16
int32_t hasRotation; // 1 = apply rotation matrix, 0 = skip
} ggml_metal_kargs_tq_encode;
typedef struct {
int32_t headDim;
int32_t numKVHeads;
int32_t bits;
int32_t firstCell;
int32_t kIsF32;
int32_t outlierBits;
int32_t outlierCount;
} ggml_metal_kargs_tq_encode_outlier;
typedef struct {
int32_t ncols; // 1 or 2 (nTokensQ == 1 ? 1 : 2)
int32_t nTokensQ;
int32_t nHeadsQ;
int32_t nSeq;
int32_t nCells;
int32_t nKVHeads;
int32_t bits; // K bits
int32_t firstCell;
int32_t packedBytes; // K packed bytes per head
int32_t v_bits; // 0 when V is f16
int32_t v_packedBytes;
int32_t hasMask; // 1 = mask buffer valid
int32_t ne31; // mask row width (= nCells)
float scale;
float logit_softcap;
uint64_t nb01; // Q stride: bytes between consecutive tokens
uint64_t nb02; // Q stride: bytes between consecutive heads
uint64_t nb03; // Q stride: bytes between consecutive sequences
uint64_t nb21; // V stride: bytes between cells (f16 path only)
uint64_t nb22; // V stride: bytes between heads (f16 path only)
uint64_t nb23; // V stride: bytes between seqs (f16 path only)
uint64_t nb31; // mask stride: bytes between token rows
} ggml_metal_kargs_tq_fattn_vec;
#endif // GGML_METAL_IMPL

View file

@ -452,6 +452,30 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break;
case GGML_OP_TQ_ENCODE:
{
n_fuse = ggml_metal_op_tq_encode(ctx, idx);
} break;
case GGML_OP_TQ_ENCODE_V:
{
n_fuse = ggml_metal_op_tq_encode_v(ctx, idx);
} break;
case GGML_OP_TQ_ENCODE_KV:
{
n_fuse = ggml_metal_op_tq_encode_kv(ctx, idx);
} break;
case GGML_OP_TQ_DEQUANT:
{
n_fuse = ggml_metal_op_tq_dequant(ctx, idx);
} break;
case GGML_OP_TQ_DEQUANT_KV:
{
n_fuse = ggml_metal_op_tq_dequant_kv(ctx, idx);
} break;
case GGML_OP_TQ_FLASH_ATTN_EXT:
{
n_fuse = ggml_metal_op_tq_flash_attn_ext(ctx, idx);
} break;
default:
{
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@ -4154,3 +4178,465 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
return 1;
}
// ── TurboQuant Metal ops ───────────────────────────────────────────────────
int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const int headDim = (int)op->ne[0];
const int numKVHeads = (int)op->ne[1];
const int nCells = (int)op->ne[2];
const int bits = (int)((const int32_t *)op->op_params)[0];
const int firstCell = (int)((const int32_t *)op->op_params)[1];
const int outlierBits = (int)((const int32_t *)op->op_params)[2];
const int outlierCount = (int)((const int32_t *)op->op_params)[3];
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
// Outlier kernel uses a 128-thread TG: threadgroup barriers + atomics on
// s_mask require all threads. Non-outlier kernel uses a single simdgroup
// (32 threads): it only reads tiisg and has no barriers, so a larger TG
// just replicates work across idle simdgroups.
const int outlier_block_size = 128;
const int nonoutlier_block_size = 32;
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
const int regular_count = headDim - outlierCount;
const int reg_packed_raw = (regular_count * bits + 7) / 8;
const int reg_packed_bytes = (reg_packed_raw + 3) & ~3;
const int out_packed_raw = (outlierCount * outlierBits + 7) / 8;
const int out_packed_bytes = (out_packed_raw + 3) & ~3;
ggml_metal_kargs_tq_dequant_outlier args = {
/*.headDim =*/ headDim,
/*.numKVHeads =*/ numKVHeads,
/*.bits =*/ bits,
/*.firstCell =*/ firstCell,
/*.reg_packed_bytes=*/ reg_packed_bytes,
/*.outlier_bits =*/ outlierBits,
/*.outlier_count =*/ outlierCount,
/*.out_packed_bytes=*/ out_packed_bytes,
};
auto pipeline = ggml_metal_library_get_pipeline_tq_dequant_outlier(lib);
const int mask_words = (headDim + 31) >> 5;
const size_t smem = GGML_PAD((size_t)headDim * sizeof(int8_t) + (size_t)mask_words * sizeof(uint32_t), 16);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // reg packed
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); // reg scales
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); // reg codebook
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); // outlier packed
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // outlier scales
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6); // outlier indices
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); // outlier codebook
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, outlier_block_size, 1, 1);
return 1;
}
const int packed_bytes = (headDim * bits + 7) / 8;
const int codebook_len = (int)op->src[2]->ne[0];
ggml_metal_kargs_tq_dequant args = {
/*.headDim =*/ headDim,
/*.numKVHeads =*/ numKVHeads,
/*.bits =*/ bits,
/*.firstCell =*/ firstCell,
/*.packed_bytes =*/ packed_bytes,
/*.codebook_len =*/ codebook_len,
};
auto pipeline = ggml_metal_library_get_pipeline_tq_dequant(lib);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // packed
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); // scales
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); // codebook
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, nonoutlier_block_size, 1, 1);
return 1;
}
int ggml_metal_op_tq_dequant_kv(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const int headDim = (int)op->ne[0];
const int numKVHeads = (int)op->ne[1];
const int nCells = (int)op->ne[2];
int32_t k_bits, v_bits, firstCell;
memcpy(&k_bits, (const int32_t *)op->op_params + 0, sizeof(int32_t));
memcpy(&v_bits, (const int32_t *)op->op_params + 1, sizeof(int32_t));
memcpy(&firstCell, (const int32_t *)op->op_params + 2, sizeof(int32_t));
const int k_packed_bytes = (headDim * k_bits + 7) / 8;
const int v_packed_bytes = (headDim * v_bits + 7) / 8;
const int k_codebook_len = (int)op->src[2]->ne[0];
const int v_codebook_len = (int)op->src[5]->ne[0];
// kernel_tq_dequant is single-simdgroup (uses only tiisg, no barriers,
// no atomics) — 32-thread TGs eliminate 4× redundant work vs 128-thread.
const int block_size = 32;
const size_t plane_size = (size_t)headDim * numKVHeads * nCells;
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_dst_v = bid_dst;
bid_dst_v.offs += plane_size * sizeof(uint16_t);
auto pipeline = ggml_metal_library_get_pipeline_tq_dequant(lib);
// K dequant → first plane
{
ggml_metal_kargs_tq_dequant args_k = {
/*.headDim =*/ headDim,
/*.numKVHeads =*/ numKVHeads,
/*.bits =*/ k_bits,
/*.firstCell =*/ firstCell,
/*.packed_bytes =*/ k_packed_bytes,
/*.codebook_len =*/ k_codebook_len,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args_k, sizeof(args_k), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
}
// V dequant → second plane
{
ggml_metal_kargs_tq_dequant args_v = {
/*.headDim =*/ headDim,
/*.numKVHeads =*/ numKVHeads,
/*.bits =*/ v_bits,
/*.firstCell =*/ firstCell,
/*.packed_bytes =*/ v_packed_bytes,
/*.codebook_len =*/ v_codebook_len,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args_v, sizeof(args_v), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 3);
ggml_metal_encoder_set_buffer (enc, bid_dst_v, 4);
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
}
return 1;
}
static int tq_encode_block_size(int headDim) {
int bs = 1;
int target = std::min(headDim, 128);
while (bs < target) bs <<= 1;
return bs;
}
int ggml_metal_op_tq_encode(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const int headDim = (int)op->src[0]->ne[0];
const int numKVHeads = (int)op->src[0]->ne[1];
const int batchSize = (int)op->src[0]->ne[2];
const int bits = (int)((const int32_t *)op->op_params)[0];
const int firstCell = (int)((const int32_t *)op->op_params)[1];
const int outlierBits = (int)((const int32_t *)op->op_params)[2];
const int outlierCount = (int)((const int32_t *)op->op_params)[3];
const int kIsF32 = (op->src[0]->type == GGML_TYPE_F32) ? 1 : 0;
const int block_size = tq_encode_block_size(headDim);
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
ggml_metal_kargs_tq_encode_outlier args = {
/*.headDim =*/ headDim,
/*.numKVHeads =*/ numKVHeads,
/*.bits =*/ bits,
/*.firstCell =*/ firstCell,
/*.kIsF32 =*/ kIsF32,
/*.outlierBits =*/ outlierBits,
/*.outlierCount =*/ outlierCount,
};
auto pipeline = ggml_metal_library_get_pipeline_tq_encode_outlier(lib);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // k
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); // rotation
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); // packed_out (dst)
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); // scales
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // boundaries
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6); // outlier_packed
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); // outlier_scales
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[7]), 8); // outlier_indices
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[8]), 9); // outlier_boundaries
ggml_metal_encoder_dispatch_threadgroups(enc, batchSize, numKVHeads, 1, block_size, 1, 1);
return 1;
}
const int hasRotation = (op->src[1] != nullptr) ? 1 : 0;
ggml_metal_kargs_tq_encode args = {
/*.headDim =*/ headDim,
/*.numKVHeads =*/ numKVHeads,
/*.bits =*/ bits,
/*.firstCell =*/ firstCell,
/*.kIsF32 =*/ kIsF32,
/*.hasRotation =*/ hasRotation,
};
auto pipeline = ggml_metal_library_get_pipeline_tq_encode(lib);
ggml_metal_buffer_id bid_rot = hasRotation
? ggml_metal_get_buffer_id(op->src[1])
: ggml_metal_get_buffer_id(op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // k
ggml_metal_encoder_set_buffer (enc, bid_rot, 2); // rotation (or dummy)
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); // packed_out (dst)
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); // scales
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // boundaries
ggml_metal_encoder_dispatch_threadgroups(enc, batchSize, numKVHeads, 1, block_size, 1, 1);
return 1;
}
int ggml_metal_op_tq_encode_v(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const int headDim = (int)op->src[0]->ne[0];
const int numKVHeads = (int)op->src[0]->ne[1];
const int batchSize = (int)op->src[0]->ne[2];
const int bits = (int)((const int32_t *)op->op_params)[0];
const int firstCell = (int)((const int32_t *)op->op_params)[1];
const int vIsF32 = (op->src[0]->type == GGML_TYPE_F32) ? 1 : 0;
const int hasRotation = (op->src[1] != nullptr) ? 1 : 0;
const int block_size = tq_encode_block_size(headDim);
ggml_metal_kargs_tq_encode args = {
/*.headDim =*/ headDim,
/*.numKVHeads =*/ numKVHeads,
/*.bits =*/ bits,
/*.firstCell =*/ firstCell,
/*.kIsF32 =*/ vIsF32,
/*.hasRotation =*/ hasRotation,
};
auto pipeline = ggml_metal_library_get_pipeline_tq_encode_v(lib);
ggml_metal_buffer_id bid_rot = hasRotation
? ggml_metal_get_buffer_id(op->src[1])
: ggml_metal_get_buffer_id(op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // v
ggml_metal_encoder_set_buffer (enc, bid_rot, 2); // rotation (or dummy)
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); // packed_out (dst)
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); // scales
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // boundaries
ggml_metal_encoder_dispatch_threadgroups(enc, batchSize, numKVHeads, 1, block_size, 1, 1);
return 1;
}
int ggml_metal_op_tq_encode_kv(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
// src layout: [0]=K, [1]=rotation, [2]=V, [3]=K_scales, [4]=K_bounds,
// [5]=V_packed, [6]=V_scales, [7]=V_bounds
const int headDim = (int)op->src[0]->ne[0];
const int numKVHeads = (int)op->src[0]->ne[1];
const int batchSize = (int)op->src[0]->ne[2];
const int kIsF32 = (op->src[0]->type == GGML_TYPE_F32) ? 1 : 0;
const int vIsF32 = (op->src[2]->type == GGML_TYPE_F32) ? 1 : 0;
const int hasRotation = (op->src[1] != nullptr) ? 1 : 0;
int32_t k_bits, v_bits, firstCell;
memcpy(&k_bits, (const int32_t *)op->op_params + 0, sizeof(int32_t));
memcpy(&v_bits, (const int32_t *)op->op_params + 1, sizeof(int32_t));
memcpy(&firstCell, (const int32_t *)op->op_params + 2, sizeof(int32_t));
const int block_size = tq_encode_block_size(headDim);
ggml_metal_buffer_id bid_rot = hasRotation
? ggml_metal_get_buffer_id(op->src[1])
: ggml_metal_get_buffer_id(op);
// K encode
{
ggml_metal_kargs_tq_encode args_k = {
/*.headDim =*/ headDim,
/*.numKVHeads =*/ numKVHeads,
/*.bits =*/ k_bits,
/*.firstCell =*/ firstCell,
/*.kIsF32 =*/ kIsF32,
/*.hasRotation =*/ hasRotation,
};
auto pipeline_k = ggml_metal_library_get_pipeline_tq_encode(lib);
ggml_metal_encoder_set_pipeline(enc, pipeline_k);
ggml_metal_encoder_set_bytes (enc, &args_k, sizeof(args_k), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // K
ggml_metal_encoder_set_buffer (enc, bid_rot, 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); // dst = K packed
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); // K scales
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // K bounds
ggml_metal_encoder_dispatch_threadgroups(enc, batchSize, numKVHeads, 1, block_size, 1, 1);
}
// V encode
{
ggml_metal_kargs_tq_encode args_v = {
/*.headDim =*/ headDim,
/*.numKVHeads =*/ numKVHeads,
/*.bits =*/ v_bits,
/*.firstCell =*/ firstCell,
/*.kIsF32 =*/ vIsF32,
/*.hasRotation =*/ hasRotation,
};
auto pipeline_v = ggml_metal_library_get_pipeline_tq_encode_v(lib);
ggml_metal_encoder_set_pipeline(enc, pipeline_v);
ggml_metal_encoder_set_bytes (enc, &args_v, sizeof(args_v), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 1); // V
ggml_metal_encoder_set_buffer (enc, bid_rot, 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 3); // V packed
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 4); // V scales
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[7]), 5); // V bounds
ggml_metal_encoder_dispatch_threadgroups(enc, batchSize, numKVHeads, 1, block_size, 1, 1);
}
return 1;
}
int ggml_metal_op_tq_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const ggml_tensor * Q = op->src[0];
const ggml_tensor * K_p = op->src[1];
const ggml_tensor * V = op->src[2];
const ggml_tensor * mask = op->src[3];
// src[4]=K_scales, src[5]=K_codebook, src[6]=V_scales (NULL for K-only), src[7]=V_codebook
const bool v_packed = (op->src[6] != nullptr);
float scale;
float logit_softcap;
int32_t bits;
int32_t firstCell;
int32_t v_bits;
memcpy(&scale, (const float *)op->op_params + 0, sizeof(float));
memcpy(&logit_softcap, (const float *)op->op_params + 1, sizeof(float));
memcpy(&bits, (const int32_t *)op->op_params + 2, sizeof(int32_t));
memcpy(&firstCell, (const int32_t *)op->op_params + 3, sizeof(int32_t));
memcpy(&v_bits, (const int32_t *)op->op_params + 4, sizeof(int32_t));
if (logit_softcap != 0.0f) { scale /= logit_softcap; }
const int D = (int)Q->ne[0];
const int nTokensQ = (int)Q->ne[1];
const int nHeadsQ = (int)Q->ne[2];
const int nSeq = (int)Q->ne[3];
const int packedBytes = (D * bits + 7) / 8;
const int v_packedBytes = v_packed ? ((D * v_bits + 7) / 8) : 0;
const int nKVHeads = v_packed ? ((int)V->ne[0] / v_packedBytes) : (int)V->ne[2];
int nCells;
if (v_packed) {
GGML_ASSERT(mask != nullptr);
nCells = (int)mask->ne[0];
} else {
nCells = (int)V->ne[1];
}
const int ncols = (nTokensQ == 1) ? 1 : 2;
const int ntiles_x = (nTokensQ + ncols - 1) / ncols;
const int hasMask = mask ? 1 : 0;
const int ne31 = mask ? (int)mask->ne[0] : 0;
ggml_metal_kargs_tq_fattn_vec args = {
/*.ncols =*/ ncols,
/*.nTokensQ =*/ nTokensQ,
/*.nHeadsQ =*/ nHeadsQ,
/*.nSeq =*/ nSeq,
/*.nCells =*/ nCells,
/*.nKVHeads =*/ nKVHeads,
/*.bits =*/ bits,
/*.firstCell =*/ firstCell,
/*.packedBytes =*/ packedBytes,
/*.v_bits =*/ v_bits,
/*.v_packedBytes=*/ v_packedBytes,
/*.hasMask =*/ hasMask,
/*.ne31 =*/ ne31,
/*.scale =*/ scale,
/*.logit_softcap=*/ logit_softcap,
/*.nb01 =*/ Q->nb[1],
/*.nb02 =*/ Q->nb[2],
/*.nb03 =*/ Q->nb[3],
/*.nb21 =*/ V->nb[1],
/*.nb22 =*/ V->nb[2],
/*.nb23 =*/ V->nb[3],
/*.nb31 =*/ mask ? mask->nb[1] : 0,
};
// Select D=128 vs D=256 pipeline. Gemma3 runs at headDim=256; everything
// else supported so far is D=128.
GGML_ASSERT(D == 128 || D == 256);
auto pipeline = v_packed
? (D == 256
? ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(lib)
: ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib))
: (D == 256
? ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256(lib)
: ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib));
ggml_metal_buffer_id bid_mask = hasMask ? ggml_metal_get_buffer_id(mask) : ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_v_scales = v_packed ? ggml_metal_get_buffer_id(op->src[6]) : ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_v_codebook = v_packed ? ggml_metal_get_buffer_id(op->src[7]) : ggml_metal_get_buffer_id(op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(Q), 1); // Q
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(K_p), 2); // K packed
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(V), 3); // V (f16 or packed i8)
ggml_metal_encoder_set_buffer (enc, bid_mask, 4); // mask
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // K scales
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6); // K codebook
ggml_metal_encoder_set_buffer (enc, bid_v_scales, 7); // V scales (or dummy)
ggml_metal_encoder_set_buffer (enc, bid_v_codebook, 8); // V codebook (or dummy)
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 9); // dst
// Block: (32, 4, 1) = 128 threads; Grid: (ntiles_x, 1, nHeadsQ*nSeq)
ggml_metal_encoder_dispatch_threadgroups(enc, ntiles_x, 1, nHeadsQ * nSeq, 32, 4, 1);
return 1;
}

View file

@ -89,6 +89,14 @@ int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
// TurboQuant ops
int ggml_metal_op_tq_encode (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tq_encode_v (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tq_encode_kv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tq_dequant (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tq_dequant_kv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tq_flash_attn_ext (ggml_metal_op_t ctx, int idx);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load diff

View file

@ -1048,9 +1048,15 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_SGD",
"GLU",
"TQ_ENCODE",
"TQ_DEQUANT",
"TQ_DEQUANT_KV",
"TQ_FLASH_ATTN_EXT",
"TQ_ENCODE_V",
"TQ_ENCODE_KV",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -1157,9 +1163,14 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"sgd(x)",
"glu(x)",
"tq_encode(k,rot,idx)->packed",
"tq_dequant(packed,scales)->f16",
"tq_flash_attn_ext(q,k_packed,v)->f32",
"tq_encode_v(v)->packed",
"tq_encode_kv(k,v)->packed",
};
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -7603,3 +7614,227 @@ bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, cons
if (p0->strict_cpu != p1->strict_cpu ) return false;
return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
}
struct ggml_tensor * ggml_tq_encode(
struct ggml_context * ctx,
struct ggml_tensor * packed,
struct ggml_tensor * scales,
struct ggml_tensor * k,
struct ggml_tensor * rotation,
int32_t firstCell,
struct ggml_tensor * boundaries,
int32_t bits) {
struct ggml_tensor * result = ggml_view_tensor(ctx, packed);
result->op = GGML_OP_TQ_ENCODE;
result->src[0] = k;
result->src[1] = rotation;
result->src[2] = NULL; // cell_idx removed; kernel computes firstCell + batch
result->src[3] = scales;
result->src[4] = boundaries;
ggml_set_op_params_i32(result, 0, bits);
ggml_set_op_params_i32(result, 1, firstCell);
ggml_set_op_params_i32(result, 2, 0); // outlier_bits (0 = uniform)
ggml_set_op_params_i32(result, 3, 0); // outlier_count (0 = uniform)
return result;
}
struct ggml_tensor * ggml_tq_dequant(
struct ggml_context * ctx,
struct ggml_tensor * encode_result,
struct ggml_tensor * scales,
struct ggml_tensor * codebook,
int headDim, int numKVHeads, int nCells, int firstCell, int bits) {
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F16,
headDim, numKVHeads, nCells);
result->op = GGML_OP_TQ_DEQUANT;
result->src[0] = encode_result;
result->src[1] = scales;
result->src[2] = codebook;
ggml_set_op_params_i32(result, 0, (int32_t)bits);
ggml_set_op_params_i32(result, 1, (int32_t)firstCell);
ggml_set_op_params_i32(result, 2, 0); // outlier_bits (0 = uniform)
ggml_set_op_params_i32(result, 3, 0); // outlier_count (0 = uniform)
return result;
}
// ggml_tq_encode_outlier: extends ggml_tq_encode with an outlier sub-block.
// Same op (GGML_OP_TQ_ENCODE) with outlier_count > 0 in op_params[3]; the
// CUDA backend dispatches to the outlier-aware kernel when it sees a non-zero
// outlier_count in op_params. The regular packed buffer is the dst (view);
// the outlier packed, scales, and indices are written as side effects via
// src[5..8], same pattern as the regular scales src[3].
struct ggml_tensor * ggml_tq_encode_outlier(
struct ggml_context * ctx,
struct ggml_tensor * packed,
struct ggml_tensor * scales,
struct ggml_tensor * k,
struct ggml_tensor * rotation,
int32_t firstCell,
struct ggml_tensor * boundaries,
int32_t bits,
struct ggml_tensor * outlier_packed,
struct ggml_tensor * outlier_scales,
struct ggml_tensor * outlier_indices,
struct ggml_tensor * outlier_boundaries,
int32_t outlier_bits,
int32_t outlier_count) {
struct ggml_tensor * result = ggml_view_tensor(ctx, packed);
result->op = GGML_OP_TQ_ENCODE;
result->src[0] = k;
result->src[1] = rotation;
result->src[2] = NULL;
result->src[3] = scales;
result->src[4] = boundaries;
result->src[5] = outlier_packed;
result->src[6] = outlier_scales;
result->src[7] = outlier_indices;
result->src[8] = outlier_boundaries;
ggml_set_op_params_i32(result, 0, bits);
ggml_set_op_params_i32(result, 1, firstCell);
ggml_set_op_params_i32(result, 2, outlier_bits);
ggml_set_op_params_i32(result, 3, outlier_count);
return result;
}
// ggml_tq_dequant_outlier: extends ggml_tq_dequant with an outlier overwrite
// pass. Reconstructs [headDim, numKVHeads, nCells] f16 by decoding the
// regular sub-block for all 128 positions, then overwriting the outlier
// channel positions from the outlier sub-block.
struct ggml_tensor * ggml_tq_dequant_outlier(
struct ggml_context * ctx,
struct ggml_tensor * encode_result,
struct ggml_tensor * scales,
struct ggml_tensor * codebook,
int headDim, int numKVHeads, int nCells, int firstCell, int bits,
struct ggml_tensor * outlier_packed,
struct ggml_tensor * outlier_scales,
struct ggml_tensor * outlier_indices,
struct ggml_tensor * outlier_codebook,
int32_t outlier_bits,
int32_t outlier_count) {
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F16,
headDim, numKVHeads, nCells);
result->op = GGML_OP_TQ_DEQUANT;
result->src[0] = encode_result;
result->src[1] = scales;
result->src[2] = codebook;
result->src[3] = outlier_packed;
result->src[4] = outlier_scales;
result->src[5] = outlier_indices;
result->src[6] = outlier_codebook;
ggml_set_op_params_i32(result, 0, (int32_t)bits);
ggml_set_op_params_i32(result, 1, (int32_t)firstCell);
ggml_set_op_params_i32(result, 2, outlier_bits);
ggml_set_op_params_i32(result, 3, outlier_count);
return result;
}
struct ggml_tensor * ggml_tq_dequant_kv(
struct ggml_context * ctx,
struct ggml_tensor * k_encode_result,
struct ggml_tensor * k_scales,
struct ggml_tensor * k_codebook,
struct ggml_tensor * v_encode_result,
struct ggml_tensor * v_scales,
struct ggml_tensor * v_codebook,
struct ggml_tensor * v_rotation,
int headDim, int numKVHeads, int nCells, int firstCell,
int k_bits, int v_bits) {
// Output: [headDim, numKVHeads, nCells, 2] f16 — last dim separates K (0) and V (1).
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F16,
headDim, numKVHeads, nCells, 2);
result->op = GGML_OP_TQ_DEQUANT_KV;
result->src[0] = k_encode_result;
result->src[1] = k_scales;
result->src[2] = k_codebook;
result->src[3] = v_encode_result;
result->src[4] = v_scales;
result->src[5] = v_codebook;
result->src[6] = v_rotation; // NULL = no rotation fusion
ggml_set_op_params_i32(result, 0, (int32_t)k_bits);
ggml_set_op_params_i32(result, 1, (int32_t)v_bits);
ggml_set_op_params_i32(result, 2, (int32_t)firstCell);
return result;
}
struct ggml_tensor * ggml_tq_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k_packed,
struct ggml_tensor * v,
struct ggml_tensor * mask,
struct ggml_tensor * scales,
struct ggml_tensor * codebook,
float scale, float logit_softcap,
int32_t bits, int32_t firstCell,
struct ggml_tensor * v_scales,
struct ggml_tensor * v_codebook,
int32_t v_bits) {
// Output: [D, nHeadsQ, nTokensQ, nSeq] f32 — same shape as standard flash_attn_ext.
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32,
q->ne[0], q->ne[2], q->ne[1], q->ne[3]);
result->op = GGML_OP_TQ_FLASH_ATTN_EXT;
result->src[0] = q;
result->src[1] = k_packed;
result->src[2] = v;
result->src[3] = mask;
result->src[4] = scales;
result->src[5] = codebook;
result->src[6] = v_scales; // NULL → V is f16 (K-only fused); non-NULL → V is packed
result->src[7] = v_codebook; // NULL when V is f16
ggml_set_op_params_f32(result, 0, scale);
ggml_set_op_params_f32(result, 1, logit_softcap);
ggml_set_op_params_i32(result, 2, bits);
ggml_set_op_params_i32(result, 3, firstCell);
ggml_set_op_params_i32(result, 4, v_bits);
return result;
}
struct ggml_tensor * ggml_tq_encode_v(
struct ggml_context * ctx,
struct ggml_tensor * packed,
struct ggml_tensor * scales,
struct ggml_tensor * v,
struct ggml_tensor * rotation,
int32_t firstCell,
struct ggml_tensor * boundaries,
int32_t bits) {
struct ggml_tensor * result = ggml_view_tensor(ctx, packed);
result->op = GGML_OP_TQ_ENCODE_V;
result->src[0] = v;
result->src[1] = rotation; // NULL = no rotation, non-NULL = R^T matrix
result->src[2] = NULL;
result->src[3] = scales;
result->src[4] = boundaries;
ggml_set_op_params_i32(result, 0, bits);
ggml_set_op_params_i32(result, 1, firstCell);
return result;
}
struct ggml_tensor * ggml_tq_encode_kv(
struct ggml_context * ctx,
struct ggml_tensor * k_packed,
struct ggml_tensor * k_scales,
struct ggml_tensor * k,
struct ggml_tensor * rotation,
struct ggml_tensor * k_boundaries,
struct ggml_tensor * v_packed,
struct ggml_tensor * v_scales,
struct ggml_tensor * v,
struct ggml_tensor * v_boundaries,
int32_t firstCell, int32_t k_bits, int32_t v_bits) {
struct ggml_tensor * result = ggml_view_tensor(ctx, k_packed);
result->op = GGML_OP_TQ_ENCODE_KV;
result->src[0] = k;
result->src[1] = rotation;
result->src[2] = v;
result->src[3] = k_scales;
result->src[4] = k_boundaries;
result->src[5] = v_packed;
result->src[6] = v_scales;
result->src[7] = v_boundaries;
ggml_set_op_params_i32(result, 0, k_bits);
ggml_set_op_params_i32(result, 1, v_bits);
ggml_set_op_params_i32(result, 2, firstCell);
return result;
}

View file

@ -0,0 +1,70 @@
package ggml
import "fmt"
// TurboQuant compute-capability thresholds.
//
// NVIDIA path: Pascal (cc 6.0) is the floor because the TQ codebook lookup
// relies on __shfl_sync, introduced on Kepler and universally stable from
// Pascal onwards; earlier archs (Maxwell and below) trip a compile-time
// assert inside tq-dequant.cu.
//
// AMD path: ggml-cuda encodes the gfx arch into props.compute_major as
// (cc - OFFSET_AMD) / 0x100 (see ml/backend/ggml/ggml/src/ggml-cuda/
// ggml-cuda.cu), so Vega/GCN/CDNA land at major=9 (0x9XX) and RDNA1+ lands
// at major>=16 (0x1010 and up). Every RDNA generation is wave32; every
// pre-RDNA AMD generation is wave64. That single boundary drives the gate:
// TQ's 32-lane __shfl_sync becomes __shfl(_, _, 32) under the HIP shim
// (vendors/hip.h), and on a 64-lane warp that sub-partitions into two
// independent 32-lane groups — lanes 32-63 never receive codebook data
// from the CUDA-tuned kernel and return garbage.
const (
tqMinNvidiaComputeMajor = 6 // Pascal
tqMinAmdComputeMajor = 16 // RDNA1 gfx1010
)
// tqDeviceAccepted returns whether TurboQuant kernels can safely run on a
// device identified by its backend library name and compute-capability
// major. The check is intentionally narrow: only wave32 GPUs are admitted.
// The TQ codebook lookup issues __shfl_sync(mask, val, lane, 32), which the
// HIP vendor shim at ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h rewrites
// as __shfl(val, lane, 32). Width is preserved, but on a wave64 warp the
// width=32 partitions the 64 physical lanes into two independent 32-lane
// sub-groups — and the CUDA-tuned kernel only seeds codebook data in the
// first 32, so lanes 32-63 shuffle against uninitialized values and return
// garbage. Rejecting all wave64 AMD (Vega/GCN/CDNA) cleanly sidesteps that.
//
// When rejected, the returned skipReason is non-empty and names the
// limitation in operator-visible terms so the accompanying slog.Warn can
// be read by somebody who doesn't have the ggml source tree open.
func tqDeviceAccepted(library string, ccMajor int) (accepted bool, skipReason string) {
switch library {
case "CUDA":
if ccMajor >= tqMinNvidiaComputeMajor {
return true, ""
}
return false, fmt.Sprintf(
"TurboQuant requires NVIDIA Pascal (cc 6.0+) or AMD RDNA1+ (gfx1010+, wave32); got CUDA cc major=%d",
ccMajor,
)
case "ROCm":
if ccMajor >= tqMinAmdComputeMajor {
return true, ""
}
return false, fmt.Sprintf(
"TurboQuant on ROCm requires RDNA1+ (gfx1010+); wave64 AMD GPUs (Vega/GCN/CDNA) are not supported "+
"because TQ's 32-lane __shfl_sync sub-partitions the 64-lane warp and returns garbage from "+
"lanes 32-63 (gfx major=%d)",
ccMajor,
)
case "Metal":
// Apple Silicon always has 32-wide SIMD groups — same as CUDA warp width.
// TQ's __shfl_sync(mask, val, lane, 32) maps to simd_shuffle(val, lane) 1:1.
return true, ""
default:
return false, fmt.Sprintf(
"TurboQuant requires a CUDA, ROCm, or Metal backend library; got library=%q",
library,
)
}
}

View file

@ -0,0 +1,88 @@
package ggml
import (
"strings"
"testing"
)
// TestTQDeviceAccepted pins the library-aware compute-capability gate that
// decides whether a given GPU can run TurboQuant kernels. The TQ codebook
// lookup uses __shfl_sync with a 32-lane mask; the HIP vendor shim at
// ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h silently lowers that to
// __shfl with width=warpSize, which reads unused lanes (and thus returns
// garbage) on wave64 GPUs. The gate's job is to keep wave64 AMD devices
// (Vega/GCN/CDNA) out of the TQ path while still accepting all wave32
// RDNA (gfx1010+) and unchanged NVIDIA Pascal+.
//
// This test runs in pure Go (no cgo, no GPU) so it executes in CI on every
// platform. It's the primary regression gate for the classifier rule; the
// scanTQDevices cgo plumbing and the slog.Warn copy rewrites are covered by
// code review and runtime verification on real hardware.
func TestTQDeviceAccepted(t *testing.T) {
cases := []struct {
name string
library string
ccMajor int
wantAccept bool
wantReasonHas string // substring that must appear in skipReason when rejected
}{
// NVIDIA path: gate unchanged from the original TurboQuant PR (Pascal+).
{"nvidia_pascal_p40", "CUDA", 6, true, ""},
{"nvidia_turing", "CUDA", 7, true, ""},
{"nvidia_ampere", "CUDA", 8, true, ""},
{"nvidia_hopper", "CUDA", 9, true, ""},
{"nvidia_maxwell", "CUDA", 5, false, "CUDA"},
{"nvidia_kepler", "CUDA", 3, false, "CUDA"},
{"nvidia_bogus_zero", "CUDA", 0, false, "CUDA"},
// AMD wave64 — must be rejected. props.compute_major = (cc - OFFSET_AMD) / 0x100,
// so Vega/GCN/CDNA all land at major=9.
{"amd_vega_gfx900", "ROCm", 9, false, "ROCm"},
{"amd_vega20_gfx906", "ROCm", 9, false, "ROCm"},
{"amd_cdna1_mi100", "ROCm", 9, false, "wave64"},
{"amd_cdna2_mi210", "ROCm", 9, false, "Vega"},
{"amd_cdna3_mi300", "ROCm", 9, false, "CDNA"},
{"amd_gcn4_polaris", "ROCm", 8, false, "ROCm"},
// AMD wave32 RDNA — must be accepted. RDNA1 gfx1010 is the minimum.
{"amd_rdna1_gfx1010", "ROCm", 16, true, ""},
{"amd_rdna2_gfx1030", "ROCm", 16, true, ""},
{"amd_rdna3_gfx1100", "ROCm", 17, true, ""},
{"amd_rdna3_5_gfx1150", "ROCm", 17, true, ""},
{"amd_rdna4_gfx1200", "ROCm", 18, true, ""},
// Metal is accepted — Apple Silicon SIMD groups are 32-wide, matching
// the TQ kernels' __shfl_sync(mask, val, lane, 32) width.
{"metal", "Metal", 7, true, ""},
// Non-CUDA/ROCm/Metal backends — reject with an informative reason.
{"vulkan", "Vulkan", 7, false, "Vulkan"},
{"sycl", "SYCL", 7, false, "SYCL"},
{"empty_library", "", 7, false, "library"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
gotAccept, gotReason := tqDeviceAccepted(tc.library, tc.ccMajor)
if gotAccept != tc.wantAccept {
t.Fatalf("tqDeviceAccepted(%q, %d) accept = %v, want %v (reason=%q)",
tc.library, tc.ccMajor, gotAccept, tc.wantAccept, gotReason)
}
if tc.wantAccept {
if gotReason != "" {
t.Errorf("tqDeviceAccepted(%q, %d) accepted but reason=%q; expected empty",
tc.library, tc.ccMajor, gotReason)
}
return
}
if gotReason == "" {
t.Fatalf("tqDeviceAccepted(%q, %d) rejected with empty reason; operators need a diagnosable message",
tc.library, tc.ccMajor)
}
if tc.wantReasonHas != "" && !strings.Contains(gotReason, tc.wantReasonHas) {
t.Errorf("tqDeviceAccepted(%q, %d) reason = %q, want substring %q",
tc.library, tc.ccMajor, gotReason, tc.wantReasonHas)
}
})
}
}

View file

@ -0,0 +1,184 @@
package ggml
import (
"math"
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/turboquant"
)
// TestOutlierEncodeDequantGPUCPUEquivalence runs tq_encode_kernel_outlier
// + tq_dequant_multihead_kernel_outlier on the GPU for a synthetic K
// batch and compares the decoded output against the CPU reference
// (EncodeKeyPerHeadOutlier + DequantKeyPerHeadOutlier in
// turboquant/encode.go). Catches kernel algorithmic drift and tests the
// exact multi-KV-head layout that broke llama3.2:3b / qwen2.5:7b under
// the __shfl_sync divergence bug.
//
// CURRENT LIMITATION: this test skips in CI / on plain `go test` runs
// because setup()'s synthetic GGUF has no tensors, so no GPU buffer
// types end up in b.schedBufts, so scanTQDevices() finds no TQ-capable
// GPU even on a machine with a P40. The test as written is correct and
// runnable under a test harness that loads a real model-backed backend
// (e.g. the tq_outlier_encode_test.go:setup function could be replaced
// with a helper that loads a tiny real .gguf with GPU-assigned layers).
// Until that harness lands, the CPU reference tests in
// turboquant/encode_test.go (TestOutlierSplitVsUniformHeavyTailed,
// TestOutlierPerHeadRoundTrip) cover algorithmic correctness and the
// full tqbench matrix covers real-model runtime verification.
//
// The test is kept in place because:
// 1. It's the correct scaffolding for a future GPU unit-test harness.
// 2. It documents the exact CPU↔GPU equivalence contract the kernels
// must satisfy.
// 3. Once schedBufts is populated (either by a better harness or by
// a future ggml change), the test becomes a regression gate
// automatically — no further wiring needed.
func TestOutlierEncodeDequantGPUCPUEquivalence(t *testing.T) {
cases := []struct {
name string
headDim int
numKVHeads int
bits int
outlierBits int
outlierCount int
preset turboquant.Preset
}{
{"d128_h8_tq3k", 128, 8, 3, 4, 32, turboquant.PresetTQ3K},
{"d128_h4_tq3k", 128, 4, 3, 4, 32, turboquant.PresetTQ3K},
{"d128_h1_tq3k", 128, 1, 3, 4, 32, turboquant.PresetTQ3K},
{"d256_h1_tq3k", 256, 1, 3, 4, 32, turboquant.PresetTQ3K},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx := setup(t)
b := ctx.(*Context).b
mgrAny := b.NewTQCompressedKManager(
tc.headDim, tc.numKVHeads, tc.bits,
tc.preset.RotationSeed,
0, // vBits (K-only)
tc.outlierBits, tc.outlierCount,
)
if mgrAny == nil {
t.Skip("no TQ-capable GPU available (need NVIDIA Pascal cc 6.0+ or AMD RDNA1+)")
}
mgr, ok := mgrAny.(*ggmlTQCompressedK)
if !ok {
t.Fatalf("unexpected TQ manager type %T", mgrAny)
}
const nCells = 4
capacity := nCells
mgr.EnsureLayer(0, capacity)
// Build a deterministic synthetic K batch: Gaussian noise
// via the same splitmix64 as the CPU tests.
kData := make([]float32, tc.headDim*tc.numKVHeads*nCells)
var rngState uint64 = 0xface_feed_cafe_0000 | uint64(tc.headDim)
rng := &rngState
for i := range kData {
kData[i] = float32(testGaussian(rng))
}
// GPU path: create a K tensor, run EncodeK + DequantK, read
// back the dequanted output.
kTensor := ctx.FromFloats(kData, tc.headDim, tc.numKVHeads, nCells)
encodeResult := mgr.EncodeK(ctx, 0, kTensor, 0)
if encodeResult == nil {
t.Fatalf("EncodeK returned nil")
}
ctx.Forward(encodeResult)
dequant := mgr.DequantK(ctx, 0, encodeResult, 0, nCells)
if dequant == nil {
t.Fatalf("DequantK returned nil")
}
ctx.Forward(dequant).Compute(dequant)
gpuOut := dequant.Floats()
if len(gpuOut) != tc.headDim*tc.numKVHeads*nCells {
t.Fatalf("gpu output len = %d, want %d",
len(gpuOut), tc.headDim*tc.numKVHeads*nCells)
}
// CPU reference: encode + dequant each (cell, head) slab
// with the same preset, compute the expected rotated-space
// output, and compare elementwise.
//
// Note: the CPU reference uses float64 blockScale while the
// GPU kernel uses float32 reductions, so scales differ by
// round-off at the 1e-6 level. The tolerance accommodates
// that plus quantizer boundary ambiguity (when a rotated
// value sits exactly on a Lloyd-Max boundary, float32 vs
// float64 accumulation can pick adjacent codebook slots).
const tol float32 = 5e-2 // loose; tight enough to catch algo bugs
var maxErr float32
var mismatches int
for c := range nCells {
for h := range tc.numKVHeads {
slab := make([]float32, tc.headDim)
for d := range tc.headDim {
slab[d] = kData[(c*tc.numKVHeads+h)*tc.headDim+d]
}
cpuEnc, err := turboquant.EncodeKeyPerHeadOutlier(slab, tc.preset)
if err != nil {
t.Fatalf("cpu encode: %v", err)
}
cpuOut := turboquant.DequantKeyPerHeadOutlier(cpuEnc, tc.preset, tc.headDim)
for d := range tc.headDim {
gpuVal := gpuOut[(c*tc.numKVHeads+h)*tc.headDim+d]
cpuVal := cpuOut[d]
diff := gpuVal - cpuVal
if diff < 0 {
diff = -diff
}
if diff > maxErr {
maxErr = diff
}
if diff > tol {
mismatches++
if mismatches <= 5 {
t.Logf("mismatch c=%d h=%d d=%d gpu=%f cpu=%f diff=%f",
c, h, d, gpuVal, cpuVal, diff)
}
}
}
}
}
t.Logf("%s: max elementwise err = %f (%d mismatches > %.3f)", tc.name, maxErr, mismatches, tol)
if mismatches > 0 {
t.Fatalf("%s: %d elements differ beyond tolerance %.3f", tc.name, mismatches, tol)
}
})
}
}
// testGaussian is a local Box-Muller generator that doesn't depend on
// the unexported helpers in the turboquant package. Uses a tiny
// splitmix64 inlined to avoid adding a test dep.
func testGaussian(state *uint64) float64 {
u1 := testUniform(state)
u2 := testUniform(state)
// Guard against log(0).
if u1 < 1e-12 {
u1 = 1e-12
}
return math.Sqrt(-2*math.Log(u1)) * math.Cos(2*math.Pi*u2)
}
func testUniform(state *uint64) float64 {
*state += 0x9e3779b97f4a7c15
z := *state
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9
z = (z ^ (z >> 27)) * 0x94d049bb133111eb
z = z ^ (z >> 31)
return float64(z>>11) / float64(1<<53)
}
// Compile-time check that ml.Tensor is what we expect.
var _ ml.Tensor = (*Tensor)(nil)

View file

@ -0,0 +1,574 @@
package ggml
import (
"log/slog"
"sync"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/turboquant"
)
// ggmlTQCompressedK implements ml.TQCompressedKManager using ggml tensors and
// GGML_OP_TQ_ENCODE / GGML_OP_TQ_DEQUANT graph ops. All buffers live in GPU
// memory; no CPU round-trips occur during the forward pass.
type ggmlTQCompressedK struct {
backend *Backend
headDim int
numKVHeads int
bits int
// Outlier-split config (post-rotation top-K channel split). When
// outlierCount > 0, EnsureLayer allocates additional tensors for an
// outlier sub-block encoded at outlierBits, and the encode/dequant
// kernels follow the outlier-aware path. When 0, uses pure uniform
// per-channel Lloyd-Max at `bits`.
outlierBits int
outlierCount int
mu sync.Mutex
// Per-layer ggml tensors, allocated lazily via EnsureLayer.
layerCtxs map[int]ml.Context
packedTensors map[int]*Tensor // regular sub-block: [regularPackedBytes*numKVHeads, capacity] i8
scalesTensors map[int]*Tensor // regular scales: [numKVHeads, capacity] f32
// Outlier sub-block per-layer tensors (populated only when outlierCount > 0).
outlierPackedTensors map[int]*Tensor // [outlierPackedBytes*numKVHeads, capacity] i8
outlierScalesTensors map[int]*Tensor // [numKVHeads, capacity] f32
outlierIndicesTensors map[int]*Tensor // [outlierCount*numKVHeads, capacity] i8 (channel idx)
// Rotation matrix R^T, shared across layers: [headDim, headDim] f32.
rotCtx ml.Context
rotTensor *Tensor // stores R^T row-major (used for K encode and Q rotate)
// Rotation matrix R (transpose of R^T), for undoing V rotation in SDPA.
// mul_mat(rotInverseTensor, x) = R @ x (recovers original from R^T @ x).
rotInverseTensor *Tensor // stores R row-major
// Codebook and boundaries tensors, shared across layers.
sharedCtx ml.Context
codebookTensor *Tensor // regular: [1<<bits] f32
boundariesTensor *Tensor // regular: [(1<<bits)-1] f32
// Outlier codebook and boundaries (populated only when outlierCount > 0).
outlierCodebookTensor *Tensor // [1<<outlierBits] f32
outlierBoundariesTensor *Tensor // [(1<<outlierBits)-1] f32
vBits int
// Per-layer V tensors, allocated lazily via EnsureVLayer.
vLayerCtxs map[int]ml.Context
vPackedTensors map[int]*Tensor // [packedBytes*numKVHeads, capacity] i8
vScalesTensors map[int]*Tensor // [numKVHeads, capacity] f32
// V codebook and boundaries (same bit width as K for tq2/tq3).
vCodebookTensor *Tensor // [1<<vBits] f32
vBoundariesTensor *Tensor // [(1<<vBits)-1] f32
// preferFusedAttention is true on Metal. The DequantKV → stock FA path
// writes a full f16 intermediate buffer before attention, doubling KV
// bandwidth vs reading packed data directly. On Metal at long context the
// fused kernel (kernel_tq_fattn_vec_packed) is dramatically faster because
// it reads packed K+V once and never materialises the f16 intermediate.
// On CUDA, DequantKV + stock FA is faster because cuDNN/cuBLAS flash
// attention is highly tuned and the intermediate buffer stays in L2.
preferFusedAttention bool
}
// hasOutliers reports whether outlier-split is active for this manager.
func (m *ggmlTQCompressedK) hasOutliers() bool {
return m.outlierCount > 0 && m.outlierBits > 0 && m.outlierCount < m.headDim
}
// PreferFusedAttention reports whether the fused flash-attention path
// (packed K+V decoded inline) should be tried before DequantKV + stock FA.
// True on Metal: the DequantKV path writes a full f16 intermediate buffer that
// doubles KV bandwidth at long context. False on CUDA/ROCm where DequantKV +
// stock FA is faster due to large L2 caches and highly-tuned flash attention.
func (m *ggmlTQCompressedK) PreferFusedAttention() bool {
return m.preferFusedAttention
}
// regularChannelCount is the number of non-outlier channels per head.
func (m *ggmlTQCompressedK) regularChannelCount() int {
if m.hasOutliers() {
return m.headDim - m.outlierCount
}
return m.headDim
}
// regularPackedBytes is the padded per-head byte count for the regular
// sub-block. The encode kernel uses atomicOr on 4-byte words to pack bits;
// for that to stay aligned, each head's region must start on a 4-byte
// boundary. Round the raw bit-count up to the next multiple of 4 so the
// per-head stride is naturally aligned. The padding bytes are never read
// during decode, and are zeroed by the encode kernel's init loop.
func (m *ggmlTQCompressedK) regularPackedBytes() int {
raw := (m.regularChannelCount()*m.bits + 7) / 8
return (raw + 3) &^ 3
}
// outlierPackedBytes is the padded per-head byte count for the outlier
// sub-block. Same 4-byte alignment as regularPackedBytes() for the same
// reason: atomicOr-on-word in the encode kernel.
func (m *ggmlTQCompressedK) outlierPackedBytes() int {
if !m.hasOutliers() {
return 0
}
raw := (m.outlierCount*m.outlierBits + 7) / 8
return (raw + 3) &^ 3
}
func (b *Backend) NewTQCompressedKManager(headDim, numKVHeads, bits int, rotationSeed uint64, vBits, outlierBits, outlierCount int) ml.TQCompressedKManager {
// TurboQuant ops run on CUDA (NVIDIA Pascal+), ROCm/HIP (AMD RDNA1+,
// gfx1010+), or Metal (Apple Silicon). The gate is wave32: the kernels
// hard-code a 32-lane shuffle for codebook lookup. On wave64 AMD (Vega/
// GCN/CDNA) the HIP shim's __shfl(…, 32) sub-partitions the 64-lane warp
// and the upper 32 lanes return garbage — those are rejected. Metal SIMD
// groups are always 32-wide on Apple Silicon, so Metal is unconditionally
// admitted. Scan the scheduler buffer types, pick the first TQ-capable
// GPU, and warn clearly if there's no suitable device.
scan := b.scanTQDevices()
if !scan.selectedOK {
if len(scan.Skipped) > 0 {
slog.Warn("turboquant: no TQ-capable GPU found; falling back to f16 KV cache. "+
"TurboQuant requires NVIDIA Pascal (cc 6.0+), AMD RDNA1+ (gfx1010+, wave32), or Apple Silicon (Metal).",
"skipped_gpus", scan.Skipped)
} else {
slog.Warn("turboquant: no GPU backend available, falling back to f16 KV cache")
}
return nil
}
if len(scan.Skipped) > 0 {
slog.Warn("turboquant: skipping unsupported GPU(s); TQ tensors will be placed on the "+
"first wave32 device (NVIDIA Pascal+, AMD RDNA1+, or Apple Silicon). To silence "+
"this warning, hide the unsupported cards with CUDA_VISIBLE_DEVICES / HIP_VISIBLE_DEVICES.",
"selected", scan.SelectedName+" (cc "+scan.SelectedCC+")",
"skipped", scan.Skipped)
}
if len(scan.Accepted) > 1 {
slog.Warn("turboquant: multi-GPU detected; TQ compressed buffers live on the "+
"primary GPU only. Layers scheduled to other GPUs will incur per-step "+
"cross-GPU transfers. On SWA models like gemma3/gemma4 this is per "+
"TQ-wrapped global sub-cache — the SWA sub-cache stays on its native "+
"GPU and is unaffected. To avoid: set num_gpu so the model fits on one "+
"GPU, or use the tq2k/tq3k (K-only) presets which let V stay on its "+
"native GPU.",
"selected", scan.SelectedName+" (cc "+scan.SelectedCC+")",
"tq_capable_gpus", scan.Accepted)
}
// Codebook and boundaries (same for all layers). Use headDim for the
// codebook dim parameter regardless of whether outlier split is active:
// the CPU path does the same (scalarCodebook(dim=headDim, bits)), and
// after per-sub-block RMS normalization the input distribution is
// approximately unit-variance Gaussian either way. Using sub-block dims
// here produced slightly different boundaries that caused observable
// quality regressions on multi-head configurations.
codebook := turboquant.ExportCodebook(headDim, bits)
boundaries := turboquant.ExportBoundaries(headDim, bits)
// All shared tensors (codebook, rotation) must be GPU-resident: TQ ops are
// CUDA-only. Using newTQContext() ensures GPU buffer type is used regardless
// of which model layers are on CPU vs GPU.
// Size shared context for up to 6 tensors (codebook+bounds ×2 regular, ×2
// outlier, ×2 V); newTQContext takes a hint count.
sharedCtx := b.newTQContext(8)
codebookT := sharedCtx.FromFloats(codebook, len(codebook)).(*Tensor)
boundariesT := sharedCtx.FromFloats(boundaries, len(boundaries)).(*Tensor)
// Outlier codebook/boundaries at a different bit width. Use headDim as
// the codebook dim (same reason as regular codebook above).
var outlierCodebookT, outlierBoundariesT *Tensor
if outlierCount > 0 && outlierBits > 0 && outlierCount < headDim {
oCodebook := turboquant.ExportCodebook(headDim, outlierBits)
oBoundaries := turboquant.ExportBoundaries(headDim, outlierBits)
outlierCodebookT = sharedCtx.FromFloats(oCodebook, len(oCodebook)).(*Tensor)
outlierBoundariesT = sharedCtx.FromFloats(oBoundaries, len(oBoundaries)).(*Tensor)
}
// V codebook and boundaries
vCodebook := turboquant.ExportCodebook(headDim, vBits)
vBoundaries := turboquant.ExportBoundaries(headDim, vBits)
vCodebookT := sharedCtx.FromFloats(vCodebook, len(vCodebook)).(*Tensor)
vBoundariesT := sharedCtx.FromFloats(vBoundaries, len(vBoundaries)).(*Tensor)
// Rotation matrix R^T: rotData[i*headDim+j] = R[j][i]. Built via
// Householder QR on a random Gaussian matrix per TurboQuant paper (arXiv
// 2504.19874) Algorithm 1.
rot := turboquant.BuildRotation(headDim, rotationSeed)
rotData := make([]float32, headDim*headDim)
for i := range headDim {
for j := range headDim {
rotData[i*headDim+j] = rot.Matrix[j*headDim+i]
}
}
rotInverseData := make([]float32, headDim*headDim)
copy(rotInverseData, rot.Matrix)
rotCtx := b.newTQContext(2)
rotTensor := rotCtx.FromFloats(rotData, headDim, headDim).(*Tensor)
rotInverseTensor := rotCtx.FromFloats(rotInverseData, headDim, headDim).(*Tensor)
m := &ggmlTQCompressedK{
backend: b,
headDim: headDim,
numKVHeads: numKVHeads,
bits: bits,
outlierBits: outlierBits,
outlierCount: outlierCount,
layerCtxs: make(map[int]ml.Context),
packedTensors: make(map[int]*Tensor),
scalesTensors: make(map[int]*Tensor),
outlierPackedTensors: make(map[int]*Tensor),
outlierScalesTensors: make(map[int]*Tensor),
outlierIndicesTensors: make(map[int]*Tensor),
rotCtx: rotCtx,
rotTensor: rotTensor,
rotInverseTensor: rotInverseTensor,
sharedCtx: sharedCtx,
codebookTensor: codebookT,
boundariesTensor: boundariesT,
outlierCodebookTensor: outlierCodebookT,
outlierBoundariesTensor: outlierBoundariesT,
vBits: vBits,
vLayerCtxs: make(map[int]ml.Context),
vPackedTensors: make(map[int]*Tensor),
vScalesTensors: make(map[int]*Tensor),
vCodebookTensor: vCodebookT,
vBoundariesTensor: vBoundariesT,
preferFusedAttention: scan.SelectedLibrary == "Metal",
}
if m.hasOutliers() {
slog.Info("turboquant: outlier split enabled",
"outlier_bits", outlierBits, "outlier_count", outlierCount,
"regular_bits", bits, "regular_channels", m.regularChannelCount(),
"effective_bits", float32(outlierCount*outlierBits+m.regularChannelCount()*bits)/float32(headDim))
}
return m
}
// EnsureLayer allocates per-layer packed and scales tensors on first use.
// When outlier split is active, also allocates the outlier sub-block tensors.
func (m *ggmlTQCompressedK) EnsureLayer(layer, capacity int) {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.packedTensors[layer]; ok {
return
}
packedBytes := m.regularPackedBytes()
// Size the per-layer context hint by how many tensors we're allocating.
// 2 (regular) + 3 (outlier) = 5 tensors when outlier split is on.
ctxHint := 2
if m.hasOutliers() {
ctxHint = 5
}
// TQ tensors must always be GPU-resident; use newTQContext, not Layer(layer),
// which would allocate CPU memory for layers assigned to CPU.
ctx := m.backend.newTQContext(ctxHint)
// Opt this layer's TQ persistent buffers into the scheduler's per-layer
// Cache accounting so the (packed K, scales, optional outlier sub-block)
// tensors below flow into btDeviceMemory.Cache[layer] via newTensor, not
// the anonymous Graph bucket. Without this, scanTQDevices' scheduler can't
// see TQ's real KV footprint and may mis-plan context-fit decisions.
ctx.layer = layer
// packed: interleaved as (cell*numKVHeads+head)*packedBytes — matches encode kernel layout.
packed := ctx.Zeros(ml.DTypeI8, packedBytes*m.numKVHeads, capacity).(*Tensor)
// scales: scales[cell*numKVHeads+head] — cell-major.
scales := ctx.Zeros(ml.DTypeF32, m.numKVHeads, capacity).(*Tensor)
m.layerCtxs[layer] = ctx
m.packedTensors[layer] = packed
m.scalesTensors[layer] = scales
if m.hasOutliers() {
oPackedBytes := m.outlierPackedBytes()
m.outlierPackedTensors[layer] = ctx.Zeros(ml.DTypeI8, oPackedBytes*m.numKVHeads, capacity).(*Tensor)
m.outlierScalesTensors[layer] = ctx.Zeros(ml.DTypeF32, m.numKVHeads, capacity).(*Tensor)
m.outlierIndicesTensors[layer] = ctx.Zeros(ml.DTypeI8, m.outlierCount*m.numKVHeads, capacity).(*Tensor)
}
}
// EncodeK creates a GGML_OP_TQ_ENCODE graph node.
// EnsureLayer must have been called for this layer before EncodeK;
// the forward pass is single-threaded per cache so the map reads below
// race only with concurrent EnsureLayer calls, which the contract forbids.
// firstCell is the index of the first cache slot being written
// (cells are sequential: firstCell+0, firstCell+1, ...).
// Returns a view of the packed buffer (use as encodeResult in DequantK).
func (m *ggmlTQCompressedK) EncodeK(ctx ml.Context, layer int, key ml.Tensor, firstCell int) ml.Tensor {
packed := m.packedTensors[layer]
if packed == nil {
return nil
}
scales := m.scalesTensors[layer]
if m.hasOutliers() {
if oPacked := m.outlierPackedTensors[layer]; oPacked != nil {
return packed.TQEncodeOutlier(ctx, scales, key, m.rotTensor, firstCell, m.boundariesTensor, m.bits,
oPacked, m.outlierScalesTensors[layer], m.outlierIndicesTensors[layer], m.outlierBoundariesTensor,
m.outlierBits, m.outlierCount)
}
}
return packed.TQEncode(ctx, scales, key, m.rotTensor, firstCell, m.boundariesTensor, m.bits)
}
// DequantK creates a GGML_OP_TQ_DEQUANT graph node. encodeResult is the
// view returned by EncodeK (establishes encode→dequant ordering). See
// EncodeK for the single-threaded access contract.
func (m *ggmlTQCompressedK) DequantK(ctx ml.Context, layer int, encodeResult ml.Tensor, firstCell, nCells int) ml.Tensor {
scales := m.scalesTensors[layer]
if scales == nil || encodeResult == nil || nCells <= 0 {
return nil
}
if m.hasOutliers() {
if oPacked := m.outlierPackedTensors[layer]; oPacked != nil {
return encodeResult.(*Tensor).TQDequantOutlier(ctx, scales, m.codebookTensor,
m.headDim, m.numKVHeads, nCells, firstCell, m.bits,
oPacked, m.outlierScalesTensors[layer], m.outlierIndicesTensors[layer], m.outlierCodebookTensor,
m.outlierBits, m.outlierCount)
}
}
return encodeResult.(*Tensor).TQDequant(ctx, scales, m.codebookTensor,
m.headDim, m.numKVHeads, nCells, firstCell, m.bits)
}
// fusedKernelSupports reports whether the fused TQ flash-attention kernel
// should be used.
//
// The fused path is the default for supported configurations (headDim=128,
// bits=2 or 3). It decodes packed K and V bits inline during flash attention
// as a fallback for configurations where DequantKV is unsupported. The
// inline-decode path is slower than DequantKV + stock FA on all measured
// hardware — DequantKV is always preferred when available.
func (m *ggmlTQCompressedK) fusedKernelSupports() bool {
// D=128 on all backends; D=256 only on Metal (kernel_tq_fattn_vec_*{,_d256}).
// CUDA still has only the D=128 kernel, so gemma3 (D=256) stays off the
// fused path on CUDA.
switch m.headDim {
case 128:
case 256:
if !m.preferFusedAttention {
return false
}
default:
return false
}
if m.bits != 2 && m.bits != 3 {
return false
}
// Outlier split changes the packed layout: the fused inline-decode FA
// kernel reads the packed buffer directly and doesn't know about outlier
// sub-blocks. Route to path 5 (separate dequant + stock FA) when outliers
// are active. Extending the fused kernel to handle outliers is deliberately
// NOT done — the fused inline-decode path is already documented as 17.6x
// slower than separate dequant + stock FA (feedback_cuda_kernel_optimization.md),
// so adding more ALU work (outlier scan / popcount / dual codebook shuffles)
// to that inner loop moves it further from the correct architecture.
if m.hasOutliers() {
return false
}
return true
}
// GetAsTQTensor wraps the packed K buffer for the given layer as a tqTensor
// so that ScaledDotProductAttention can dispatch to the fused kernel.
// Returns (nil, false) when the fused path is not supported.
func (m *ggmlTQCompressedK) GetAsTQTensor(ctx ml.Context, layer int, encodeResult ml.Tensor, firstCell, nCells int) (ml.Tensor, bool) {
if !m.fusedKernelSupports() {
return nil, false
}
scales := m.scalesTensors[layer]
if scales == nil || encodeResult == nil || nCells <= 0 {
return nil, false
}
return &tqTensor{
Tensor: encodeResult.(*Tensor),
scales: scales,
codebook: m.codebookTensor,
bits: m.bits,
headDim: m.headDim,
nKVHeads: m.numKVHeads,
nCells: nCells,
firstCell: firstCell,
}, true
}
// GetAsTQTensorKV wraps both packed K and packed V buffers as a tqTensor for
// the fully fused K+V TQ flash-attention path. Returns (nil, false) when
// fused is not supported or V compression is not yet active for this layer.
func (m *ggmlTQCompressedK) GetAsTQTensorKV(ctx ml.Context, layer int, kEncodeResult, vEncodeResult ml.Tensor, firstCell, nCells int) (ml.Tensor, bool) {
if !m.fusedKernelSupports() {
return nil, false
}
kScales := m.scalesTensors[layer]
vScales := m.vScalesTensors[layer]
if kScales == nil || kEncodeResult == nil || nCells <= 0 {
return nil, false
}
if vScales == nil || vEncodeResult == nil {
return nil, false
}
return &tqTensor{
Tensor: kEncodeResult.(*Tensor),
scales: kScales,
codebook: m.codebookTensor,
bits: m.bits,
headDim: m.headDim,
nKVHeads: m.numKVHeads,
nCells: nCells,
firstCell: firstCell,
vPacked: vEncodeResult.(*Tensor),
vScales: vScales,
vCodebook: m.vCodebookTensor,
vBits: m.vBits,
}, true
}
func (m *ggmlTQCompressedK) RotationMatrix(_ ml.Context, _ int) ml.Tensor {
return m.rotTensor
}
// RotationMatrixR returns R (not R^T) for use as the V rotation undo matrix.
// mul_mat(R, R^T @ v) = v (recovers original from rotated V).
func (m *ggmlTQCompressedK) RotationMatrixR() ml.Tensor {
return m.rotInverseTensor
}
// EnsureVLayer allocates per-layer V packed and scales tensors on first use.
func (m *ggmlTQCompressedK) EnsureVLayer(layer, capacity int) {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.vPackedTensors[layer]; ok {
return
}
// 4-byte alignment — matches regularPackedBytes() so the scheduler and the
// Go-side allocator agree on padded bytes per head. The encode/dequant
// kernels read the raw bits; padding is never touched.
raw := (m.headDim*m.vBits + 7) / 8
packedBytes := (raw + 3) &^ 3
ctx := m.backend.newTQContext(2)
// Same per-layer Cache accounting as EnsureLayer — see that comment for why.
ctx.layer = layer
packed := ctx.Zeros(ml.DTypeI8, packedBytes*m.numKVHeads, capacity).(*Tensor)
scales := ctx.Zeros(ml.DTypeF32, m.numKVHeads, capacity).(*Tensor)
m.vLayerCtxs[layer] = ctx
m.vPackedTensors[layer] = packed
m.vScalesTensors[layer] = scales
}
// EncodeV creates a GGML_OP_TQ_ENCODE_V graph node.
// EnsureVLayer must have been called for this layer before EncodeV.
func (m *ggmlTQCompressedK) EncodeV(ctx ml.Context, layer int, value ml.Tensor, firstCell int) ml.Tensor {
packed := m.vPackedTensors[layer]
if packed == nil {
return nil
}
// Pass the K rotation matrix so V outlier energy spreads evenly before
// quantization. SDPA's post-attention R @ output step undoes the rotation.
return packed.TQEncodeV(ctx, m.vScalesTensors[layer], value, m.rotTensor, firstCell, m.vBoundariesTensor, m.vBits)
}
// EncodeKV creates a single GGML_OP_TQ_ENCODE_KV graph node encoding both
// K and V, halving scheduler overhead vs separate EncodeK + EncodeV.
// Returns (kEncodeResult, vEncodeResult) — both reference the same op for
// graph dependency tracking.
//
// When outlier-split is active, the combined encode kernel is not used
// because it only understands the uniform packed layout. Falls back to
// separate EncodeK (outlier-aware) + EncodeV (uniform) calls.
func (m *ggmlTQCompressedK) EncodeKV(ctx ml.Context, layer int, key, value ml.Tensor, firstCell int) (ml.Tensor, ml.Tensor) {
if m.hasOutliers() {
return m.EncodeK(ctx, layer, key, firstCell), m.EncodeV(ctx, layer, value, firstCell)
}
kPacked := m.packedTensors[layer]
vPacked := m.vPackedTensors[layer]
if kPacked == nil || vPacked == nil {
return nil, nil
}
kResult := kPacked.TQEncodeKV(ctx,
m.scalesTensors[layer], key, m.rotTensor, m.boundariesTensor,
vPacked, m.vScalesTensors[layer], value, m.vBoundariesTensor,
firstCell, m.bits, m.vBits)
// kResult is the EncodeKV op output (K packed view); the scheduler uses it
// to order DequantKV after EncodeKV. V packed buffer was written as a side
// effect by the combined kernel. Return the V packed tensor directly —
// DequantKV reads its data pointer (which now contains the encoded V).
// Graph ordering is still correct: DequantKV depends on kResult (src[0]),
// and both kernels run on the same CUDA stream.
return kResult, vPacked
}
// DequantV creates a GGML_OP_TQ_DEQUANT graph node for V.
// encodeResult is the view returned by EncodeV (establishes encode→dequant ordering).
func (m *ggmlTQCompressedK) DequantV(ctx ml.Context, layer int, encodeResult ml.Tensor, firstCell, nCells int) ml.Tensor {
scales := m.vScalesTensors[layer]
if scales == nil || encodeResult == nil || nCells <= 0 {
return nil
}
return encodeResult.(*Tensor).TQDequant(ctx, scales, m.vCodebookTensor,
m.headDim, m.numKVHeads, nCells, firstCell, m.vBits)
}
// DequantKV creates a single GGML_OP_TQ_DEQUANT_KV graph node that dequants
// both K and V in one op, halving scheduler overhead vs separate DequantK+DequantV.
// Returns (kTensor, vTensor) as views into the combined output.
//
// When outlier-split is active, the combined kernel cannot be used because
// its K reader assumes the uniform packed layout. Returns (nil, nil) to
// force Get() to fall through to the separate DequantK + DequantV path.
func (m *ggmlTQCompressedK) DequantKV(ctx ml.Context, layer int, kEncodeResult, vEncodeResult ml.Tensor, firstCell, nCells int) (ml.Tensor, ml.Tensor) {
if m.hasOutliers() {
return nil, nil
}
kScales := m.scalesTensors[layer]
vScales := m.vScalesTensors[layer]
if kScales == nil || kEncodeResult == nil || nCells <= 0 {
return nil, nil
}
if vScales == nil || vEncodeResult == nil {
return nil, nil
}
combined := TQDequantKV(ctx, m.backend,
kEncodeResult.(*Tensor), kScales, m.codebookTensor,
vEncodeResult.(*Tensor), vScales, m.vCodebookTensor,
m.rotInverseTensor, // R matrix for fused V rotation undo
m.headDim, m.numKVHeads, nCells, firstCell, m.bits, m.vBits)
// Split the [headDim, numKVHeads, nCells, 2] output into K and V views.
planeBytes := m.headDim * m.numKVHeads * nCells * 2 // f16 = 2 bytes
kView := combined.View(ctx, 0, m.headDim, combined.Stride(1), m.numKVHeads, combined.Stride(2), nCells)
vView := combined.View(ctx, planeBytes, m.headDim, combined.Stride(1), m.numKVHeads, combined.Stride(2), nCells)
return kView, vView
}
func (m *ggmlTQCompressedK) Close() {
m.mu.Lock()
defer m.mu.Unlock()
for _, ctx := range m.layerCtxs {
ctx.Close()
}
for _, ctx := range m.vLayerCtxs {
ctx.Close()
}
if m.rotCtx != nil {
m.rotCtx.Close()
}
if m.sharedCtx != nil {
m.sharedCtx.Close()
}
m.packedTensors = nil
m.scalesTensors = nil
m.layerCtxs = nil
m.vPackedTensors = nil
m.vScalesTensors = nil
m.vLayerCtxs = nil
m.rotCtx = nil
m.sharedCtx = nil
}

View file

@ -0,0 +1,96 @@
package ggml
// #include "ggml/include/ggml.h"
import "C"
import ml "github.com/ollama/ollama/ml"
// tqTensor wraps a packed-K buffer with the metadata needed for the fused TQ
// flash attention kernel. The Tensor field holds the encode result (a view of
// the persistent packed-K buffer).
//
// When vPacked is non-nil, the K+V fused kernel is used: V is decoded inline
// from vPacked, bypassing the separate TQ_DEQUANT op for V.
type tqTensor struct {
*Tensor // packed K view ([packedBytes*nKVHeads, capacity] i8; encode result)
scales *Tensor // K scales [nKVHeads, capacity] f32
codebook *Tensor // K codebook [1<<bits] f32
bits int
headDim int
nKVHeads int
nCells int
firstCell int
// V packed fields (nil = V is f16 from inner cache; non-nil = K+V fused)
vPacked *Tensor // packed V view [v_packedBytes*nKVHeads, capacity] i8
vScales *Tensor // V scales [nKVHeads, capacity] f32
vCodebook *Tensor // V codebook [1<<vBits] f32
vBits int
}
// Permute propagates the tqTensor wrapper through the key permutation that
// ScaledDotProductAttention applies before the flash-attention dispatch.
// The packed-K layout is custom (not standard ggml strides), so we preserve
// the wrapper metadata and let the CUDA kernel ignore the permuted strides.
func (t *tqTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
return &tqTensor{
Tensor: t.Tensor.Permute(ctx, shape...).(*Tensor),
scales: t.scales,
codebook: t.codebook,
bits: t.bits,
headDim: t.headDim,
nKVHeads: t.nKVHeads,
nCells: t.nCells,
firstCell: t.firstCell,
vPacked: t.vPacked,
vScales: t.vScales,
vCodebook: t.vCodebook,
vBits: t.vBits,
}
}
// TQFlashAttention creates a GGML_OP_TQ_FLASH_ATTN_EXT graph node.
// query: permuted+rotated Q [D, nTokensQ, nHeadsQ, nSeq] f32
// tqk: TQ packed-K wrapper (may also carry V packed fields for K+V fused)
// value: permuted f16 V for K-only fused, OR packed i8 V for K+V fused
func (b *Backend) tqFlashAttention(
ctx ml.Context,
query *Tensor,
tqk *tqTensor,
value *Tensor,
mask ml.Tensor,
scale float64,
logitSoftcap float64,
) ml.Tensor {
var maskT *C.struct_ggml_tensor
if mask != nil {
maskT = mask.(*Tensor).t
}
// K+V fused: pass V packed tensors to the C API.
// K-only fused: pass NULL for v_scales (backward compat).
var vScalesT, vCodebookT *C.struct_ggml_tensor
vBits := C.int32_t(0)
if tqk.vPacked != nil {
vScalesT = tqk.vScales.t
vCodebookT = tqk.vCodebook.t
vBits = C.int32_t(tqk.vBits)
}
t := C.ggml_tq_flash_attn_ext(
ctx.(*Context).ctx,
query.t,
tqk.Tensor.t,
value.t,
maskT,
tqk.scales.t,
tqk.codebook.t,
C.float(scale),
C.float(logitSoftcap),
C.int32_t(tqk.bits),
C.int32_t(tqk.firstCell),
vScalesT,
vCodebookT,
vBits,
)
return &Tensor{b: b, t: t}
}

View file

@ -38,6 +38,7 @@ type Model interface {
Backend() ml.Backend
Config() config
SetCache(kvcache.Cache)
}
// Validator is an optional interface that models can implement to perform
@ -104,6 +105,12 @@ func (m *Base) Config() config {
return m.config
}
// SetCache replaces the model's cache. Used by TurboQuant to wrap the
// Causal cache with compression.
func (m *Base) SetCache(cache kvcache.Cache) {
m.config.Cache = cache
}
var models = make(map[string]func(fs.Config) (Model, error))
// Register registers a model constructor for the given architecture

View file

@ -249,8 +249,8 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
if cc, ok := wc.UnderlyingCache().(kvcache.CausalConfigurable); ok {
cc.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
}

View file

@ -198,8 +198,8 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
if cc, ok := wc.UnderlyingCache().(kvcache.CausalConfigurable); ok {
cc.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
}

View file

@ -46,7 +46,33 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
cache := model.Config().Cache
if cache != nil {
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
dtype := kvCacheTypeFromStr(kvCacheType)
if preset, ok := kvcache.PresetFromDType(dtype); ok {
wrapped, active := kvcache.WrapWithTurboQuant(cache, preset)
if active {
cache = wrapped
// Force f16 at Init. TurboQuantCache manages its preset
// internally and always passes f16 to its inner *Causal; any
// sibling sub-caches in a WrapperCache (e.g. the SWA side of
// gemma3/gemma4) need f16 too since they can't allocate the
// ggml-unknown TQ dtypes.
dtype = ml.DTypeF16
// For the top-level *Causal case, WrapWithTurboQuant returns
// a new *TurboQuantCache and the model must be re-pointed at
// it. For the *WrapperCache case, the same pointer is returned
// (mutated in place) and no SetCache is needed.
if tqc, ok := wrapped.(*kvcache.TurboQuantCache); ok {
model.SetCache(tqc)
}
slog.Info("using turboquant kv cache", "preset", preset.Name)
} else {
// Non-wrappable cache (recurrent, SWA-only, etc.); fall back
// to f16 so Init doesn't receive an unmapped dtype and panic.
dtype = ml.DTypeF16
slog.Warn("turboquant requested but cache is not wrappable, falling back to f16")
}
}
cache.Init(model.Backend(), dtype, numSlots, int(numCtx), batchSize)
}
return &InputCache{
@ -64,6 +90,14 @@ func kvCacheTypeFromStr(s string) ml.DType {
return ml.DTypeQ80
case "q4_0":
return ml.DTypeQ40
case "tq2":
return ml.DTypeTQ2
case "tq3":
return ml.DTypeTQ3
case "tq3k":
return ml.DTypeTQ3K
case "tq2k":
return ml.DTypeTQ2K
default:
return ml.DTypeF16
}

178
turboquant/block.go Normal file
View file

@ -0,0 +1,178 @@
package turboquant
import (
"bytes"
"encoding/binary"
"fmt"
"io"
)
type Block struct {
Version uint8
PresetID uint8
Role uint8
Objective uint8
OriginalDim uint16
PaddedDim uint16
BlockDim uint16
RegularBits uint8
RotationSeed uint64
CodebookID uint16
QJLRows uint16
AuxLayoutID uint8
ChannelIndices []uint16
Scale float32
RegularIndices []byte
Residual ResidualSketch
}
func (b Block) MarshalBinary() ([]byte, error) {
var buf bytes.Buffer
fields := []any{
b.Version,
b.PresetID,
b.Role,
b.Objective,
b.OriginalDim,
b.PaddedDim,
b.BlockDim,
b.RegularBits,
b.RotationSeed,
b.CodebookID,
b.QJLRows,
b.AuxLayoutID,
uint16(len(b.ChannelIndices)),
b.Scale,
uint32(len(b.RegularIndices)),
b.Residual.Seed,
b.Residual.Scale,
b.Residual.SketchDim,
uint32(len(b.Residual.Signs)),
}
for _, field := range fields {
if err := binary.Write(&buf, binary.LittleEndian, field); err != nil {
return nil, err
}
}
for _, idx := range b.ChannelIndices {
if err := binary.Write(&buf, binary.LittleEndian, idx); err != nil {
return nil, err
}
}
buf.Write(b.RegularIndices)
buf.Write(b.Residual.Signs)
return buf.Bytes(), nil
}
func (b *Block) UnmarshalBinary(data []byte) error {
r := bytes.NewReader(data)
var channelCount uint16
var regularLen, residualLen uint32
fields := []any{
&b.Version,
&b.PresetID,
&b.Role,
&b.Objective,
&b.OriginalDim,
&b.PaddedDim,
&b.BlockDim,
&b.RegularBits,
&b.RotationSeed,
&b.CodebookID,
&b.QJLRows,
&b.AuxLayoutID,
&channelCount,
&b.Scale,
&regularLen,
&b.Residual.Seed,
&b.Residual.Scale,
&b.Residual.SketchDim,
&residualLen,
}
for _, field := range fields {
if err := binary.Read(r, binary.LittleEndian, field); err != nil {
return err
}
}
if b.Version != BlockVersion {
return fmt.Errorf("unsupported block version %d", b.Version)
}
if b.Objective != uint8(objectiveMSE) && b.Objective != uint8(objectiveProduct) {
return fmt.Errorf("unsupported block objective %d", b.Objective)
}
if b.OriginalDim == 0 || b.OriginalDim > b.PaddedDim || b.BlockDim != b.PaddedDim {
return fmt.Errorf("invalid block dims: original=%d padded=%d block=%d", b.OriginalDim, b.PaddedDim, b.BlockDim)
}
if channelCount > 0 {
b.ChannelIndices = make([]uint16, channelCount)
for i := range b.ChannelIndices {
if err := binary.Read(r, binary.LittleEndian, &b.ChannelIndices[i]); err != nil {
return err
}
}
}
b.RegularIndices = make([]byte, regularLen)
b.Residual.Signs = make([]byte, residualLen)
for _, dst := range [][]byte{b.RegularIndices, b.Residual.Signs} {
if _, err := io.ReadFull(r, dst); err != nil {
return err
}
}
if r.Len() != 0 {
return fmt.Errorf("unexpected trailing bytes in turboquant block: %d", r.Len())
}
return nil
}
func packBits(values []uint8, bitsPerValue int) []byte {
if bitsPerValue <= 0 || len(values) == 0 {
return nil
}
totalBits := len(values) * bitsPerValue
out := make([]byte, (totalBits+7)/8)
mask := uint8((1 << bitsPerValue) - 1)
bitPos := 0
for _, value := range values {
packed := value & mask
bytePos := bitPos / 8
shift := bitPos % 8
out[bytePos] |= packed << shift
if shift+bitsPerValue > 8 {
out[bytePos+1] |= packed >> (8 - shift)
}
bitPos += bitsPerValue
}
return out
}
func unpackBits(data []byte, bitsPerValue, count int) []uint8 {
out := make([]uint8, count)
if bitsPerValue <= 0 {
return out
}
mask := uint8((1 << bitsPerValue) - 1)
bitPos := 0
for i := range count {
bytePos := bitPos / 8
shift := bitPos % 8
if bytePos >= len(data) {
return out
}
value := data[bytePos] >> shift
if shift+bitsPerValue > 8 && bytePos+1 < len(data) {
value |= data[bytePos+1] << (8 - shift)
}
out[i] = value & mask
bitPos += bitsPerValue
}
return out
}
func expectedPackedBytes(count, bitsPerValue int) int {
if count <= 0 || bitsPerValue <= 0 {
return 0
}
return ((count * bitsPerValue) + 7) / 8
}

185
turboquant/block_test.go Normal file
View file

@ -0,0 +1,185 @@
package turboquant
import "testing"
func TestBlockMarshalRoundTrip(t *testing.T) {
block := Block{
Version: BlockVersion,
PresetID: PresetTQ3.ID,
Role: uint8(roleKey),
Objective: uint8(objectiveProduct),
OriginalDim: 8,
PaddedDim: 8,
BlockDim: 8,
RegularBits: 3,
RotationSeed: 77,
CodebookID: 3,
QJLRows: 4,
AuxLayoutID: 1,
Scale: 1,
RegularIndices: []byte{1, 2, 3},
Residual: ResidualSketch{
Seed: 88,
Scale: 0.5,
SketchDim: 4,
Signs: []byte{0x0f},
},
}
data, err := block.MarshalBinary()
if err != nil {
t.Fatal(err)
}
var decoded Block
if err := decoded.UnmarshalBinary(data); err != nil {
t.Fatal(err)
}
if decoded.Version != block.Version ||
decoded.PresetID != block.PresetID ||
decoded.Role != block.Role ||
decoded.Objective != block.Objective ||
decoded.OriginalDim != block.OriginalDim ||
decoded.PaddedDim != block.PaddedDim ||
decoded.BlockDim != block.BlockDim ||
decoded.RegularBits != block.RegularBits ||
decoded.RotationSeed != block.RotationSeed ||
decoded.CodebookID != block.CodebookID ||
decoded.QJLRows != block.QJLRows ||
decoded.AuxLayoutID != block.AuxLayoutID ||
string(decoded.RegularIndices) != string(block.RegularIndices) ||
decoded.Residual.Seed != block.Residual.Seed ||
decoded.Residual.Scale != block.Residual.Scale ||
decoded.Residual.SketchDim != block.Residual.SketchDim ||
string(decoded.Residual.Signs) != string(block.Residual.Signs) {
t.Fatalf("decoded block mismatch: %+v", decoded)
}
}
func TestBlockUnmarshalRejectsBadVersion(t *testing.T) {
block := Block{Version: BlockVersion, PresetID: PresetTQ2.ID, Role: uint8(roleValue), Objective: uint8(objectiveMSE), OriginalDim: 4, PaddedDim: 4, BlockDim: 4, RegularBits: 2}
data, err := block.MarshalBinary()
if err != nil {
t.Fatal(err)
}
data[0] = 99
var decoded Block
if err := decoded.UnmarshalBinary(data); err == nil {
t.Fatal("expected unsupported block version error")
}
}
func TestPackBitsRoundTripMixedWidths(t *testing.T) {
values2 := []uint8{1, 3, 0, 2}
roundTrip2 := unpackBits(packBits(values2, 2), 2, len(values2))
for i := range values2 {
if roundTrip2[i] != values2[i] {
t.Fatalf("2-bit round trip mismatch at %d: got %d want %d", i, roundTrip2[i], values2[i])
}
}
values3 := []uint8{3, 7, 1, 5}
roundTrip3 := unpackBits(packBits(values3, 3), 3, len(values3))
for i := range values3 {
if roundTrip3[i] != values3[i] {
t.Fatalf("3-bit round trip mismatch at %d: got %d want %d", i, roundTrip3[i], values3[i])
}
}
}
// TestEncodeOutlierSplitLayout verifies that vectors larger than OutlierCount
// are encoded as two blocks (outlier + regular) with the expected bit widths and
// ChannelIndices, and that vectors at or below OutlierCount stay as a single block.
//
// Uses an explicit outlier-enabled preset because PresetTQ3's shipped default
// is OutlierCount=0 (outlier split is opt-in infrastructure; see PresetTQ3
// comment in turboquant.go).
func TestEncodeOutlierSplitLayout(t *testing.T) {
preset := testOutlierPreset(PresetTQ3, 32)
// dim=70 > OutlierCount=32: two blocks expected.
encoded, err := EncodeVector(pseudoRandomVector(70, 0x55), preset)
if err != nil {
t.Fatal(err)
}
if len(encoded.Blocks) != 2 {
t.Fatalf("block count = %d, want 2", len(encoded.Blocks))
}
outlierBlock := encoded.Blocks[0]
regularBlock := encoded.Blocks[1]
if int(outlierBlock.OriginalDim) != preset.OutlierCount {
t.Errorf("outlier block dim = %d, want %d", outlierBlock.OriginalDim, preset.OutlierCount)
}
if outlierBlock.RegularBits != uint8(preset.OutlierBits) {
t.Errorf("outlier bits = %d, want %d", outlierBlock.RegularBits, preset.OutlierBits)
}
if len(outlierBlock.ChannelIndices) != preset.OutlierCount {
t.Errorf("outlier ChannelIndices len = %d, want %d", len(outlierBlock.ChannelIndices), preset.OutlierCount)
}
wantRegularDim := 70 - preset.OutlierCount
if int(regularBlock.OriginalDim) != wantRegularDim {
t.Errorf("regular block dim = %d, want %d", regularBlock.OriginalDim, wantRegularDim)
}
if regularBlock.RegularBits != uint8(preset.ValueBits) {
t.Errorf("regular bits = %d, want %d", regularBlock.RegularBits, preset.ValueBits)
}
if len(regularBlock.ChannelIndices) != wantRegularDim {
t.Errorf("regular ChannelIndices len = %d, want %d", len(regularBlock.ChannelIndices), wantRegularDim)
}
// ChannelIndices across both blocks must cover all 70 channels exactly once.
seen := make([]int, 70)
for _, idx := range outlierBlock.ChannelIndices {
seen[idx]++
}
for _, idx := range regularBlock.ChannelIndices {
seen[idx]++
}
for i, count := range seen {
if count != 1 {
t.Errorf("channel %d appears %d times across blocks", i, count)
}
}
}
// TestBlockMarshalWithChannelIndices verifies that ChannelIndices round-trips correctly.
func TestBlockMarshalWithChannelIndices(t *testing.T) {
block := Block{
Version: BlockVersion,
PresetID: PresetTQ2.ID,
Role: uint8(roleKey),
Objective: uint8(objectiveMSE),
OriginalDim: 4,
PaddedDim: 4,
BlockDim: 4,
RegularBits: 2,
RotationSeed: 42,
CodebookID: 2,
QJLRows: 0,
AuxLayoutID: 1,
ChannelIndices: []uint16{0, 3, 7, 12},
Scale: 0.5,
RegularIndices: []byte{0b10110001},
}
data, err := block.MarshalBinary()
if err != nil {
t.Fatal(err)
}
var got Block
if err := got.UnmarshalBinary(data); err != nil {
t.Fatal(err)
}
if len(got.ChannelIndices) != len(block.ChannelIndices) {
t.Fatalf("ChannelIndices len: got %d want %d", len(got.ChannelIndices), len(block.ChannelIndices))
}
for i := range block.ChannelIndices {
if got.ChannelIndices[i] != block.ChannelIndices[i] {
t.Errorf("ChannelIndices[%d]: got %d want %d", i, got.ChannelIndices[i], block.ChannelIndices[i])
}
}
}

216
turboquant/codebook.go Normal file
View file

@ -0,0 +1,216 @@
package turboquant
import (
"math"
"slices"
"sync"
)
type codebookCacheKey struct {
dim int
bits int
}
type scalarCodebookCacheValue struct {
codebook []float32
boundaries []float32
}
var scalarCodebookCache sync.Map
// ExportCodebook returns the Lloyd-Max codebook centroids for the given
// dim and bits. Used by the CUDA dequant kernel (loaded into GPU constant memory).
func ExportCodebook(dim, bits int) []float32 {
cb, _ := scalarCodebook(dim, bits)
return cb
}
// ExportBoundaries returns the Lloyd-Max decision boundaries for the given
// dim and bits. Used by the CUDA encode kernel for binary-search quantization.
// Boundaries are the midpoints between adjacent centroids; len = (1<<bits) - 1.
func ExportBoundaries(dim, bits int) []float32 {
_, boundaries := scalarCodebook(dim, bits)
return boundaries
}
func scalarCodebook(dim int, bits int) ([]float32, []float32) {
key := codebookCacheKey{dim: dim, bits: bits}
if cached, ok := scalarCodebookCache.Load(key); ok {
value := cached.(scalarCodebookCacheValue)
return append([]float32(nil), value.codebook...), append([]float32(nil), value.boundaries...)
}
codebook := buildLloydMaxCodebook(dim, bits)
value := scalarCodebookCacheValue{
codebook: codebook,
boundaries: codebookBoundaries(codebook),
}
actual, _ := scalarCodebookCache.LoadOrStore(key, value)
cached := actual.(scalarCodebookCacheValue)
return append([]float32(nil), cached.codebook...), append([]float32(nil), cached.boundaries...)
}
func buildLloydMaxCodebook(dim int, bits int) []float32 {
levels := 1 << bits
if levels <= 1 {
return []float32{0}
}
samples := unitVectorCoordSamples(dim, bits, 65536)
slices.Sort(samples)
centroids := make([]float64, levels)
for level := range levels {
begin := level * len(samples) / levels
end := (level + 1) * len(samples) / levels
if end <= begin {
end = begin + 1
}
centroids[level] = meanFloat64(samples[begin:end])
}
for range 48 {
slices.Sort(centroids)
bounds := make([]float64, levels-1)
for i := range bounds {
bounds[i] = (centroids[i] + centroids[i+1]) / 2
}
sums := make([]float64, levels)
counts := make([]int, levels)
for _, sample := range samples {
idx := quantizeScalarFloat64(sample, bounds)
sums[idx] += sample
counts[idx]++
}
maxDelta := 0.0
for i := range centroids {
var next float64
if counts[i] > 0 {
next = sums[i] / float64(counts[i])
} else if i == 0 {
next = bounds[0] - 0.25
} else if i == len(centroids)-1 {
next = bounds[len(bounds)-1] + 0.25
} else {
next = (bounds[i-1] + bounds[i]) / 2
}
maxDelta = math.Max(maxDelta, math.Abs(next-centroids[i]))
centroids[i] = next
}
if maxDelta < 1e-6 {
break
}
}
slices.Sort(centroids)
codebook := make([]float32, len(centroids))
for i := range centroids {
codebook[i] = float32(centroids[i])
}
return codebook
}
// unitVectorCoordSamples returns count samples from the exact marginal
// distribution of a single coordinate of a uniformly random unit vector in R^d:
//
// z_1 / ‖z‖ · √d, z ~ N(0, I_d)
//
// For large d this converges to N(0,1); for smaller d the heavier tails of the
// Beta((d-3)/2,(d-3)/2) coordinate distribution are preserved. Using this
// distribution — rather than pure N(0,1) — produces the optimal Lloyd-Max
// codebook for the actual coordinate distribution that arises after RMS
// normalization and random rotation (Paper §3.1, Eq. 4, Lemma 1).
func unitVectorCoordSamples(dim int, bits int, count int) []float64 {
rng := splitmix64(uint64(bits+1)<<48 ^ uint64(dim+1)<<16 ^ 0x4d595df4d0f33173)
out := make([]float64, count)
if dim <= 1 {
for i := range out {
out[i] = gaussianFloat64(&rng)
}
return out
}
sqrtDim := math.Sqrt(float64(dim))
for i := range out {
z0 := gaussianFloat64(&rng)
sumSq := z0 * z0
for k := 1; k < dim; k++ {
g := gaussianFloat64(&rng)
sumSq += g * g
}
norm := math.Sqrt(sumSq)
if norm < 1e-15 {
out[i] = 0
} else {
out[i] = z0 / norm * sqrtDim
}
}
return out
}
func meanFloat64(values []float64) float64 {
if len(values) == 0 {
return 0
}
total := 0.0
for _, value := range values {
total += value
}
return total / float64(len(values))
}
func codebookBoundaries(codebook []float32) []float32 {
if len(codebook) < 2 {
return nil
}
out := make([]float32, len(codebook)-1)
for i := range out {
out[i] = (codebook[i] + codebook[i+1]) / 2
}
return out
}
func quantizeScalarByBoundary(v float32, codebook []float32, boundaries []float32) uint8 {
if len(codebook) == 0 {
return 0
}
if len(boundaries) != len(codebook)-1 {
return quantizeScalarNearest(v, codebook)
}
idx := 0
for idx < len(boundaries) && v >= boundaries[idx] {
idx++
}
return uint8(idx)
}
func quantizeScalarFloat64(v float64, boundaries []float64) int {
idx := 0
for idx < len(boundaries) && v >= boundaries[idx] {
idx++
}
return idx
}
func quantizeScalarNearest(v float32, codebook []float32) uint8 {
best := 0
bestDist := float32(math.MaxFloat32)
for i, centroid := range codebook {
d := abs32(v - centroid)
if d < bestDist {
bestDist = d
best = i
}
}
return uint8(best)
}
func dequantizeScalar(idx uint8, codebook []float32) float32 {
if int(idx) >= len(codebook) {
return 0
}
return codebook[idx]
}

198
turboquant/codebook_test.go Normal file
View file

@ -0,0 +1,198 @@
package turboquant
import (
"fmt"
"math"
"testing"
)
// TestUnitVectorCoordSamplesVariance checks that unitVectorCoordSamples
// produces samples with variance ≈ 1.0 for the normalized coordinate
// distribution (z_1/‖z‖ · √d). The variance of this distribution is exactly
// d · Var[z_1/‖z‖] = d · (1/d) = 1, independent of d. For small d the
// distribution has heavier tails than N(0,1) but still variance=1.
func TestUnitVectorCoordSamplesVariance(t *testing.T) {
for _, dim := range []int{4, 8, 16, 32, 64, 128} {
samples := unitVectorCoordSamples(dim, 2, 32768)
var sum, sumSq float64
for _, s := range samples {
sum += s
sumSq += s * s
}
n := float64(len(samples))
mean := sum / n
variance := sumSq/n - mean*mean
// Mean should be ~0, variance ~1 for all d.
if math.Abs(mean) > 0.05 {
t.Errorf("dim=%d: mean=%.4f, want ~0", dim, mean)
}
if math.Abs(variance-1.0) > 0.05 {
t.Errorf("dim=%d: variance=%.4f, want ~1.0", dim, variance)
}
}
}
// TestUnitVectorCoordSamplesKurtosis checks that small-d samples have lower
// kurtosis than N(0,1) (kurtosis=3), confirming the bounded-support Beta tails.
// The kurtosis of the coordinate distribution is 3d/(d+2), which equals
// 2.0 at d=4 and approaches 3.0 as d→∞. So for small d the distribution is
// platykurtic (kurtosis < 3, bounded support), unlike the unbounded Gaussian.
func TestUnitVectorCoordSamplesKurtosis(t *testing.T) {
// For d=4: kurtosis = 3×4/(4+2) = 2.0 (platykurtic, well below Gaussian's 3).
// For d=128: kurtosis = 3×128/130 ≈ 2.95 (close to Gaussian's 3).
samplesSmall := unitVectorCoordSamples(4, 2, 65536)
var s2, s4 float64
for _, s := range samplesSmall {
s2 += s * s
s4 += s * s * s * s
}
n := float64(len(samplesSmall))
varSmall := s2 / n
kurtSmall := (s4 / n) / (varSmall * varSmall)
if kurtSmall >= 3.0 {
t.Errorf("dim=4: kurtosis=%.3f, want < 3.0 (platykurtic for bounded distribution)", kurtSmall)
}
samplesLarge := unitVectorCoordSamples(128, 2, 65536)
var l2, l4 float64
for _, s := range samplesLarge {
l2 += s * s
l4 += s * s * s * s
}
varLarge := l2 / n
kurtLarge := (l4 / n) / (varLarge * varLarge)
// For d=128 kurtosis should be close to Gaussian's 3.
if math.Abs(kurtLarge-3.0) > 0.3 {
t.Errorf("dim=128: kurtosis=%.3f, want ~3.0 (close to Gaussian)", kurtLarge)
}
}
func TestScalarCodebookDeterministic(t *testing.T) {
for _, bits := range []int{2, 3} {
codebookA, boundsA := scalarCodebook(128, bits)
codebookB, boundsB := scalarCodebook(128, bits)
if len(codebookA) != 1<<bits {
t.Fatalf("bits=%d codebook len=%d", bits, len(codebookA))
}
for i := range codebookA {
if codebookA[i] != codebookB[i] {
t.Fatalf("bits=%d centroid mismatch at %d", bits, i)
}
}
for i := range boundsA {
if boundsA[i] != boundsB[i] {
t.Fatalf("bits=%d boundary mismatch at %d", bits, i)
}
}
}
}
func TestCodebookBoundariesMonotonic(t *testing.T) {
for _, bits := range []int{2, 3} {
_, bounds := scalarCodebook(128, bits)
for i := 1; i < len(bounds); i++ {
if bounds[i] <= bounds[i-1] {
t.Fatalf("bits=%d boundaries are not monotonic", bits)
}
}
}
}
func TestQuantizeScalarByBoundaryDeterministic(t *testing.T) {
codebook, bounds := scalarCodebook(128, 3)
mid := bounds[2]
left := quantizeScalarByBoundary(mid-1e-6, codebook, bounds)
right := quantizeScalarByBoundary(mid+1e-6, codebook, bounds)
atBoundary := quantizeScalarByBoundary(mid, codebook, bounds)
if left != 2 {
t.Fatalf("left boundary bucket = %d, want 2", left)
}
if right != 3 {
t.Fatalf("right boundary bucket = %d, want 3", right)
}
if atBoundary != 3 {
t.Fatalf("exact boundary bucket = %d, want 3", atBoundary)
}
}
// TestPaperTheoremOneMSEBounds verifies that the Lloyd-Max codebook achieves
// the total-distortion bounds from Theorem 1 of arXiv:2504.19874
// for unit-norm vectors at bit widths 14.
//
// Paper Theorem 1 bounds (total distortion ||v - v_hat||^2 for unit-norm vectors):
//
// b=1 → 0.36, b=2 → 0.117, b=3 → 0.03, b=4 → 0.009
func TestPaperTheoremOneMSEBounds(t *testing.T) {
paperBounds := map[int]float64{
1: 0.36,
2: 0.117,
3: 0.03,
4: 0.009,
}
const dim = 128
const nTrials = 200
rot := BuildRotation(dim, 0x42c0ffee)
for bits := 1; bits <= 4; bits++ {
bits := bits // capture loop variable
t.Run(fmt.Sprintf("bits=%d", bits), func(t *testing.T) {
bound := paperBounds[bits]
codebook, boundaries := scalarCodebook(dim, bits)
var totalMSE float64
for trial := range nTrials {
vec := pseudoRandomVector(dim, uint64(trial)*0x9e3779b97f4a7c15+1)
// Normalize to unit norm.
var norm float64
for _, v := range vec {
norm += float64(v) * float64(v)
}
norm = math.Sqrt(norm)
if norm < 1e-10 {
continue
}
for i := range vec {
vec[i] /= float32(norm)
}
// Apply rotation (matches internal encode behavior).
rotated := ApplyRotation(vec, rot)
// Compute RMS scale (same as blockScale in encode.go).
var sumSq float64
for _, v := range rotated {
sumSq += float64(v) * float64(v)
}
scale := float32(math.Sqrt(sumSq / float64(dim)))
// Quantize each element and accumulate per-element MSE.
var mse float64
for _, v := range rotated {
normalized := float32(0)
if scale > 0 {
normalized = v / scale
}
idx := quantizeScalarByBoundary(normalized, codebook, boundaries)
recon := codebook[idx] * scale
diff := float64(v - recon)
mse += diff * diff
}
totalMSE += mse
}
avgMSE := totalMSE / float64(nTrials)
// Allow 50% headroom over the paper bound to account for finite
// sample size and finite dimension effects.
if avgMSE > bound*1.5 {
t.Errorf("bits=%d: avg MSE %.6f exceeds 1.5× paper bound %.6f",
bits, avgMSE, bound*1.5)
}
t.Logf("bits=%d: avg MSE=%.6f, paper bound=%.6f, ratio=%.2f",
bits, avgMSE, bound, avgMSE/bound)
})
}
}

113
turboquant/decode.go Normal file
View file

@ -0,0 +1,113 @@
package turboquant
import (
"bytes"
"encoding/binary"
"fmt"
"io"
)
// DecodeVector fully dequantizes an encoded vector back to float32 in
// original space. This applies inverse rotation and, for product-mode
// blocks, adds the reconstructed QJL residual.
func DecodeVector(data []byte) ([]float32, Preset, error) {
ev, err := UnmarshalEncodedVector(data)
if err != nil {
return nil, Preset{}, err
}
decoded := make([]float32, ev.Dim)
offset := 0
for _, block := range ev.Blocks {
blockDim := int(block.OriginalDim)
codebook, _ := scalarCodebook(blockDim, int(block.RegularBits))
indices := unpackBits(block.RegularIndices, int(block.RegularBits), blockDim)
rotated := make([]float32, blockDim)
for i, idx := range indices {
rotated[i] = dequantizeScalar(idx, codebook) * block.Scale
}
if vectorObjective(block.Objective) == objectiveProduct {
residual := reconstructResidual(blockDim, block.Residual)
for i := range rotated {
rotated[i] += residual[i]
}
}
original := ApplyInverseRotation(rotated, BuildRotation(blockDim, block.RotationSeed))
if len(block.ChannelIndices) == blockDim {
// Scatter to original channel positions.
for i, chIdx := range block.ChannelIndices {
decoded[chIdx] = original[i]
}
} else {
// Legacy single-block: fill contiguous range.
copy(decoded[offset:], original)
offset += blockDim
}
}
return decoded, ev.Preset, nil
}
// UnmarshalEncodedVector deserializes an EncodedVector from its binary form.
func UnmarshalEncodedVector(data []byte) (EncodedVector, error) {
r := bytes.NewReader(data)
var version uint8
var presetID uint8
var dim uint32
var blockCount uint32
for _, field := range []any{&version, &presetID, &dim, &blockCount} {
if err := binary.Read(r, binary.LittleEndian, field); err != nil {
return EncodedVector{}, err
}
}
if version != BlockVersion {
return EncodedVector{}, fmt.Errorf("unsupported encoded vector version %d", version)
}
preset, err := PresetByID(presetID)
if err != nil {
return EncodedVector{}, err
}
blocks := make([]Block, 0, blockCount)
totalDim := 0
for range int(blockCount) {
var blockLen uint32
if err := binary.Read(r, binary.LittleEndian, &blockLen); err != nil {
return EncodedVector{}, err
}
blockData := make([]byte, blockLen)
if _, err := io.ReadFull(r, blockData); err != nil {
return EncodedVector{}, err
}
var block Block
if err := block.UnmarshalBinary(blockData); err != nil {
return EncodedVector{}, err
}
if block.PresetID != preset.ID {
return EncodedVector{}, fmt.Errorf("block preset id %d does not match encoded preset %d", block.PresetID, preset.ID)
}
if len(block.RegularIndices) != expectedPackedBytes(int(block.OriginalDim), int(block.RegularBits)) {
return EncodedVector{}, fmt.Errorf("invalid primary index length %d for dim %d and bits %d", len(block.RegularIndices), block.OriginalDim, block.RegularBits)
}
if block.Residual.SketchDim != block.QJLRows {
return EncodedVector{}, fmt.Errorf("residual sketch dim %d does not match qjl rows %d", block.Residual.SketchDim, block.QJLRows)
}
totalDim += int(block.OriginalDim)
blocks = append(blocks, block)
}
if r.Len() != 0 {
return EncodedVector{}, fmt.Errorf("unexpected trailing bytes in encoded vector: %d", r.Len())
}
if totalDim != int(dim) {
return EncodedVector{}, fmt.Errorf("encoded vector dim mismatch: header=%d blocks=%d", dim, totalDim)
}
return EncodedVector{
Version: version,
Preset: preset,
Dim: int(dim),
Blocks: blocks,
}, nil
}

127
turboquant/decode_test.go Normal file
View file

@ -0,0 +1,127 @@
package turboquant
import (
"bytes"
"encoding/binary"
"testing"
)
func TestUnmarshalEncodedVectorRejectsBadHeader(t *testing.T) {
if _, err := UnmarshalEncodedVector([]byte{1, 2, 3}); err == nil {
t.Fatal("expected malformed header error")
}
}
func TestUnmarshalEncodedVectorRejectsWrongVersion(t *testing.T) {
encoded, err := EncodeVector([]float32{1, 2, 3, 4}, PresetTQ3)
if err != nil {
t.Fatal(err)
}
data, err := encoded.MarshalBinary()
if err != nil {
t.Fatal(err)
}
data[0] = 99
if _, err := UnmarshalEncodedVector(data); err == nil {
t.Fatal("expected unsupported version error")
}
}
func TestUnmarshalEncodedVectorRejectsBadPresetID(t *testing.T) {
encoded, err := EncodeVector([]float32{1, 2, 3, 4}, PresetTQ3)
if err != nil {
t.Fatal(err)
}
data, err := encoded.MarshalBinary()
if err != nil {
t.Fatal(err)
}
data[1] = 99
if _, err := UnmarshalEncodedVector(data); err == nil {
t.Fatal("expected bad preset id error")
}
}
func TestUnmarshalEncodedVectorRejectsTruncatedBlockPayload(t *testing.T) {
encoded, err := EncodeVector(pseudoRandomVector(16, 0x99), PresetTQ2)
if err != nil {
t.Fatal(err)
}
data, err := encoded.MarshalBinary()
if err != nil {
t.Fatal(err)
}
truncated := data[:len(data)-1]
if _, err := UnmarshalEncodedVector(truncated); err == nil {
t.Fatal("expected truncated block payload error")
}
}
func TestDecodeVectorRejectsInvalidIndexLengths(t *testing.T) {
encoded, err := EncodeVector(pseudoRandomVector(16, 0x77), PresetTQ3)
if err != nil {
t.Fatal(err)
}
data, err := encoded.MarshalBinary()
if err != nil {
t.Fatal(err)
}
ev, err := UnmarshalEncodedVector(data)
if err != nil {
t.Fatal(err)
}
ev.Blocks[0].RegularIndices = append(ev.Blocks[0].RegularIndices, 0)
corrupt, err := marshalTestEncodedVector(ev)
if err != nil {
t.Fatal(err)
}
if _, _, err := DecodeVector(corrupt); err == nil {
t.Fatal("expected invalid primary index length error")
}
}
func TestDecodeVectorPreservesOriginalLength(t *testing.T) {
values := pseudoRandomVector(130, 0x66)
encoded, err := EncodeVector(values, PresetTQ3)
if err != nil {
t.Fatal(err)
}
data, err := encoded.MarshalBinary()
if err != nil {
t.Fatal(err)
}
decoded, _, err := DecodeVector(data)
if err != nil {
t.Fatal(err)
}
if len(decoded) != len(values) {
t.Fatalf("decoded len = %d, want %d", len(decoded), len(values))
}
}
func marshalTestEncodedVector(ev EncodedVector) ([]byte, error) {
var buf bytes.Buffer
for _, field := range []any{ev.Version, ev.Preset.ID, uint32(ev.Dim), uint32(len(ev.Blocks))} {
if err := binary.Write(&buf, binary.LittleEndian, field); err != nil {
return nil, err
}
}
for _, block := range ev.Blocks {
blockData, err := block.MarshalBinary()
if err != nil {
return nil, err
}
if err := binary.Write(&buf, binary.LittleEndian, uint32(len(blockData))); err != nil {
return nil, err
}
buf.Write(blockData)
}
return buf.Bytes(), nil
}

386
turboquant/encode.go Normal file
View file

@ -0,0 +1,386 @@
package turboquant
import (
"bytes"
"encoding/binary"
"fmt"
"math"
)
// Rotation seed modifiers for outlier and regular sub-blocks.
// Using ASCII mnemonics: "OUTL1ER\0" and "REGULAR\0".
const (
outlierSeedXOR = uint64(0x4f55544c31455200)
regularSeedXOR = uint64(0x524547554c415200)
)
type EncodedVector struct {
Version uint8
Preset Preset
Dim int
Blocks []Block
}
func EncodeVector(values []float32, preset Preset) (EncodedVector, error) {
return encodeVector(values, preset, roleGeneric, objectiveMSE, preset.ValueBits)
}
func EncodeKeyVector(values []float32, preset Preset) (EncodedVector, error) {
return encodeVector(values, preset, roleKey, objectiveProduct, preset.KeyPrimaryBits)
}
func EncodeValueVector(values []float32, preset Preset) (EncodedVector, error) {
return encodeVector(values, preset, roleValue, objectiveMSE, preset.ValueBits)
}
func encodeVector(values []float32, preset Preset, role vectorRole, objective vectorObjective, bits int) (EncodedVector, error) {
dim := len(values)
if dim <= 0 {
return EncodedVector{}, fmt.Errorf("invalid turboquant vector dim %d", dim)
}
if bits <= 0 || bits >= 8 {
return EncodedVector{}, fmt.Errorf("invalid turboquant bit width %d", bits)
}
if preset.HasOutlierSplit() && dim > preset.OutlierCount {
return encodeVectorWithOutliers(values, preset, role, objective, bits)
}
block, err := encodeSubBlock(values, nil, preset, role, objective, bits, preset.RotationSeed)
if err != nil {
return EncodedVector{}, err
}
return EncodedVector{
Version: BlockVersion,
Preset: preset,
Dim: dim,
Blocks: []Block{block},
}, nil
}
// encodeVectorWithOutliers splits the vector into outlier and regular channel
// sub-blocks, each encoded independently with its own rotation. The outlier
// block is stored first in Blocks.
func encodeVectorWithOutliers(values []float32, preset Preset, role vectorRole, objective vectorObjective, regularBits int) (EncodedVector, error) {
split := SplitOutlierChannels(values, preset.OutlierCount)
outlierBits := preset.OutlierBits
if outlierBits <= 0 || outlierBits >= 8 {
return EncodedVector{}, fmt.Errorf("invalid outlier bit width %d", outlierBits)
}
if regularBits <= 0 || regularBits >= 8 {
return EncodedVector{}, fmt.Errorf("invalid regular bit width %d", regularBits)
}
outlierSeed := preset.RotationSeed ^ outlierSeedXOR
regularSeed := preset.RotationSeed ^ regularSeedXOR
outlierBlock, err := encodeSubBlock(split.OutlierValues, split.OutlierIndices, preset, role, objective, outlierBits, outlierSeed)
if err != nil {
return EncodedVector{}, fmt.Errorf("outlier block: %w", err)
}
// Regular block always uses MSE (no QJL sketch) regardless of the key/value
// role. QJL is only applied to the outlier block, which concentrates the
// residual correction budget on the highest-magnitude channels.
regularBlock, err := encodeSubBlock(split.RegularValues, split.RegularIndices, preset, role, objectiveMSE, regularBits, regularSeed)
if err != nil {
return EncodedVector{}, fmt.Errorf("regular block: %w", err)
}
return EncodedVector{
Version: BlockVersion,
Preset: preset,
Dim: len(values),
Blocks: []Block{outlierBlock, regularBlock},
}, nil
}
// encodeSubBlock encodes a sub-vector (identified by channelIndices into the
// original full-dim vector) as a single Block. If channelIndices is nil, the
// block covers all channels (single-block legacy path).
func encodeSubBlock(values []float32, channelIndices []uint16, preset Preset, role vectorRole, objective vectorObjective, bits int, rotationSeed uint64) (Block, error) {
dim := len(values)
if dim <= 0 {
return Block{}, fmt.Errorf("empty sub-block")
}
codebook, boundaries := scalarCodebook(dim, bits)
rotation := BuildRotation(dim, rotationSeed)
rotated := ApplyRotation(values, rotation)
scale := blockScale(rotated)
primaryCodes := make([]uint8, dim)
reconRotated := make([]float32, dim)
if scale == 0 {
for i := range primaryCodes {
primaryCodes[i] = quantizeScalarByBoundary(0, codebook, boundaries)
reconRotated[i] = 0
}
} else {
for i, value := range rotated {
normalized := value / scale
idx := quantizeScalarByBoundary(normalized, codebook, boundaries)
primaryCodes[i] = idx
reconRotated[i] = dequantizeScalar(idx, codebook) * scale
}
}
qjlRows := 0
if objective == objectiveProduct {
qjlRows = preset.KeyQJLRows(dim)
}
block := Block{
Version: BlockVersion,
PresetID: preset.ID,
Role: uint8(role),
Objective: uint8(objective),
OriginalDim: uint16(dim),
PaddedDim: uint16(dim),
BlockDim: uint16(dim),
RegularBits: uint8(bits),
RotationSeed: rotationSeed,
CodebookID: uint16(bits),
QJLRows: uint16(qjlRows),
AuxLayoutID: 1,
ChannelIndices: channelIndices,
Scale: scale,
RegularIndices: packBits(primaryCodes, bits),
Residual: encodeResidual(rotated, reconRotated, qjlRows, rotationSeed^0x9e3779b97f4a7c15),
}
return block, nil
}
// EncodeKeyPerHead quantizes a single attention head's key vector using
// rotation + Lloyd-Max without outlier split or QJL residual. Returns
// packed N-bit indices and an RMS scale. The quantized values are in
// rotated space — the caller must rotate Q to match at attention time.
//
// This produces a GPU-friendly flat representation: just packed bits + scale.
func EncodeKeyPerHead(values []float32, preset Preset) (packedIndices []byte, scale float32, err error) {
dim := len(values)
if dim <= 0 {
return nil, 0, fmt.Errorf("empty head vector")
}
bits := preset.KeyPrimaryBits
if bits <= 0 || bits >= 8 {
return nil, 0, fmt.Errorf("invalid bit width %d", bits)
}
codebook, boundaries := scalarCodebook(dim, bits)
rotation := BuildRotation(dim, preset.RotationSeed)
rotated := ApplyRotation(values, rotation)
scale = blockScale(rotated)
codes := make([]uint8, dim)
if scale > 0 {
for i, v := range rotated {
codes[i] = quantizeScalarByBoundary(v/scale, codebook, boundaries)
}
}
return packBits(codes, bits), scale, nil
}
// DequantKeyPerHead reconstructs f32 values from packed indices + scale in
// rotated space (no inverse rotation). Used for CPU-side testing/fallback.
func DequantKeyPerHead(packedIndices []byte, scale float32, headDim, bits int) []float32 {
codebook, _ := scalarCodebook(headDim, bits)
indices := unpackBits(packedIndices, bits, headDim)
out := make([]float32, headDim)
for i, idx := range indices {
out[i] = dequantizeScalar(idx, codebook) * scale
}
return out
}
// OutlierPerHead is the CPU reference for TurboQuant paper Algorithm 1
// Sec 4.3's outlier-split K encoding. Mirrors the GPU kernel
// tq_encode_kernel_outlier / tq_dequant_multihead_kernel_outlier in
// ml/backend/ggml/ggml/src/ggml-cuda/tq-*.cu:
//
// 1. Rotate K by Householder QR of a random Gaussian matrix (shared
// rotation, same as EncodeKeyPerHead).
// 2. Select the top-K channels by absolute rotated magnitude as
// outliers.
// 3. Compute independent RMS scales for regular and outlier sub-blocks.
// 4. Quantize each sub-block with its own Lloyd-Max codebook at the
// preset's primary bits (regular) and outlier bits (outlier).
// 5. Return packed regular + packed outlier streams, per-sub-block
// scales, and the channel index list.
//
// Unlike EncodeKeyVector / encodeVectorWithOutliers, the split happens
// in ROTATED space (single rotation matmul) because that is what the
// GPU can do cheaply and what the paper's symmetric-rotation formulation
// assumes. The Go reference exists to validate the GPU kernel
// bit-exactly, not to reproduce the block-protocol outlier path.
type OutlierPerHead struct {
RegularPacked []byte
RegularScale float32
OutlierPacked []byte
OutlierScale float32
OutlierIndices []int
}
func EncodeKeyPerHeadOutlier(values []float32, preset Preset) (OutlierPerHead, error) {
dim := len(values)
if dim <= 0 {
return OutlierPerHead{}, fmt.Errorf("empty head vector")
}
bits := preset.KeyPrimaryBits
if bits <= 0 || bits >= 8 {
return OutlierPerHead{}, fmt.Errorf("invalid regular bit width %d", bits)
}
outlierBits := preset.OutlierBits
outlierCount := preset.OutlierCount
if outlierBits <= 0 || outlierBits >= 8 || outlierCount <= 0 || outlierCount >= dim {
return OutlierPerHead{}, fmt.Errorf("invalid outlier split %d@%dbits over dim=%d", outlierCount, outlierBits, dim)
}
regularCount := dim - outlierCount
rotation := BuildRotation(dim, preset.RotationSeed)
rotated := ApplyRotation(values, rotation)
// Step 2: top-K outlier select by abs(rotated). Mirrors the serial
// thread-0 selection in tq_encode_kernel_outlier step 4. We use the
// same "mark-and-scan" algorithm so the outlier set order matches
// the GPU kernel exactly for a given input.
isOutlier := make([]bool, dim)
outlierPos := make([]int, 0, outlierCount)
outlierVal := make([]float32, 0, outlierCount)
for range outlierCount {
bestVal := float32(-1.0)
bestIdx := 0
for i := range dim {
if isOutlier[i] {
continue
}
a := rotated[i]
if a < 0 {
a = -a
}
if a > bestVal {
bestVal = a
bestIdx = i
}
}
isOutlier[bestIdx] = true
outlierPos = append(outlierPos, bestIdx)
outlierVal = append(outlierVal, rotated[bestIdx])
}
// Build the regular channel position list in ascending order. The
// GPU kernel walks i=0..dim-1 and appends non-outlier positions to
// s_reg_pos in order; the CPU reference does the same so packed
// regular slot r maps to the same original channel on both sides.
regularRotated := make([]float32, 0, regularCount)
for i := range dim {
if !isOutlier[i] {
regularRotated = append(regularRotated, rotated[i])
}
}
regularScale := blockScale(regularRotated)
outlierScale := blockScale(outlierVal)
regularCodebook, regularBoundaries := scalarCodebook(dim, bits)
outlierCodebook, outlierBoundaries := scalarCodebook(dim, outlierBits)
regularCodes := make([]uint8, regularCount)
if regularScale > 0 {
for r, v := range regularRotated {
regularCodes[r] = quantizeScalarByBoundary(v/regularScale, regularCodebook, regularBoundaries)
}
}
outlierCodes := make([]uint8, outlierCount)
if outlierScale > 0 {
for r, v := range outlierVal {
outlierCodes[r] = quantizeScalarByBoundary(v/outlierScale, outlierCodebook, outlierBoundaries)
}
}
return OutlierPerHead{
RegularPacked: packBits(regularCodes, bits),
RegularScale: regularScale,
OutlierPacked: packBits(outlierCodes, outlierBits),
OutlierScale: outlierScale,
OutlierIndices: outlierPos,
}, nil
}
// DequantKeyPerHeadOutlier is the CPU reference for the outlier-split
// dequant. Returns the reconstructed rotated-space vector that the
// caller can compare against ApplyRotation(original, rotation).
func DequantKeyPerHeadOutlier(enc OutlierPerHead, preset Preset, headDim int) []float32 {
outlierCount := len(enc.OutlierIndices)
regularCount := headDim - outlierCount
bits := preset.KeyPrimaryBits
outlierBits := preset.OutlierBits
regularCodebook, _ := scalarCodebook(headDim, bits)
outlierCodebook, _ := scalarCodebook(headDim, outlierBits)
regularIdx := unpackBits(enc.RegularPacked, bits, regularCount)
outlierIdx := unpackBits(enc.OutlierPacked, outlierBits, outlierCount)
// Build position-to-slot lookups. outlier_slot_for[i] is the index
// in OutlierIndices that equals i (or -1 if i is regular).
outlierSlotFor := make([]int, headDim)
for i := range outlierSlotFor {
outlierSlotFor[i] = -1
}
for slot, pos := range enc.OutlierIndices {
outlierSlotFor[pos] = slot
}
out := make([]float32, headDim)
regPos := 0
for i := range headDim {
if outlierSlotFor[i] >= 0 {
out[i] = dequantizeScalar(outlierIdx[outlierSlotFor[i]], outlierCodebook) * enc.OutlierScale
} else {
out[i] = dequantizeScalar(regularIdx[regPos], regularCodebook) * enc.RegularScale
regPos++
}
}
return out
}
func blockScale(values []float32) float32 {
if len(values) == 0 {
return 0
}
var sumSquares float64
for _, value := range values {
sumSquares += float64(value * value)
}
if sumSquares < 1e-12 {
return 0
}
return float32(math.Sqrt(sumSquares / float64(len(values))))
}
func (e EncodedVector) MarshalBinary() ([]byte, error) {
var buf bytes.Buffer
header := []any{
e.Version,
e.Preset.ID,
uint32(e.Dim),
uint32(len(e.Blocks)),
}
for _, field := range header {
if err := binary.Write(&buf, binary.LittleEndian, field); err != nil {
return nil, err
}
}
for _, block := range e.Blocks {
blockData, err := block.MarshalBinary()
if err != nil {
return nil, err
}
if err := binary.Write(&buf, binary.LittleEndian, uint32(len(blockData))); err != nil {
return nil, err
}
buf.Write(blockData)
}
return buf.Bytes(), nil
}

964
turboquant/encode_test.go Normal file
View file

@ -0,0 +1,964 @@
package turboquant
import (
"fmt"
"math"
"testing"
)
// testOutlierPreset returns a copy of base with outlier split forced on
// for testing. The shipped PresetTQ3 / PresetTQ3K default to
// OutlierCount=0 (see comments on those presets in turboquant.go): the
// paper-grounded outlier-split path is available in the kernels and Go
// helpers but disabled in the default presets because on the models
// this fork ships against it hurts decode throughput and PPL without
// the Phase 2A asymmetric-quantization follow-up. Tests that exercise
// the outlier-split code path explicitly opt in via this helper.
func testOutlierPreset(base Preset, count int) Preset {
out := base
out.OutlierCount = count
return out
}
// TestOutlierSplitBoundaries pins the single-block vs two-block boundary.
// dim == OutlierCount must produce a single block (no split).
// dim == OutlierCount+1 is the smallest two-block encoding.
// TestMemoryFormulaMatchesMarshalSize verifies that the two-block byte-count
// formula used in fs/ggml/ggml.go GraphSize matches the actual MarshalBinary
// output size. This catches formula drift whenever Block layout changes.
func TestMemoryFormulaMatchesMarshalSize(t *testing.T) {
// The GraphSize formula replicated below assumes the two-block
// outlier-split layout. All shipped tq* presets default to uniform
// (OutlierCount=0) so force outlier variants here; the formula itself
// is what's being pinned, not the default preset state.
tq2Split := testOutlierPreset(PresetTQ2, 32)
tq3Split := testOutlierPreset(PresetTQ3, 32)
cases := []struct {
preset Preset
dim int
}{
{tq2Split, 128},
{tq3Split, 128},
{tq3Split, 64},
}
for _, tc := range cases {
vec := make([]float32, tc.dim)
keyEncoded, err := EncodeKeyVector(vec, tc.preset)
if err != nil {
t.Fatalf("%s dim=%d key: %v", tc.preset.Name, tc.dim, err)
}
keyData, err := keyEncoded.MarshalBinary()
if err != nil {
t.Fatalf("%s dim=%d key marshal: %v", tc.preset.Name, tc.dim, err)
}
valueEncoded, err := EncodeValueVector(vec, tc.preset)
if err != nil {
t.Fatalf("%s dim=%d value: %v", tc.preset.Name, tc.dim, err)
}
valueData, err := valueEncoded.MarshalBinary()
if err != nil {
t.Fatalf("%s dim=%d value marshal: %v", tc.preset.Name, tc.dim, err)
}
// Replicate the fs/ggml/ggml.go GraphSize formula.
const outlierCount = uint64(32)
outlierBits := uint64(tc.preset.OutlierBits)
regularKeyBits := uint64(tc.preset.KeyPrimaryBits)
regularValueBits := uint64(tc.preset.ValueBits)
dim := uint64(tc.dim)
outlierData := (outlierCount*outlierBits + 7) / 8
qjlData := (outlierCount + 7) / 8
wantKey := 122 + 2*dim + outlierData + ((dim-outlierCount)*regularKeyBits+7)/8 + qjlData
wantValue := 122 + 2*dim + outlierData + ((dim-outlierCount)*regularValueBits+7)/8
if uint64(len(keyData)) != wantKey {
t.Errorf("%s dim=%d: key MarshalBinary=%d bytes, formula=%d",
tc.preset.Name, tc.dim, len(keyData), wantKey)
}
if uint64(len(valueData)) != wantValue {
t.Errorf("%s dim=%d: value MarshalBinary=%d bytes, formula=%d",
tc.preset.Name, tc.dim, len(valueData), wantValue)
}
}
}
func TestOutlierSplitBoundaries(t *testing.T) {
for _, preset := range []Preset{testOutlierPreset(PresetTQ2, 32), testOutlierPreset(PresetTQ3, 32)} {
atBoundary := pseudoRandomVector(preset.OutlierCount, 0xbabe)
encoded, err := EncodeKeyVector(atBoundary, preset)
if err != nil {
t.Fatalf("%s dim=OutlierCount: %v", preset.Name, err)
}
if len(encoded.Blocks) != 1 {
t.Errorf("%s dim=%d: got %d blocks, want 1 (no split at exact boundary)",
preset.Name, preset.OutlierCount, len(encoded.Blocks))
}
if len(encoded.Blocks[0].ChannelIndices) != 0 {
t.Errorf("%s dim=%d: single-block should have no ChannelIndices", preset.Name, preset.OutlierCount)
}
minSplit := pseudoRandomVector(preset.OutlierCount+1, 0xbabe)
encoded2, err := EncodeKeyVector(minSplit, preset)
if err != nil {
t.Fatalf("%s dim=OutlierCount+1: %v", preset.Name, err)
}
if len(encoded2.Blocks) != 2 {
t.Errorf("%s dim=%d: got %d blocks, want 2 (minimum outlier split)",
preset.Name, preset.OutlierCount+1, len(encoded2.Blocks))
}
// Regular block has dim=1; verify it round-trips cleanly.
data, err := encoded2.MarshalBinary()
if err != nil {
t.Fatalf("%s dim=OutlierCount+1 marshal: %v", preset.Name, err)
}
decoded, _, err := DecodeVector(data)
if err != nil {
t.Fatalf("%s dim=OutlierCount+1 decode: %v", preset.Name, err)
}
if len(decoded) != preset.OutlierCount+1 {
t.Errorf("%s dim=OutlierCount+1: decoded len=%d want %d",
preset.Name, len(decoded), preset.OutlierCount+1)
}
}
}
func TestEncodeDecodeRoundTripAcrossShapes(t *testing.T) {
testCases := []struct {
name string
values []float32
preset Preset
maxMSE float32
}{
{name: "small", values: []float32{0.25, -1.5, 3.25, 0.75, -0.5, 2.0, 1.0}, preset: PresetTQ3, maxMSE: 1.5},
{name: "non-power-of-two", values: pseudoRandomVector(70, 2), preset: PresetTQ3, maxMSE: 3.0},
{name: "multi-head-like", values: pseudoRandomVector(128, 3), preset: PresetTQ2, maxMSE: 5.0},
{name: "constant", values: filledVector(31, 1.5), preset: PresetTQ2, maxMSE: 1.0},
{name: "zero", values: filledVector(64, 0), preset: PresetTQ3, maxMSE: 0.01},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encoded, err := EncodeVector(tc.values, tc.preset)
if err != nil {
t.Fatal(err)
}
data, err := encoded.MarshalBinary()
if err != nil {
t.Fatal(err)
}
decoded, preset, err := DecodeVector(data)
if err != nil {
t.Fatal(err)
}
if preset.Name != tc.preset.Name {
t.Fatalf("preset = %q, want %q", preset.Name, tc.preset.Name)
}
if len(decoded) != len(tc.values) {
t.Fatalf("decoded len = %d, want %d", len(decoded), len(tc.values))
}
stats := Compare(tc.values, decoded)
if stats.MSE > tc.maxMSE {
t.Fatalf("MSE = %v, want <= %v", stats.MSE, tc.maxMSE)
}
})
}
}
func TestEncodeVectorDeterministicBytes(t *testing.T) {
values := pseudoRandomVector(96, 0x4242)
encodedA, err := EncodeVector(values, PresetTQ3)
if err != nil {
t.Fatal(err)
}
encodedB, err := EncodeVector(values, PresetTQ3)
if err != nil {
t.Fatal(err)
}
dataA, err := encodedA.MarshalBinary()
if err != nil {
t.Fatal(err)
}
dataB, err := encodedB.MarshalBinary()
if err != nil {
t.Fatal(err)
}
if string(dataA) != string(dataB) {
t.Fatal("expected byte-identical encoding output")
}
}
func TestEncodeKeyAndValueUseDifferentObjectives(t *testing.T) {
values := pseudoRandomVector(32, 0x77)
keyEncoded, err := EncodeKeyVector(values, PresetTQ3)
if err != nil {
t.Fatal(err)
}
valueEncoded, err := EncodeValueVector(values, PresetTQ3)
if err != nil {
t.Fatal(err)
}
if keyEncoded.Blocks[0].Objective != uint8(objectiveProduct) {
t.Fatalf("key objective = %d, want %d", keyEncoded.Blocks[0].Objective, objectiveProduct)
}
if valueEncoded.Blocks[0].Objective != uint8(objectiveMSE) {
t.Fatalf("value objective = %d, want %d", valueEncoded.Blocks[0].Objective, objectiveMSE)
}
if keyEncoded.Blocks[0].QJLRows == 0 {
t.Fatal("expected product-mode key rows to carry a residual sketch")
}
if valueEncoded.Blocks[0].QJLRows != 0 {
t.Fatal("expected MSE value rows to omit a residual sketch")
}
}
func TestPresetNames(t *testing.T) {
for _, name := range []string{"tq2", "tq3", "tq3k", "tq2k"} {
preset, err := PresetByName(name)
if err != nil {
t.Fatal(err)
}
if preset.Name != name {
t.Fatalf("preset %q resolved to %q", name, preset.Name)
}
}
}
// TestQJLDimMatchesPaperSpec verifies that the QJL sketch uses d random
// projections (one per dimension), matching the paper's specification in
// arXiv 2504.19874. With QJLRowsDivisor=1 this ensures the estimator variance
// matches the paper's theoretical analysis and the bit accounting is exact:
// tq2 = 2.5 bits/elem avg, tq3 = 3.5 bits/elem avg.
func TestQJLDimMatchesPaperSpec(t *testing.T) {
cases := []struct {
preset Preset
dim int
}{
{PresetTQ2, 64},
{PresetTQ2, 128},
{PresetTQ3, 128},
{PresetTQ3, 256},
}
for _, tc := range cases {
got := tc.preset.KeyQJLRows(tc.dim)
if got != tc.dim {
t.Errorf("%s: KeyQJLRows(%d) = %d, want %d (paper spec: d projections per d-dim vector)",
tc.preset.Name, tc.dim, got, tc.dim)
}
}
}
// TestPaperMSEDistortionBound verifies that the MSE quantizer operates within
// the information-theoretic bounds from Theorem 3 of arXiv 2504.19874:
//
// lower: D_mse >= 1/4^b
// upper: D_mse <= (√3π/2) / 4^b ≈ 2.72 / 4^b
//
// Tested on 1000 random unit vectors at dim=128 (the primary validated head_dim).
func TestPaperMSEDistortionBound(t *testing.T) {
const dim = 128
const trials = 1000
cases := []struct {
preset Preset
bits int
}{
{PresetTQ2, PresetTQ2.ValueBits},
{PresetTQ3, PresetTQ3.ValueBits},
}
for _, tc := range cases {
t.Run(tc.preset.Name, func(t *testing.T) {
lowerBound := 1.0 / math.Pow(4, float64(tc.bits))
paperUpper := 2.72 / math.Pow(4, float64(tc.bits))
var totalDistortion float64
rng := splitmix64(0x1234567890abcdef)
for range trials {
// Random unit vector drawn from the uniform distribution on S^{d-1}.
vec := make([]float32, dim)
var norm2 float64
for j := range vec {
v := gaussianFloat64(&rng)
vec[j] = float32(v)
norm2 += v * v
}
norm := math.Sqrt(norm2)
for j := range vec {
vec[j] /= float32(norm)
}
encoded, err := EncodeValueVector(vec, tc.preset)
if err != nil {
t.Fatal(err)
}
data, err := encoded.MarshalBinary()
if err != nil {
t.Fatal(err)
}
decoded, _, err := DecodeVector(data)
if err != nil {
t.Fatal(err)
}
// D_mse = ||x - x_hat||^2 / ||x||^2 = ||x - x_hat||^2 (||x||=1)
var distortion float64
for j := range vec {
d := float64(vec[j] - decoded[j])
distortion += d * d
}
totalDistortion += distortion
}
avgDistortion := totalDistortion / float64(trials)
t.Logf("%s D_mse = %.6f, paper bounds [%.6f, %.6f]",
tc.preset.Name, avgDistortion, lowerBound, paperUpper)
// Allow 1.5× headroom over the paper's upper bound to account for
// finite-d effects and the Cartesian (non-PolarQuant) encoding path.
if avgDistortion > paperUpper*1.5 {
t.Fatalf("D_mse = %.6f exceeds paper upper bound × 1.5 (%.6f)",
avgDistortion, paperUpper*1.5)
}
})
}
}
// TestPaperProductUnbiasedness verifies that the product-objective estimator
// is near-unbiased: E[score(q, encoded_k) - dot(q, k)] ≈ 0. This is the
// central claim of Q_prod in arXiv 2504.19874.
func TestPaperProductUnbiasedness(t *testing.T) {
const dim = 128
const trials = 500
// Allow 5% signed relative bias averaged over 500 trials.
const maxRelBias = 0.05
rng := splitmix64(0xdeadbeefcafe1234)
var signedBias, rmsTrue float64
for range trials {
query := make([]float32, dim)
key := make([]float32, dim)
for j := range query {
query[j] = float32(gaussianFloat64(&rng))
key[j] = float32(gaussianFloat64(&rng))
}
encoded, err := EncodeKeyVector(key, PresetTQ3)
if err != nil {
t.Fatal(err)
}
data, err := encoded.MarshalBinary()
if err != nil {
t.Fatal(err)
}
decoded, _, err := DecodeVector(data)
if err != nil {
t.Fatal(err)
}
var estimated float32
for j := range query {
estimated += query[j] * decoded[j]
}
var trueDot float32
for j := range query {
trueDot += query[j] * key[j]
}
signedBias += float64(estimated - trueDot)
rmsTrue += float64(trueDot * trueDot)
}
avgSignedBias := signedBias / float64(trials)
rmsTrue = math.Sqrt(rmsTrue / float64(trials))
relativeBias := math.Abs(avgSignedBias) / rmsTrue
t.Logf("avg signed bias = %.6f, rms true dot = %.6f, relative bias = %.4f",
avgSignedBias, rmsTrue, relativeBias)
if relativeBias > maxRelBias {
t.Fatalf("relative bias = %.4f, want <= %.4f (product estimator should be near-unbiased)",
relativeBias, maxRelBias)
}
}
// TestOutlierSplitMSEImproves verifies that encoding with the outlier-split
// strategy achieves lower MSE than uniform quantization at the same average bit
// rate. This is the core quality claim of §4.3 of arXiv 2504.19874.
func TestOutlierSplitMSEImproves(t *testing.T) {
const dim = 128
const trials = 200
rng := splitmix64(0xfeedbabe12345678)
for _, preset := range []Preset{testOutlierPreset(PresetTQ2, 32), testOutlierPreset(PresetTQ3, 32)} {
t.Run(preset.Name, func(t *testing.T) {
var splitMSE, uniformMSE float64
for range trials {
vec := make([]float32, dim)
for j := range vec {
vec[j] = float32(gaussianFloat64(&rng))
}
// Outlier-split encoding (2-block, current path).
splitEncoded, err := EncodeValueVector(vec, preset)
if err != nil {
t.Fatal(err)
}
splitData, err := splitEncoded.MarshalBinary()
if err != nil {
t.Fatal(err)
}
splitDecoded, _, err := DecodeVector(splitData)
if err != nil {
t.Fatal(err)
}
for j := range vec {
d := float64(vec[j] - splitDecoded[j])
splitMSE += d * d
}
// Uniform encoding: both sub-block sizes at regular bits, same total bits.
// Encode the whole vector at regular bits to match the average bit rate.
uniformBits := preset.ValueBits
unifBlock, err := encodeSubBlock(vec, nil, preset, roleValue, objectiveMSE, uniformBits, preset.RotationSeed)
if err != nil {
t.Fatal(err)
}
codebook, _ := scalarCodebook(dim, uniformBits)
rot := BuildRotation(dim, preset.RotationSeed)
uIndices := unpackBits(unifBlock.RegularIndices, uniformBits, dim)
unifRotated := make([]float32, dim)
for j, idx := range uIndices {
unifRotated[j] = dequantizeScalar(idx, codebook) * unifBlock.Scale
}
unifDecoded := ApplyInverseRotation(unifRotated, rot)
for j := range vec {
d := float64(vec[j] - unifDecoded[j])
uniformMSE += d * d
}
}
splitMSE /= float64(trials * dim)
uniformMSE /= float64(trials * dim)
t.Logf("%s: split MSE=%.6f uniform MSE=%.6f", preset.Name, splitMSE, uniformMSE)
if splitMSE >= uniformMSE {
t.Errorf("outlier split MSE (%.6f) not lower than uniform MSE (%.6f)", splitMSE, uniformMSE)
}
})
}
}
// TestOutlierSplitProductUnbiasedness verifies that the multi-block product
// estimator remains near-unbiased after the outlier split is applied.
func TestOutlierSplitProductUnbiasedness(t *testing.T) {
const dim = 128
const trials = 300
const maxRelBias = 0.07
outlierPreset := testOutlierPreset(PresetTQ3, 32)
rng := splitmix64(0xabcdef0123456789)
var signedBias, rmsTrue float64
for range trials {
query := make([]float32, dim)
key := make([]float32, dim)
for j := range query {
query[j] = float32(gaussianFloat64(&rng))
key[j] = float32(gaussianFloat64(&rng))
}
encoded, err := EncodeKeyVector(key, outlierPreset)
if err != nil {
t.Fatal(err)
}
data, err := encoded.MarshalBinary()
if err != nil {
t.Fatal(err)
}
decoded, _, err := DecodeVector(data)
if err != nil {
t.Fatal(err)
}
var estimated float32
for j := range query {
estimated += query[j] * decoded[j]
}
var trueDot float32
for j := range query {
trueDot += query[j] * key[j]
}
signedBias += float64(estimated - trueDot)
rmsTrue += float64(trueDot * trueDot)
}
avgSignedBias := signedBias / float64(trials)
rmsTrue = math.Sqrt(rmsTrue / float64(trials))
relativeBias := math.Abs(avgSignedBias) / rmsTrue
t.Logf("outlier-split avg signed bias = %.6f, rms true dot = %.6f, relative bias = %.4f",
avgSignedBias, rmsTrue, relativeBias)
if relativeBias > maxRelBias {
t.Fatalf("relative bias = %.4f, want <= %.4f (multi-block estimator should be near-unbiased)",
relativeBias, maxRelBias)
}
}
func TestDistortionThresholds(t *testing.T) {
tq2Mean, err := meanMSEForPreset(PresetTQ2)
if err != nil {
t.Fatal(err)
}
tq3Mean, err := meanMSEForPreset(PresetTQ3)
if err != nil {
t.Fatal(err)
}
if tq2Mean > 5.5 {
t.Fatalf("tq2 mean MSE = %v, want <= 5.5", tq2Mean)
}
if tq3Mean > 3.5 {
t.Fatalf("tq3 mean MSE = %v, want <= 3.5", tq3Mean)
}
if tq3Mean > tq2Mean {
t.Fatalf("tq3 mean MSE = %v, want <= tq2 mean MSE %v", tq3Mean, tq2Mean)
}
}
// ── per-head uniform encoding ──────────────────────────────────────────────
func TestEncodeKeyPerHeadRoundTrip(t *testing.T) {
const dim = 128
const trials = 200
for _, preset := range []Preset{PresetTQ2, PresetTQ3} {
t.Run(preset.Name, func(t *testing.T) {
rng := splitmix64(0xbeefcafe)
var totalMSE float64
var totalAbsErr, totalAbsDot float64
for range trials {
values := make([]float32, dim)
query := make([]float32, dim)
for j := range values {
values[j] = float32(gaussianFloat64(&rng))
query[j] = float32(gaussianFloat64(&rng))
}
packed, scale, err := EncodeKeyPerHead(values, preset)
if err != nil {
t.Fatal(err)
}
expectedBytes := (dim*preset.KeyPrimaryBits + 7) / 8
if len(packed) != expectedBytes {
t.Fatalf("packed len = %d, want %d", len(packed), expectedBytes)
}
dequantRot := DequantKeyPerHead(packed, scale, dim, preset.KeyPrimaryBits)
// Verify MSE in rotated space (reconstruction quality).
rotation := BuildRotation(dim, preset.RotationSeed)
valuesRot := ApplyRotation(values, rotation)
var mse float64
for j := range valuesRot {
d := float64(valuesRot[j] - dequantRot[j])
mse += d * d
}
mse /= float64(dim)
totalMSE += mse
// Verify dot product in rotated space matches true dot product.
queryRot := ApplyRotation(query, rotation)
var estDot, trueDot float32
for j := range queryRot {
estDot += queryRot[j] * dequantRot[j]
}
for j := range query {
trueDot += query[j] * values[j]
}
totalAbsErr += math.Abs(float64(estDot - trueDot))
totalAbsDot += math.Abs(float64(trueDot))
_ = scale
}
avgMSE := totalMSE / trials
avgRelErr := totalAbsErr / (totalAbsDot + 1e-8)
t.Logf("avg MSE = %.6f, avg relative dot error = %.4f", avgMSE, avgRelErr)
// Uniform encoding (no outlier split, no QJL) has higher error than
// full TQ. These thresholds validate the codec works, not paper quality.
maxRelErr := 0.45
if preset.KeyPrimaryBits >= 3 {
maxRelErr = 0.25
}
if avgRelErr > maxRelErr {
t.Fatalf("avg relative dot error = %.4f, want <= %.4f", avgRelErr, maxRelErr)
}
})
}
}
// TestEncodeKeyPerHeadRoundTripLargeDims checks CPU reference round-trip
// quality at the head dims used by models outside the llama/qwen D=128 norm.
// These are regression targets for the non-128 code path: gemma3 global uses
// D=256, gemma4 global uses D=512. If any sub-test passes but the CUDA path
// produces different output at the same D, the bug is kernel-side; if a
// sub-test fails, the core TQ math (rotation + codebook) has a D-dependent
// bug.
func TestEncodeKeyPerHeadRoundTripLargeDims(t *testing.T) {
const trials = 200
for _, dim := range []int{256, 512} {
for _, preset := range []Preset{PresetTQ2, PresetTQ3} {
t.Run(fmt.Sprintf("dim%d_%s", dim, preset.Name), func(t *testing.T) {
rng := splitmix64(0xbeefcafe ^ uint64(dim))
var totalMSE, totalAbsErr, totalAbsDot float64
for range trials {
values := make([]float32, dim)
query := make([]float32, dim)
for j := range values {
values[j] = float32(gaussianFloat64(&rng))
query[j] = float32(gaussianFloat64(&rng))
}
packed, scale, err := EncodeKeyPerHead(values, preset)
if err != nil {
t.Fatal(err)
}
dequantRot := DequantKeyPerHead(packed, scale, dim, preset.KeyPrimaryBits)
rotation := BuildRotation(dim, preset.RotationSeed)
valuesRot := ApplyRotation(values, rotation)
var mse float64
for j := range valuesRot {
d := float64(valuesRot[j] - dequantRot[j])
mse += d * d
}
mse /= float64(dim)
totalMSE += mse
queryRot := ApplyRotation(query, rotation)
var estDot, trueDot float32
for j := range queryRot {
estDot += queryRot[j] * dequantRot[j]
}
for j := range query {
trueDot += query[j] * values[j]
}
totalAbsErr += math.Abs(float64(estDot - trueDot))
totalAbsDot += math.Abs(float64(trueDot))
}
avgMSE := totalMSE / trials
avgRelErr := totalAbsErr / (totalAbsDot + 1e-8)
t.Logf("D=%d %s: avg MSE = %.6f, avg rel dot error = %.4f", dim, preset.Name, avgMSE, avgRelErr)
maxRelErr := 0.45
if preset.KeyPrimaryBits >= 3 {
maxRelErr = 0.25
}
if avgRelErr > maxRelErr {
t.Fatalf("D=%d %s: avg relative dot error = %.4f, want <= %.4f", dim, preset.Name, avgRelErr, maxRelErr)
}
})
}
}
}
// TestEncodeKeyPerHeadWithDCOffset checks whether TQ handles K vectors with
// a DC bias component. Models like qwen2/qwen2.5 have learned Q/K/V bias
// tensors; their K vectors have a non-zero mean that TQ's RMS-based scale
// normalization doesn't center out. This test adds a DC offset to random
// Gaussian samples and measures round-trip quality vs the non-biased case.
func TestEncodeKeyPerHeadWithDCOffset(t *testing.T) {
const dim = 128
const trials = 200
offsets := []float32{0.0, 0.3, 1.0, 3.0}
for _, offset := range offsets {
for _, preset := range []Preset{PresetTQ3} {
t.Run(fmt.Sprintf("offset%.1f_%s", offset, preset.Name), func(t *testing.T) {
rng := splitmix64(0xbeefcafe)
var totalAbsErr, totalAbsDot float64
for range trials {
values := make([]float32, dim)
query := make([]float32, dim)
for j := range values {
values[j] = float32(gaussianFloat64(&rng)) + offset
query[j] = float32(gaussianFloat64(&rng))
}
packed, scale, err := EncodeKeyPerHead(values, preset)
if err != nil {
t.Fatal(err)
}
dequantRot := DequantKeyPerHead(packed, scale, dim, preset.KeyPrimaryBits)
rotation := BuildRotation(dim, preset.RotationSeed)
queryRot := ApplyRotation(query, rotation)
var estDot, trueDot float32
for j := range queryRot {
estDot += queryRot[j] * dequantRot[j]
}
for j := range query {
trueDot += query[j] * values[j]
}
totalAbsErr += math.Abs(float64(estDot - trueDot))
totalAbsDot += math.Abs(float64(trueDot))
}
avgRelErr := totalAbsErr / (totalAbsDot + 1e-8)
t.Logf("offset=%.1f %s: avg rel dot error = %.4f", offset, preset.Name, avgRelErr)
})
}
}
}
func TestEncodeKeyPerHeadZeroVector(t *testing.T) {
values := make([]float32, 64)
packed, scale, err := EncodeKeyPerHead(values, PresetTQ3)
if err != nil {
t.Fatal(err)
}
if scale != 0 {
t.Fatalf("expected zero scale for zero vector, got %v", scale)
}
dequant := DequantKeyPerHead(packed, scale, 64, PresetTQ3.KeyPrimaryBits)
for i, v := range dequant {
if v != 0 {
t.Fatalf("dequant[%d] = %v, want 0", i, v)
}
}
}
func TestEncodeKeyPerHeadDimSizes(t *testing.T) {
for _, dim := range []int{32, 64, 96, 128, 256} {
t.Run(fmt.Sprintf("dim%d", dim), func(t *testing.T) {
values := pseudoRandomVector(dim, uint64(dim))
packed, scale, err := EncodeKeyPerHead(values, PresetTQ3)
if err != nil {
t.Fatal(err)
}
if scale <= 0 {
t.Fatalf("expected positive scale for dim=%d", dim)
}
expectedBytes := (dim*PresetTQ3.KeyPrimaryBits + 7) / 8
if len(packed) != expectedBytes {
t.Fatalf("packed len = %d, want %d for dim=%d", len(packed), expectedBytes, dim)
}
})
}
}
func TestEncodeKeyPerHeadDeterministic(t *testing.T) {
for _, preset := range []Preset{PresetTQ2, PresetTQ3} {
t.Run(preset.Name, func(t *testing.T) {
vec := pseudoRandomVector(128, 0xcafe)
packed1, scale1, err := EncodeKeyPerHead(vec, preset)
if err != nil {
t.Fatalf("first EncodeKeyPerHead: %v", err)
}
packed2, scale2, err := EncodeKeyPerHead(vec, preset)
if err != nil {
t.Fatalf("second EncodeKeyPerHead: %v", err)
}
if scale1 != scale2 {
t.Fatalf("scale not deterministic: %f vs %f", scale1, scale2)
}
for i := range packed1 {
if packed1[i] != packed2[i] {
t.Fatalf("packed byte %d differs: %d vs %d", i, packed1[i], packed2[i])
}
}
})
}
}
// studentT3Float64 samples from Student-t with df=3 via Z / sqrt(ChiSq3/3).
// Student-t df=3 has finite mean/variance but heavy tails (kurtosis → ∞),
// which stresses quantizers that assume Gaussian inputs.
func studentT3Float64(rng *splitmix64) float64 {
z := gaussianFloat64(rng)
// ChiSq(3) = sum of 3 squared standard normals.
var chisq float64
for range 3 {
n := gaussianFloat64(rng)
chisq += n * n
}
return z / math.Sqrt(chisq/3.0)
}
// TestOutlierSplitVsUniformHeavyTailed verifies that the CPU reference
// outlier-split encoder beats the uniform encoder on heavy-tailed K
// vectors (Student-t df=3). This is the statistical regression test
// for the outlier-split algorithm: if someone breaks the split logic
// (e.g. misattributes channels to sub-blocks, uses the wrong scale),
// this test starts failing because outlier no longer helps.
//
// Acceptance: on Student-t df=3 inputs, outlier-split relative dot
// error must be <= 1.5x the uniform relative dot error. In practice
// outlier-split is substantially BETTER on heavy tails — the 1.5x
// bound is a floor to catch "outlier split broken" regressions, not
// the expected performance.
func TestOutlierSplitVsUniformHeavyTailed(t *testing.T) {
const dim = 128
const trials = 200
presets := []Preset{
testOutlierPreset(PresetTQ3, 32),
testOutlierPreset(PresetTQ3K, 32),
}
for _, preset := range presets {
t.Run(preset.Name, func(t *testing.T) {
rng := splitmix64(0xc0de_babe_dead_beef)
var uniformAbsErr, uniformAbsDot float64
var outlierAbsErr, outlierAbsDot float64
for trial := range trials {
values := make([]float32, dim)
query := make([]float32, dim)
for j := range values {
values[j] = float32(studentT3Float64(&rng))
query[j] = float32(gaussianFloat64(&rng))
}
rotation := BuildRotation(dim, preset.RotationSeed)
valuesRot := ApplyRotation(values, rotation)
queryRot := ApplyRotation(query, rotation)
var trueDot float32
for j := range query {
trueDot += query[j] * values[j]
}
// Uniform path.
uPacked, uScale, err := EncodeKeyPerHead(values, preset)
if err != nil {
t.Fatalf("uniform encode: %v", err)
}
uDeq := DequantKeyPerHead(uPacked, uScale, dim, preset.KeyPrimaryBits)
var uEstDot float32
for j := range uDeq {
uEstDot += queryRot[j] * uDeq[j]
}
uniformAbsErr += math.Abs(float64(uEstDot - trueDot))
uniformAbsDot += math.Abs(float64(trueDot))
// Outlier-split path.
oEnc, err := EncodeKeyPerHeadOutlier(values, preset)
if err != nil {
t.Fatalf("outlier encode: %v", err)
}
oDeq := DequantKeyPerHeadOutlier(oEnc, preset, dim)
var oEstDot float32
for j := range oDeq {
oEstDot += queryRot[j] * oDeq[j]
}
outlierAbsErr += math.Abs(float64(oEstDot - trueDot))
outlierAbsDot += math.Abs(float64(trueDot))
// Sanity: the outlier-split decode should still be close
// enough to the rotated input that a round-trip MSE is
// sensible. Catches frame-shifted reconstructions where
// the dot happens to land right but the vector is wrong.
var mse float64
for j := range oDeq {
d := float64(oDeq[j] - valuesRot[j])
mse += d * d
}
mse /= float64(dim)
if mse > 4.0 {
t.Fatalf("trial %d: outlier-split rotated-space MSE = %.4f, reconstruction is garbage", trial, mse)
}
}
uniformRelErr := uniformAbsErr / (uniformAbsDot + 1e-8)
outlierRelErr := outlierAbsErr / (outlierAbsDot + 1e-8)
t.Logf("%s Student-t df=3: uniform rel-dot-err=%.4f, outlier rel-dot-err=%.4f, ratio=%.3f",
preset.Name, uniformRelErr, outlierRelErr, outlierRelErr/uniformRelErr)
// Outlier-split must not be worse than 1.5x uniform. In
// practice it should be substantially better.
if outlierRelErr > 1.5*uniformRelErr {
t.Fatalf("%s: outlier rel-dot-err %.4f > 1.5x uniform %.4f — outlier split broken",
preset.Name, outlierRelErr, uniformRelErr)
}
})
}
}
// TestOutlierPerHeadRoundTrip pins the CPU reference bit-exact round-trip
// at the shapes actually used by GPU kernels (headDim=128, 256). Encode
// → dequant → compare against ApplyRotation(values, rotation). This
// catches algorithmic drift (wrong codebook selection, off-by-one in
// slot mapping, index packing errors) even when the GPU path is broken.
func TestOutlierPerHeadRoundTrip(t *testing.T) {
cases := []struct {
name string
dim int
preset Preset
}{
{"d128_tq3", 128, testOutlierPreset(PresetTQ3, 32)},
{"d128_tq3k", 128, testOutlierPreset(PresetTQ3K, 32)},
{"d256_tq3k", 256, testOutlierPreset(PresetTQ3K, 32)}, // gemma3:1b global-layer headDim
{"d128_tq2k", 128, testOutlierPreset(PresetTQ2K, 32)},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
const trials = 50
rng := splitmix64(0xfeed_face_0000_0000 | uint64(tc.dim))
var totalMSE float64
for range trials {
values := make([]float32, tc.dim)
for j := range values {
values[j] = float32(gaussianFloat64(&rng))
}
enc, err := EncodeKeyPerHeadOutlier(values, tc.preset)
if err != nil {
t.Fatalf("encode: %v", err)
}
if len(enc.OutlierIndices) != tc.preset.OutlierCount {
t.Fatalf("outlier count = %d, want %d", len(enc.OutlierIndices), tc.preset.OutlierCount)
}
expRegularBytes := ((tc.dim-tc.preset.OutlierCount)*tc.preset.KeyPrimaryBits + 7) / 8
if len(enc.RegularPacked) != expRegularBytes {
t.Fatalf("regular packed len = %d, want %d", len(enc.RegularPacked), expRegularBytes)
}
expOutlierBytes := (tc.preset.OutlierCount*tc.preset.OutlierBits + 7) / 8
if len(enc.OutlierPacked) != expOutlierBytes {
t.Fatalf("outlier packed len = %d, want %d", len(enc.OutlierPacked), expOutlierBytes)
}
dec := DequantKeyPerHeadOutlier(enc, tc.preset, tc.dim)
rotation := BuildRotation(tc.dim, tc.preset.RotationSeed)
valuesRot := ApplyRotation(values, rotation)
var mse float64
for j := range valuesRot {
d := float64(valuesRot[j] - dec[j])
mse += d * d
}
mse /= float64(tc.dim)
totalMSE += mse
}
avgMSE := totalMSE / trials
t.Logf("%s: avg rotated-space round-trip MSE = %.4f", tc.name, avgMSE)
// Lloyd-Max 3-bit on unit-variance Gaussian has MSE ~0.04 per
// channel in rotated space. 0.15 is a loose bound that still
// catches format bugs.
if avgMSE > 0.15 {
t.Fatalf("%s: round-trip MSE %.4f too large — encoder/dequant mismatch", tc.name, avgMSE)
}
})
}
}

88
turboquant/outlier.go Normal file
View file

@ -0,0 +1,88 @@
package turboquant
import (
"cmp"
"slices"
)
// OutlierSplit partitions a vector's channels into outlier and regular sets.
// OutlierIndices and RegularIndices are sorted in ascending order.
// OutlierValues and RegularValues contain the corresponding channel values.
type OutlierSplit struct {
OutlierIndices []uint16
OutlierValues []float32
RegularIndices []uint16
RegularValues []float32
}
// SplitOutlierChannels identifies the top-outlierCount channels by absolute
// magnitude and returns them separately from the remaining regular channels.
// Tie-breaking is by index ascending for determinism.
// If outlierCount <= 0, all channels are returned as regular (empty outlier).
// If outlierCount >= len(values), all channels are returned as outlier.
func SplitOutlierChannels(values []float32, outlierCount int) OutlierSplit {
dim := len(values)
if outlierCount <= 0 || dim == 0 {
indices := make([]uint16, dim)
vals := make([]float32, dim)
for i := range values {
indices[i] = uint16(i)
vals[i] = values[i]
}
return OutlierSplit{RegularIndices: indices, RegularValues: vals}
}
if outlierCount >= dim {
indices := make([]uint16, dim)
vals := make([]float32, dim)
for i := range values {
indices[i] = uint16(i)
vals[i] = values[i]
}
return OutlierSplit{OutlierIndices: indices, OutlierValues: vals}
}
// Sort channel indices by abs value descending, tie-break by index ascending.
order := make([]int, dim)
for i := range order {
order[i] = i
}
slices.SortStableFunc(order, func(a, b int) int {
va := abs32(values[a])
vb := abs32(values[b])
if va != vb {
return cmp.Compare(vb, va)
}
return cmp.Compare(a, b)
})
outlierSet := make([]int, outlierCount)
regularSet := make([]int, dim-outlierCount)
copy(outlierSet, order[:outlierCount])
copy(regularSet, order[outlierCount:])
// Each set is sorted by index ascending so ChannelIndices is deterministic.
slices.Sort(outlierSet)
slices.Sort(regularSet)
outlierIndices := make([]uint16, outlierCount)
outlierValues := make([]float32, outlierCount)
for i, idx := range outlierSet {
outlierIndices[i] = uint16(idx)
outlierValues[i] = values[idx]
}
regularCount := dim - outlierCount
regularIndices := make([]uint16, regularCount)
regularValues := make([]float32, regularCount)
for i, idx := range regularSet {
regularIndices[i] = uint16(idx)
regularValues[i] = values[idx]
}
return OutlierSplit{
OutlierIndices: outlierIndices,
OutlierValues: outlierValues,
RegularIndices: regularIndices,
RegularValues: regularValues,
}
}

133
turboquant/outlier_test.go Normal file
View file

@ -0,0 +1,133 @@
package turboquant
import (
"testing"
)
func TestSplitOutlierChannelsBasic(t *testing.T) {
values := []float32{0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6}
split := SplitOutlierChannels(values, 3)
// Top-3 by abs magnitude: indices 1(0.9), 3(0.8), 5(0.7)
wantOutlierIdx := []uint16{1, 3, 5}
wantRegularIdx := []uint16{0, 2, 4, 6, 7}
if len(split.OutlierIndices) != len(wantOutlierIdx) {
t.Fatalf("outlier count: got %d want %d", len(split.OutlierIndices), len(wantOutlierIdx))
}
for i, idx := range wantOutlierIdx {
if split.OutlierIndices[i] != idx {
t.Errorf("outlier[%d]: got %d want %d", i, split.OutlierIndices[i], idx)
}
if split.OutlierValues[i] != values[idx] {
t.Errorf("outlierVal[%d]: got %v want %v", i, split.OutlierValues[i], values[idx])
}
}
if len(split.RegularIndices) != len(wantRegularIdx) {
t.Fatalf("regular count: got %d want %d", len(split.RegularIndices), len(wantRegularIdx))
}
for i, idx := range wantRegularIdx {
if split.RegularIndices[i] != idx {
t.Errorf("regular[%d]: got %d want %d", i, split.RegularIndices[i], idx)
}
if split.RegularValues[i] != values[idx] {
t.Errorf("regularVal[%d]: got %v want %v", i, split.RegularValues[i], values[idx])
}
}
}
func TestSplitOutlierChannelsCoversAll(t *testing.T) {
// Every channel index must appear exactly once across outlier + regular.
values := []float32{3, 1, 4, 1, 5, 9, 2, 6}
for outlierCount := 0; outlierCount <= len(values); outlierCount++ {
split := SplitOutlierChannels(values, outlierCount)
seen := make(map[uint16]int)
for _, idx := range split.OutlierIndices {
seen[idx]++
}
for _, idx := range split.RegularIndices {
seen[idx]++
}
for i := range values {
if seen[uint16(i)] != 1 {
t.Errorf("outlierCount=%d: channel %d appears %d times", outlierCount, i, seen[uint16(i)])
}
}
if len(split.OutlierIndices)+len(split.RegularIndices) != len(values) {
t.Errorf("outlierCount=%d: total channels %d != %d", outlierCount, len(split.OutlierIndices)+len(split.RegularIndices), len(values))
}
}
}
func TestSplitOutlierChannelsZeroOutliers(t *testing.T) {
values := []float32{1, 2, 3, 4}
split := SplitOutlierChannels(values, 0)
if len(split.OutlierIndices) != 0 {
t.Errorf("expected 0 outliers, got %d", len(split.OutlierIndices))
}
if len(split.RegularIndices) != 4 {
t.Errorf("expected 4 regular, got %d", len(split.RegularIndices))
}
}
func TestSplitOutlierChannelsAllOutliers(t *testing.T) {
values := []float32{1, 2, 3, 4}
split := SplitOutlierChannels(values, len(values))
if len(split.OutlierIndices) != 4 {
t.Errorf("expected 4 outliers, got %d", len(split.OutlierIndices))
}
if len(split.RegularIndices) != 0 {
t.Errorf("expected 0 regular, got %d", len(split.RegularIndices))
}
}
func TestSplitOutlierChannelsNegativeValues(t *testing.T) {
// Negative values: magnitude matters, not sign.
values := []float32{-5, 1, -3, 2}
split := SplitOutlierChannels(values, 2)
// Top-2 by abs: index 0 (5.0), index 2 (3.0)
if split.OutlierIndices[0] != 0 || split.OutlierIndices[1] != 2 {
t.Errorf("expected outlier indices [0,2], got %v", split.OutlierIndices)
}
}
func TestSplitOutlierChannelsTieBreakerByIndex(t *testing.T) {
// Equal magnitudes: lower index wins as outlier (tie-break index ascending).
values := []float32{1, 1, 1, 1}
split := SplitOutlierChannels(values, 2)
// Should pick indices 0 and 1 as outliers.
if split.OutlierIndices[0] != 0 || split.OutlierIndices[1] != 1 {
t.Errorf("expected outlier indices [0,1], got %v", split.OutlierIndices)
}
}
func TestSplitOutlierChannelsIndicesSorted(t *testing.T) {
// Output indices must be sorted ascending in both slices.
values := pseudoRandomVector(16, 42)
split := SplitOutlierChannels(values, 5)
for i := 1; i < len(split.OutlierIndices); i++ {
if split.OutlierIndices[i] <= split.OutlierIndices[i-1] {
t.Errorf("outlier indices not sorted at %d: %d <= %d", i, split.OutlierIndices[i], split.OutlierIndices[i-1])
}
}
for i := 1; i < len(split.RegularIndices); i++ {
if split.RegularIndices[i] <= split.RegularIndices[i-1] {
t.Errorf("regular indices not sorted at %d: %d <= %d", i, split.RegularIndices[i], split.RegularIndices[i-1])
}
}
}
func TestSplitOutlierChannelsDeterministic(t *testing.T) {
values := pseudoRandomVector(32, 99)
a := SplitOutlierChannels(values, 8)
b := SplitOutlierChannels(values, 8)
if len(a.OutlierIndices) != len(b.OutlierIndices) {
t.Fatal("non-deterministic outlier count")
}
for i := range a.OutlierIndices {
if a.OutlierIndices[i] != b.OutlierIndices[i] {
t.Fatalf("non-deterministic at outlier index %d", i)
}
}
}

158
turboquant/residual_qjl.go Normal file
View file

@ -0,0 +1,158 @@
package turboquant
import "math"
const qjlUnbiasScale = 1.2533141373155001 // sqrt(pi / 2)
type ResidualSketch struct {
Seed uint64
Scale float32 // residual L2 norm; retained name keeps the old struct shape stable
SketchDim uint16
Signs []byte
}
func encodeResidual(rotated, approx []float32, sketchSpec any, seed uint64) ResidualSketch {
var sketchRows int
switch v := sketchSpec.(type) {
case int:
sketchRows = v
case Preset:
sketchRows = v.KeyQJLRows(len(rotated))
default:
return ResidualSketch{}
}
if sketchRows <= 0 || len(rotated) == 0 {
return ResidualSketch{}
}
residual := make([]float32, len(rotated))
var l2 float64
for i := range rotated {
delta := rotated[i] - approx[i]
residual[i] = delta
l2 += float64(delta * delta)
}
if l2 == 0 {
return ResidualSketch{
Seed: seed,
SketchDim: uint16(sketchRows),
Signs: make([]byte, expectedPackedBytes(sketchRows, 1)),
}
}
signBits := make([]uint8, sketchRows)
for row := range sketchRows {
if gaussianProjectionDot(residual, seed, row) >= 0 {
signBits[row] = 1
}
}
return ResidualSketch{
Seed: seed,
Scale: float32(math.Sqrt(l2)),
SketchDim: uint16(sketchRows),
Signs: packBits(signBits, 1),
}
}
func reconstructResidual(dim int, sketch ResidualSketch) []float32 {
out := make([]float32, dim)
if dim == 0 || sketch.SketchDim == 0 || sketch.Scale == 0 {
return out
}
signBits := unpackBits(sketch.Signs, 1, int(sketch.SketchDim))
scale := float32(qjlUnbiasScale) * sketch.Scale / float32(sketch.SketchDim)
for row, bit := range signBits {
sign := float32(-1)
if bit == 1 {
sign = 1
}
// Residual reconstruction stays on float32 accumulation today; if a backend-specific half path is introduced, it needs an explicit FP32-accumulate audit before rollout.
for col := range dim {
out[col] += sign * gaussianProjectionEntry(sketch.Seed, row, col) * scale
}
}
return out
}
func residualDotCorrection(queryRot []float32, sketch ResidualSketch) float32 {
if len(queryRot) == 0 || sketch.SketchDim == 0 || sketch.Scale == 0 {
return 0
}
signBits := unpackBits(sketch.Signs, 1, int(sketch.SketchDim))
var total float32
for row, bit := range signBits {
sign := float32(-1)
if bit == 1 {
sign = 1
}
total += sign * gaussianProjectionDot(queryRot, sketch.Seed, row)
}
correction := float32(qjlUnbiasScale) * sketch.Scale * (total / float32(sketch.SketchDim))
queryNorm := float32(math.Sqrt(float64(dotSelf(queryRot))))
if queryNorm == 0 {
return 0
}
maxCorrection := sketch.Scale * queryNorm
if correction > maxCorrection {
return maxCorrection
}
if correction < -maxCorrection {
return -maxCorrection
}
if sketch.Scale < 1e-6 {
return 0
}
return correction
}
// PrecomputeCorrectionVec builds the vector w = (√π/2 · residualNorm / sketchDim) · Σ_j sign_j · G_j,
// where G_j is row j of the random Gaussian projection matrix and sign_j is the stored QJL sign bit.
// Scoring then reduces to dot(queryRotated, w), replacing the per-query O(dim²) Gaussian projection
// loop with a single O(dim) dot product.
//
// The returned slice has length dim and is nil when the sketch carries no correction (Scale==0).
func PrecomputeCorrectionVec(sketch ResidualSketch, dim int) []float32 {
if sketch.SketchDim == 0 || sketch.Scale == 0 {
return nil
}
out := make([]float32, dim)
signBits := unpackBits(sketch.Signs, 1, int(sketch.SketchDim))
scale := float32(qjlUnbiasScale) * sketch.Scale / float32(sketch.SketchDim)
for row, bit := range signBits {
sign := float32(-1)
if bit == 1 {
sign = 1
}
sv := sign * scale
for col := range dim {
out[col] += sv * gaussianProjectionEntry(sketch.Seed, row, col)
}
}
return out
}
func gaussianProjectionDot(values []float32, seed uint64, row int) float32 {
var out float32
for col, value := range values {
out += value * gaussianProjectionEntry(seed, row, col)
}
return out
}
func gaussianProjectionEntry(seed uint64, row, col int) float32 {
local := splitmix64(seed ^ uint64(row+1)*0x9e3779b97f4a7c15 ^ uint64(col+1)*0xbf58476d1ce4e5b9)
return float32(gaussianFloat64(&local))
}
func dotSelf(values []float32) float32 {
var out float32
for _, value := range values {
out += value * value
}
return out
}

View file

@ -0,0 +1,58 @@
package turboquant
import "testing"
func TestResidualSketchDeterministic(t *testing.T) {
a := encodeResidual([]float32{1, 2, 3, 4}, []float32{0, 0, 0, 0}, PresetTQ2, 99)
b := encodeResidual([]float32{1, 2, 3, 4}, []float32{0, 0, 0, 0}, PresetTQ2, 99)
if a.Scale != b.Scale || a.Seed != b.Seed || string(a.Signs) != string(b.Signs) {
t.Fatal("residual sketch is not deterministic")
}
}
func TestReconstructResidualLength(t *testing.T) {
sketch := encodeResidual(
[]float32{1, 2, 3, 4, 5, 6, 7, 8},
[]float32{0, 0, 0, 0, 0, 0, 0, 0},
PresetTQ3,
123,
)
reconstructed := reconstructResidual(8, sketch)
if len(reconstructed) != 8 {
t.Fatalf("reconstructed length = %d, want 8", len(reconstructed))
}
if len(sketch.Signs) != expectedPackedBytes(int(sketch.SketchDim), 1) {
t.Fatalf("sign sketch bytes = %d, want %d", len(sketch.Signs), expectedPackedBytes(int(sketch.SketchDim), 1))
}
}
func TestZeroResidualProducesZeroReconstruction(t *testing.T) {
sketch := encodeResidual(
[]float32{1, 1, 1, 1},
[]float32{1, 1, 1, 1},
PresetTQ3,
456,
)
reconstructed := reconstructResidual(4, sketch)
for i, value := range reconstructed {
if abs32(value) > 1e-6 {
t.Fatalf("reconstructed[%d] = %v, want 0", i, value)
}
}
}
func TestResidualDotCorrectionDeterministicAndFinite(t *testing.T) {
rotated := []float32{1.5, -2.5, 0.5, 3.0, -1.25, 2.25, 0.75, -0.5}
approx := []float32{1.0, -2.0, 0.25, 2.5, -1.0, 2.0, 0.5, -0.25}
sketch := encodeResidual(rotated, approx, PresetTQ3, 789)
query := []float32{0.2, -0.1, 0.3, 0.4, -0.2, 0.5, -0.6, 0.7}
a := residualDotCorrection(query, sketch)
b := residualDotCorrection(query, sketch)
if a != b {
t.Fatalf("dot correction = %v and %v, want deterministic output", a, b)
}
if a != a {
t.Fatal("dot correction returned NaN")
}
}

211
turboquant/rotation.go Normal file
View file

@ -0,0 +1,211 @@
package turboquant
import (
"math"
"sync"
)
type Rotation struct {
Dim int
Seed uint64
Matrix []float32 // row-major orthogonal matrix
}
type rotationCacheKey struct {
dim int
seed uint64
}
var rotationCache sync.Map
func BuildRotation(dim int, seed uint64) Rotation {
key := rotationCacheKey{dim: dim, seed: seed}
if cached, ok := rotationCache.Load(key); ok {
return cached.(Rotation)
}
rot := Rotation{
Dim: dim,
Seed: seed,
Matrix: buildOrthogonalMatrix(dim, seed),
}
actual, _ := rotationCache.LoadOrStore(key, rot)
return actual.(Rotation)
}
func ApplyRotation(x []float32, rot Rotation) []float32 {
if len(x) != rot.Dim {
panic("turboquant: vector length does not match rotation dimension")
}
out := make([]float32, rot.Dim)
for row := range rot.Dim {
base := row * rot.Dim
var sum float32
for col, value := range x {
sum += rot.Matrix[base+col] * value
}
out[row] = sum
}
return out
}
func ApplyInverseRotation(y []float32, rot Rotation) []float32 {
if len(y) != rot.Dim {
panic("turboquant: vector length does not match rotation dimension")
}
out := make([]float32, rot.Dim)
for row := range rot.Dim {
yVal := y[row]
base := row * rot.Dim
// The inverse rotation accumulates in float32 on the reconstruction path; keep this FP32-accumulate behavior explicit while long-generation corruption audits remain active.
for col := range rot.Dim {
out[col] += rot.Matrix[base+col] * yVal
}
}
return out
}
// buildOrthogonalMatrix returns a dim×dim orthogonal matrix derived from the
// given seed using Householder QR factorisation. The input is the same seeded
// Gaussian matrix used by the previous Gram-Schmidt path, but the QR Q-factor
// is numerically unconditionally orthogonal. This algorithm replaced the
// classical Gram-Schmidt used through BlockVersion 3; the Householder path was
// introduced at BlockVersion 4 (current: BlockVersion 6).
//
// Algorithm: apply dim-1 Householder reflectors H_1…H_{dim-1} from the left
// to reduce A to upper triangular form R. Simultaneously accumulate
// Q = H_1 * H_2 * … * H_{dim-1} by right-multiplying each reflector into Q,
// starting from the identity. The resulting Q is the orthogonal factor in
// A = QR and has orthonormal rows and columns.
func buildOrthogonalMatrix(dim int, seed uint64) []float32 {
if dim <= 0 {
return nil
}
// Initialise A with the same seeded Gaussian rows as before.
a := make([][]float64, dim)
for row := range dim {
a[row] = make([]float64, dim)
rng := splitmix64(seed ^ uint64(dim)<<32 ^ uint64(row+1)*0x9e3779b97f4a7c15)
for col := range dim {
a[row][col] = gaussianFloat64(&rng)
}
}
// Q starts as the identity; we accumulate Q = H_1 * … * H_{dim-1}.
q := make([][]float64, dim)
for i := range q {
q[i] = make([]float64, dim)
q[i][i] = 1.0
}
v := make([]float64, dim) // scratch buffer for the Householder vector
for k := range dim - 1 {
n := dim - k
// Build Householder vector for column k, rows k..dim-1.
// Sign chosen to avoid cancellation: sigma = sign(a[k][k]) * ||x||.
for i := range n {
v[i] = a[k+i][k]
}
sigma := vectorNorm64(v[:n])
if v[0] >= 0 {
sigma = -sigma
}
v[0] -= sigma
vnorm2 := dotFloat64(v[:n], v[:n])
if vnorm2 < 1e-28 {
continue // column already zeroed — no reflector needed
}
beta := 2.0 / vnorm2
// Apply H_k to a[k:, k:] from the left.
for j := k; j < dim; j++ {
var dot float64
for i := range n {
dot += v[i] * a[k+i][j]
}
dot *= beta
for i := range n {
a[k+i][j] -= dot * v[i]
}
}
// Apply H_k to q[:, k:] from the right: q ← q * H_k.
for i := range dim {
var dot float64
for j := range n {
dot += q[i][k+j] * v[j]
}
dot *= beta
for j := range n {
q[i][k+j] -= dot * v[j]
}
}
}
// Sign-normalise each row so the first non-negligible element is positive.
// This convention matches the previous Gram-Schmidt path and makes the
// output deterministic despite the reflector sign ambiguity.
for row := range dim {
for _, value := range q[row] {
if math.Abs(value) <= 1e-12 {
continue
}
if value < 0 {
for col := range dim {
q[row][col] = -q[row][col]
}
}
break
}
}
out := make([]float32, dim*dim)
for row := range dim {
for col := range dim {
out[row*dim+col] = float32(q[row][col])
}
}
return out
}
func dotFloat64(a, b []float64) float64 {
var out float64
for i := range a {
out += a[i] * b[i]
}
return out
}
func vectorNorm64(values []float64) float64 {
var sum float64
for _, value := range values {
sum += value * value
}
return math.Sqrt(sum)
}
func gaussianFloat64(rng *splitmix64) float64 {
u1 := unitUniform64(rng)
u2 := unitUniform64(rng)
return math.Sqrt(-2*math.Log(u1)) * math.Cos(2*math.Pi*u2)
}
func unitUniform64(rng *splitmix64) float64 {
const scale = 1.0 / (1 << 53)
return (float64(rng.next()>>11) + 0.5) * scale
}
type splitmix64 uint64
func (s *splitmix64) next() uint64 {
*s += 0x9e3779b97f4a7c15
z := uint64(*s)
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9
z = (z ^ (z >> 27)) * 0x94d049bb133111eb
return z ^ (z >> 31)
}

189
turboquant/rotation_test.go Normal file
View file

@ -0,0 +1,189 @@
package turboquant
import (
"fmt"
"math"
"testing"
)
// testDims covers small, medium, and the primary production dims.
var testDims = []int{4, 8, 16, 32, 64, 128}
func TestBuildRotationDeterministic(t *testing.T) {
for _, dim := range testDims {
a := BuildRotation(dim, 123)
b := BuildRotation(dim, 123)
if a.Dim != b.Dim || a.Seed != b.Seed || len(a.Matrix) != len(b.Matrix) {
t.Fatalf("rotation metadata mismatch for dim %d", dim)
}
for i := range a.Matrix {
if a.Matrix[i] != b.Matrix[i] {
t.Fatalf("rotation mismatch at dim=%d idx=%d", dim, i)
}
}
}
}
func TestBuildRotationDifferentSeedsDiffer(t *testing.T) {
a := BuildRotation(16, 111)
b := BuildRotation(16, 222)
different := false
for i := range a.Matrix {
if a.Matrix[i] != b.Matrix[i] {
different = true
break
}
}
if !different {
t.Fatal("different seeds produced the same orthogonal matrix")
}
}
func TestApplyInverseRotation(t *testing.T) {
for _, dim := range testDims {
values := pseudoRandomVector(dim, uint64(dim)*17)
rot := BuildRotation(dim, uint64(dim)*19)
got := ApplyInverseRotation(ApplyRotation(values, rot), rot)
for i := range values {
if abs32(values[i]-got[i]) > 1e-4 {
t.Fatalf("dim=%d idx=%d got=%v want=%v", dim, i, got[i], values[i])
}
}
}
}
func TestRotationPreservesNorm(t *testing.T) {
for _, dim := range testDims {
values := pseudoRandomVector(dim, uint64(dim)*23)
rot := BuildRotation(dim, uint64(dim)*29)
before := vectorNorm(values)
after := vectorNorm(ApplyRotation(values, rot))
if abs32(before-after) > 1e-3 {
t.Fatalf("dim=%d norm drift=%v", dim, abs32(before-after))
}
}
}
// TestAttentionScoreRotationInvariance verifies the exact mathematical
// invariant: dot(Q,K) == dot(R@Q, R@K) for an orthogonal matrix R. Because R
// is orthogonal, R^T R = I, so the inner product is preserved exactly (up to
// floating-point rounding). This is the property that makes rotating K before
// quantization safe when Q is also rotated at attention time.
func TestAttentionScoreRotationInvariance(t *testing.T) {
for _, dim := range []int{64, 128, 256} {
t.Run(fmt.Sprintf("dim=%d", dim), func(t *testing.T) {
seed := uint64(0x42c0ffee)
q := pseudoRandomVector(dim, seed)
k := pseudoRandomVector(dim, seed^0xbeef)
rot := BuildRotation(dim, seed+1)
qRot := ApplyRotation(q, rot)
kRot := ApplyRotation(k, rot)
var dotOrig float64
for i := range q {
dotOrig += float64(q[i]) * float64(k[i])
}
var dotRot float64
for i := range qRot {
dotRot += float64(qRot[i]) * float64(kRot[i])
}
relErr := math.Abs(dotOrig-dotRot) / (math.Abs(dotOrig) + 1e-10)
if relErr > 1e-4 {
t.Errorf("dim=%d: dot(Q,K)=%.6f dot(RQ,RK)=%.6f relErr=%.2e",
dim, dotOrig, dotRot, relErr)
}
})
}
}
// TestQuantizedAttentionScorePreservation verifies the practical end-to-end
// path: encode K per-head (which stores R@k), rotate Q (giving R@q), then
// compute (R@q)·(R@k_quant) and compare against the true dot(Q,K). Single
// trials can have high per-sample error from quantization noise, so we average
// over many trials and check the mean relative error, matching the pattern used
// by TestEncodeKeyPerHeadRoundTrip.
func TestQuantizedAttentionScorePreservation(t *testing.T) {
const dim = 128
const trials = 100
for _, preset := range []Preset{PresetTQ2, PresetTQ3} {
t.Run(preset.Name, func(t *testing.T) {
rng := splitmix64(0x1111feed)
rot := BuildRotation(dim, preset.RotationSeed)
var totalAbsErr, totalAbsDot float64
for trial := range trials {
q := make([]float32, dim)
k := make([]float32, dim)
for j := range q {
q[j] = float32(gaussianFloat64(&rng))
k[j] = float32(gaussianFloat64(&rng))
}
// True attention score.
var trueScore float64
for i := range q {
trueScore += float64(q[i]) * float64(k[i])
}
// Quantized path: encode K in rotated space, rotate Q, compute score.
packed, scale, err := EncodeKeyPerHead(k, preset)
if err != nil {
t.Fatalf("trial %d EncodeKeyPerHead: %v", trial, err)
}
kRecon := DequantKeyPerHead(packed, scale, dim, preset.KeyPrimaryBits)
qRot := ApplyRotation(q, rot)
var quantScore float64
for i := range qRot {
quantScore += float64(qRot[i]) * float64(kRecon[i])
}
totalAbsErr += math.Abs(trueScore - quantScore)
totalAbsDot += math.Abs(trueScore)
}
avgRelErr := totalAbsErr / (totalAbsDot + 1e-10)
t.Logf("preset=%s avg relative dot error = %.4f over %d trials", preset.Name, avgRelErr, trials)
// tq3 (3-bit) is tighter than tq2 (2-bit); these thresholds match the
// codec quality validated by TestEncodeKeyPerHeadRoundTrip.
maxRelErr := 0.45
if preset.KeyPrimaryBits >= 3 {
maxRelErr = 0.25
}
if avgRelErr > maxRelErr {
t.Errorf("preset=%s avg relative dot error = %.4f, want <= %.4f",
preset.Name, avgRelErr, maxRelErr)
}
})
}
}
// TestBuildRotationIsOrthogonal verifies that Q satisfies Q*Q^T = I (rows are
// orthonormal). This is the core invariant required by the TurboQuant encoding
// and is guaranteed unconditionally by the Householder QR algorithm.
func TestBuildRotationIsOrthogonal(t *testing.T) {
// Include dim=256 to exercise the algorithm well beyond typical head_dim.
for _, dim := range append(testDims, 256) {
rot := BuildRotation(dim, uint64(dim)*31)
for i := range dim {
for j := i; j < dim; j++ {
var dot float32
for k := range dim {
dot += rot.Matrix[i*dim+k] * rot.Matrix[j*dim+k]
}
if i == j {
if math.Abs(float64(dot-1)) > 5e-5 {
t.Fatalf("dim=%d row %d: self-dot=%.6f, want 1.0", dim, i, dot)
}
} else if math.Abs(float64(dot)) > 5e-5 {
t.Fatalf("dim=%d rows %d,%d: cross-dot=%.6f, want 0.0", dim, i, j, dot)
}
}
}
}
}

36
turboquant/stats.go Normal file
View file

@ -0,0 +1,36 @@
package turboquant
import "math"
type Stats struct {
MSE float32
RMSE float32
MeanAbsErr float32
MaxAbsErr float32
}
func Compare(reference, approx []float32) Stats {
var s Stats
if len(reference) == 0 || len(reference) != len(approx) {
return s
}
for i := range reference {
err := reference[i] - approx[i]
s.MSE += err * err
s.MeanAbsErr += abs32(err)
if abs32(err) > s.MaxAbsErr {
s.MaxAbsErr = abs32(err)
}
}
s.MSE /= float32(len(reference))
s.RMSE = float32(math.Sqrt(float64(s.MSE)))
s.MeanAbsErr /= float32(len(reference))
return s
}
func abs32(v float32) float32 {
if v < 0 {
return -v
}
return v
}

41
turboquant/stats_test.go Normal file
View file

@ -0,0 +1,41 @@
package turboquant
import (
"math"
"testing"
)
func TestCompareIdenticalVectors(t *testing.T) {
values := []float32{1, -2, 3, -4}
stats := Compare(values, values)
if stats.MSE != 0 || stats.RMSE != 0 || stats.MeanAbsErr != 0 || stats.MaxAbsErr != 0 {
t.Fatalf("unexpected non-zero stats: %+v", stats)
}
}
func TestCompareKnownExample(t *testing.T) {
reference := []float32{1, 2, 3}
approx := []float32{2, 0, 3}
stats := Compare(reference, approx)
if abs32(stats.MSE-float32(5.0/3.0)) > 1e-6 {
t.Fatalf("MSE = %v, want %v", stats.MSE, float32(5.0/3.0))
}
if abs32(stats.MeanAbsErr-1) > 1e-6 {
t.Fatalf("MeanAbsErr = %v, want 1", stats.MeanAbsErr)
}
if abs32(stats.MaxAbsErr-2) > 1e-6 {
t.Fatalf("MaxAbsErr = %v, want 2", stats.MaxAbsErr)
}
}
func TestCompareRMSEMatchesSqrtMSE(t *testing.T) {
reference := []float32{1, 2, 3}
approx := []float32{2, 0, 3}
stats := Compare(reference, approx)
want := float32(math.Sqrt(float64(stats.MSE)))
if abs32(stats.RMSE-want) > 1e-6 {
t.Fatalf("RMSE = %v, want %v", stats.RMSE, want)
}
}

View file

@ -0,0 +1,66 @@
package turboquant
import "math"
type namedVector struct {
name string
values []float32
}
func deterministicCorpus() []namedVector {
return []namedVector{
{name: "ramp-16", values: []float32{-4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}},
{name: "alternating-16", values: []float32{1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8}},
{name: "sparse-17", values: []float32{0, 0, 0, 7, 0, 0, -5, 0, 0, 0, 9, 0, 0, 0, 0, -3, 0}},
{name: "constant-31", values: filledVector(31, 1.5)},
{name: "zero-64", values: filledVector(64, 0)},
{name: "random-65", values: pseudoRandomVector(65, 0x1234abcd)},
}
}
func pseudoRandomVector(n int, seed uint64) []float32 {
rng := splitmix64(seed)
out := make([]float32, n)
for i := range out {
u := float32(rng.next()&0xffff) / 65535
out[i] = (u * 10) - 5
}
return out
}
func filledVector(n int, value float32) []float32 {
out := make([]float32, n)
for i := range out {
out[i] = value
}
return out
}
func vectorNorm(values []float32) float32 {
var sum float64
for _, v := range values {
sum += float64(v) * float64(v)
}
return float32(math.Sqrt(sum))
}
func meanMSEForPreset(preset Preset) (float32, error) {
var total float32
corpus := deterministicCorpus()
for _, tc := range corpus {
encoded, err := EncodeVector(tc.values, preset)
if err != nil {
return 0, err
}
data, err := encoded.MarshalBinary()
if err != nil {
return 0, err
}
decoded, _, err := DecodeVector(data)
if err != nil {
return 0, err
}
total += Compare(tc.values, decoded).MSE
}
return total / float32(len(corpus)), nil
}

136
turboquant/turboquant.go Normal file
View file

@ -0,0 +1,136 @@
package turboquant
import (
"fmt"
)
const BlockVersion = 6
type vectorRole uint8
const (
roleGeneric vectorRole = iota
roleKey
roleValue
)
type vectorObjective uint8
const (
objectiveMSE vectorObjective = iota + 1
objectiveProduct
)
type Preset struct {
ID uint8
Name string
RotationSeed uint64
KeyPrimaryBits int
ValueBits int
QJLRowsDivisor int
OutlierBits int
OutlierCount int
}
var (
// All four tq* presets ship with OutlierCount=0 (pure uniform Lloyd-Max
// after Householder QR rotation, i.e. the core of TurboQuant Algorithm 1
// §3.1 without the optional outlier split from §4.3 or the QJL residual
// sketch from Algorithm 2). The uniform defaults were chosen after
// measuring that on the models this fork ships against (llama, gemma3,
// qwen3-coder), outlier split hurts both decode throughput and PPL — the
// paper's split targets heavy-tailed rotated K distributions, which these
// models don't exhibit, and the extra metadata (92 vs 52 bytes/head/cell
// at oc=32) translates to ~25% decode regression on 3B-class models at
// short context. Keeping the defaults symmetric across tq2 / tq3 / tq2k /
// tq3k means the digit in the preset name maps directly to effective
// bits/elem: "tq3" is exactly 3 bits, not the 3.25 bits you'd get under
// outlier split with oc=32.
//
// The outlier-split kernel path remains in the code (encode/dequant
// dispatchers check op_params[2..3] and route to the outlier variant when
// both are non-zero). A future dynamic-dispatch PR can enable it per
// model / per env var — e.g. for the qwen2 family once Phase 2A
// asymmetric quantization lands, or for models with larger headDim where
// the metadata overhead amortizes better. See project_tq_backlog.md for
// the planned dynamic-oc dispatch work.
// tq2: 2-bit K + 2-bit V, both rotated and Lloyd-Max quantized. Highest
// compression tier — effective 2 bits/elem both sides.
PresetTQ2 = newPreset(1, "tq2", 2, 2, 1, 0x25c0ffee, 3, 0)
// tq3: 3-bit K + 3-bit V, both rotated and Lloyd-Max quantized. Default
// "balanced" tier — effective 3 bits/elem both sides.
PresetTQ3 = newPreset(2, "tq3", 3, 3, 1, 0x35c0ffee, 4, 0)
// tq3k: 3-bit K only, V stays as f16. ~40% KV VRAM savings with near-f16
// decode (no V dequant at all). ValueBits=0 signals K-only mode to the
// kvcache layer.
PresetTQ3K = newPreset(3, "tq3k", 3, 0, 1, 0x35c0ffee, 4, 0)
// tq2k: 2-bit K only, V stays as f16. Maximum K compression with f16 V;
// smallest K footprint before PPL degrades too much. ValueBits=0 signals
// K-only mode to the kvcache layer.
PresetTQ2K = newPreset(4, "tq2k", 2, 0, 1, 0x25c0ffee, 3, 0)
)
func newPreset(id uint8, name string, keyBits int, valueBits int, qjlRowsDivisor int, seed uint64, outlierBits int, outlierCount int) Preset {
return Preset{
ID: id,
Name: name,
RotationSeed: seed,
KeyPrimaryBits: keyBits,
ValueBits: valueBits,
QJLRowsDivisor: qjlRowsDivisor,
OutlierBits: outlierBits,
OutlierCount: outlierCount,
}
}
func (p Preset) HasOutlierSplit() bool {
return p.OutlierBits > 0 && p.OutlierCount > 0
}
func PresetByName(name string) (Preset, error) {
switch name {
case "tq2":
return PresetTQ2, nil
case "tq3":
return PresetTQ3, nil
case "tq3k":
return PresetTQ3K, nil
case "tq2k":
return PresetTQ2K, nil
default:
return Preset{}, fmt.Errorf("unknown turboquant preset %q", name)
}
}
func PresetByID(id uint8) (Preset, error) {
switch id {
case PresetTQ2.ID:
return PresetTQ2, nil
case PresetTQ3.ID:
return PresetTQ3, nil
case PresetTQ3K.ID:
return PresetTQ3K, nil
case PresetTQ2K.ID:
return PresetTQ2K, nil
default:
return Preset{}, fmt.Errorf("unknown turboquant preset id %d", id)
}
}
func (p Preset) KeyQJLRows(dim int) int {
if dim <= 0 {
return 0
}
if p.QJLRowsDivisor <= 0 {
return 0
}
rows := dim / p.QJLRowsDivisor
if rows < 1 {
return 1
}
return rows
}