mirror of
https://github.com/ollama/ollama
synced 2026-04-23 08:45:14 +00:00
Merge c56d85ee5e into 21883571b7
This commit is contained in:
commit
a31703aa73
|
|
@ -41,6 +41,7 @@ set(GGML_LLAMAFILE ON)
|
|||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||
set(GGML_CUDA_GRAPHS ON)
|
||||
set(GGML_CUDA_FA ON)
|
||||
set(GGML_CUDA_FA_ALL_QUANTS OFF)
|
||||
set(GGML_CUDA_COMPRESSION_MODE default)
|
||||
|
||||
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
|
|
|
|||
10
docs/faq.mdx
10
docs/faq.mdx
|
|
@ -363,6 +363,16 @@ The currently available K/V cache quantization types are:
|
|||
- `f16` - high precision and memory usage (default).
|
||||
- `q8_0` - 8-bit quantization, uses approximately 1/2 the memory of `f16` with a very small loss in precision, this usually has no noticeable impact on the model's quality (recommended if not using f16).
|
||||
- `q4_0` - 4-bit quantization, uses approximately 1/4 the memory of `f16` with a small-medium loss in precision that may be more noticeable at higher context sizes.
|
||||
- `tq3k` - TurboQuant 3-bit K-only quantization, uses approximately 3/5 the memory of `f16` with a negligible quality impact, and unlike `q4_0`/`q8_0` does not require Flash Attention to be enabled.
|
||||
- `tq2k` - TurboQuant 2-bit K-only quantization, uses approximately 3/5 the memory of `f16` with a small quality impact, and unlike `q4_0`/`q8_0` does not require Flash Attention to be enabled.
|
||||
- `tq3` - TurboQuant 3-bit K and V quantization, uses approximately 1/5 the memory of `f16` with a negligible quality impact. Requires Flash Attention.
|
||||
- `tq2` - TurboQuant 2-bit K and V quantization, uses approximately 1/7 the memory of `f16` with a small quality impact. Requires Flash Attention.
|
||||
|
||||
<Note>
|
||||
The `tq*` (TurboQuant) cache types require an NVIDIA GPU with compute capability 6.0 or newer (Pascal or later) and run only through Ollama's native engine. TurboQuant applies to full-context attention layers; sliding-window layers (e.g. in Gemma) continue to use `f16` storage.
|
||||
|
||||
TurboQuant has been validated on Llama and Gemma 3 architectures. Models in the Qwen 2 family (including Qwen 2.5) are a known weak spot: their learned K bias gives them a per-channel asymmetric distribution that the paper-grounded rotation does not model well, so output quality under `tq3k`/`tq2k` is currently noticeably worse than `q8_0` or `q4_0`. For stable long-context use on Qwen 2, prefer `q8_0` or `q4_0`.
|
||||
</Note>
|
||||
|
||||
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
|
||||
|
||||
|
|
|
|||
|
|
@ -852,7 +852,7 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||
return slices.Contains([]string{"q8_0", "q4_0", "tq2", "tq3", "tq3k", "tq2k"}, cacheType)
|
||||
}
|
||||
|
||||
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
|
||||
|
|
@ -863,6 +863,20 @@ func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// KVCacheTypeRequiresFlashAttention reports whether a KV cache type can only be
|
||||
// used when flash attention is enabled. Ggml-native quantized types (q8_0,
|
||||
// q4_0) store K/V as quantized tensors that the non-FA softmax+matmul attention
|
||||
// path cannot consume directly. TurboQuant K-only presets (tq2k, tq3k) dequant
|
||||
// the packed K buffer to f16 before the attention op runs, so they work with
|
||||
// either FA or the standard attention path.
|
||||
func (f GGML) KVCacheTypeRequiresFlashAttention(cacheType string) bool {
|
||||
switch cacheType {
|
||||
case "tq2k", "tq3k":
|
||||
return false
|
||||
}
|
||||
return f.KVCacheTypeIsQuantized(cacheType)
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
func (f GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||
|
|
@ -913,6 +927,18 @@ func kvCacheBytesPerElement(cacheType string) float64 {
|
|||
return 1 // 1/2 of fp16
|
||||
case "q4_0":
|
||||
return 0.5 // 1/4 of fp16
|
||||
case "tq3":
|
||||
// 3-bit TQ K (~0.41 B/elem) + 3-bit TQ V (~0.41 B/elem), averaged over K+V
|
||||
return 0.41
|
||||
case "tq3k":
|
||||
// 3-bit TQ K (~0.41 B/elem) + f16 V (2 B/elem), averaged over K+V
|
||||
return 1.205
|
||||
case "tq2":
|
||||
// 2-bit TQ K (~0.28 B/elem) + 2-bit TQ V (~0.28 B/elem), averaged over K+V
|
||||
return 0.28
|
||||
case "tq2k":
|
||||
// 2-bit TQ K (~0.28 B/elem) + f16 V (2 B/elem), averaged over K+V
|
||||
return 1.14
|
||||
case "f32":
|
||||
return 4 // f32 (default for recurrent)
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -82,3 +82,12 @@ type Cache interface {
|
|||
type CheckpointCache interface {
|
||||
PrepareRestore(seq int, targetPos int32) (int32, bool)
|
||||
}
|
||||
|
||||
// CausalConfigurable is implemented by caches that accept CausalOptions
|
||||
// (the concrete *Causal type and any wrappers that delegate to it, e.g.,
|
||||
// *TurboQuantCache). Model code that needs to set per-layer masking options
|
||||
// should assert against this interface rather than the concrete *Causal so
|
||||
// that TurboQuant-wrapped global sub-caches still receive SetCausal calls.
|
||||
type CausalConfigurable interface {
|
||||
SetCausal(ctx ml.Context, opts CausalOptions)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,21 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|||
// The tensors are of shape embed dim, kv heads, batch size
|
||||
// The mask is of shape history size, batch size
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
// DTypeK and DTypeV are the storage dtypes for K and V respectively.
|
||||
// Init takes a single dtype and assigns it to both unless a caller has
|
||||
// explicitly set one or both before Init runs (e.g., TurboQuantCache
|
||||
// which forces its inner Causal to f16 on both sides while routing
|
||||
// compressed K through a separate manager via SkipK).
|
||||
DTypeK ml.DType
|
||||
DTypeV ml.DType
|
||||
|
||||
// SkipK suppresses K tensor allocation and writes. When true, an external
|
||||
// manager (e.g., TurboQuantCache) handles K storage and returns K from Get.
|
||||
SkipK bool
|
||||
|
||||
// SkipV suppresses V tensor allocation and writes. When true, an external
|
||||
// manager (e.g., TurboQuantCache) handles V storage and returns V from Get.
|
||||
SkipV bool
|
||||
|
||||
// swaWindowSize is the number of tokens that will be included in the mask
|
||||
// during attention operations. swaMemorySize is the number of tokens that
|
||||
|
|
@ -44,6 +58,10 @@ type Causal struct {
|
|||
|
||||
// locations for data storage for this batch
|
||||
curLoc ml.Tensor
|
||||
// curLocs mirrors curLoc as a raw int slice for TurboQuant's per-cell
|
||||
// byte-map indexing. Only populated when SkipK or SkipV is set — non-TQ
|
||||
// users never pay the allocation cost.
|
||||
curLocs []int
|
||||
|
||||
// mask of the cache as used by this batch
|
||||
curMask ml.Tensor
|
||||
|
|
@ -175,7 +193,15 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
|||
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
c.cells = make([]cacheCell, cacheSize)
|
||||
|
||||
c.DType = dtype
|
||||
// Resolve effective K/V dtypes. If a caller (e.g. TurboQuantCache) hasn't
|
||||
// already set DTypeK or DTypeV before Init runs, fall back to the single
|
||||
// dtype parameter for both — the historical single-dtype behaviour.
|
||||
if c.DTypeK == ml.DTypeOther {
|
||||
c.DTypeK = dtype
|
||||
}
|
||||
if c.DTypeV == ml.DTypeOther {
|
||||
c.DTypeV = dtype
|
||||
}
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
c.maxBatch = maxBatch
|
||||
|
|
@ -242,6 +268,17 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
|||
}
|
||||
|
||||
c.curLoc = ctx.Input().FromInts(locs, len(locs))
|
||||
// curLocs is only needed by TurboQuantCache's per-cell byte-map indexing.
|
||||
// Skip the allocation and conversion loop on the non-TQ path so vanilla
|
||||
// users don't pay per-batch overhead for an unused slice.
|
||||
if c.SkipK || c.SkipV {
|
||||
c.curLocs = make([]int, len(locs))
|
||||
for i, l := range locs {
|
||||
c.curLocs[i] = int(l)
|
||||
}
|
||||
} else {
|
||||
c.curLocs = nil
|
||||
}
|
||||
c.curMask = c.buildMask(ctx)
|
||||
|
||||
return nil
|
||||
|
|
@ -254,8 +291,39 @@ func newRange() cellRange {
|
|||
}
|
||||
}
|
||||
|
||||
// Returns a slice of locations where each token in the batch should be stored
|
||||
// Returns a slice of locations where each token in the batch should be stored.
|
||||
//
|
||||
// When SkipK/SkipV is set (TurboQuant's compressed path), the backing TQ
|
||||
// encode kernels write a contiguous run of cells starting at loc[0]; a
|
||||
// fragmented allocation would desynchronize the compressed K/V buffers
|
||||
// from the per-cell metadata. Require a contiguous empty run of the
|
||||
// required length and surface ErrKvCacheFull otherwise — the runner treats
|
||||
// that the same as an out-of-space condition, which lets a fragmented
|
||||
// cache recover as other sequences complete rather than crashing the
|
||||
// process on a gapped write.
|
||||
func (c *Causal) findLocs() ([]int32, error) {
|
||||
if c.SkipK || c.SkipV {
|
||||
runStart, runLen := -1, 0
|
||||
for i := range c.cells {
|
||||
if len(c.cells[i].sequences) == 0 {
|
||||
if runLen == 0 {
|
||||
runStart = i
|
||||
}
|
||||
runLen++
|
||||
if runLen >= c.curBatchSize {
|
||||
loc := make([]int32, c.curBatchSize)
|
||||
for j := range loc {
|
||||
loc[j] = int32(runStart + j)
|
||||
}
|
||||
return loc, nil
|
||||
}
|
||||
} else {
|
||||
runLen = 0
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("%w (cache: %v batch: %v, no contiguous run)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
|
||||
}
|
||||
|
||||
loc := make([]int32, 0, c.curBatchSize)
|
||||
|
||||
for i := range c.cells {
|
||||
|
|
@ -409,40 +477,47 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
|||
}
|
||||
|
||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
cachedSize := c.curMask.Dim(0)
|
||||
|
||||
key = key.View(ctx, rowSize*c.curCellRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
cachedSize,
|
||||
)
|
||||
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
value = value.View(ctx, elemSize*c.curCellRange.min,
|
||||
cachedSize, value.Stride(1),
|
||||
vHeadDim, value.Stride(2),
|
||||
numKVHeads,
|
||||
)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
value = value.View(ctx, rowSize*c.curCellRange.min,
|
||||
vHeadDim, value.Stride(1),
|
||||
numKVHeads, value.Stride(2),
|
||||
var key ml.Tensor
|
||||
if !c.SkipK {
|
||||
k := c.keys[c.curLayer]
|
||||
kHeadDim := k.Dim(0)
|
||||
numKVHeads := k.Dim(1)
|
||||
rowSize := k.Stride(2)
|
||||
key = k.View(ctx, rowSize*c.curCellRange.min,
|
||||
kHeadDim, k.Stride(1),
|
||||
numKVHeads, k.Stride(2),
|
||||
cachedSize,
|
||||
)
|
||||
}
|
||||
|
||||
var value ml.Tensor
|
||||
if !c.SkipV {
|
||||
v := c.values[c.curLayer]
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := v.Dim(1)
|
||||
vKVHeads := v.Dim(2)
|
||||
elemSize := v.Stride(0)
|
||||
|
||||
value = v.View(ctx, elemSize*c.curCellRange.min,
|
||||
cachedSize, v.Stride(1),
|
||||
vHeadDim, v.Stride(2),
|
||||
vKVHeads,
|
||||
)
|
||||
} else {
|
||||
vHeadDim := v.Dim(0)
|
||||
vKVHeads := v.Dim(1)
|
||||
rowSize := v.Stride(2)
|
||||
|
||||
value = v.View(ctx, rowSize*c.curCellRange.min,
|
||||
vHeadDim, v.Stride(1),
|
||||
vKVHeads, v.Stride(2),
|
||||
cachedSize,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return key, value, c.curMask
|
||||
}
|
||||
|
||||
|
|
@ -460,37 +535,45 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
|
||||
}
|
||||
|
||||
if _, ok := c.values[c.curLayer]; !ok {
|
||||
if c.config.PermutedV {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
|
||||
} else {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
|
||||
if !c.SkipK {
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DTypeK, kHeadDim, numKVHeads, len(c.cells))
|
||||
}
|
||||
}
|
||||
|
||||
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
|
||||
keyCache := c.keys[c.curLayer]
|
||||
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
|
||||
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
|
||||
if !c.SkipV {
|
||||
if _, ok := c.values[c.curLayer]; !ok {
|
||||
if c.config.PermutedV {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DTypeV, len(c.cells), vHeadDim, numKVHeads)
|
||||
} else {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DTypeV, vHeadDim, numKVHeads, len(c.cells))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.PermutedV {
|
||||
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||
value = value.Permute(ctx, 2, 0, 1, 3)
|
||||
if !c.SkipK {
|
||||
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
|
||||
keyCache := c.keys[c.curLayer]
|
||||
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
|
||||
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
|
||||
}
|
||||
|
||||
valueCache := c.values[c.curLayer]
|
||||
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||
if !c.SkipV {
|
||||
if c.config.PermutedV {
|
||||
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
|
||||
value = value.Permute(ctx, 2, 0, 1, 3)
|
||||
|
||||
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||
} else {
|
||||
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
|
||||
valueCache := c.values[c.curLayer]
|
||||
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
|
||||
valueCache := c.values[c.curLayer]
|
||||
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
|
||||
|
||||
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||
} else {
|
||||
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
|
||||
valueCache := c.values[c.curLayer]
|
||||
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
|
||||
|
||||
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ var (
|
|||
// Conv state shape (per layer, per sequence): [convDim, convChannels]
|
||||
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
|
||||
type Recurrent struct {
|
||||
kv *Causal
|
||||
kv Cache
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
|
|
@ -95,6 +95,19 @@ type Recurrent struct {
|
|||
writableError error
|
||||
}
|
||||
|
||||
// AttentionKV returns the inner attention KV cache when it is still the
|
||||
// original *Causal. Returns nil after the cache has been replaced (e.g. by
|
||||
// *TurboQuantCache), making WrapWithTurboQuant idempotently skip a second wrap.
|
||||
func (c *Recurrent) AttentionKV() *Causal {
|
||||
causal, _ := c.kv.(*Causal)
|
||||
return causal
|
||||
}
|
||||
|
||||
// SetAttentionKV replaces the inner attention KV cache. Used by
|
||||
// WrapWithTurboQuant to inject compression into the attention path
|
||||
// without disturbing the embedded conv/recurrent state tensors.
|
||||
func (c *Recurrent) SetAttentionKV(kv Cache) { c.kv = kv }
|
||||
|
||||
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
|
||||
return &Recurrent{
|
||||
kv: NewCausalCache(config.Shift),
|
||||
|
|
|
|||
619
kvcache/turboquant.go
Normal file
619
kvcache/turboquant.go
Normal file
|
|
@ -0,0 +1,619 @@
|
|||
package kvcache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/turboquant"
|
||||
)
|
||||
|
||||
type TurboQuantCache struct {
|
||||
meta *Causal
|
||||
preset turboquant.Preset
|
||||
isReserve bool
|
||||
|
||||
compressedK ml.TQCompressedKManager
|
||||
|
||||
// phase2Checked ensures GPU encode is activated at most once.
|
||||
phase2Checked bool
|
||||
|
||||
headDim int
|
||||
numKVHeads int
|
||||
|
||||
// encodeResults stores per-layer EncodeK result tensors for the current
|
||||
// forward pass. DequantK uses them as src[0] to establish the graph
|
||||
// dependency (encode before dequant in the ggml scheduler).
|
||||
encodeResults map[int]ml.Tensor
|
||||
|
||||
// vEncodeResults stores per-layer EncodeV result tensors for the current
|
||||
// forward pass. DequantV uses them to establish the encode→dequant ordering.
|
||||
vEncodeResults map[int]ml.Tensor
|
||||
|
||||
// logPathOnce ensures each active Get() path is logged at most once per
|
||||
// cache instance (avoids log spam: Get() is called every layer every step).
|
||||
logPathOnce [5]sync.Once
|
||||
|
||||
// fusedFallbackEligible gates the inline-decode fused-FA fallback paths
|
||||
// (Get paths 2 and 4). The CUDA fused kernel is template-instantiated only
|
||||
// at D=128, so models with a larger head dim (gemma4 D=512) must skip it
|
||||
// to avoid a kernel-side GGML_ASSERT. The Metal fused kernel has both
|
||||
// D=128 and D=256 variants (kernel_tq_fattn_vec_*{,_d256}), so gemma3
|
||||
// D=256 is eligible on Metal but not on CUDA until the CUDA kernel gains
|
||||
// a D=256 instantiation. The DequantK + stock FA path (Get paths 0/1/5)
|
||||
// works at any head dim — this gate is specific to the inline-decode
|
||||
// variants.
|
||||
fusedFallbackEligible bool
|
||||
|
||||
// preferFusedAttn is true on Metal. At long context, DequantKV + stock FA
|
||||
// writes a full f16 intermediate buffer (doubling KV bandwidth) before
|
||||
// attention reads it. The fused kernel (kernel_tq_fattn_vec_packed) reads
|
||||
// packed K+V directly and skips the intermediate write — dramatically
|
||||
// faster on Metal for *decode* (Q=1). On CUDA the DequantKV path is
|
||||
// preferred because the intermediate buffer stays in L2 and stock flash
|
||||
// attention is highly tuned.
|
||||
//
|
||||
// For prefill (Q>1, see curQueryLen) DequantKV wins on every backend
|
||||
// because it decodes each packed cell once then runs stock FA with full
|
||||
// batch amortisation, whereas the fused kernel re-decodes each cell per
|
||||
// Q-token. Get() uses the combination of preferFusedAttn and curQueryLen
|
||||
// to route prefill vs decode independently.
|
||||
preferFusedAttn bool
|
||||
|
||||
// curQueryLen is the number of query tokens in the current forward pass,
|
||||
// captured at StartForward from len(batch.Positions). curQueryLen==1
|
||||
// means decode; curQueryLen>1 means prefill (or a batched prompt chunk).
|
||||
curQueryLen int
|
||||
|
||||
// rotMatrix is the R^T rotation matrix sized for this cache's headDim.
|
||||
// Set in activateGPUEncode. Get() sets the backend's tqRotationMatrix to
|
||||
// this value per-call (consume-once) right before returning rotated K.
|
||||
rotMatrix ml.Tensor
|
||||
|
||||
// vRotMatrix is the R (inverse) matrix used to undo V rotation after
|
||||
// attention in K+V presets (tq3/tq2). Nil for K-only presets. Get() sets
|
||||
// the backend's tqVRotationMatrix to this value per-call along with
|
||||
// rotMatrix so SDPA applies R @ attn_out after flash attention.
|
||||
vRotMatrix ml.Tensor
|
||||
|
||||
// rotSetter is the cached type assertion of c.meta.backend onto the TQ
|
||||
// rotation setter interface, populated once in activateGPUEncode. nil
|
||||
// when the backend doesn't support TQ rotation hooks (the fallback case).
|
||||
rotSetter tqRotSetter
|
||||
}
|
||||
|
||||
// tqRotSetter is the backend hook TurboQuantCache uses to arm the per-call
|
||||
// rotation matrices SDPA consumes. Implemented by ml/backend/ggml.Backend.
|
||||
type tqRotSetter interface {
|
||||
SetTQRotationMatrix(ml.Tensor)
|
||||
SetTQVRotationMatrix(ml.Tensor)
|
||||
}
|
||||
|
||||
// isSWACausal reports whether a *Causal has sliding-window attention
|
||||
// active. Plain Causal caches have swaWindowSize either 0 (before Init
|
||||
// normalizes the default) or math.MaxInt32 (after); SWA constructors set
|
||||
// it to the actual window size.
|
||||
func isSWACausal(c *Causal) bool {
|
||||
return c.swaWindowSize > 0 && c.swaWindowSize != math.MaxInt32
|
||||
}
|
||||
|
||||
// AttentionKVWrapper is implemented by caches that embed *Recurrent and
|
||||
// expose the attention half of a hybrid (SSM/recurrent + attention) cache.
|
||||
// WrapWithTurboQuant uses it to inject TurboQuant compression into the
|
||||
// attention KV path without disturbing conv/recurrent state buffers.
|
||||
// *kvcache.Recurrent implements this interface, so any model HybridCache that
|
||||
// embeds *Recurrent satisfies it automatically via Go method promotion.
|
||||
type AttentionKVWrapper interface {
|
||||
AttentionKV() *Causal
|
||||
SetAttentionKV(Cache)
|
||||
}
|
||||
|
||||
// WrapWithTurboQuant returns a cache that applies TurboQuant compression to
|
||||
// global-attention Causal layers and a bool reporting whether any wrapping
|
||||
// took effect. For a top-level *Causal (non-SWA), it returns a new
|
||||
// *TurboQuantCache. For a *WrapperCache, it mutates the caches slice in
|
||||
// place, replacing every non-SWA *Causal sub-cache with a *TurboQuantCache,
|
||||
// and returns the same *WrapperCache pointer. This enables TQ on SWA models
|
||||
// like gemma3/gemma4 where the global attention layers dominate KV memory
|
||||
// at long context. Returns (cache, false) if no eligible sub-caches were
|
||||
// found.
|
||||
func WrapWithTurboQuant(cache Cache, preset turboquant.Preset) (Cache, bool) {
|
||||
switch c := cache.(type) {
|
||||
case *Causal:
|
||||
// Reject SWA caches. Plain NewCausalCache leaves swaWindowSize=0
|
||||
// until Init() normalizes it to math.MaxInt32; SWA constructors set
|
||||
// it to the actual window size. "Plain causal" means the field is
|
||||
// either 0 (uninitialized default) or math.MaxInt32 (post-Init).
|
||||
if isSWACausal(c) {
|
||||
slog.Warn("turboquant: top-level Causal is sliding-window, cannot wrap")
|
||||
return cache, false
|
||||
}
|
||||
return &TurboQuantCache{
|
||||
meta: c,
|
||||
preset: preset,
|
||||
encodeResults: make(map[int]ml.Tensor),
|
||||
vEncodeResults: make(map[int]ml.Tensor),
|
||||
}, true
|
||||
|
||||
case *WrapperCache:
|
||||
// Mutate sub-caches in place: replace every non-SWA *Causal with a
|
||||
// *TurboQuantCache wrapping it. SWA sub-caches (SWACache, SWAMemCache)
|
||||
// are left untouched — they still allocate f16 K/V as before.
|
||||
wrapped := 0
|
||||
for i, sub := range c.caches {
|
||||
inner, ok := sub.(*Causal)
|
||||
if !ok || isSWACausal(inner) {
|
||||
continue
|
||||
}
|
||||
c.caches[i] = &TurboQuantCache{
|
||||
meta: inner,
|
||||
preset: preset,
|
||||
encodeResults: make(map[int]ml.Tensor),
|
||||
vEncodeResults: make(map[int]ml.Tensor),
|
||||
}
|
||||
wrapped++
|
||||
}
|
||||
if wrapped == 0 {
|
||||
slog.Warn("turboquant: no eligible Causal sub-caches in WrapperCache, falling back to unwrapped cache")
|
||||
return cache, false
|
||||
}
|
||||
slog.Info("turboquant: wrapped Causal sub-caches inside WrapperCache",
|
||||
"count", wrapped, "preset", preset.Name)
|
||||
return cache, true
|
||||
|
||||
case AttentionKVWrapper:
|
||||
inner := c.AttentionKV()
|
||||
if inner == nil {
|
||||
slog.Warn("turboquant: hybrid cache inner kv is not *Causal (already wrapped?), leaving as-is")
|
||||
return cache, false
|
||||
}
|
||||
if isSWACausal(inner) {
|
||||
slog.Warn("turboquant: hybrid cache inner *Causal is sliding-window, cannot wrap")
|
||||
return cache, false
|
||||
}
|
||||
c.SetAttentionKV(&TurboQuantCache{
|
||||
meta: inner,
|
||||
preset: preset,
|
||||
encodeResults: make(map[int]ml.Tensor),
|
||||
vEncodeResults: make(map[int]ml.Tensor),
|
||||
})
|
||||
slog.Info("turboquant: wrapped attention KV in hybrid recurrent cache",
|
||||
"preset", preset.Name)
|
||||
return cache, true
|
||||
|
||||
default:
|
||||
slog.Warn("turboquant: underlying cache is not *Causal or *WrapperCache, falling back to unwrapped cache")
|
||||
return cache, false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TurboQuantCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
// K is always compressed; suppress inner Causal from allocating it.
|
||||
c.meta.SkipK = true
|
||||
// V is compressed only when ValueBits > 0 (tq2/tq3). K-only presets
|
||||
// (tq2k/tq3k) have ValueBits=0 and keep V in the f16 Causal cache.
|
||||
if c.preset.ValueBits > 0 {
|
||||
c.meta.SkipV = true
|
||||
}
|
||||
c.meta.Init(backend, ml.DTypeF16, maxSequences, capacity, maxBatch)
|
||||
slog.Info("turboquant cache initialized", "preset", c.preset.Name,
|
||||
"K_bits", c.preset.KeyPrimaryBits, "V_bits", c.preset.ValueBits)
|
||||
}
|
||||
|
||||
func (c *TurboQuantCache) Close() {
|
||||
if c.compressedK != nil {
|
||||
c.compressedK.Close()
|
||||
c.compressedK = nil
|
||||
}
|
||||
c.meta.Close()
|
||||
}
|
||||
|
||||
func (c *TurboQuantCache) SetLayer(layer int) { c.meta.SetLayer(layer) }
|
||||
func (c *TurboQuantCache) SetConfig(config ml.CacheConfig) { c.meta.SetConfig(config) }
|
||||
|
||||
func (c *TurboQuantCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
c.isReserve = reserve
|
||||
c.curQueryLen = len(batch.Positions)
|
||||
clear(c.encodeResults)
|
||||
clear(c.vEncodeResults)
|
||||
return c.meta.StartForward(ctx, batch, reserve)
|
||||
}
|
||||
|
||||
func (c *TurboQuantCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
// Capture headDim early — needed even during reserve to size placeholders.
|
||||
if c.headDim == 0 && key != nil {
|
||||
c.headDim = key.Dim(0)
|
||||
c.numKVHeads = key.Dim(1)
|
||||
}
|
||||
|
||||
// Activate the GPU encode path on the first Put (reserve or not) once we
|
||||
// know headDim/numKVHeads. Doing this during reserve lets EnsureLayer run
|
||||
// on the reserve pass, which books TQ's per-layer persistent K/V buffers
|
||||
// into btDeviceMemory.Cache[layer] (via newTensor's ctx.layer accounting
|
||||
// in EnsureLayer). Without this, the scheduler's fit/alloc probe sees a
|
||||
// 0-byte Cache for tq2/tq3 and only the f16 V half for tq2k/tq3k, which
|
||||
// is wrong for both headline-footprint reporting and long-context fit math.
|
||||
if !c.phase2Checked && c.headDim > 0 {
|
||||
c.phase2Checked = true
|
||||
c.activateGPUEncode()
|
||||
}
|
||||
|
||||
if c.isReserve {
|
||||
c.meta.Put(ctx, key, value)
|
||||
// Eagerly allocate TQ persistent buffers during reserve so the scheduler's
|
||||
// per-layer Cache totals reflect the real post-compression footprint. No
|
||||
// encode kernels run on reserve — we only need the tensors in the layer
|
||||
// context so newTensor's c.b.btDeviceMemory[...].Cache[layer] bump fires.
|
||||
if c.compressedK != nil {
|
||||
layer := c.meta.curLayer
|
||||
capacity := len(c.meta.cells)
|
||||
c.compressedK.EnsureLayer(layer, capacity)
|
||||
if c.preset.ValueBits > 0 {
|
||||
c.compressedK.EnsureVLayer(layer, capacity)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if c.compressedK != nil {
|
||||
layer := c.meta.curLayer
|
||||
capacity := len(c.meta.cells)
|
||||
|
||||
c.compressedK.EnsureLayer(layer, capacity)
|
||||
if c.preset.ValueBits > 0 {
|
||||
c.compressedK.EnsureVLayer(layer, capacity)
|
||||
}
|
||||
|
||||
// The TQ encode kernels (CUDA + Metal) write a contiguous run of
|
||||
// cells starting at firstCell. Causal.findLocs guarantees a
|
||||
// contiguous run when SkipK/SkipV is set (otherwise it returns
|
||||
// ErrKvCacheFull before we reach here); this loop is a defensive
|
||||
// invariant check that should never fire in practice.
|
||||
firstCell := 0
|
||||
if len(c.meta.curLocs) > 0 {
|
||||
firstCell = c.meta.curLocs[0]
|
||||
for i := 1; i < len(c.meta.curLocs); i++ {
|
||||
if c.meta.curLocs[i] != firstCell+i {
|
||||
panic(fmt.Sprintf("turboquant: non-contiguous cache slots %v — findLocs invariant violated", c.meta.curLocs))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.preset.ValueBits > 0 {
|
||||
// Combined K+V encode: single GGML op, two back-to-back kernels.
|
||||
kResult, vResult := c.compressedK.EncodeKV(ctx, layer, key, value, firstCell)
|
||||
if kResult != nil {
|
||||
ctx.Forward(kResult)
|
||||
c.encodeResults[layer] = kResult
|
||||
c.vEncodeResults[layer] = vResult
|
||||
}
|
||||
} else {
|
||||
// K-only presets (tq2k/tq3k): V stays as f16 in the Causal cache.
|
||||
encodeResult := c.compressedK.EncodeK(ctx, layer, key, firstCell)
|
||||
if encodeResult != nil {
|
||||
ctx.Forward(encodeResult)
|
||||
c.encodeResults[layer] = encodeResult
|
||||
}
|
||||
}
|
||||
|
||||
// Inner Causal.Put() tracks cell metadata (positions, masks).
|
||||
// SkipK and SkipV suppress the actual K/V tensor writes.
|
||||
c.meta.Put(ctx, key, value)
|
||||
return
|
||||
}
|
||||
|
||||
c.meta.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *TurboQuantCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
if c.isReserve {
|
||||
key, value, mask := c.meta.Get(ctx)
|
||||
// SkipK: synthesize zero f16 K placeholder for graph sizing.
|
||||
if key == nil && c.headDim > 0 {
|
||||
nCells := 1
|
||||
if c.meta.curMask != nil {
|
||||
nCells = c.meta.curMask.Dim(0)
|
||||
}
|
||||
key = ctx.Input().Zeros(ml.DTypeF16, c.headDim, c.numKVHeads, nCells)
|
||||
}
|
||||
// SkipV: synthesize zero f16 V placeholder for graph sizing.
|
||||
if value == nil && c.headDim > 0 {
|
||||
nCells := 1
|
||||
if c.meta.curMask != nil {
|
||||
nCells = c.meta.curMask.Dim(0)
|
||||
}
|
||||
value = ctx.Input().Zeros(ml.DTypeF16, c.headDim, c.numKVHeads, nCells)
|
||||
}
|
||||
return key, value, mask
|
||||
}
|
||||
|
||||
if c.compressedK != nil {
|
||||
layer := c.meta.curLayer
|
||||
firstCell := c.meta.curCellRange.min
|
||||
nCells := c.meta.curMask.Dim(0)
|
||||
|
||||
encodeResult := c.encodeResults[layer]
|
||||
vEncodeResult := c.vEncodeResults[layer]
|
||||
|
||||
// 0. K-only presets (tq2k/tq3k): DequantK + f16 V from Causal → stock FA.
|
||||
// Skips the fused FA kernel (slower than stock FA on Pascal).
|
||||
if c.preset.ValueBits == 0 && encodeResult != nil {
|
||||
key := c.compressedK.DequantK(ctx, layer, encodeResult, firstCell, nCells)
|
||||
if key != nil {
|
||||
c.logPathOnce[0].Do(func() {
|
||||
slog.Info("turboquant: using K-only DequantK + f16 V path")
|
||||
})
|
||||
_, value, mask := c.meta.Get(ctx)
|
||||
c.armRotationForNextSDPA()
|
||||
return key, value, mask
|
||||
}
|
||||
}
|
||||
|
||||
// Prefill vs decode routing on Metal: the fused inline-decode kernel
|
||||
// re-decodes each packed K/V cell per Q-token, so its cost scales
|
||||
// O(nCells × nTokensQ). For prefill (nTokensQ ≫ 1) DequantKV + stock
|
||||
// FA wins by decoding each cell once and then letting stock FA
|
||||
// amortise the read across all Q-tokens. For decode (nTokensQ=1) the
|
||||
// fused path wins by avoiding the f16 intermediate write. This flag
|
||||
// routes independently of preferFusedAttn so CUDA (preferFusedAttn=
|
||||
// false) continues to take path 1 for everything.
|
||||
useDequantKVForPrefill := c.preferFusedAttn && c.curQueryLen > 1
|
||||
|
||||
// 1. Combined K+V dequant → stock FA.
|
||||
// * CUDA/ROCm: always (preferFusedAttn=false).
|
||||
// * Metal: prefill only (batched Q). Metal decode drops to path 2.
|
||||
if vEncodeResult != nil && (!c.preferFusedAttn || useDequantKVForPrefill) {
|
||||
key, value := c.compressedK.DequantKV(ctx, layer, encodeResult, vEncodeResult, firstCell, nCells)
|
||||
if key != nil && value != nil {
|
||||
c.logPathOnce[1].Do(func() {
|
||||
if useDequantKVForPrefill {
|
||||
slog.Info("turboquant: using combined DequantKV + stock FA path (Metal prefill)")
|
||||
} else {
|
||||
slog.Info("turboquant: using combined DequantKV + stock FA path")
|
||||
}
|
||||
})
|
||||
_, _, mask := c.meta.Get(ctx)
|
||||
c.armRotationForNextSDPA()
|
||||
return key, value, mask
|
||||
}
|
||||
}
|
||||
|
||||
// 2. K+V fused inline-decode: reads packed K+V directly, no f16 intermediate.
|
||||
// Primary path on Metal decode (Q=1); fallback on CUDA/ROCm.
|
||||
// Instantiated at D=128 always; D=256 on Metal only.
|
||||
if vEncodeResult != nil && c.fusedFallbackEligible {
|
||||
if tqkv, ok := c.compressedK.GetAsTQTensorKV(ctx, layer, encodeResult, vEncodeResult, firstCell, nCells); ok {
|
||||
c.logPathOnce[2].Do(func() {
|
||||
if c.preferFusedAttn {
|
||||
slog.Info("turboquant: using K+V fused inline-decode path (Metal decode: avoids f16 intermediate)")
|
||||
} else {
|
||||
slog.Warn("turboquant: falling back to K+V inline-decode fused kernel (slower)")
|
||||
}
|
||||
})
|
||||
_, _, mask := c.meta.Get(ctx)
|
||||
c.armRotationForNextSDPA()
|
||||
return tqkv, nil, mask
|
||||
}
|
||||
}
|
||||
|
||||
// 1b. DequantKV fallback when Metal decode tried path 2 and the fused
|
||||
// path was unavailable (e.g. headDim outside 128/256).
|
||||
if vEncodeResult != nil && c.preferFusedAttn {
|
||||
key, value := c.compressedK.DequantKV(ctx, layer, encodeResult, vEncodeResult, firstCell, nCells)
|
||||
if key != nil && value != nil {
|
||||
c.logPathOnce[1].Do(func() {
|
||||
slog.Info("turboquant: using combined DequantKV + stock FA path (fused unavailable)")
|
||||
})
|
||||
_, _, mask := c.meta.Get(ctx)
|
||||
c.armRotationForNextSDPA()
|
||||
return key, value, mask
|
||||
}
|
||||
}
|
||||
|
||||
// 3. V-only dequant for K-only fused or separate K dequant fallback.
|
||||
var value ml.Tensor
|
||||
if vEncodeResult != nil {
|
||||
value = c.compressedK.DequantV(ctx, layer, vEncodeResult, firstCell, nCells)
|
||||
}
|
||||
|
||||
// 4. Try K-only fused: K decoded inline, V is dequanted f16. Gated on
|
||||
// fusedFallbackEligible for the same D=128 reason as path 2.
|
||||
if c.fusedFallbackEligible {
|
||||
if tqk, ok := c.compressedK.GetAsTQTensor(ctx, layer, encodeResult, firstCell, nCells); ok {
|
||||
c.logPathOnce[3].Do(func() {
|
||||
slog.Warn("turboquant: falling back to K-only inline-decode fused kernel")
|
||||
})
|
||||
_, metaValue, mask := c.meta.Get(ctx)
|
||||
if value == nil {
|
||||
value = metaValue
|
||||
}
|
||||
c.armRotationForNextSDPA()
|
||||
return tqk, value, mask
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Separate K + V dequant fallback (last resort).
|
||||
c.logPathOnce[4].Do(func() {
|
||||
slog.Warn("turboquant: falling back to separate K + V dequant path")
|
||||
})
|
||||
key := c.compressedK.DequantK(ctx, layer, encodeResult, firstCell, nCells)
|
||||
_, metaValue, mask := c.meta.Get(ctx)
|
||||
if value == nil {
|
||||
value = metaValue
|
||||
}
|
||||
c.armRotationForNextSDPA()
|
||||
return key, value, mask
|
||||
}
|
||||
|
||||
return c.meta.Get(ctx)
|
||||
}
|
||||
|
||||
// armRotationForNextSDPA sets the backend's tqRotationMatrix (and
|
||||
// tqVRotationMatrix for K+V presets) so the next SDPA call — which happens
|
||||
// immediately after this Get returns — rotates Q to match the TQ-rotated K
|
||||
// and applies the V rotation undo after attention. SDPA reads-and-clears the
|
||||
// flags, so non-TQ sub-cache layers in a WrapperCache (e.g. gemma3 SWA
|
||||
// layers) are unaffected.
|
||||
func (c *TurboQuantCache) armRotationForNextSDPA() {
|
||||
if c.rotMatrix == nil || c.rotSetter == nil {
|
||||
return
|
||||
}
|
||||
c.rotSetter.SetTQRotationMatrix(c.rotMatrix)
|
||||
// For K+V presets (tq3/tq2), also arm the V rotation undo on the next
|
||||
// SDPA call. For K-only presets (tq3k/tq2k), c.vRotMatrix is nil so the
|
||||
// V rotation is not armed and SDPA's consumed vRot stays nil.
|
||||
if c.vRotMatrix != nil {
|
||||
c.rotSetter.SetTQVRotationMatrix(c.vRotMatrix)
|
||||
}
|
||||
}
|
||||
|
||||
// activateGPUEncode initialises the TQ compressed-K manager if the backend
|
||||
// supports it and re-enables Q rotation (stored K is in rotated space).
|
||||
func (c *TurboQuantCache) activateGPUEncode() {
|
||||
// fallbackToF16 un-skips K/V on the inner Causal so subsequent Put/Get
|
||||
// on this cache store and read f16 tensors like an ordinary non-TQ cache.
|
||||
// Init() sets SkipK/SkipV unconditionally, so any failure to activate GPU
|
||||
// encode must reverse those flags before the current Put() continues into
|
||||
// c.meta.Put — otherwise SDPA receives a nil K/V from Get and segfaults.
|
||||
// K allocation in Causal.Put is lazy (keyed on presence of c.keys[layer])
|
||||
// so flipping the flag on the first Put is sufficient.
|
||||
fallbackToF16 := func() {
|
||||
c.meta.SkipK = false
|
||||
c.meta.SkipV = false
|
||||
}
|
||||
|
||||
tqb, ok := c.meta.backend.(ml.TQCompressedKBackend)
|
||||
if !ok {
|
||||
fallbackToF16()
|
||||
return
|
||||
}
|
||||
|
||||
// Pass the preset's outlier config so the manager can enable post-rotation
|
||||
// outlier split on the GPU encode path. This is required for correct
|
||||
// output on models with learned K bias (e.g. qwen2 family) and matches
|
||||
// the TurboQuant paper's validated experimental setup.
|
||||
mgr := tqb.NewTQCompressedKManager(
|
||||
c.headDim, c.numKVHeads, c.preset.KeyPrimaryBits, c.preset.RotationSeed,
|
||||
c.preset.ValueBits, c.preset.OutlierBits, c.preset.OutlierCount,
|
||||
)
|
||||
if mgr == nil {
|
||||
slog.Info("turboquant: GPU encode not available, using f16 K fallback")
|
||||
fallbackToF16()
|
||||
return
|
||||
}
|
||||
c.compressedK = mgr
|
||||
|
||||
type fusedAttnPreferrer interface {
|
||||
PreferFusedAttention() bool
|
||||
}
|
||||
if ff, ok := mgr.(fusedAttnPreferrer); ok {
|
||||
c.preferFusedAttn = ff.PreferFusedAttention()
|
||||
if c.preferFusedAttn {
|
||||
slog.Info("turboquant: preferring fused flash-attention path (Metal: avoids f16 intermediate buffer)")
|
||||
}
|
||||
}
|
||||
|
||||
// The inline-decode fused-FA fallback paths (Get paths 2 and 4) dispatch
|
||||
// to a kernel that is D-specialised. CUDA has only D=128 today; Metal has
|
||||
// D=128 and D=256 (kernel_tq_fattn_vec_*{,_d256}). Models with an
|
||||
// unsupported head dim (e.g. gemma4 D=512, or gemma3 D=256 on CUDA) must
|
||||
// skip these fallbacks to avoid a kernel-side GGML_ASSERT; path 5
|
||||
// (separate K+V dequant) handles them correctly.
|
||||
c.fusedFallbackEligible = c.headDim == 128 ||
|
||||
(c.headDim == 256 && c.preferFusedAttn)
|
||||
if !c.fusedFallbackEligible {
|
||||
reason := "headDim != 128"
|
||||
if c.headDim == 256 {
|
||||
reason = "headDim == 256 but backend lacks D=256 fused kernel"
|
||||
}
|
||||
slog.Info("turboquant: inline-decode fused-FA fallback paths disabled",
|
||||
"reason", reason, "headDim", c.headDim)
|
||||
}
|
||||
|
||||
// Cache the rotation matrices and the backend's rotation-setter hook on
|
||||
// TurboQuantCache so Get() can arm them per-call without re-running a
|
||||
// type assertion every layer. We do NOT set them at activate time — a
|
||||
// sticky backend-global rotation would corrupt attention on unwrapped
|
||||
// SWA layers in mixed-head-dim models like gemma3.
|
||||
c.rotMatrix = mgr.RotationMatrix(nil, 0)
|
||||
if rs, ok := c.meta.backend.(tqRotSetter); ok {
|
||||
c.rotSetter = rs
|
||||
}
|
||||
|
||||
type vRotFusedSetter interface {
|
||||
SetTQVRotFusedInDequant(bool)
|
||||
}
|
||||
type vRotProvider interface {
|
||||
RotationMatrixR() ml.Tensor
|
||||
}
|
||||
if c.preset.ValueBits > 0 {
|
||||
if vp, ok := mgr.(vRotProvider); ok {
|
||||
c.vRotMatrix = vp.RotationMatrixR()
|
||||
}
|
||||
// DequantKV outputs R^T @ v (still rotated). SDPA applies the rotation
|
||||
// undo as R @ attn_out via mulmat, which is dramatically faster than
|
||||
// the per-cell matmul of the fused dequant kernel.
|
||||
if rs, ok := c.meta.backend.(vRotFusedSetter); ok {
|
||||
rs.SetTQVRotFusedInDequant(false)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("turboquant: GPU-native encode active",
|
||||
"headDim", c.headDim, "numKVHeads", c.numKVHeads,
|
||||
"K_bits", c.preset.KeyPrimaryBits, "V_bits", c.preset.ValueBits)
|
||||
}
|
||||
|
||||
func (c *TurboQuantCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.meta.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
}
|
||||
|
||||
func (c *TurboQuantCache) CanResume(seq int, pos int32) bool {
|
||||
return c.meta.CanResume(seq, pos)
|
||||
}
|
||||
|
||||
// Remove returns ErrNotSupported for any partial eviction when GPU-compressed
|
||||
// K is active: the compressed buffer cannot be RoPE-shifted in-place, and the
|
||||
// inner Causal.Remove path silently skips its shiftFn loop because SkipK
|
||||
// leaves c.keys empty — survivors would keep their old RoPE embeddings while
|
||||
// their positions are decremented, producing wrong attention scores. Only a
|
||||
// full-sequence eviction (Remove(seq, 0, MaxInt32)) avoids the shift path;
|
||||
// other callers must fall back to full reprocessing via ErrReprocessInputs.
|
||||
func (c *TurboQuantCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if c.compressedK != nil && !(beginIndex == 0 && endIndex == math.MaxInt32) {
|
||||
return ErrNotSupported
|
||||
}
|
||||
return c.meta.Remove(seq, beginIndex, endIndex)
|
||||
}
|
||||
|
||||
func (c *TurboQuantCache) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||
c.meta.SetCausal(ctx, opts)
|
||||
}
|
||||
|
||||
func PresetFromDType(dtype ml.DType) (turboquant.Preset, bool) {
|
||||
switch dtype {
|
||||
case ml.DTypeTQ2:
|
||||
return turboquant.PresetTQ2, true
|
||||
case ml.DTypeTQ3:
|
||||
return turboquant.PresetTQ3, true
|
||||
case ml.DTypeTQ3K:
|
||||
return turboquant.PresetTQ3K, true
|
||||
case ml.DTypeTQ2K:
|
||||
return turboquant.PresetTQ2K, true
|
||||
default:
|
||||
return turboquant.Preset{}, false
|
||||
}
|
||||
}
|
||||
|
||||
var _ Cache = (*TurboQuantCache)(nil)
|
||||
|
||||
// Note: TurboQuantCache intentionally does NOT implement CheckpointCache.
|
||||
// CheckpointCache is for recurrent caches that need special per-sequence
|
||||
// state restoration. The inner *Causal cache supports plain CanResume, which
|
||||
// is the right semantics for TQ-wrapped caches too — the runner will fall
|
||||
// through to the CanResume branch when TurboQuantCache is in use (see
|
||||
// runner/ollamarunner/cache.go:163-170). An earlier implementation had a
|
||||
// stub PrepareRestore that always returned (0, false), which forced a full
|
||||
// prompt reprocess on every resume for long-context gemma/llama runs.
|
||||
144
kvcache/turboquant_test.go
Normal file
144
kvcache/turboquant_test.go
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
package kvcache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/turboquant"
|
||||
)
|
||||
|
||||
// TestWrapWithTurboQuantWrapperCache verifies Path C: when wrapping a
|
||||
// WrapperCache that contains a SWA Causal + a plain Causal, TurboQuant
|
||||
// replaces only the plain Causal sub-cache and leaves the SWA sub-cache
|
||||
// untouched. This is the gemma3/gemma4 pattern.
|
||||
func TestWrapWithTurboQuantWrapperCache(t *testing.T) {
|
||||
swa := NewSWAMemCache(1024, 4096, nil)
|
||||
global := NewCausalCache(nil)
|
||||
wc := NewWrapperCache(swa, global)
|
||||
|
||||
wrapped, active := WrapWithTurboQuant(wc, turboquant.PresetTQ3K)
|
||||
|
||||
if !active {
|
||||
t.Fatalf("expected active=true when wrapping a WrapperCache with a plain Causal sub-cache")
|
||||
}
|
||||
// The same WrapperCache pointer should be returned (mutated in place).
|
||||
if wrapped != wc {
|
||||
t.Fatalf("expected the same WrapperCache pointer to be returned")
|
||||
}
|
||||
if len(wc.caches) != 2 {
|
||||
t.Fatalf("expected 2 sub-caches, got %d", len(wc.caches))
|
||||
}
|
||||
// SWA sub-cache untouched.
|
||||
if swaSub, ok := wc.caches[0].(*Causal); !ok || swaSub != swa {
|
||||
t.Fatalf("expected wc.caches[0] to remain the original SWA Causal, got %T", wc.caches[0])
|
||||
}
|
||||
// Global sub-cache replaced with *TurboQuantCache wrapping the original.
|
||||
tqc, ok := wc.caches[1].(*TurboQuantCache)
|
||||
if !ok {
|
||||
t.Fatalf("expected wc.caches[1] to be *TurboQuantCache, got %T", wc.caches[1])
|
||||
}
|
||||
if tqc.meta != global {
|
||||
t.Fatalf("expected TurboQuantCache.meta to point at the original inner Causal")
|
||||
}
|
||||
if tqc.preset.Name != turboquant.PresetTQ3K.Name {
|
||||
t.Fatalf("expected preset %q, got %q", turboquant.PresetTQ3K.Name, tqc.preset.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapWithTurboQuantSWAOnly verifies that wrapping a WrapperCache
|
||||
// containing only SWA sub-caches yields active=false and leaves the
|
||||
// cache unchanged.
|
||||
func TestWrapWithTurboQuantSWAOnly(t *testing.T) {
|
||||
swa := NewSWAMemCache(1024, 4096, nil)
|
||||
wc := NewWrapperCache(swa)
|
||||
|
||||
wrapped, active := WrapWithTurboQuant(wc, turboquant.PresetTQ3K)
|
||||
|
||||
if active {
|
||||
t.Fatalf("expected active=false for a SWA-only WrapperCache")
|
||||
}
|
||||
if wrapped != wc {
|
||||
t.Fatalf("expected the same WrapperCache pointer to be returned")
|
||||
}
|
||||
// Sub-cache must remain the original SWA Causal.
|
||||
if sub, ok := wc.caches[0].(*Causal); !ok || sub != swa {
|
||||
t.Fatalf("expected wc.caches[0] unchanged, got %T", wc.caches[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapWithTurboQuantTopLevelSWA verifies that a top-level SWA Causal
|
||||
// cannot be wrapped — TurboQuant needs full-context Causal semantics.
|
||||
func TestWrapWithTurboQuantTopLevelSWA(t *testing.T) {
|
||||
swa := NewSWAMemCache(1024, 4096, nil)
|
||||
|
||||
wrapped, active := WrapWithTurboQuant(swa, turboquant.PresetTQ3K)
|
||||
|
||||
if active {
|
||||
t.Fatalf("expected active=false for a top-level SWA Causal")
|
||||
}
|
||||
if wrapped != swa {
|
||||
t.Fatalf("expected the SWA Causal to be returned unchanged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapWithTurboQuantTopLevelCausal verifies the existing top-level
|
||||
// non-SWA Causal case still works: a new *TurboQuantCache is returned
|
||||
// wrapping the input.
|
||||
func TestWrapWithTurboQuantTopLevelCausal(t *testing.T) {
|
||||
c := NewCausalCache(nil)
|
||||
|
||||
wrapped, active := WrapWithTurboQuant(c, turboquant.PresetTQ3K)
|
||||
|
||||
if !active {
|
||||
t.Fatalf("expected active=true for a top-level plain Causal")
|
||||
}
|
||||
tqc, ok := wrapped.(*TurboQuantCache)
|
||||
if !ok {
|
||||
t.Fatalf("expected *TurboQuantCache, got %T", wrapped)
|
||||
}
|
||||
if tqc.meta != c {
|
||||
t.Fatalf("expected TurboQuantCache.meta to point at the input Causal")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapWithTurboQuantHybridCache verifies that a cache embedding *Recurrent
|
||||
// (matching the model/models/*/HybridCache shape) has its inner *Causal swapped
|
||||
// for a *TurboQuantCache in place, and the outer pointer is returned unchanged.
|
||||
func TestWrapWithTurboQuantHybridCache(t *testing.T) {
|
||||
type HybridCache struct{ *Recurrent }
|
||||
r := NewRecurrentCache(RecurrentConfig{ConvDim: 4, ConvChannels: 2, RecurrentStateSize: 4})
|
||||
hc := &HybridCache{Recurrent: r}
|
||||
|
||||
wrapped, active := WrapWithTurboQuant(hc, turboquant.PresetTQ3K)
|
||||
|
||||
if !active {
|
||||
t.Fatal("expected active=true for AttentionKVWrapper")
|
||||
}
|
||||
if wrapped != hc {
|
||||
t.Fatalf("expected same pointer returned, got %T", wrapped)
|
||||
}
|
||||
tqc, ok := r.kv.(*TurboQuantCache)
|
||||
if !ok {
|
||||
t.Fatalf("expected inner kv to be *TurboQuantCache, got %T", r.kv)
|
||||
}
|
||||
if tqc.preset.Name != turboquant.PresetTQ3K.Name {
|
||||
t.Fatalf("preset mismatch: got %q", tqc.preset.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapWithTurboQuantHybridCacheIdempotent verifies that a second wrap attempt
|
||||
// on an already-wrapped hybrid cache returns (cache, false) rather than double-wrapping.
|
||||
func TestWrapWithTurboQuantHybridCacheIdempotent(t *testing.T) {
|
||||
type HybridCache struct{ *Recurrent }
|
||||
r := NewRecurrentCache(RecurrentConfig{ConvDim: 4, ConvChannels: 2, RecurrentStateSize: 4})
|
||||
hc := &HybridCache{Recurrent: r}
|
||||
|
||||
_, _ = WrapWithTurboQuant(hc, turboquant.PresetTQ3K)
|
||||
wrapped, active := WrapWithTurboQuant(hc, turboquant.PresetTQ3K)
|
||||
|
||||
if active {
|
||||
t.Fatal("expected active=false on second wrap (already wrapped)")
|
||||
}
|
||||
if wrapped != hc {
|
||||
t.Fatalf("expected same pointer returned, got %T", wrapped)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,98 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Verrilli <msv@pobox.com>
|
||||
Date: Sun, 19 Apr 2026 23:28:06 -0400
|
||||
Subject: [PATCH] =?UTF-8?q?turboquant:=20fix=20HIP=20compile=20=E2=80=94?=
|
||||
=?UTF-8?q?=20drop=20cuda=5Ffp16.h,=20pass=20=5F=5Fshfl=5Fsync=20width?=
|
||||
MIME-Version: 1.0
|
||||
Content-Type: text/plain; charset=UTF-8
|
||||
Content-Transfer-Encoding: 8bit
|
||||
|
||||
Two minimal, behavior-neutral edits so the TQ kernels compile under
|
||||
hipcc on the ROCm 7 preset:
|
||||
|
||||
1. Remove direct `#include <cuda_fp16.h>` from tq-encode.cu,
|
||||
tq-encode-v.cu, and tq-dequant.cu. Header is transitively pulled
|
||||
in via common.cuh → vendors/{cuda,hip,musa}.h; direct include
|
||||
is redundant on CUDA and fatal on HIP.
|
||||
|
||||
2. Pass width=32 explicitly on four __shfl_sync sites in tq-dequant.cu.
|
||||
HIP's shim is a 4-arg macro with no defaults; CUDA's 3-arg overload
|
||||
defaults to width=warpSize=32, so passing 32 is a compile fix on
|
||||
HIP and a no-op on CUDA.
|
||||
---
|
||||
ggml/src/ggml-cuda/tq-dequant.cu | 14 ++++++++------
|
||||
ggml/src/ggml-cuda/tq-encode-v.cu | 1 -
|
||||
ggml/src/ggml-cuda/tq-encode.cu | 1 -
|
||||
3 files changed, 8 insertions(+), 8 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/tq-dequant.cu b/ggml/src/ggml-cuda/tq-dequant.cu
|
||||
index 03ed64543..4a4cc4d64 100644
|
||||
--- a/ggml/src/ggml-cuda/tq-dequant.cu
|
||||
+++ b/ggml/src/ggml-cuda/tq-dequant.cu
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "tq-dequant.cuh"
|
||||
-#include <cuda_fp16.h>
|
||||
|
||||
// Optimized TQ dequant kernel: warp-shuffle codebook + hardcoded bit extraction.
|
||||
//
|
||||
@@ -52,8 +51,11 @@ __global__ void tq_dequant_multihead_kernel(
|
||||
}
|
||||
|
||||
// Codebook lookup via warp shuffle: zero global memory latency.
|
||||
- // Width = warpSize (32) works because cb_lane is periodic with period (1<<bits).
|
||||
- float val = __shfl_sync(0xFFFFFFFF, cb_lane, idx) * scale;
|
||||
+ // Width = 32 works because cb_lane is periodic with period (1<<bits).
|
||||
+ // Pass width explicitly — the HIP __shfl_sync shim is a 4-arg macro
|
||||
+ // that doesn't default, and CUDA's 3-arg overload is width=warpSize=32
|
||||
+ // anyway, so this is behavior-neutral on NVIDIA.
|
||||
+ float val = __shfl_sync(0xFFFFFFFF, cb_lane, idx, 32) * scale;
|
||||
cell_out[elem] = __float2half_rn(val);
|
||||
}
|
||||
}
|
||||
@@ -180,7 +182,7 @@ __global__ void tq_dequant_multihead_kernel_outlier(
|
||||
if (reg_shift + bits > 8) {
|
||||
reg_idx |= (cell_reg[reg_byte_idx + 1] << (8 - reg_shift)) & cb_mask;
|
||||
}
|
||||
- float reg_val = __shfl_sync(0xFFFFFFFF, cb_lane_reg, reg_idx) * regScale;
|
||||
+ float reg_val = __shfl_sync(0xFFFFFFFF, cb_lane_reg, reg_idx, 32) * regScale;
|
||||
|
||||
int out_slot_safe = (outlier_slot >= 0) ? outlier_slot : 0;
|
||||
int out_bit_offset = out_slot_safe * outlier_bits;
|
||||
@@ -190,7 +192,7 @@ __global__ void tq_dequant_multihead_kernel_outlier(
|
||||
if (out_shift + outlier_bits > 8) {
|
||||
out_idx |= (cell_outl[out_byte_idx + 1] << (8 - out_shift)) & ocb_mask;
|
||||
}
|
||||
- float out_val = __shfl_sync(0xFFFFFFFF, cb_lane_out, out_idx) * outScale;
|
||||
+ float out_val = __shfl_sync(0xFFFFFFFF, cb_lane_out, out_idx, 32) * outScale;
|
||||
|
||||
float val = (outlier_slot >= 0) ? out_val : reg_val;
|
||||
cell_out[elem] = __float2half_rn(val);
|
||||
@@ -308,7 +310,7 @@ __global__ void tq_dequant_v_rotated_kernel(
|
||||
if (shift + bits > 8) {
|
||||
idx |= (cell_packed[byte_idx + 1] << (8 - shift)) & cb_mask;
|
||||
}
|
||||
- s_rotV[elem] = __shfl_sync(0xFFFFFFFF, cb_lane, idx) * scale;
|
||||
+ s_rotV[elem] = __shfl_sync(0xFFFFFFFF, cb_lane, idx, 32) * scale;
|
||||
__syncthreads();
|
||||
|
||||
// Phase 2: each thread computes one output element = dot(R[elem,:], s_rotV).
|
||||
diff --git a/ggml/src/ggml-cuda/tq-encode-v.cu b/ggml/src/ggml-cuda/tq-encode-v.cu
|
||||
index f9673861d..09518dccd 100644
|
||||
--- a/ggml/src/ggml-cuda/tq-encode-v.cu
|
||||
+++ b/ggml/src/ggml-cuda/tq-encode-v.cu
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "tq-encode-v.cuh"
|
||||
-#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
|
||||
#define TQ_ENCODE_V_BLOCK_SIZE 128
|
||||
diff --git a/ggml/src/ggml-cuda/tq-encode.cu b/ggml/src/ggml-cuda/tq-encode.cu
|
||||
index 12e7f7ba7..9c9eafa39 100644
|
||||
--- a/ggml/src/ggml-cuda/tq-encode.cu
|
||||
+++ b/ggml/src/ggml-cuda/tq-encode.cu
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "tq-encode.cuh"
|
||||
-#include <cuda_fp16.h>
|
||||
#include <math.h>
|
||||
|
||||
#define TQ_ENCODE_BLOCK_SIZE 128
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,667 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Verrilli <msv@pobox.com>
|
||||
Date: Tue, 21 Apr 2026 21:02:21 +0000
|
||||
Subject: [PATCH] ml/backend/ggml: add D=256 TurboQuant flash-attention kernels
|
||||
on Metal
|
||||
MIME-Version: 1.0
|
||||
Content-Type: text/plain; charset=UTF-8
|
||||
Content-Transfer-Encoding: 8bit
|
||||
|
||||
The fused inline-decode kernels were only instantiated at D=128. gemma3
|
||||
runs at headDim=256 and was therefore forced onto the DequantKV + stock
|
||||
FA path — which materialises a ~3.5 GB f16 intermediate K+V tensor at
|
||||
32k context and OOMs on Apple Silicon with tq3.
|
||||
|
||||
Adds D=256 variants of the two fused kernels:
|
||||
|
||||
* kernel_tq_fattn_vec_f16_d256
|
||||
* kernel_tq_fattn_vec_packed_d256
|
||||
|
||||
Thread layout is unchanged (32x4 = 128 threads); each thread now covers
|
||||
32 D-positions in the Q and K loops, accumulates four V-passes, and
|
||||
writes two output elements in the final phase. Shared memory grows
|
||||
from 2048 to 4096 floats (16 KiB, well under the 32 KiB per-threadgroup
|
||||
limit).
|
||||
|
||||
The dispatch picks D=128 or D=256 based on Q->ne[0]. Go-side eligibility
|
||||
extends to headDim==256 only when preferFusedAttention is true (Metal);
|
||||
CUDA continues to gate on headDim==128 until the CUDA kernel gains the
|
||||
same instantiation.
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-device.cpp | 6 +-
|
||||
ggml/src/ggml-metal/ggml-metal-device.h | 6 +-
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 11 +-
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 562 ++++++++++++++++++++++
|
||||
4 files changed, 579 insertions(+), 6 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
index 2686d2d30..c99fafe00 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
@@ -1720,5 +1720,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_v"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_outlier"); }
|
||||
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
|
||||
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16_d256"); }
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed_d256"); }
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
index aadd82659..fb45cbbfd 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
@@ -194,8 +194,10 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib);
|
||||
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
|
||||
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib);
|
||||
|
||||
// MTLResidencySet wrapper
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index 7abc0eaac..b5ab1c14e 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -4602,9 +4602,16 @@ int ggml_metal_op_tq_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb31 =*/ mask ? mask->nb[1] : 0,
|
||||
};
|
||||
|
||||
+ // Select D=128 vs D=256 pipeline. Gemma3 runs at headDim=256; everything
|
||||
+ // else supported so far is D=128.
|
||||
+ GGML_ASSERT(D == 128 || D == 256);
|
||||
auto pipeline = v_packed
|
||||
- ? ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib)
|
||||
- : ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib);
|
||||
+ ? (D == 256
|
||||
+ ? ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(lib)
|
||||
+ : ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib))
|
||||
+ : (D == 256
|
||||
+ ? ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256(lib)
|
||||
+ : ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib));
|
||||
|
||||
ggml_metal_buffer_id bid_mask = hasMask ? ggml_metal_get_buffer_id(mask) : ggml_metal_get_buffer_id(op);
|
||||
ggml_metal_buffer_id bid_v_scales = v_packed ? ggml_metal_get_buffer_id(op->src[6]) : ggml_metal_get_buffer_id(op);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index d0bc6bf9e..2718e8bb1 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -11184,6 +11184,293 @@ kernel void kernel_tq_fattn_vec_f16(
|
||||
}
|
||||
}
|
||||
|
||||
+// ─────────────────────────────────────────────────────────────────────────────
|
||||
+// kernel_tq_fattn_vec_f16_d256
|
||||
+// TQ fused flash-attention at head dim 256: K packed i8, V f16.
|
||||
+// Thread layout identical to the D=128 variant (32×4 = 128 threads) but each
|
||||
+// thread now produces 2 output elements (D/nthreads = 2) and covers twice as
|
||||
+// many D positions in the Q/K/V loops.
|
||||
+// Grid: (ntiles_x, 1, nHeadsQ*nSeq)
|
||||
+// ─────────────────────────────────────────────────────────────────────────────
|
||||
+kernel void kernel_tq_fattn_vec_f16_d256(
|
||||
+ constant ggml_metal_kargs_tq_fattn_vec & args,
|
||||
+ device const char * Q_data [[buffer(1)]],
|
||||
+ device const uint8_t * K_packed [[buffer(2)]],
|
||||
+ device const half * V_data [[buffer(3)]],
|
||||
+ device const half * mask_data [[buffer(4)]],
|
||||
+ device const float * K_scales [[buffer(5)]],
|
||||
+ device const float * K_cb [[buffer(6)]],
|
||||
+ device const float * dummy_vs [[buffer(7)]],
|
||||
+ device const float * dummy_vc [[buffer(8)]],
|
||||
+ device float * dst [[buffer(9)]],
|
||||
+ uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
+ uint tiisg [[thread_index_in_simdgroup]],
|
||||
+ uint sgitg [[simdgroup_index_in_threadgroup]])
|
||||
+{
|
||||
+ constexpr int D = 256;
|
||||
+ constexpr int nthreads = 128;
|
||||
+ constexpr int nthreads_KQ = 8;
|
||||
+ constexpr int nthreads_V = 8;
|
||||
+ constexpr int V_cols_per_iter = 4;
|
||||
+ constexpr int nwarps = 4;
|
||||
+
|
||||
+ const int ic0 = (int)tgpig.x * args.ncols;
|
||||
+ const int blk_z = (int)tgpig.z;
|
||||
+ const int sequence = blk_z / args.nHeadsQ;
|
||||
+ const int head = blk_z % args.nHeadsQ;
|
||||
+ const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
|
||||
+ const int head_kv = head / gqa_ratio;
|
||||
+
|
||||
+ const int tid = (int)sgitg * 32 + (int)tiisg;
|
||||
+
|
||||
+ device const float * Q = (device const float *)Q_data
|
||||
+ + (long)sequence * (args.nb03 / sizeof(float))
|
||||
+ + (long)head * (args.nb02 / sizeof(float))
|
||||
+ + (long)ic0 * (args.nb01 / sizeof(float));
|
||||
+
|
||||
+ device const uint8_t * K_p = K_packed
|
||||
+ + (long)args.firstCell * args.nKVHeads * args.packedBytes
|
||||
+ + (long)head_kv * args.packedBytes;
|
||||
+ device const float * K_sc = K_scales
|
||||
+ + (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
+
|
||||
+ device const half * V = V_data
|
||||
+ + (long)sequence * (args.nb23 / sizeof(half))
|
||||
+ + (long)head_kv * (args.nb22 / sizeof(half));
|
||||
+
|
||||
+ device const half * maskh = args.hasMask
|
||||
+ ? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
|
||||
+ : nullptr;
|
||||
+
|
||||
+ const int k_cb_mask = (1 << args.bits) - 1;
|
||||
+ const float k_cb_lane = K_cb[tiisg & k_cb_mask];
|
||||
+
|
||||
+ const int tid_kq = (int)tiisg % nthreads_KQ;
|
||||
+
|
||||
+ // D=256: Q_reg holds 16 float2 per thread per query slot (D/(2*nthreads_KQ)).
|
||||
+ float2 Q_reg[2][16];
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ const int elem = tid_kq * 16 + i; // float2 index within [0, 127]
|
||||
+ Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
|
||||
+ }
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ Q_reg[j][i].x *= args.scale;
|
||||
+ Q_reg[j][i].y *= args.scale;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ // D=256: VKQ holds 4 passes × 4 float2 = 16 float2 per query slot.
|
||||
+ float2 VKQ[2][16];
|
||||
+ for (int j = 0; j < 2; j++)
|
||||
+ for (int i = 0; i < 16; i++)
|
||||
+ VKQ[j][i] = float2(0.0f, 0.0f);
|
||||
+
|
||||
+ float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
|
||||
+ float KQ_sum[2] = { 0.0f, 0.0f };
|
||||
+
|
||||
+ // D=256: KQ_tg sized nwarps*V_cols_per_iter*D = 4*4*256 = 4096 floats (16 KiB).
|
||||
+ threadgroup float KQ_tg[4096];
|
||||
+ threadgroup float KQ_max_tg[2][32];
|
||||
+ threadgroup float KQ_sum_tg[2][32];
|
||||
+
|
||||
+ for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
|
||||
+
|
||||
+ float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
|
||||
+
|
||||
+ for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
|
||||
+ const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
|
||||
+ const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
|
||||
+ const int cell_rel = k_VKQ_0 + i_KQ;
|
||||
+ const bool in_range = (cell_rel < args.nCells);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
|
||||
+ const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
+
|
||||
+ // D=256: 16 k-iterations × 2 elements each = 32 D-positions per thread.
|
||||
+ float sum = 0.0f;
|
||||
+ for (int k = 0; k < 16; k++) {
|
||||
+ const int start_elem = tid_kq * 32 + k * 2; // float index [0..254]
|
||||
+ float k_dec[2];
|
||||
+ if (args.bits == 3) {
|
||||
+ const int bit_pos0 = start_elem * 3;
|
||||
+ const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
|
||||
+ const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
|
||||
+ int idx0 = (int)((w0 >> sh0) & 7);
|
||||
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)idx0) * rms_scale;
|
||||
+ const int bit_pos1 = (start_elem + 1) * 3;
|
||||
+ const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
|
||||
+ const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
|
||||
+ int idx1 = (int)((w1 >> sh1) & 7);
|
||||
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)idx1) * rms_scale;
|
||||
+ } else {
|
||||
+ const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
|
||||
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
|
||||
+ const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
|
||||
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
|
||||
+ }
|
||||
+ sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
|
||||
+ }
|
||||
+ sum += simd_shuffle_xor(sum, 4);
|
||||
+ sum += simd_shuffle_xor(sum, 2);
|
||||
+ sum += simd_shuffle_xor(sum, 1);
|
||||
+
|
||||
+ if (args.logit_softcap != 0.0f) {
|
||||
+ sum = args.logit_softcap * tanh(sum);
|
||||
+ }
|
||||
+
|
||||
+ if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
|
||||
+ sum += float(maskh[(long)j * args.ne31 + i_KQ]);
|
||||
+ }
|
||||
+
|
||||
+ if (!in_range) sum = -FLT_MAX/2.0f;
|
||||
+
|
||||
+ KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
|
||||
+
|
||||
+ if (tid_kq == (uint)i_KQ_0) {
|
||||
+ KQ_tg[j * nthreads + tid] = sum;
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ KQ_max_new[j] = simd_max(KQ_max_new[j]);
|
||||
+
|
||||
+ const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
|
||||
+ KQ_max[j] = KQ_max_new[j];
|
||||
+
|
||||
+ const float kq_val = KQ_tg[j * nthreads + tid];
|
||||
+ const float kq_exp = exp(kq_val - KQ_max[j]);
|
||||
+ KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
|
||||
+ KQ_tg[j * nthreads + tid] = kq_exp;
|
||||
+
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ VKQ[j][i].x *= KQ_max_scale;
|
||||
+ VKQ[j][i].y *= KQ_max_scale;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ // D=256: 4 passes × 64 V-elements each = full 256-element coverage.
|
||||
+ // pass 0 → V[ 0.. 63] → VKQ[ 0.. 3]
|
||||
+ // pass 1 → V[ 64..127] → VKQ[ 4.. 7]
|
||||
+ // pass 2 → V[128..191] → VKQ[ 8..11]
|
||||
+ // pass 3 → V[192..255] → VKQ[12..15]
|
||||
+ for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
|
||||
+ const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
|
||||
+ const int cell_rel = k_VKQ_0 + k;
|
||||
+
|
||||
+ float KQ_k[2];
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ KQ_k[j] = KQ_tg[j * nthreads + k];
|
||||
+ }
|
||||
+
|
||||
+ device const half * V_cell = (cell_rel < args.nCells)
|
||||
+ ? V + (long)cell_rel * (args.nb21 / sizeof(half))
|
||||
+ : nullptr;
|
||||
+
|
||||
+ const int v_tid = (int)tiisg % nthreads_V;
|
||||
+ for (int pass = 0; pass < 4; pass++) {
|
||||
+ for (int i = 0; i < 8; i++) {
|
||||
+ const int elem = pass * 64 + v_tid * 8 + i;
|
||||
+ float v_val = (V_cell && elem < D) ? float(V_cell[elem]) : 0.0f;
|
||||
+ const int vkq_idx = pass * 4 + i / 2;
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (i % 2 == 0) VKQ[j][vkq_idx].x += v_val * KQ_k[j];
|
||||
+ else VKQ[j][vkq_idx].y += v_val * KQ_k[j];
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ } // end KV loop
|
||||
+
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (sgitg == 0) {
|
||||
+ KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
|
||||
+ KQ_sum_tg[j][tiisg] = 0.0f;
|
||||
+ }
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (tiisg == 0) {
|
||||
+ KQ_max_tg[j][sgitg] = KQ_max[j];
|
||||
+ }
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
|
||||
+
|
||||
+ float kqmax_new = KQ_max_tg[j][tiisg];
|
||||
+ kqmax_new = simd_max(kqmax_new);
|
||||
+ const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
|
||||
+ KQ_max[j] = kqmax_new;
|
||||
+
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ VKQ[j][i].x *= kqmax_scale;
|
||||
+ VKQ[j][i].y *= kqmax_scale;
|
||||
+ }
|
||||
+
|
||||
+ // D=256: VKQ_tg layout per (sgitg,v-group) region of D/2 = 128 float2:
|
||||
+ // VKQ[ 0.. 3] → VKQ_tg[v_tid*4 + 0..3] (pass 0, float2 slots 0..31)
|
||||
+ // VKQ[ 4.. 7] → VKQ_tg[32 + v_tid*4 + 0..3] (pass 1, float2 slots 32..63)
|
||||
+ // VKQ[ 8..11] → VKQ_tg[64 + v_tid*4 + 0..3] (pass 2, float2 slots 64..95)
|
||||
+ // VKQ[12..15] → VKQ_tg[96 + v_tid*4 + 0..3] (pass 3, float2 slots 96..127)
|
||||
+ const int v_tid = (int)tiisg % nthreads_V;
|
||||
+ threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
|
||||
+ + (long)sgitg * (V_cols_per_iter * D/2)
|
||||
+ + (long)((int)tiisg / nthreads_V) * (D/2);
|
||||
+ VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
|
||||
+ VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
|
||||
+ VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
|
||||
+ VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
|
||||
+
|
||||
+ KQ_sum[j] *= kqmax_scale;
|
||||
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
+ if (tiisg == 0) {
|
||||
+ KQ_sum_tg[j][sgitg] = KQ_sum[j];
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ // D=256: each thread writes 2 output positions (tid, tid+128). Compute
|
||||
+ // KQ_sum with ALL lanes participating in the simd_sum (required for
|
||||
+ // convergence), then two writes per thread.
|
||||
+ KQ_sum[j] = KQ_sum_tg[j][tiisg];
|
||||
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
+
|
||||
+ const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
|
||||
+ for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
|
||||
+ const int out_elem = out_offset + tid;
|
||||
+ float dst_val = 0.0f;
|
||||
+ for (int w = 0; w < nwarps; w++) {
|
||||
+ for (int v = 0; v < V_cols_per_iter; v++) {
|
||||
+ dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
|
||||
+ }
|
||||
+ }
|
||||
+ dst_val /= KQ_sum[j];
|
||||
+ dst[out_idx * D + out_elem] = dst_val;
|
||||
+ }
|
||||
+
|
||||
+ if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// kernel_tq_fattn_vec_packed
|
||||
// TQ fused flash-attention: K packed i8, V packed i8.
|
||||
@@ -11451,3 +11738,278 @@ kernel void kernel_tq_fattn_vec_packed(
|
||||
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
+
|
||||
+// ─────────────────────────────────────────────────────────────────────────────
|
||||
+// kernel_tq_fattn_vec_packed_d256
|
||||
+// TQ fused flash-attention at head dim 256: K packed i8, V packed i8.
|
||||
+// Mirrors kernel_tq_fattn_vec_packed with 4 V-passes and 2 outputs per thread.
|
||||
+// ─────────────────────────────────────────────────────────────────────────────
|
||||
+kernel void kernel_tq_fattn_vec_packed_d256(
|
||||
+ constant ggml_metal_kargs_tq_fattn_vec & args,
|
||||
+ device const char * Q_data [[buffer(1)]],
|
||||
+ device const uint8_t * K_packed [[buffer(2)]],
|
||||
+ device const uint8_t * V_packed [[buffer(3)]],
|
||||
+ device const half * mask_data [[buffer(4)]],
|
||||
+ device const float * K_scales [[buffer(5)]],
|
||||
+ device const float * K_cb [[buffer(6)]],
|
||||
+ device const float * V_scales [[buffer(7)]],
|
||||
+ device const float * V_cb [[buffer(8)]],
|
||||
+ device float * dst [[buffer(9)]],
|
||||
+ uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
+ uint tiisg [[thread_index_in_simdgroup]],
|
||||
+ uint sgitg [[simdgroup_index_in_threadgroup]])
|
||||
+{
|
||||
+ constexpr int D = 256;
|
||||
+ constexpr int nthreads = 128;
|
||||
+ constexpr int nthreads_KQ = 8;
|
||||
+ constexpr int nthreads_V = 8;
|
||||
+ constexpr int V_cols_per_iter = 4;
|
||||
+ constexpr int nwarps = 4;
|
||||
+
|
||||
+ const int ic0 = (int)tgpig.x * args.ncols;
|
||||
+ const int blk_z = (int)tgpig.z;
|
||||
+ const int sequence = blk_z / args.nHeadsQ;
|
||||
+ const int head = blk_z % args.nHeadsQ;
|
||||
+ const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
|
||||
+ const int head_kv = head / gqa_ratio;
|
||||
+
|
||||
+ const int tid = (int)sgitg * 32 + (int)tiisg;
|
||||
+
|
||||
+ device const float * Q = (device const float *)Q_data
|
||||
+ + (long)sequence * (args.nb03 / sizeof(float))
|
||||
+ + (long)head * (args.nb02 / sizeof(float))
|
||||
+ + (long)ic0 * (args.nb01 / sizeof(float));
|
||||
+
|
||||
+ device const uint8_t * K_p = K_packed
|
||||
+ + (long)args.firstCell * args.nKVHeads * args.packedBytes
|
||||
+ + (long)head_kv * args.packedBytes;
|
||||
+ device const float * K_sc = K_scales
|
||||
+ + (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
+
|
||||
+ device const uint8_t * V_p = V_packed
|
||||
+ + (long)args.firstCell * args.nKVHeads * args.v_packedBytes
|
||||
+ + (long)head_kv * args.v_packedBytes;
|
||||
+ device const float * V_sc = V_scales
|
||||
+ + (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
+
|
||||
+ device const half * maskh = args.hasMask
|
||||
+ ? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
|
||||
+ : nullptr;
|
||||
+
|
||||
+ const int k_cb_mask = (1 << args.bits) - 1;
|
||||
+ const float k_cb_lane = K_cb[tiisg & k_cb_mask];
|
||||
+ const int v_cb_mask = (1 << args.v_bits) - 1;
|
||||
+ const float v_cb_lane = V_cb[tiisg & v_cb_mask];
|
||||
+
|
||||
+ const int tid_kq = (int)tiisg % nthreads_KQ;
|
||||
+
|
||||
+ float2 Q_reg[2][16];
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ const int elem = tid_kq * 16 + i;
|
||||
+ Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
|
||||
+ }
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ Q_reg[j][i].x *= args.scale;
|
||||
+ Q_reg[j][i].y *= args.scale;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ float2 VKQ[2][16];
|
||||
+ for (int j = 0; j < 2; j++)
|
||||
+ for (int i = 0; i < 16; i++)
|
||||
+ VKQ[j][i] = float2(0.0f, 0.0f);
|
||||
+
|
||||
+ float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
|
||||
+ float KQ_sum[2] = { 0.0f, 0.0f };
|
||||
+
|
||||
+ threadgroup float KQ_tg[4096];
|
||||
+ threadgroup float KQ_max_tg[2][32];
|
||||
+ threadgroup float KQ_sum_tg[2][32];
|
||||
+
|
||||
+ for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
|
||||
+
|
||||
+ float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
|
||||
+
|
||||
+ for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
|
||||
+ const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
|
||||
+ const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
|
||||
+ const int cell_rel = k_VKQ_0 + i_KQ;
|
||||
+ const bool in_range = (cell_rel < args.nCells);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
|
||||
+ const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
+
|
||||
+ float sum = 0.0f;
|
||||
+ for (int k = 0; k < 16; k++) {
|
||||
+ const int start_elem = tid_kq * 32 + k * 2;
|
||||
+ float k_dec[2];
|
||||
+ if (args.bits == 3) {
|
||||
+ const int bit_pos0 = start_elem * 3;
|
||||
+ const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
|
||||
+ const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
|
||||
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((w0 >> sh0) & 7)) * rms_scale;
|
||||
+ const int bit_pos1 = (start_elem + 1) * 3;
|
||||
+ const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
|
||||
+ const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
|
||||
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((w1 >> sh1) & 7)) * rms_scale;
|
||||
+ } else {
|
||||
+ const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
|
||||
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
|
||||
+ const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
|
||||
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
|
||||
+ }
|
||||
+ sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
|
||||
+ }
|
||||
+ sum += simd_shuffle_xor(sum, 4);
|
||||
+ sum += simd_shuffle_xor(sum, 2);
|
||||
+ sum += simd_shuffle_xor(sum, 1);
|
||||
+
|
||||
+ if (args.logit_softcap != 0.0f) {
|
||||
+ sum = args.logit_softcap * tanh(sum);
|
||||
+ }
|
||||
+ if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
|
||||
+ sum += float(maskh[(long)j * args.ne31 + i_KQ]);
|
||||
+ }
|
||||
+ if (!in_range) sum = -FLT_MAX/2.0f;
|
||||
+
|
||||
+ KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
|
||||
+
|
||||
+ if (tid_kq == (uint)i_KQ_0) {
|
||||
+ KQ_tg[j * nthreads + tid] = sum;
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ KQ_max_new[j] = simd_max(KQ_max_new[j]);
|
||||
+
|
||||
+ const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
|
||||
+ KQ_max[j] = KQ_max_new[j];
|
||||
+
|
||||
+ const float kq_val = KQ_tg[j * nthreads + tid];
|
||||
+ const float kq_exp = exp(kq_val - KQ_max[j]);
|
||||
+ KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
|
||||
+ KQ_tg[j * nthreads + tid] = kq_exp;
|
||||
+
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ VKQ[j][i].x *= KQ_max_scale;
|
||||
+ VKQ[j][i].y *= KQ_max_scale;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ // D=256: 4 passes cover V[0..63], V[64..127], V[128..191], V[192..255].
|
||||
+ for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
|
||||
+ const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
|
||||
+ const int cell_rel = k_VKQ_0 + k;
|
||||
+ const bool in_range_v = (cell_rel < args.nCells);
|
||||
+
|
||||
+ float KQ_k[2];
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ KQ_k[j] = KQ_tg[j * nthreads + k];
|
||||
+ }
|
||||
+
|
||||
+ device const uint8_t * v_row = in_range_v
|
||||
+ ? V_p + (long)cell_rel * args.nKVHeads * args.v_packedBytes
|
||||
+ : nullptr;
|
||||
+ const float v_rms = in_range_v ? V_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
+
|
||||
+ const int v_tid = (int)tiisg % nthreads_V;
|
||||
+ for (int pass = 0; pass < 4; pass++) {
|
||||
+ float v_dec[8];
|
||||
+ const int start_elem = pass * 64 + v_tid * 8;
|
||||
+ if (v_row && start_elem < D) {
|
||||
+ tq_decode_8_shfl(v_row, v_cb_lane, v_rms, start_elem, args.v_bits, v_dec);
|
||||
+ } else {
|
||||
+ tq_decode_8_shfl(K_p, v_cb_lane, 0.0f, 0, args.v_bits, v_dec);
|
||||
+ }
|
||||
+ for (int i = 0; i < 8; i++) {
|
||||
+ const int vkq_idx = pass * 4 + i / 2;
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (i % 2 == 0) VKQ[j][vkq_idx].x += v_dec[i] * KQ_k[j];
|
||||
+ else VKQ[j][vkq_idx].y += v_dec[i] * KQ_k[j];
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ } // end KV loop
|
||||
+
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (sgitg == 0) {
|
||||
+ KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
|
||||
+ KQ_sum_tg[j][tiisg] = 0.0f;
|
||||
+ }
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (tiisg == 0) {
|
||||
+ KQ_max_tg[j][sgitg] = KQ_max[j];
|
||||
+ }
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
|
||||
+
|
||||
+ float kqmax_new = KQ_max_tg[j][tiisg];
|
||||
+ kqmax_new = simd_max(kqmax_new);
|
||||
+ const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
|
||||
+ KQ_max[j] = kqmax_new;
|
||||
+
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ VKQ[j][i].x *= kqmax_scale;
|
||||
+ VKQ[j][i].y *= kqmax_scale;
|
||||
+ }
|
||||
+
|
||||
+ const int v_tid = (int)tiisg % nthreads_V;
|
||||
+ threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
|
||||
+ + (long)sgitg * (V_cols_per_iter * D/2)
|
||||
+ + (long)((int)tiisg / nthreads_V) * (D/2);
|
||||
+ VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
|
||||
+ VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
|
||||
+ VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
|
||||
+ VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
|
||||
+
|
||||
+ KQ_sum[j] *= kqmax_scale;
|
||||
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
+ if (tiisg == 0) {
|
||||
+ KQ_sum_tg[j][sgitg] = KQ_sum[j];
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ KQ_sum[j] = KQ_sum_tg[j][tiisg];
|
||||
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
+
|
||||
+ const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
|
||||
+ for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
|
||||
+ const int out_elem = out_offset + tid;
|
||||
+ float dst_val = 0.0f;
|
||||
+ for (int w = 0; w < nwarps; w++) {
|
||||
+ for (int v = 0; v < V_cols_per_iter; v++) {
|
||||
+ dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
|
||||
+ }
|
||||
+ }
|
||||
+ dst_val /= KQ_sum[j];
|
||||
+ dst[out_idx * D + out_elem] = dst_val;
|
||||
+ }
|
||||
+
|
||||
+ if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+ }
|
||||
+}
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Verrilli <msv@pobox.com>
|
||||
Date: Tue, 21 Apr 2026 21:02:22 +0000
|
||||
Subject: [PATCH] ml/backend/ggml: optimize the Metal TurboQuant dequant kernel
|
||||
|
||||
Two independent improvements to kernel_tq_dequant.
|
||||
|
||||
Dispatch: the kernel uses only [[thread_index_in_simdgroup]] (tiisg,
|
||||
0..31 per SIMDgroup), has no sgitg stride, no threadgroup barriers, and
|
||||
no atomics. It was nonetheless dispatched with 128-thread threadgroups
|
||||
(four SIMDgroups x 32), so all four SIMDgroups ran the outer loop
|
||||
identically and wrote the same f16 bytes four times. Drop non-outlier
|
||||
dispatches to 32-thread threadgroups. The outlier kernel still
|
||||
dispatches at 128 - it uses s_mask atomics and a popcount reduction
|
||||
that legitimately need the full threadgroup.
|
||||
|
||||
Inner loop: replace the 1-element-per-iteration scalar path with a
|
||||
4-elements-per-iteration vectorised path that issues a single half4
|
||||
store per iteration. For bits=2 the 4 elements fit in one byte; for
|
||||
bits=3 they fit in a 16-bit window (shift0 in {0,4}). The per-cell
|
||||
scale is pre-multiplied into the codebook lane at kernel entry so the
|
||||
decode path drops one fmul per element. The scalar fallback is
|
||||
preserved for head dims that aren't a multiple of 128.
|
||||
|
||||
Decode throughput on llama3.2:3b tq3 at 32k context improves ~9% on
|
||||
Apple Silicon (~42 -> ~46 tok/s); the K-only DequantK path used by
|
||||
tq2k/tq3k benefits from the same kernel.
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 15 ++++++++---
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 37 +++++++++++++++++++++++---
|
||||
2 files changed, 45 insertions(+), 7 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index b5ab1c14e..ea3580b0b 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -4197,7 +4197,12 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
- const int block_size = std::min(128, headDim);
|
||||
+ // Outlier kernel uses a 128-thread TG: threadgroup barriers + atomics on
|
||||
+ // s_mask require all threads. Non-outlier kernel uses a single simdgroup
|
||||
+ // (32 threads): it only reads tiisg and has no barriers, so a larger TG
|
||||
+ // just replicates work across idle simdgroups.
|
||||
+ const int outlier_block_size = 128;
|
||||
+ const int nonoutlier_block_size = 32;
|
||||
|
||||
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
|
||||
const int regular_count = headDim - outlierCount;
|
||||
@@ -4234,7 +4239,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
- ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
|
||||
+ ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, outlier_block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -4259,7 +4264,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); // codebook
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
||||
|
||||
- ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
|
||||
+ ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, nonoutlier_block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -4283,7 +4288,9 @@ int ggml_metal_op_tq_dequant_kv(ggml_metal_op_t ctx, int idx) {
|
||||
const int k_codebook_len = (int)op->src[2]->ne[0];
|
||||
const int v_codebook_len = (int)op->src[5]->ne[0];
|
||||
|
||||
- const int block_size = std::min(128, headDim);
|
||||
+ // kernel_tq_dequant is single-simdgroup (uses only tiisg, no barriers,
|
||||
+ // no atomics) — 32-thread TGs eliminate 4× redundant work vs 128-thread.
|
||||
+ const int block_size = 32;
|
||||
const size_t plane_size = (size_t)headDim * numKVHeads * nCells;
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index 2718e8bb1..b61f755aa 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -10246,8 +10246,40 @@ kernel void kernel_tq_dequant(
|
||||
|
||||
const int cb_mask = (1 << args.bits) - 1;
|
||||
// Load one codebook entry per lane; period ≤ 8 divides 32, so simd_shuffle is exact.
|
||||
- const float cb_lane = codebook[tiisg & cb_mask];
|
||||
+ // Pre-multiply by the per-cell scale so the decode path drops 1 fmul per element.
|
||||
+ const float scaled_cb_lane = codebook[tiisg & cb_mask] * scale;
|
||||
+
|
||||
+ // Fast path: when headDim is a multiple of 128 (=32 lanes × 4 elements per thread),
|
||||
+ // each thread decodes 4 consecutive D-positions per iter and writes a single half4.
|
||||
+ // bits=2: 4 elems = 8 bits, always byte-aligned (shift0=0 since elem_base mod 4 == 0).
|
||||
+ // bits=3: 4 elems = 12 bits, shift0 ∈ {0,4}, always fits in a 16-bit window.
|
||||
+ // A 16-bit window (2 packed bytes) suffices for both. The scalar fallback
|
||||
+ // covers non-multiple-of-128 head dims.
|
||||
+ if ((args.headDim & 127) == 0) {
|
||||
+ const int iters = args.headDim >> 7;
|
||||
+ for (int iter = 0; iter < iters; iter++) {
|
||||
+ const int elem_base = iter * 128 + (int)tiisg * 4;
|
||||
+ const int bit_offset = elem_base * args.bits;
|
||||
+ const int byte_base = bit_offset >> 3;
|
||||
+ const int shift0 = bit_offset & 7;
|
||||
+
|
||||
+ uint w = (uint)cell_packed[byte_base];
|
||||
+ if (args.bits == 3) {
|
||||
+ w |= ((uint)cell_packed[byte_base + 1] << 8);
|
||||
+ }
|
||||
+
|
||||
+ half4 v4;
|
||||
+ v4[0] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> shift0 ) & cb_mask)));
|
||||
+ v4[1] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + args.bits)) & cb_mask)));
|
||||
+ v4[2] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 2 * args.bits)) & cb_mask)));
|
||||
+ v4[3] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 3 * args.bits)) & cb_mask)));
|
||||
+
|
||||
+ *((device half4 *)(cell_out + elem_base)) = v4;
|
||||
+ }
|
||||
+ return;
|
||||
+ }
|
||||
|
||||
+ // Scalar fallback for head dims that aren't a multiple of 128.
|
||||
for (uint elem = tiisg; elem < (uint)args.headDim; elem += 32) {
|
||||
const int bit_offset = (int)elem * args.bits;
|
||||
const int byte_idx = bit_offset >> 3;
|
||||
@@ -10256,8 +10288,7 @@ kernel void kernel_tq_dequant(
|
||||
if (shift + args.bits > 8) {
|
||||
idx |= ((int)(cell_packed[byte_idx + 1] << (8 - shift))) & cb_mask;
|
||||
}
|
||||
- const float val = simd_shuffle(cb_lane, (ushort)idx) * scale;
|
||||
- cell_out[elem] = half(val);
|
||||
+ cell_out[elem] = half(simd_shuffle(scaled_cb_lane, (ushort)idx));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -258,16 +258,20 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
|||
if fa {
|
||||
slog.Info("enabling flash attention")
|
||||
loadRequest.FlashAttention = ml.FlashAttentionEnabled
|
||||
}
|
||||
|
||||
// Flash Attention also supports kv cache quantization
|
||||
// Enable if the requested and kv cache type is supported by the model
|
||||
if f.SupportsKVCacheType(kvct) {
|
||||
loadRequest.KvCacheType = kvct
|
||||
} else {
|
||||
// Most quantized KV cache types require flash attention, but TurboQuant
|
||||
// K-only presets (tq2k/tq3k) dequant to f16 before attention and work
|
||||
// with either FA or the standard softmax+matmul attention path.
|
||||
if kvct != "" {
|
||||
switch {
|
||||
case f.KVCacheTypeRequiresFlashAttention(kvct) && !fa:
|
||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||
case !f.SupportsKVCacheType(kvct):
|
||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||
default:
|
||||
loadRequest.KvCacheType = kvct
|
||||
}
|
||||
} else if kvct != "" && kvct != "f16" {
|
||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -365,7 +365,7 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
|
|||
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
||||
sb.WriteString("[")
|
||||
defer func() { sb.WriteString("]") }()
|
||||
for i := 0; i < dims[0]; i++ {
|
||||
for i := range dims[0] {
|
||||
if i >= items && i < dims[0]-items {
|
||||
sb.WriteString("..., ")
|
||||
// skip to next printable element
|
||||
|
|
@ -408,9 +408,82 @@ const (
|
|||
DTypeQ80
|
||||
DTypeQ40
|
||||
DTypeI32
|
||||
DTypeI8
|
||||
DTypeMXFP4
|
||||
DTypeTQ2
|
||||
DTypeTQ3
|
||||
DTypeTQ3K
|
||||
DTypeTQ2K
|
||||
)
|
||||
|
||||
// TQCompressedKManager manages GPU-resident packed N-bit key indices for
|
||||
// TurboQuant VRAM compression. Implemented by the ggml backend.
|
||||
type TQCompressedKManager interface {
|
||||
// EnsureLayer allocates per-layer GPU tensors (packed + scales) on first use.
|
||||
// capacity = total cache cell count for this layer.
|
||||
EnsureLayer(layer, capacity int)
|
||||
|
||||
// EncodeK creates a GGML_OP_TQ_ENCODE graph node that encodes key vectors
|
||||
// into the persistent compressed buffer. Returns a view of the packed buffer;
|
||||
// pass this result to DequantK to establish the graph ordering dependency.
|
||||
// key: [headDim, numKVHeads, batchSize] f16
|
||||
// firstCell: first cache slot index; cells are sequential (firstCell+0, +1, ...)
|
||||
EncodeK(ctx Context, layer int, key Tensor, firstCell int) Tensor
|
||||
|
||||
// DequantK creates a GGML_OP_TQ_DEQUANT graph node returning
|
||||
// [headDim, numKVHeads, nCells] f16 ready for flash attention.
|
||||
// encodeResult is the tensor returned by EncodeK for this layer+step,
|
||||
// establishing encode→dequant ordering in the ggml graph.
|
||||
DequantK(ctx Context, layer int, encodeResult Tensor, firstCell, nCells int) Tensor
|
||||
|
||||
// GetAsTQTensor returns a tqTensor wrapping the packed K buffer for the
|
||||
// fused TQ flash-attention path. Returns (nil, false) when the config is
|
||||
// not supported by the fused kernel — callers must fall back to DequantK in that case.
|
||||
GetAsTQTensor(ctx Context, layer int, encodeResult Tensor, firstCell, nCells int) (Tensor, bool)
|
||||
|
||||
// GetAsTQTensorKV returns a tqTensor wrapping both packed K and packed V
|
||||
// buffers for the fully fused K+V TQ flash-attention path. Returns (nil, false)
|
||||
// when fused is not supported or V compression is not active.
|
||||
GetAsTQTensorKV(ctx Context, layer int, kEncodeResult, vEncodeResult Tensor, firstCell, nCells int) (Tensor, bool)
|
||||
|
||||
// RotationMatrix returns the persistent [headDim, headDim] f32 R^T tensor.
|
||||
RotationMatrix(ctx Context, layer int) Tensor
|
||||
|
||||
// EnsureVLayer allocates per-layer V packed and scales tensors on first use.
|
||||
EnsureVLayer(layer, capacity int)
|
||||
|
||||
// EncodeV creates a GGML_OP_TQ_ENCODE_V graph node encoding value vectors
|
||||
// into the persistent compressed V buffer. Returns a view of the V packed buffer.
|
||||
EncodeV(ctx Context, layer int, value Tensor, firstCell int) Tensor
|
||||
|
||||
// DequantV creates a GGML_OP_TQ_DEQUANT graph node returning
|
||||
// [headDim, numKVHeads, nCells] f16 for V, ready for flash attention.
|
||||
DequantV(ctx Context, layer int, encodeResult Tensor, firstCell, nCells int) Tensor
|
||||
|
||||
// DequantKV creates a single GGML_OP_TQ_DEQUANT_KV graph node that dequants
|
||||
// both K and V in one op, halving GGML scheduler overhead.
|
||||
// Returns (key, value) as [headDim, numKVHeads, nCells] f16 views.
|
||||
DequantKV(ctx Context, layer int, kEncodeResult, vEncodeResult Tensor, firstCell, nCells int) (Tensor, Tensor)
|
||||
|
||||
// EncodeKV creates a single GGML_OP_TQ_ENCODE_KV graph node encoding both
|
||||
// K and V, halving scheduler overhead vs separate EncodeK + EncodeV.
|
||||
EncodeKV(ctx Context, layer int, key, value Tensor, firstCell int) (Tensor, Tensor)
|
||||
|
||||
// Close frees all GPU buffers.
|
||||
Close()
|
||||
}
|
||||
|
||||
// TQCompressedKBackend is implemented by backends that support TQ compressed K.
|
||||
//
|
||||
// outlierBits/outlierCount: optional post-rotation outlier split. When
|
||||
// outlierCount > 0, the manager allocates additional per-layer tensors for an
|
||||
// outlier sub-block at the specified bit width and the encode/dequant kernels
|
||||
// route through the outlier-aware path. Set outlierCount = 0 for pure uniform
|
||||
// per-channel Lloyd-Max at `bits` (historical behavior).
|
||||
type TQCompressedKBackend interface {
|
||||
NewTQCompressedKManager(headDim, numKVHeads, bits int, rotationSeed uint64, vBits, outlierBits, outlierCount int) TQCompressedKManager
|
||||
}
|
||||
|
||||
type SamplingMode int
|
||||
|
||||
const (
|
||||
|
|
|
|||
|
|
@ -111,6 +111,30 @@ type Backend struct {
|
|||
|
||||
flashAttention ml.FlashAttentionType
|
||||
|
||||
// tqRotationMatrix is a per-call flag set by TurboQuantCache.Get() right
|
||||
// before it returns rotated K. SDPA reads and clears it; if non-nil, SDPA
|
||||
// applies R^T @ Q to match the K rotation. This per-call (not sticky)
|
||||
// semantics is required for mixed-head-dim models like gemma3, where only
|
||||
// the global sub-cache of a WrapperCache is TQ-wrapped and the SWA sub-
|
||||
// cache passes through plain f16 K. A sticky rotation would be applied to
|
||||
// every SDPA call and corrupt attention on the unwrapped SWA layers.
|
||||
tqRotationMatrix ml.Tensor
|
||||
|
||||
// tqVRotationMatrix is a per-call flag set by TurboQuantCache.Get() right
|
||||
// before returning K+V (tq3/tq2) when V was encoded with Hadamard rotation
|
||||
// (R^T @ v). SDPA reads and clears it; if non-nil, SDPA applies R @
|
||||
// attn_out after the attention op to undo the V rotation. Per-call (not
|
||||
// sticky) semantics is required for mixed-head-dim models like gemma3,
|
||||
// where only the global sub-cache of a WrapperCache is TQ-wrapped and the
|
||||
// SWA sub-cache's V is plain f16. A sticky rotation would corrupt
|
||||
// attention on unwrapped SWA layers.
|
||||
tqVRotationMatrix ml.Tensor
|
||||
|
||||
// tqVRotFusedInDequant is true when DequantKV applies the V rotation undo
|
||||
// internally (via the rotated V kernel). When true, the stock FA path in
|
||||
// SDPA skips the mulmat undo — V is already in the unrotated domain.
|
||||
tqVRotFusedInDequant bool
|
||||
|
||||
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
||||
maxGraphNodes int
|
||||
|
||||
|
|
@ -664,6 +688,97 @@ func (b *Backend) NewContext() ml.Context {
|
|||
return b.NewContextSize(b.maxGraphNodes)
|
||||
}
|
||||
|
||||
// TQDeviceScan describes the GPU devices discovered in the scheduler from the
|
||||
// perspective of TurboQuant: which one TQ will use, plus the names of any GPUs
|
||||
// that were skipped because they're not wave32-capable (NVIDIA < Pascal, or
|
||||
// AMD wave64 Vega/GCN/CDNA). Used to emit actionable warnings and to avoid
|
||||
// dispatching TQ kernels to an unsupported card — either one that would hit
|
||||
// the compute-capability assert in tq-dequant.cu or one whose HIP __shfl_sync
|
||||
// shim would silently produce garbage on 64-lane warps.
|
||||
type TQDeviceScan struct {
|
||||
// selected is the buffer type TQ will place its tensors on. Zero-valued
|
||||
// if no TQ-capable GPU is present.
|
||||
selected C.ggml_backend_buffer_type_t
|
||||
selectedOK bool
|
||||
SelectedName string // e.g. "NVIDIA Tesla P40"
|
||||
SelectedCC string // e.g. "6.1"
|
||||
SelectedLibrary string // e.g. "Metal", "CUDA", "ROCm"
|
||||
// Accepted lists "<name> (cc X.Y)" for every TQ-capable GPU in schedBufts.
|
||||
Accepted []string
|
||||
// Skipped lists "<name> (cc X.Y, <library>): <reason>" for every non-host GPU
|
||||
// in schedBufts that fails the wave32 gate (CUDA < Pascal, ROCm wave64, or a
|
||||
// non-CUDA/ROCm backend). The reason is included so operators can diagnose
|
||||
// without reading the source tree.
|
||||
Skipped []string
|
||||
}
|
||||
|
||||
// scanTQDevices walks the scheduler buffer types and classifies each GPU via
|
||||
// tqDeviceAccepted: accepted GPUs (NVIDIA Pascal+, AMD RDNA1+) are eligible to
|
||||
// host TQ tensors; others are skipped with a diagnosable reason. The first
|
||||
// accepted buffer type is marked as selected; TQ tensors will be placed there
|
||||
// regardless of which scheduler index it occupies.
|
||||
func (b *Backend) scanTQDevices() TQDeviceScan {
|
||||
var scan TQDeviceScan
|
||||
for _, buft := range b.schedBufts {
|
||||
if C.ggml_backend_buft_is_host(buft) {
|
||||
continue
|
||||
}
|
||||
dev := C.ggml_backend_buft_get_device(buft)
|
||||
if dev == nil {
|
||||
continue
|
||||
}
|
||||
var props C.struct_ggml_backend_dev_props
|
||||
C.ggml_backend_dev_get_props(dev, &props)
|
||||
name := C.GoString(props.name)
|
||||
var library string
|
||||
if props.library != nil {
|
||||
library = C.GoString(props.library)
|
||||
}
|
||||
cc := fmt.Sprintf("%d.%d", int(props.compute_major), int(props.compute_minor))
|
||||
accepted, skipReason := tqDeviceAccepted(library, int(props.compute_major))
|
||||
if !accepted {
|
||||
scan.Skipped = append(scan.Skipped,
|
||||
fmt.Sprintf("%s (cc %s, %s): %s", name, cc, library, skipReason))
|
||||
continue
|
||||
}
|
||||
scan.Accepted = append(scan.Accepted, fmt.Sprintf("%s (cc %s)", name, cc))
|
||||
if !scan.selectedOK {
|
||||
scan.selected = buft
|
||||
scan.selectedOK = true
|
||||
scan.SelectedName = name
|
||||
scan.SelectedCC = cc
|
||||
scan.SelectedLibrary = library
|
||||
}
|
||||
}
|
||||
return scan
|
||||
}
|
||||
|
||||
// newTQContext creates a GGML context whose tensors are allocated in GPU
|
||||
// memory (CUDA, HIP, or Metal). Used by the TQ compressed KV cache manager:
|
||||
// TQ encode/decode ops require their tensors (packed buffers, scales,
|
||||
// codebook, rotation matrix) to reside on the GPU regardless of which model
|
||||
// layers are on CPU vs GPU. TQ tensors always land on the first TQ-capable
|
||||
// GPU — NVIDIA Pascal (cc 6.0)+, AMD RDNA1 (gfx1010)+, or Apple Silicon
|
||||
// (Metal, always wave32) — in the scheduler. In a mixed rig, unsupported
|
||||
// cards are skipped: older NVIDIA would hit the compute-capability assert in
|
||||
// tq-dequant.cu, and wave64 AMD (Vega/CDNA) would silently corrupt through
|
||||
// the HIP __shfl_sync shim.
|
||||
func (b *Backend) newTQContext(n int) *Context {
|
||||
var allocatedBuffers []C.ggml_backend_buffer_t
|
||||
scan := b.scanTQDevices()
|
||||
return &Context{
|
||||
b: b,
|
||||
ctx: C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
|
||||
no_alloc: true,
|
||||
}),
|
||||
buft: scan.selected,
|
||||
allocatedBuffers: &allocatedBuffers,
|
||||
maxGraphNodes: n,
|
||||
layer: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backend) NewContextSize(n int) ml.Context {
|
||||
if n > b.maxGraphNodes {
|
||||
panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
|
||||
|
|
@ -683,6 +798,23 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
|||
}
|
||||
}
|
||||
|
||||
// SetTQRotationMatrix registers the TQ rotation matrix for Q rotation in SDPA.
|
||||
// Called by TurboQuantCache when Phase 2 CUDA dequant activates.
|
||||
func (b *Backend) SetTQRotationMatrix(m ml.Tensor) {
|
||||
b.tqRotationMatrix = m
|
||||
}
|
||||
|
||||
// SetTQVRotationMatrix registers the rotation matrix used for V encoding.
|
||||
// When non-nil, SDPA applies R @ attn_out after the TQ fused flash attention
|
||||
// to undo the V rotation (V was stored as R^T @ v).
|
||||
func (b *Backend) SetTQVRotationMatrix(m ml.Tensor) {
|
||||
b.tqVRotationMatrix = m
|
||||
}
|
||||
|
||||
func (b *Backend) SetTQVRotFusedInDequant(fused bool) {
|
||||
b.tqVRotFusedInDequant = fused
|
||||
}
|
||||
|
||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||
if b.flashAttention == ml.FlashAttentionEnabled {
|
||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16}
|
||||
|
|
@ -1118,6 +1250,8 @@ func (t *Tensor) DType() ml.DType {
|
|||
return ml.DTypeQ40
|
||||
case C.GGML_TYPE_I32:
|
||||
return ml.DTypeI32
|
||||
case C.GGML_TYPE_I8:
|
||||
return ml.DTypeI8
|
||||
case C.GGML_TYPE_MXFP4:
|
||||
return ml.DTypeMXFP4
|
||||
default:
|
||||
|
|
@ -1137,6 +1271,8 @@ func ggmlDType(dtype ml.DType) uint32 {
|
|||
return C.GGML_TYPE_Q4_0
|
||||
case ml.DTypeI32:
|
||||
return C.GGML_TYPE_I32
|
||||
case ml.DTypeI8:
|
||||
return C.GGML_TYPE_I8
|
||||
case ml.DTypeMXFP4:
|
||||
return C.GGML_TYPE_MXFP4
|
||||
default:
|
||||
|
|
@ -1367,6 +1503,191 @@ func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tenso
|
|||
}
|
||||
}
|
||||
|
||||
// TQEncode creates a GGML_OP_TQ_ENCODE graph node.
|
||||
// t = packed buffer [packedBytes*numKVHeads, capacity] i8 (dst, view returned)
|
||||
// scales = [numKVHeads, capacity] f32 (written as side output via src[3])
|
||||
// k = [headDim, numKVHeads, batchSize] f16
|
||||
// rot = [headDim, headDim] f32 (R^T row-major)
|
||||
// cidx = [batchSize] i32
|
||||
// bounds = [(1<<bits)-1] f32
|
||||
func (t *Tensor) TQEncode(ctx ml.Context, scales, k, rot ml.Tensor, firstCell int, bounds ml.Tensor, bits int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tq_encode(
|
||||
ctx.(*Context).ctx,
|
||||
t.t,
|
||||
scales.(*Tensor).t,
|
||||
k.(*Tensor).t,
|
||||
rot.(*Tensor).t,
|
||||
C.int32_t(firstCell),
|
||||
bounds.(*Tensor).t,
|
||||
C.int32_t(bits),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// TQEncodeV creates a GGML_OP_TQ_ENCODE_V graph node.
|
||||
// t = packed buffer [packedBytes*numKVHeads, capacity] i8 (dst, view returned)
|
||||
// scales = [numKVHeads, capacity] f32 (written as side output via src[3])
|
||||
// v = [headDim, numKVHeads, batchSize] f16 or f32
|
||||
// rot = [headDim, headDim] f32 R^T row-major, or nil (no rotation)
|
||||
// bounds = [(1<<bits)-1] f32
|
||||
func (t *Tensor) TQEncodeV(ctx ml.Context, scales, v ml.Tensor, rot ml.Tensor, firstCell int, bounds ml.Tensor, bits int) ml.Tensor {
|
||||
var rotT *C.struct_ggml_tensor
|
||||
if rot != nil {
|
||||
rotT = rot.(*Tensor).t
|
||||
}
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tq_encode_v(
|
||||
ctx.(*Context).ctx,
|
||||
t.t,
|
||||
scales.(*Tensor).t,
|
||||
v.(*Tensor).t,
|
||||
rotT,
|
||||
C.int32_t(firstCell),
|
||||
bounds.(*Tensor).t,
|
||||
C.int32_t(bits),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// TQEncodeKV creates a GGML_OP_TQ_ENCODE_KV graph node encoding both K and V
|
||||
// in a single GGML op. t = K packed buffer (view returned). V packed buffer
|
||||
// is written as a side effect via src[5].
|
||||
func (t *Tensor) TQEncodeKV(ctx ml.Context,
|
||||
kScales, k, rot, kBounds ml.Tensor,
|
||||
vPacked, vScales, v, vBounds ml.Tensor,
|
||||
firstCell, kBits, vBits int,
|
||||
) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tq_encode_kv(
|
||||
ctx.(*Context).ctx,
|
||||
t.t,
|
||||
kScales.(*Tensor).t,
|
||||
k.(*Tensor).t,
|
||||
rot.(*Tensor).t,
|
||||
kBounds.(*Tensor).t,
|
||||
vPacked.(*Tensor).t,
|
||||
vScales.(*Tensor).t,
|
||||
v.(*Tensor).t,
|
||||
vBounds.(*Tensor).t,
|
||||
C.int32_t(firstCell),
|
||||
C.int32_t(kBits),
|
||||
C.int32_t(vBits),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// TQDequant creates a GGML_OP_TQ_DEQUANT graph node.
|
||||
// t = encode result (view of packed, establishes graph dependency)
|
||||
// scales = [numKVHeads, capacity] f32
|
||||
// codebook = [1<<bits] f32
|
||||
// Returns [headDim, numKVHeads, nCells] f16.
|
||||
func (t *Tensor) TQDequant(ctx ml.Context, scales, codebook ml.Tensor, headDim, numKVHeads, nCells, firstCell, bits int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tq_dequant(
|
||||
ctx.(*Context).ctx,
|
||||
t.t,
|
||||
scales.(*Tensor).t,
|
||||
codebook.(*Tensor).t,
|
||||
C.int(headDim),
|
||||
C.int(numKVHeads),
|
||||
C.int(nCells),
|
||||
C.int(firstCell),
|
||||
C.int(bits),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// TQEncodeOutlier creates a GGML_OP_TQ_ENCODE graph node with the outlier
|
||||
// sub-block extension. op_params[3] (outlier_count) > 0 signals to the CUDA
|
||||
// dispatcher that the outlier kernel should run.
|
||||
func (t *Tensor) TQEncodeOutlier(ctx ml.Context, scales, k, rot ml.Tensor, firstCell int, bounds ml.Tensor, bits int,
|
||||
outlierPacked, outlierScales, outlierIndices, outlierBounds ml.Tensor, outlierBits, outlierCount int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tq_encode_outlier(
|
||||
ctx.(*Context).ctx,
|
||||
t.t,
|
||||
scales.(*Tensor).t,
|
||||
k.(*Tensor).t,
|
||||
rot.(*Tensor).t,
|
||||
C.int32_t(firstCell),
|
||||
bounds.(*Tensor).t,
|
||||
C.int32_t(bits),
|
||||
outlierPacked.(*Tensor).t,
|
||||
outlierScales.(*Tensor).t,
|
||||
outlierIndices.(*Tensor).t,
|
||||
outlierBounds.(*Tensor).t,
|
||||
C.int32_t(outlierBits),
|
||||
C.int32_t(outlierCount),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// TQDequantOutlier creates a GGML_OP_TQ_DEQUANT graph node with the outlier
|
||||
// overwrite pass. op_params[3] (outlier_count) > 0 signals the dispatcher.
|
||||
func (t *Tensor) TQDequantOutlier(ctx ml.Context, scales, codebook ml.Tensor, headDim, numKVHeads, nCells, firstCell, bits int,
|
||||
outlierPacked, outlierScales, outlierIndices, outlierCodebook ml.Tensor, outlierBits, outlierCount int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tq_dequant_outlier(
|
||||
ctx.(*Context).ctx,
|
||||
t.t,
|
||||
scales.(*Tensor).t,
|
||||
codebook.(*Tensor).t,
|
||||
C.int(headDim),
|
||||
C.int(numKVHeads),
|
||||
C.int(nCells),
|
||||
C.int(firstCell),
|
||||
C.int(bits),
|
||||
outlierPacked.(*Tensor).t,
|
||||
outlierScales.(*Tensor).t,
|
||||
outlierIndices.(*Tensor).t,
|
||||
outlierCodebook.(*Tensor).t,
|
||||
C.int32_t(outlierBits),
|
||||
C.int32_t(outlierCount),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// TQDequantKV creates a GGML_OP_TQ_DEQUANT_KV graph node that dequants both
|
||||
// K and V in a single GGML op. Returns a [headDim, numKVHeads, nCells, 2] f16
|
||||
// tensor; the caller splits it into K (ne[3]=0) and V (ne[3]=1) views.
|
||||
func TQDequantKV(ctx ml.Context, b *Backend,
|
||||
kEncode, kScales, kCodebook *Tensor,
|
||||
vEncode, vScales, vCodebook *Tensor,
|
||||
vRotation *Tensor,
|
||||
headDim, numKVHeads, nCells, firstCell, kBits, vBits int,
|
||||
) *Tensor {
|
||||
var vRotT *C.struct_ggml_tensor
|
||||
if vRotation != nil {
|
||||
vRotT = vRotation.t
|
||||
}
|
||||
return &Tensor{
|
||||
b: b,
|
||||
t: C.ggml_tq_dequant_kv(
|
||||
ctx.(*Context).ctx,
|
||||
kEncode.t,
|
||||
kScales.t,
|
||||
kCodebook.t,
|
||||
vEncode.t,
|
||||
vScales.t,
|
||||
vCodebook.t,
|
||||
vRotT,
|
||||
C.int(headDim),
|
||||
C.int(numKVHeads),
|
||||
C.int(nCells),
|
||||
C.int(firstCell),
|
||||
C.int(kBits),
|
||||
C.int(vBits),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SetInplace(ctx ml.Context, src ml.Tensor, nb1, nb2, nb3, offset int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
|
@ -1744,9 +2065,60 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
|||
}
|
||||
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
// TQ consume-once state: TurboQuantCache.Get() sets tqRotationMatrix and
|
||||
// tqVRotationMatrix right before returning rotated K/V for a TQ-wrapped
|
||||
// layer. Read and clear here so the NEXT SDPA call (potentially for a
|
||||
// non-TQ sub-cache of a WrapperCache, e.g. the SWA side of gemma3) doesn't
|
||||
// see a stale rotation and corrupt attention on unrotated tensors. The
|
||||
// nil pre-check keeps non-TQ workloads from dirtying the backend cache
|
||||
// line on every attention op.
|
||||
var rot, vRot ml.Tensor
|
||||
if t.b.tqRotationMatrix != nil || t.b.tqVRotationMatrix != nil {
|
||||
rot = t.b.tqRotationMatrix
|
||||
vRot = t.b.tqVRotationMatrix
|
||||
t.b.tqRotationMatrix = nil
|
||||
t.b.tqVRotationMatrix = nil
|
||||
}
|
||||
|
||||
// TQ: K is stored in rotated space (R^T @ k). Rotate Q to match so
|
||||
// attention = (R^T q)^T (R^T k) = q^T k.
|
||||
// rotTensor stores R^T row-major; ggml_mul_mat(rotTensor, Q) = R^T @ Q.
|
||||
if rot != nil && query.Dim(0) == rot.Dim(0) {
|
||||
// Make query contiguous before mul_mat; permuted (non-contiguous) tensors
|
||||
// may cause incorrect results with cuBLAS batched matmul.
|
||||
query = query.Contiguous(ctx)
|
||||
query = rot.Mulmat(ctx, query)
|
||||
}
|
||||
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
||||
// TQ fused flash attention: check for tqTensor BEFORE permuting value,
|
||||
// because the K+V fused path passes packed V directly (no permute needed).
|
||||
if tqk, ok := key.(*tqTensor); ok {
|
||||
if sinks != nil || vmla != nil {
|
||||
panic("ggml: TQ compressed K does not support sinks or vmla attention")
|
||||
}
|
||||
var attnOut ml.Tensor
|
||||
if tqk.vPacked != nil {
|
||||
// K+V fused: V is packed i8 inside tqTensor; pass it directly.
|
||||
attnOut = t.b.tqFlashAttention(ctx, query.(*Tensor), tqk, tqk.vPacked, mask, scale, 0)
|
||||
} else {
|
||||
// K-only fused: V is f16, permute normally.
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
attnOut = t.b.tqFlashAttention(ctx, query.(*Tensor), tqk, value.(*Tensor), mask, scale, 0)
|
||||
}
|
||||
// If V was encoded with Hadamard rotation (R^T @ v), the FA output is
|
||||
// R^T @ output_orig. Recover output_orig by applying R.
|
||||
// tqVRotationMatrix stores R (not R^T); mul_mat(R, x) = R @ x = output_orig.
|
||||
// Uses the consumed vRot local from the top of SDPA.
|
||||
if vRot != nil && attnOut.Dim(0) == vRot.Dim(0) {
|
||||
attnOut = vRot.(*Tensor).Mulmat(ctx, attnOut)
|
||||
}
|
||||
return attnOut
|
||||
}
|
||||
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
|
|
@ -1764,7 +2136,17 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
|
|||
kqv = cur.(*Tensor).t
|
||||
}
|
||||
|
||||
return &Tensor{b: t.b, t: kqv}
|
||||
attnOut := ml.Tensor(&Tensor{b: t.b, t: kqv})
|
||||
|
||||
// Two-pass TQ path: if DequantKV fused the V rotation undo into the
|
||||
// dequant kernel, V is already unrotated — skip the mulmat.
|
||||
// Otherwise (no rotation fusion), apply R @ attn_out to undo rotation.
|
||||
// Uses the consumed vRot local from the top of SDPA.
|
||||
if vRot != nil && !t.b.tqVRotFusedInDequant && attnOut.Dim(0) == vRot.Dim(0) {
|
||||
attnOut = vRot.(*Tensor).Mulmat(ctx, attnOut)
|
||||
}
|
||||
|
||||
return attnOut
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = &Tensor{
|
||||
|
|
|
|||
155
ml/backend/ggml/ggml/include/ggml.h
vendored
155
ml/backend/ggml/ggml/include/ggml.h
vendored
|
|
@ -567,6 +567,13 @@ extern "C" {
|
|||
|
||||
GGML_OP_GLU,
|
||||
|
||||
GGML_OP_TQ_ENCODE,
|
||||
GGML_OP_TQ_DEQUANT,
|
||||
GGML_OP_TQ_DEQUANT_KV,
|
||||
GGML_OP_TQ_FLASH_ATTN_EXT,
|
||||
GGML_OP_TQ_ENCODE_V,
|
||||
GGML_OP_TQ_ENCODE_KV,
|
||||
|
||||
GGML_OP_COUNT,
|
||||
};
|
||||
|
||||
|
|
@ -2714,6 +2721,154 @@ extern "C" {
|
|||
GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
|
||||
GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
|
||||
|
||||
// TurboQuant GPU-native key encoding: f16 K + rotation → packed N-bit + scales.
|
||||
// packed: [packedBytes*numKVHeads, capacity] GGML_TYPE_I8
|
||||
// scales: [numKVHeads, capacity] f32 (written as side output via src[3])
|
||||
// k: [headDim, numKVHeads, batchSize] f16
|
||||
// rotation: [headDim, headDim] f32 (R^T row-major)
|
||||
// cell_idx: [batchSize] i32
|
||||
// boundaries: [(1<<bits)-1] f32
|
||||
GGML_API struct ggml_tensor * ggml_tq_encode(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * packed,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * rotation,
|
||||
int32_t firstCell,
|
||||
struct ggml_tensor * boundaries,
|
||||
int32_t bits);
|
||||
|
||||
// TurboQuant GPU-native key dequant: packed N-bit → f16 [headDim, numKVHeads, nCells].
|
||||
// encode_result: view of packed buffer returned by ggml_tq_encode (establishes graph
|
||||
// dependency: encode must run before dequant)
|
||||
// scales: [numKVHeads, capacity] f32
|
||||
// codebook: [1<<bits] f32
|
||||
GGML_API struct ggml_tensor * ggml_tq_dequant(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * encode_result,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * codebook,
|
||||
int headDim,
|
||||
int numKVHeads,
|
||||
int nCells,
|
||||
int firstCell,
|
||||
int bits);
|
||||
|
||||
// TurboQuant GPU-native key encode with outlier split: same op family as
|
||||
// ggml_tq_encode but adds top-K outlier sub-block. After rotation, the top
|
||||
// outlier_count channels by absolute magnitude go into a higher-bit sub-block;
|
||||
// the remaining channels go into a lower-bit regular sub-block. Matches the
|
||||
// TurboQuant paper (arXiv 2504.19874, Sec 4.3) experimental setup. outlier_count
|
||||
// is stored in op_params[3] so the CUDA backend dispatcher can branch on it.
|
||||
// packed, scales, outlier_packed, outlier_scales, outlier_indices are written
|
||||
// as side effects (persistent GPU buffers, not graph intermediates).
|
||||
GGML_API struct ggml_tensor * ggml_tq_encode_outlier(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * packed,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * rotation,
|
||||
int32_t firstCell,
|
||||
struct ggml_tensor * boundaries,
|
||||
int32_t bits,
|
||||
struct ggml_tensor * outlier_packed,
|
||||
struct ggml_tensor * outlier_scales,
|
||||
struct ggml_tensor * outlier_indices,
|
||||
struct ggml_tensor * outlier_boundaries,
|
||||
int32_t outlier_bits,
|
||||
int32_t outlier_count);
|
||||
|
||||
// TurboQuant GPU-native key dequant with outlier overwrite pass: reconstructs
|
||||
// [headDim, numKVHeads, nCells] f16 by decoding the regular sub-block for all
|
||||
// positions and then overwriting the outlier channels from the outlier
|
||||
// sub-block. Paired with ggml_tq_encode_outlier.
|
||||
GGML_API struct ggml_tensor * ggml_tq_dequant_outlier(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * encode_result,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * codebook,
|
||||
int headDim,
|
||||
int numKVHeads,
|
||||
int nCells,
|
||||
int firstCell,
|
||||
int bits,
|
||||
struct ggml_tensor * outlier_packed,
|
||||
struct ggml_tensor * outlier_scales,
|
||||
struct ggml_tensor * outlier_indices,
|
||||
struct ggml_tensor * outlier_codebook,
|
||||
int32_t outlier_bits,
|
||||
int32_t outlier_count);
|
||||
|
||||
// Combined K+V dequant: single op dequants both K and V packed buffers.
|
||||
// Output: [headDim, numKVHeads, nCells, 2] f16 where ne[3]=0 is K, ne[3]=1 is V.
|
||||
// Halves scheduler overhead vs separate DequantK + DequantV.
|
||||
GGML_API struct ggml_tensor * ggml_tq_dequant_kv(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k_encode_result,
|
||||
struct ggml_tensor * k_scales,
|
||||
struct ggml_tensor * k_codebook,
|
||||
struct ggml_tensor * v_encode_result,
|
||||
struct ggml_tensor * v_scales,
|
||||
struct ggml_tensor * v_codebook,
|
||||
struct ggml_tensor * v_rotation, // R [headDim, headDim] f32 — undo V rotation during dequant; NULL = no rotation
|
||||
int headDim,
|
||||
int numKVHeads,
|
||||
int nCells,
|
||||
int firstCell,
|
||||
int k_bits,
|
||||
int v_bits);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_tq_flash_attn_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * k_packed,
|
||||
struct ggml_tensor * v, // f16 when v_scales==NULL; packed i8 when v_scales!=NULL
|
||||
struct ggml_tensor * mask,
|
||||
struct ggml_tensor * scales, // K scales [nKVHeads, capacity] f32
|
||||
struct ggml_tensor * codebook, // K codebook [1<<bits] f32
|
||||
float scale, float logit_softcap,
|
||||
int32_t bits, int32_t firstCell,
|
||||
struct ggml_tensor * v_scales, // NULL → V is f16; non-NULL → V is packed i8
|
||||
struct ggml_tensor * v_codebook, // V codebook (required when v_scales != NULL)
|
||||
int32_t v_bits); // V bit width (ignored when v_scales == NULL)
|
||||
|
||||
// TurboQuant GPU-native value encode: V → RMS scale → Lloyd-Max quantize → packed bits.
|
||||
// Like ggml_tq_encode but for V. rotation may be NULL (no rotation) or a [headDim, headDim]
|
||||
// R^T matrix to apply before quantization. Using the same rotation as K spreads outlier
|
||||
// energy evenly so the Lloyd-Max codebook applies correctly.
|
||||
// packed: [packedBytes*numKVHeads, capacity] i8 (destination buffer, view returned)
|
||||
// scales: [numKVHeads, capacity] f32 (scale per head per cell)
|
||||
// v: [headDim, numKVHeads, batchSize] f16 or f32
|
||||
// rotation: [headDim, headDim] f32 R^T row-major, or NULL
|
||||
// boundaries: [(1<<bits)-1] f32
|
||||
GGML_API struct ggml_tensor * ggml_tq_encode_v(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * packed,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * rotation,
|
||||
int32_t firstCell,
|
||||
struct ggml_tensor * boundaries,
|
||||
int32_t bits);
|
||||
|
||||
// Combined K+V encode: single GGML op encoding both K and V, halving scheduler overhead.
|
||||
// k_packed: [packedBytes*numKVHeads, capacity] i8 (K destination, view returned)
|
||||
// v_packed: [packedBytes*numKVHeads, capacity] i8 (V destination, side-effect write)
|
||||
GGML_API struct ggml_tensor * ggml_tq_encode_kv(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k_packed,
|
||||
struct ggml_tensor * k_scales,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * rotation,
|
||||
struct ggml_tensor * k_boundaries,
|
||||
struct ggml_tensor * v_packed,
|
||||
struct ggml_tensor * v_scales,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * v_boundaries,
|
||||
int32_t firstCell,
|
||||
int32_t k_bits,
|
||||
int32_t v_bits);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
7
ml/backend/ggml/ggml/src/ggml-backend.cpp
vendored
7
ml/backend/ggml/ggml/src/ggml-backend.cpp
vendored
|
|
@ -692,7 +692,12 @@ static bool ggml_is_view_op(enum ggml_op op) {
|
|||
#endif
|
||||
|
||||
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
|
||||
#define GGML_SCHED_MAX_SPLIT_INPUTS 30
|
||||
// Increased from 30 to 128 to support TurboQuant K+V compression on large MoE
|
||||
// models (e.g. qwen3-coder:30b/48-layer) where per-layer TQ encode ops and
|
||||
// MoE expert routing create more cross-backend split inputs than the original
|
||||
// limit allows. Upstream GGML has a FIXME here: the check only fires when the
|
||||
// split is exactly full, so multi-input ops can overshoot the limit.
|
||||
#define GGML_SCHED_MAX_SPLIT_INPUTS 128
|
||||
#endif
|
||||
|
||||
#ifndef GGML_SCHED_MAX_COPIES
|
||||
|
|
|
|||
23
ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
vendored
23
ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
vendored
|
|
@ -2080,6 +2080,20 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
// nop
|
||||
} break;
|
||||
case GGML_OP_TQ_ENCODE:
|
||||
case GGML_OP_TQ_DEQUANT:
|
||||
case GGML_OP_TQ_DEQUANT_KV:
|
||||
case GGML_OP_TQ_FLASH_ATTN_EXT:
|
||||
case GGML_OP_TQ_ENCODE_V:
|
||||
case GGML_OP_TQ_ENCODE_KV:
|
||||
{
|
||||
// CUDA-only ops. If these reach CPU, it means the scheduler
|
||||
// incorrectly assigned them to CPU — abort to diagnose.
|
||||
fprintf(stderr, "[FATAL] TQ op %s reached CPU backend! This is a scheduler bug.\n",
|
||||
ggml_op_name(tensor->op));
|
||||
fflush(stderr);
|
||||
GGML_ABORT("TQ op reached CPU backend");
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -2401,6 +2415,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
case GGML_OP_TQ_ENCODE:
|
||||
case GGML_OP_TQ_DEQUANT:
|
||||
case GGML_OP_TQ_DEQUANT_KV:
|
||||
case GGML_OP_TQ_FLASH_ATTN_EXT:
|
||||
case GGML_OP_TQ_ENCODE_V:
|
||||
case GGML_OP_TQ_ENCODE_KV:
|
||||
{
|
||||
n_tasks = 1; // CUDA-only; handled as no-op in compute_forward
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
|
|||
|
|
@ -426,7 +426,16 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
|||
}
|
||||
}
|
||||
|
||||
// TQ ops are CUDA-only; CPU has no implementation.
|
||||
// Returning false here prevents the scheduler from routing them to CPU in pass 3.
|
||||
switch (op->op) {
|
||||
case GGML_OP_TQ_ENCODE:
|
||||
case GGML_OP_TQ_DEQUANT:
|
||||
case GGML_OP_TQ_DEQUANT_KV:
|
||||
case GGML_OP_TQ_FLASH_ATTN_EXT:
|
||||
case GGML_OP_TQ_ENCODE_V:
|
||||
case GGML_OP_TQ_ENCODE_KV:
|
||||
return false;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
return
|
||||
|
|
|
|||
28
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
28
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
|
|
@ -58,6 +58,10 @@
|
|||
#include "ggml-cuda/tri.cuh"
|
||||
#include "ggml-cuda/cumsum.cuh"
|
||||
#include "ggml-cuda/fill.cuh"
|
||||
#include "ggml-cuda/tq-encode.cuh"
|
||||
#include "ggml-cuda/tq-dequant.cuh"
|
||||
#include "ggml-cuda/tq-fattn.cuh"
|
||||
#include "ggml-cuda/tq-encode-v.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -2872,6 +2876,24 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_FILL:
|
||||
ggml_cuda_op_fill(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TQ_ENCODE:
|
||||
ggml_cuda_tq_encode(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TQ_DEQUANT:
|
||||
ggml_cuda_tq_dequant(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TQ_DEQUANT_KV:
|
||||
ggml_cuda_tq_dequant_kv(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TQ_FLASH_ATTN_EXT:
|
||||
ggml_cuda_tq_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TQ_ENCODE_V:
|
||||
ggml_cuda_tq_encode_v(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TQ_ENCODE_KV:
|
||||
ggml_cuda_tq_encode_kv(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -4910,6 +4932,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_TQ_ENCODE:
|
||||
case GGML_OP_TQ_DEQUANT:
|
||||
case GGML_OP_TQ_DEQUANT_KV:
|
||||
case GGML_OP_TQ_FLASH_ATTN_EXT:
|
||||
case GGML_OP_TQ_ENCODE_V:
|
||||
case GGML_OP_TQ_ENCODE_KV:
|
||||
return true;
|
||||
|
||||
default:
|
||||
|
|
|
|||
394
ml/backend/ggml/ggml/src/ggml-cuda/tq-dequant.cu
vendored
Normal file
394
ml/backend/ggml/ggml/src/ggml-cuda/tq-dequant.cu
vendored
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
#include "tq-dequant.cuh"
|
||||
|
||||
// Optimized TQ dequant kernel: warp-shuffle codebook + hardcoded bit extraction.
|
||||
//
|
||||
// Grid: (nCells, numKVHeads). Block: 128 threads.
|
||||
// For D=128 and 128 threads: each thread decodes exactly 1 element.
|
||||
// Codebook lookup via __shfl_sync eliminates global memory reads.
|
||||
// Output is written as f16.
|
||||
//
|
||||
// This kernel is the "separate dequant" path — paired with the stock f16
|
||||
// flash attention kernel, it avoids injecting decode ALU into the
|
||||
// bandwidth-bound FA loop.
|
||||
|
||||
#if __CUDA_ARCH__ >= 600 || !defined(__CUDA_ARCH__)
|
||||
__global__ void tq_dequant_multihead_kernel(
|
||||
const uint8_t *packed, // [(firstCell+c)*numKVHeads+h]*packed_bytes
|
||||
const float *scales, // [(firstCell+c)*numKVHeads+h]
|
||||
const float *codebook, // [codebook_len]
|
||||
uint16_t *output, // [nCells * numKVHeads * headDim] f16
|
||||
int headDim,
|
||||
int numKVHeads,
|
||||
int bits,
|
||||
int packed_bytes,
|
||||
int codebook_len,
|
||||
int firstCell
|
||||
) {
|
||||
int c = blockIdx.x; // cell index within [0, nCells)
|
||||
int h = blockIdx.y; // head index within [0, numKVHeads)
|
||||
int cell = firstCell + c;
|
||||
|
||||
int slot = cell * numKVHeads + h;
|
||||
float scale = scales[slot];
|
||||
const uint8_t *cell_packed = packed + (size_t)slot * packed_bytes;
|
||||
__half *cell_out = (__half *)(output + ((size_t)c * numKVHeads + h) * headDim);
|
||||
|
||||
// Load one codebook entry per lane for warp-shuffle lookup.
|
||||
// For 3-bit (8 entries): lanes 0-7 hold codebook[0-7], repeated every 8 lanes.
|
||||
// For 2-bit (4 entries): lanes 0-3 hold codebook[0-3], repeated.
|
||||
const int cb_mask = (1 << bits) - 1;
|
||||
const float cb_lane = codebook[threadIdx.x & cb_mask];
|
||||
|
||||
for (int elem = threadIdx.x; elem < headDim; elem += blockDim.x) {
|
||||
// Generic bit extraction (handles any alignment).
|
||||
const int bit_offset = elem * bits;
|
||||
const int byte_idx = bit_offset >> 3;
|
||||
const int shift = bit_offset & 7;
|
||||
|
||||
int idx = (cell_packed[byte_idx] >> shift) & cb_mask;
|
||||
if (shift + bits > 8) {
|
||||
idx |= (cell_packed[byte_idx + 1] << (8 - shift)) & cb_mask;
|
||||
}
|
||||
|
||||
// Codebook lookup via warp shuffle: zero global memory latency.
|
||||
// Width = 32 works because cb_lane is periodic with period (1<<bits).
|
||||
// Pass width explicitly — the HIP __shfl_sync shim is a 4-arg macro
|
||||
// that doesn't default, and CUDA's 3-arg overload is width=warpSize=32
|
||||
// anyway, so this is behavior-neutral on NVIDIA.
|
||||
float val = __shfl_sync(0xFFFFFFFF, cb_lane, idx, 32) * scale;
|
||||
cell_out[elem] = __float2half_rn(val);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Stub for sm < 600 (no __shfl_sync). Never executed — the launcher
|
||||
// asserts compute capability >= 6.0. Only exists so the kernel launch
|
||||
// in ggml_cuda_tq_dequant compiles for Maxwell targets.
|
||||
__global__ void tq_dequant_multihead_kernel(
|
||||
const uint8_t *, const float *, const float *, uint16_t *,
|
||||
int, int, int, int, int, int) {}
|
||||
#endif
|
||||
|
||||
// ── outlier-split dequant ──────────────────────────────────────────────────
|
||||
// Paired with tq_encode_kernel_outlier. Reads the regular packed sub-block and
|
||||
// the outlier packed sub-block from two separate per-layer tensors, decodes
|
||||
// each with its own codebook/scale, and writes a single [headDim] f16 K
|
||||
// vector per (cell, head) to the output tensor.
|
||||
//
|
||||
// The outlier_indices tensor holds the top-K channel positions the encoder
|
||||
// selected. For each output position elem, the kernel scans outlier_indices
|
||||
// to determine whether elem is an outlier (and if so, which outlier slot k)
|
||||
// or a regular channel (and if so, its contiguous regular slot r = elem minus
|
||||
// the number of outlier indices less than elem). This scan is O(outlier_count)
|
||||
// per thread, which is cheap: at outlier_count=32 it's 32 compares per thread,
|
||||
// all in registers.
|
||||
#if __CUDA_ARCH__ >= 600 || !defined(__CUDA_ARCH__)
|
||||
__global__ void tq_dequant_multihead_kernel_outlier(
|
||||
const uint8_t *reg_packed,
|
||||
const float *reg_scales,
|
||||
const float *reg_codebook,
|
||||
const uint8_t *out_packed,
|
||||
const float *out_scales,
|
||||
const uint8_t *out_indices,
|
||||
const float *out_codebook,
|
||||
uint16_t *output,
|
||||
int headDim,
|
||||
int numKVHeads,
|
||||
int bits,
|
||||
int reg_packed_bytes,
|
||||
int outlier_bits,
|
||||
int outlier_count,
|
||||
int out_packed_bytes,
|
||||
int firstCell
|
||||
) {
|
||||
int c = blockIdx.x;
|
||||
int h = blockIdx.y;
|
||||
int cell = firstCell + c;
|
||||
int slot = cell * numKVHeads + h;
|
||||
|
||||
float regScale = reg_scales[slot];
|
||||
float outScale = out_scales[slot];
|
||||
const uint8_t *cell_reg = reg_packed + (size_t)slot * reg_packed_bytes;
|
||||
const uint8_t *cell_outl = out_packed + (size_t)slot * out_packed_bytes;
|
||||
const uint8_t *cell_idx = out_indices + (size_t)slot * outlier_count;
|
||||
__half *cell_out = (__half *)(output + ((size_t)c * numKVHeads + h) * headDim);
|
||||
|
||||
// Shared memory layout:
|
||||
// s_outl_slot[headDim] int8_t — outlier slot k or -1 if not outlier
|
||||
// s_mask[headDim / 32 (min 4)] uint32 — outlier bitmap (bit i = channel i outlier)
|
||||
//
|
||||
// Per-element `regular_slot` is computed at decode time via popcount over the
|
||||
// mask bits below `elem` — O(1) for headDim <= 256 (up to 8 popc ops). This
|
||||
// replaces the per-element O(outlier_count) classification scan.
|
||||
// Setup is fully parallel: no serial prefix sum, no idle threads.
|
||||
const int mask_words = (headDim + 31) >> 5; // e.g. 4 for headDim=128, 8 for 256
|
||||
extern __shared__ char s_mem_dq[];
|
||||
int8_t *s_outl_slot = (int8_t *)s_mem_dq;
|
||||
uint32_t *s_mask = (uint32_t *)(s_outl_slot + headDim);
|
||||
|
||||
// Step A: init s_outl_slot to -1 and s_mask to 0 (parallel over all threads).
|
||||
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
|
||||
s_outl_slot[i] = -1;
|
||||
}
|
||||
for (int w = threadIdx.x; w < mask_words; w += blockDim.x) {
|
||||
s_mask[w] = 0u;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step B: threads 0..outlier_count-1 each register one outlier — write its
|
||||
// slot into s_outl_slot AND set its bit in s_mask. atomicOr needed because
|
||||
// two outliers can land in the same word.
|
||||
if ((int)threadIdx.x < outlier_count) {
|
||||
int pos = (int)cell_idx[threadIdx.x];
|
||||
s_outl_slot[pos] = (int8_t)threadIdx.x;
|
||||
atomicOr(&s_mask[pos >> 5], 1u << (pos & 31));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Warp-shuffle register-resident codebooks.
|
||||
const int cb_mask = (1 << bits) - 1;
|
||||
const int ocb_mask = (1 << outlier_bits) - 1;
|
||||
const float cb_lane_reg = reg_codebook[threadIdx.x & cb_mask];
|
||||
const float cb_lane_out = out_codebook[threadIdx.x & ocb_mask];
|
||||
|
||||
for (int elem = threadIdx.x; elem < headDim; elem += blockDim.x) {
|
||||
int outlier_slot = (int)s_outl_slot[elem];
|
||||
|
||||
// regular_slot = elem - popcount(outlier_mask bits below elem).
|
||||
// Sum popcount of fully-covered 32-bit chunks, then partial chunk.
|
||||
int outliers_below = 0;
|
||||
const int full_words = elem >> 5;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < 8; w++) { // hard-unrolled; loop body is a no-op past mask_words
|
||||
if (w < full_words && w < mask_words) {
|
||||
outliers_below += __popc(s_mask[w]);
|
||||
}
|
||||
}
|
||||
if (full_words < mask_words) {
|
||||
uint32_t partial_bits = (1u << (elem & 31)) - 1u;
|
||||
outliers_below += __popc(s_mask[full_words] & partial_bits);
|
||||
}
|
||||
int regular_slot = elem - outliers_below;
|
||||
|
||||
// Both sub-block decodes must execute unconditionally because
|
||||
// __shfl_sync with mask 0xFFFFFFFF requires every lane in the warp
|
||||
// to be at the same instruction. Putting a shuffle inside a divergent
|
||||
// if-branch is undefined behavior (observed as all-zero decodes for
|
||||
// some (cell, head) blocks on multi-head models). Compute both values
|
||||
// up front, then select.
|
||||
int reg_bit_offset = regular_slot * bits;
|
||||
int reg_byte_idx = reg_bit_offset >> 3;
|
||||
int reg_shift = reg_bit_offset & 7;
|
||||
int reg_idx = (cell_reg[reg_byte_idx] >> reg_shift) & cb_mask;
|
||||
if (reg_shift + bits > 8) {
|
||||
reg_idx |= (cell_reg[reg_byte_idx + 1] << (8 - reg_shift)) & cb_mask;
|
||||
}
|
||||
float reg_val = __shfl_sync(0xFFFFFFFF, cb_lane_reg, reg_idx, 32) * regScale;
|
||||
|
||||
int out_slot_safe = (outlier_slot >= 0) ? outlier_slot : 0;
|
||||
int out_bit_offset = out_slot_safe * outlier_bits;
|
||||
int out_byte_idx = out_bit_offset >> 3;
|
||||
int out_shift = out_bit_offset & 7;
|
||||
int out_idx = (cell_outl[out_byte_idx] >> out_shift) & ocb_mask;
|
||||
if (out_shift + outlier_bits > 8) {
|
||||
out_idx |= (cell_outl[out_byte_idx + 1] << (8 - out_shift)) & ocb_mask;
|
||||
}
|
||||
float out_val = __shfl_sync(0xFFFFFFFF, cb_lane_out, out_idx, 32) * outScale;
|
||||
|
||||
float val = (outlier_slot >= 0) ? out_val : reg_val;
|
||||
cell_out[elem] = __float2half_rn(val);
|
||||
}
|
||||
}
|
||||
#else
|
||||
__global__ void tq_dequant_multihead_kernel_outlier(
|
||||
const uint8_t *, const float *, const float *,
|
||||
const uint8_t *, const float *, const uint8_t *, const float *,
|
||||
uint16_t *, int, int, int, int, int, int, int, int) {}
|
||||
#endif
|
||||
|
||||
void ggml_cuda_tq_dequant(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_cuda_info().devices[ctx.device].cc >= 600 &&
|
||||
"TurboQuant dequant requires compute capability 6.0+ (Pascal or newer)");
|
||||
const struct ggml_tensor * encode_result = dst->src[0]; // view of packed
|
||||
const struct ggml_tensor * scales = dst->src[1];
|
||||
const struct ggml_tensor * codebook = dst->src[2];
|
||||
|
||||
const int headDim = (int)dst->ne[0];
|
||||
const int numKVHeads = (int)dst->ne[1];
|
||||
const int nCells = (int)dst->ne[2];
|
||||
const int bits = (int)((const int32_t *)dst->op_params)[0];
|
||||
const int firstCell = (int)((const int32_t *)dst->op_params)[1];
|
||||
const int outlierBits = (int)((const int32_t *)dst->op_params)[2];
|
||||
const int outlierCount = (int)((const int32_t *)dst->op_params)[3];
|
||||
|
||||
dim3 grid(nCells, numKVHeads);
|
||||
int block_size = 128;
|
||||
if (headDim < block_size) block_size = headDim;
|
||||
|
||||
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
|
||||
// Outlier-split dequant: read both sub-blocks.
|
||||
const struct ggml_tensor * outlier_packed = dst->src[3];
|
||||
const struct ggml_tensor * outlier_scales = dst->src[4];
|
||||
const struct ggml_tensor * outlier_indices = dst->src[5];
|
||||
const struct ggml_tensor * outlier_codebook = dst->src[6];
|
||||
|
||||
const int regular_count = headDim - outlierCount;
|
||||
// Per-head stride for both packed tensors is padded up to a 4-byte
|
||||
// multiple so atomicOr-on-word stays aligned in the encode kernel.
|
||||
// The Go-side ggmlTQCompressedK.regularPackedBytes() applies the
|
||||
// same padding, so the kernel-visible layout matches the allocator.
|
||||
const int reg_packed_raw = (regular_count * bits + 7) / 8;
|
||||
const int reg_packed_bytes = (reg_packed_raw + 3) & ~3;
|
||||
const int out_packed_raw = (outlierCount * outlierBits + 7) / 8;
|
||||
const int out_packed_bytes = (out_packed_raw + 3) & ~3;
|
||||
|
||||
// Shared memory: s_outl_slot (headDim * int8) + s_mask (ceil(headDim/32) * u32).
|
||||
const int mask_words = (headDim + 31) >> 5;
|
||||
size_t smem = (size_t)headDim * sizeof(int8_t)
|
||||
+ (size_t)mask_words * sizeof(uint32_t);
|
||||
|
||||
tq_dequant_multihead_kernel_outlier<<<grid, block_size, smem, ctx.stream()>>>(
|
||||
(const uint8_t *)encode_result->data,
|
||||
(const float *)scales->data,
|
||||
(const float *)codebook->data,
|
||||
(const uint8_t *)outlier_packed->data,
|
||||
(const float *)outlier_scales->data,
|
||||
(const uint8_t *)outlier_indices->data,
|
||||
(const float *)outlier_codebook->data,
|
||||
(uint16_t *)dst->data,
|
||||
headDim, numKVHeads, bits, reg_packed_bytes,
|
||||
outlierBits, outlierCount, out_packed_bytes, firstCell
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const int packed_bytes = (headDim * bits + 7) / 8;
|
||||
const int codebook_len = (int)codebook->ne[0];
|
||||
|
||||
tq_dequant_multihead_kernel<<<grid, block_size, 0, ctx.stream()>>>(
|
||||
(const uint8_t *)encode_result->data,
|
||||
(const float *)scales->data,
|
||||
(const float *)codebook->data,
|
||||
(uint16_t *)dst->data,
|
||||
headDim, numKVHeads, bits, packed_bytes, codebook_len, firstCell
|
||||
);
|
||||
}
|
||||
|
||||
// V dequant with fused rotation undo: decode packed V to shared memory
|
||||
// (rotated domain), then multiply by R [headDim × headDim] to produce
|
||||
// unrotated f16 output. Eliminates the per-layer mulmat op in SDPA.
|
||||
//
|
||||
// Grid: (nCells, numKVHeads). Block: 128 (= headDim for D=128).
|
||||
// Shared memory: headDim floats = 512 bytes.
|
||||
#if __CUDA_ARCH__ >= 600 || !defined(__CUDA_ARCH__)
|
||||
__global__ void tq_dequant_v_rotated_kernel(
|
||||
const uint8_t * __restrict__ packed,
|
||||
const float * __restrict__ scales,
|
||||
const float * __restrict__ codebook,
|
||||
const float * __restrict__ rotation, // R [headDim, headDim] row-major
|
||||
uint16_t * __restrict__ output,
|
||||
int headDim, int numKVHeads, int bits, int packed_bytes,
|
||||
int codebook_len, int firstCell)
|
||||
{
|
||||
extern __shared__ float s_rotV[]; // headDim floats
|
||||
|
||||
int c = blockIdx.x;
|
||||
int h = blockIdx.y;
|
||||
int cell = firstCell + c;
|
||||
int slot = cell * numKVHeads + h;
|
||||
float scale = scales[slot];
|
||||
const uint8_t *cell_packed = packed + (size_t)slot * packed_bytes;
|
||||
|
||||
// Phase 1: decode one element per thread into shared memory (rotated domain).
|
||||
const int cb_mask = (1 << bits) - 1;
|
||||
const float cb_lane = codebook[threadIdx.x & cb_mask];
|
||||
|
||||
int elem = threadIdx.x;
|
||||
int bit_offset = elem * bits;
|
||||
int byte_idx = bit_offset >> 3;
|
||||
int shift = bit_offset & 7;
|
||||
int idx = (cell_packed[byte_idx] >> shift) & cb_mask;
|
||||
if (shift + bits > 8) {
|
||||
idx |= (cell_packed[byte_idx + 1] << (8 - shift)) & cb_mask;
|
||||
}
|
||||
s_rotV[elem] = __shfl_sync(0xFFFFFFFF, cb_lane, idx, 32) * scale;
|
||||
__syncthreads();
|
||||
|
||||
// Phase 2: each thread computes one output element = dot(R[elem,:], s_rotV).
|
||||
// R is in L2 (64KB, fits in P40's 3MB L2; read-only, broadcast across blocks).
|
||||
const float *R_row = rotation + elem * headDim;
|
||||
float sum = 0.0f;
|
||||
for (int j = 0; j < headDim; j++) {
|
||||
sum += R_row[j] * s_rotV[j];
|
||||
}
|
||||
|
||||
__half *cell_out = (__half *)(output + ((size_t)c * numKVHeads + h) * headDim);
|
||||
cell_out[elem] = __float2half_rn(sum);
|
||||
}
|
||||
#else
|
||||
// Stub for sm < 600 (no __shfl_sync). Not currently launched, but kept
|
||||
// so future call sites compile for Maxwell targets without special casing.
|
||||
__global__ void tq_dequant_v_rotated_kernel(
|
||||
const uint8_t * __restrict__, const float * __restrict__,
|
||||
const float * __restrict__, const float * __restrict__,
|
||||
uint16_t * __restrict__,
|
||||
int, int, int, int, int, int) {}
|
||||
#endif
|
||||
|
||||
// Combined K+V dequant: two back-to-back kernel launches in a single GGML op.
|
||||
// Output: [headDim, numKVHeads, nCells, 2] f16 — K at ne[3]=0, V at ne[3]=1.
|
||||
// When src[6] (v_rotation) is non-NULL, V is dequanted with rotation fused.
|
||||
void ggml_cuda_tq_dequant_kv(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_cuda_info().devices[ctx.device].cc >= 600 &&
|
||||
"TurboQuant dequant requires compute capability 6.0+ (Pascal or newer)");
|
||||
const struct ggml_tensor * k_encode = dst->src[0];
|
||||
const struct ggml_tensor * k_scales = dst->src[1];
|
||||
const struct ggml_tensor * k_cb = dst->src[2];
|
||||
const struct ggml_tensor * v_encode = dst->src[3];
|
||||
const struct ggml_tensor * v_scales = dst->src[4];
|
||||
const struct ggml_tensor * v_cb = dst->src[5];
|
||||
const struct ggml_tensor * v_rotation = dst->src[6]; // NULL = no rotation fusion
|
||||
|
||||
const int headDim = (int)dst->ne[0];
|
||||
const int numKVHeads = (int)dst->ne[1];
|
||||
const int nCells = (int)dst->ne[2];
|
||||
|
||||
int32_t k_bits, v_bits, firstCell;
|
||||
memcpy(&k_bits, (const int32_t *)dst->op_params + 0, sizeof(int32_t));
|
||||
memcpy(&v_bits, (const int32_t *)dst->op_params + 1, sizeof(int32_t));
|
||||
memcpy(&firstCell, (const int32_t *)dst->op_params + 2, sizeof(int32_t));
|
||||
|
||||
const int k_packed_bytes = (headDim * k_bits + 7) / 8;
|
||||
const int v_packed_bytes = (headDim * v_bits + 7) / 8;
|
||||
|
||||
dim3 grid(nCells, numKVHeads);
|
||||
int block_size = 128;
|
||||
if (headDim < block_size) {
|
||||
block_size = headDim;
|
||||
}
|
||||
|
||||
const size_t plane_size = (size_t)headDim * numKVHeads * nCells;
|
||||
uint16_t * out_base = (uint16_t *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// K dequant → first plane (offset 0) — always unrotated
|
||||
tq_dequant_multihead_kernel<<<grid, block_size, 0, stream>>>(
|
||||
(const uint8_t *)k_encode->data,
|
||||
(const float *)k_scales->data,
|
||||
(const float *)k_cb->data,
|
||||
out_base,
|
||||
headDim, numKVHeads, k_bits, k_packed_bytes, (int)k_cb->ne[0], firstCell
|
||||
);
|
||||
|
||||
// V dequant → second plane (offset plane_size). Plain dequant only — the
|
||||
// rotation undo (R @ attn_out) is handled by SDPA via mulmat, which is
|
||||
// dramatically faster than the per-cell matmul the fused kernel did.
|
||||
(void)v_rotation;
|
||||
tq_dequant_multihead_kernel<<<grid, block_size, 0, stream>>>(
|
||||
(const uint8_t *)v_encode->data,
|
||||
(const float *)v_scales->data,
|
||||
(const float *)v_cb->data,
|
||||
out_base + plane_size,
|
||||
headDim, numKVHeads, v_bits, v_packed_bytes, (int)v_cb->ne[0], firstCell
|
||||
);
|
||||
}
|
||||
6
ml/backend/ggml/ggml/src/ggml-cuda/tq-dequant.cuh
vendored
Normal file
6
ml/backend/ggml/ggml/src/ggml-cuda/tq-dequant.cuh
vendored
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_tq_dequant(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst);
|
||||
void ggml_cuda_tq_dequant_kv(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst);
|
||||
154
ml/backend/ggml/ggml/src/ggml-cuda/tq-encode-v.cu
vendored
Normal file
154
ml/backend/ggml/ggml/src/ggml-cuda/tq-encode-v.cu
vendored
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
#include "tq-encode-v.cuh"
|
||||
#include <math.h>
|
||||
|
||||
#define TQ_ENCODE_V_BLOCK_SIZE 128
|
||||
|
||||
__global__ void tq_encode_v_kernel(
|
||||
const void *v,
|
||||
const float *rotation, // [headDim, headDim] R^T row-major, or NULL for no rotation
|
||||
uint8_t *packed_out,
|
||||
float *scales_out,
|
||||
int firstCell,
|
||||
const float *boundaries,
|
||||
int headDim,
|
||||
int numKVHeads,
|
||||
int bits,
|
||||
int numBoundaries,
|
||||
int vIsF32
|
||||
) {
|
||||
int batch = blockIdx.x;
|
||||
int head = blockIdx.y;
|
||||
int cell = firstCell + batch;
|
||||
|
||||
extern __shared__ char s_mem[];
|
||||
float *s_v = (float *)s_mem;
|
||||
float *s_rot = s_v + headDim;
|
||||
float *s_reduce = s_rot + headDim;
|
||||
uint8_t *s_idx = (uint8_t *)(s_reduce + blockDim.x);
|
||||
|
||||
// Step 1: Load V[batch, head] into shared memory as f32
|
||||
int base_v = batch * numKVHeads * headDim + head * headDim;
|
||||
for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
|
||||
if (vIsF32) {
|
||||
s_v[d] = ((const float *)v)[base_v + d];
|
||||
} else {
|
||||
s_v[d] = __half2float(__ushort_as_half(((const uint16_t *)v)[base_v + d]));
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 2: Apply Hadamard rotation if provided (R^T @ v spreads outlier energy evenly)
|
||||
float *s_input = s_v;
|
||||
if (rotation != NULL) {
|
||||
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
|
||||
float sum = 0.0f;
|
||||
for (int j = 0; j < headDim; j++) {
|
||||
sum += rotation[i * headDim + j] * s_v[j];
|
||||
}
|
||||
s_rot[i] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
s_input = s_rot;
|
||||
}
|
||||
|
||||
// Step 3: RMS scale = sqrt(mean(v^2))
|
||||
float local_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
|
||||
local_sq += s_input[i] * s_input[i];
|
||||
}
|
||||
s_reduce[threadIdx.x] = local_sq;
|
||||
__syncthreads();
|
||||
|
||||
for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride)
|
||||
s_reduce[threadIdx.x] += s_reduce[threadIdx.x + stride];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float scale = 0.0f;
|
||||
if (threadIdx.x == 0) {
|
||||
float sum_sq = s_reduce[0];
|
||||
if (sum_sq > 1e-12f)
|
||||
scale = sqrtf(sum_sq / (float)headDim);
|
||||
scales_out[cell * numKVHeads + head] = scale;
|
||||
s_reduce[0] = scale;
|
||||
}
|
||||
__syncthreads();
|
||||
scale = s_reduce[0];
|
||||
|
||||
// Step 4: Quantize via boundary binary search
|
||||
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
|
||||
float val = (scale > 0.0f) ? (s_input[i] / scale) : 0.0f;
|
||||
int idx = 0;
|
||||
for (int b = 0; b < numBoundaries; b++) {
|
||||
if (val >= boundaries[b]) idx++;
|
||||
}
|
||||
s_idx[i] = (uint8_t)idx;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 5: Pack bits LSB-first
|
||||
{
|
||||
int packed_bytes = (headDim * bits + 7) / 8;
|
||||
uint8_t *out = packed_out + (cell * numKVHeads + head) * packed_bytes;
|
||||
uint8_t bitmask = (uint8_t)((1 << bits) - 1);
|
||||
|
||||
for (int p = threadIdx.x; p < packed_bytes; p += blockDim.x) {
|
||||
out[p] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int elem = threadIdx.x; elem < headDim; elem += blockDim.x) {
|
||||
int bit_offset = elem * bits;
|
||||
int byte_idx = bit_offset >> 3;
|
||||
int shift = bit_offset & 7;
|
||||
uint8_t val = s_idx[elem] & bitmask;
|
||||
atomicOr((unsigned int *)(out + (byte_idx & ~3)),
|
||||
(unsigned int)(val << shift) << ((byte_idx & 3) * 8));
|
||||
if (shift + bits > 8) {
|
||||
int byte_idx2 = byte_idx + 1;
|
||||
atomicOr((unsigned int *)(out + (byte_idx2 & ~3)),
|
||||
(unsigned int)(val >> (8 - shift)) << ((byte_idx2 & 3) * 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_tq_encode_v(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * v = dst->src[0];
|
||||
const struct ggml_tensor * rotation = dst->src[1]; // NULL when no rotation
|
||||
const struct ggml_tensor * scales = dst->src[3];
|
||||
const struct ggml_tensor * boundaries = dst->src[4];
|
||||
|
||||
const int headDim = (int)v->ne[0];
|
||||
const int numKVHeads = (int)v->ne[1];
|
||||
const int batchSize = (int)v->ne[2];
|
||||
const int bits = (int)((const int32_t *)dst->op_params)[0];
|
||||
const int firstCell = (int)((const int32_t *)dst->op_params)[1];
|
||||
const int numBoundaries = (1 << bits) - 1;
|
||||
const int vIsF32 = (v->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
|
||||
const float * rotation_ptr = rotation ? (const float *)rotation->data : nullptr;
|
||||
|
||||
dim3 grid(batchSize, numKVHeads);
|
||||
int block_size = (headDim < TQ_ENCODE_V_BLOCK_SIZE) ? headDim : TQ_ENCODE_V_BLOCK_SIZE;
|
||||
int bs = 1;
|
||||
while (bs < block_size) bs <<= 1;
|
||||
block_size = bs;
|
||||
|
||||
size_t smem = (size_t)headDim * 2 * sizeof(float) // s_v + s_rot
|
||||
+ (size_t)block_size * sizeof(float)
|
||||
+ (size_t)headDim * sizeof(uint8_t);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
tq_encode_v_kernel<<<grid, block_size, smem, stream>>>(
|
||||
v->data,
|
||||
rotation_ptr,
|
||||
(uint8_t *)dst->data,
|
||||
(float *)scales->data,
|
||||
firstCell,
|
||||
(const float *)boundaries->data,
|
||||
headDim, numKVHeads, bits, numBoundaries, vIsF32
|
||||
);
|
||||
}
|
||||
5
ml/backend/ggml/ggml/src/ggml-cuda/tq-encode-v.cuh
vendored
Normal file
5
ml/backend/ggml/ggml/src/ggml-cuda/tq-encode-v.cuh
vendored
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_tq_encode_v(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst);
|
||||
531
ml/backend/ggml/ggml/src/ggml-cuda/tq-encode.cu
vendored
Normal file
531
ml/backend/ggml/ggml/src/ggml-cuda/tq-encode.cu
vendored
Normal file
|
|
@ -0,0 +1,531 @@
|
|||
#include "tq-encode.cuh"
|
||||
#include <math.h>
|
||||
|
||||
#define TQ_ENCODE_BLOCK_SIZE 128
|
||||
|
||||
// ── kernel ──────────────────────────────────────────────────────────────────
|
||||
|
||||
__global__ void tq_encode_kernel(
|
||||
const void *k, // f16 or f32, ggml layout [headDim, numKVHeads, batchSize]
|
||||
const float *rotation, // [headDim, headDim] R^T stored row-major
|
||||
uint8_t *packed_out, // [(c*numKVHeads+h)*packedBytes] interleaved
|
||||
float *scales_out, // [c*numKVHeads+h] interleaved
|
||||
int firstCell, // first cache cell index (cell = firstCell + batch)
|
||||
const float *boundaries, // [numLevels-1]
|
||||
int headDim,
|
||||
int numKVHeads,
|
||||
int bits,
|
||||
int numBoundaries, // = (1<<bits) - 1
|
||||
int kIsF32 // non-zero when k is float32 (vs float16)
|
||||
) {
|
||||
int batch = blockIdx.x;
|
||||
int head = blockIdx.y;
|
||||
int cell = firstCell + batch;
|
||||
|
||||
// Shared memory layout:
|
||||
// s_k[headDim] – K values as f32
|
||||
// s_rot[headDim] – rotated values
|
||||
// s_reduce[BLOCK] – warp reduction scratch
|
||||
// s_idx[headDim] – quantized indices (uint8)
|
||||
extern __shared__ char s_mem[];
|
||||
float *s_k = (float *)s_mem;
|
||||
float *s_rot = s_k + headDim;
|
||||
float *s_reduce = s_rot + headDim;
|
||||
uint8_t *s_idx = (uint8_t *)(s_reduce + blockDim.x);
|
||||
|
||||
// ── Step 1: Load K[batch, head] into shared memory as f32 ────────────────
|
||||
// K layout: element (dim=d, head=h, batch=b) at b*numKVHeads*headDim + h*headDim + d
|
||||
int base_k = batch * numKVHeads * headDim + head * headDim;
|
||||
for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
|
||||
if (kIsF32) {
|
||||
s_k[d] = ((const float *)k)[base_k + d];
|
||||
} else {
|
||||
s_k[d] = __half2float(__ushort_as_half(((const uint16_t *)k)[base_k + d]));
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ── Step 2: Rotation matmul: rotated[i] = Σ_j rotation[i*headDim+j] * s_k[j] ──
|
||||
// rotation stores Q^T row-major (rotTensor[i][j] = Q^T[i][j]).
|
||||
// This computes rotated = Q^T @ k.
|
||||
// Q is also rotated as Q^T @ q (via ggml_mul_mat(rotTensor, q)),
|
||||
// so attention = (Q^T q)^T (Q^T k) = q^T Q Q^T k = q^T k.
|
||||
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
|
||||
float sum = 0.0f;
|
||||
for (int j = 0; j < headDim; j++) {
|
||||
sum += rotation[i * headDim + j] * s_k[j];
|
||||
}
|
||||
s_rot[i] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ── Step 3: RMS scale = sqrt(mean(rotated²)) ─────────────────────────────
|
||||
float local_sq = 0.0f;
|
||||
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
|
||||
local_sq += s_rot[i] * s_rot[i];
|
||||
}
|
||||
s_reduce[threadIdx.x] = local_sq;
|
||||
__syncthreads();
|
||||
|
||||
for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride)
|
||||
s_reduce[threadIdx.x] += s_reduce[threadIdx.x + stride];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float scale = 0.0f;
|
||||
if (threadIdx.x == 0) {
|
||||
float sum_sq = s_reduce[0];
|
||||
if (sum_sq > 1e-12f)
|
||||
scale = sqrtf(sum_sq / (float)headDim);
|
||||
scales_out[cell * numKVHeads + head] = scale;
|
||||
s_reduce[0] = scale; // broadcast via shared mem
|
||||
}
|
||||
__syncthreads();
|
||||
scale = s_reduce[0];
|
||||
|
||||
// ── Step 4: Quantize each element via boundary binary search ─────────────
|
||||
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
|
||||
float v = (scale > 0.0f) ? (s_rot[i] / scale) : 0.0f;
|
||||
int idx = 0;
|
||||
for (int b = 0; b < numBoundaries; b++) {
|
||||
if (v >= boundaries[b]) idx++;
|
||||
}
|
||||
s_idx[i] = (uint8_t)idx;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ── Step 5: Pack bits LSB-first into output (parallel via atomicOr) ─────
|
||||
{
|
||||
int packed_bytes = (headDim * bits + 7) / 8;
|
||||
uint8_t *out = packed_out + (cell * numKVHeads + head) * packed_bytes;
|
||||
uint8_t bitmask = (uint8_t)((1 << bits) - 1);
|
||||
|
||||
// Zero output buffer in parallel.
|
||||
for (int p = threadIdx.x; p < packed_bytes; p += blockDim.x) {
|
||||
out[p] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Each thread packs its own elements using atomicOr on 4-byte aligned words.
|
||||
for (int elem = threadIdx.x; elem < headDim; elem += blockDim.x) {
|
||||
int bit_offset = elem * bits;
|
||||
int byte_idx = bit_offset >> 3;
|
||||
int shift = bit_offset & 7;
|
||||
uint8_t v = s_idx[elem] & bitmask;
|
||||
// Pack into 4-byte aligned word: position byte within 32-bit word.
|
||||
atomicOr((unsigned int *)(out + (byte_idx & ~3)),
|
||||
(unsigned int)(v << shift) << ((byte_idx & 3) * 8));
|
||||
if (shift + bits > 8) {
|
||||
int byte_idx2 = byte_idx + 1;
|
||||
atomicOr((unsigned int *)(out + (byte_idx2 & ~3)),
|
||||
(unsigned int)(v >> (8 - shift)) << ((byte_idx2 & 3) * 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── outlier-split kernel ─────────────────────────────────────────────────────
|
||||
//
|
||||
// Implements the TurboQuant paper's actual experimental configuration
|
||||
// (arXiv 2504.19874 Sec 4.3): split channels into a top-K outlier set and a
|
||||
// regular set, each encoded with its own RMS scale and codebook at different
|
||||
// bit widths. Outlier selection happens in ROTATED space (single rotation
|
||||
// matmul shared between both sub-blocks), not in the original space like the
|
||||
// CPU reference — this is the only paper-realistic algorithm we can run
|
||||
// cheaply on the GPU without a second rotation pass. For the near-orthogonal
|
||||
// rotations produced by QR on a Gaussian matrix, the top-K in rotated space
|
||||
// is a close proxy for the top-K in original space.
|
||||
|
||||
__global__ void tq_encode_kernel_outlier(
|
||||
const void *k, // f16 or f32, [headDim, numKVHeads, batchSize]
|
||||
const float *rotation, // [headDim, headDim] R^T row-major
|
||||
uint8_t *packed_out, // regular packed [regularPackedBytes*numKVHeads*cells]
|
||||
float *scales_out, // regular scales [numKVHeads*cells]
|
||||
uint8_t *outlier_packed, // outlier packed [outlierPackedBytes*numKVHeads*cells]
|
||||
float *outlier_scales, // outlier scales [numKVHeads*cells]
|
||||
uint8_t *outlier_indices, // outlier channel idx [outlierCount*numKVHeads*cells] (interpreted as uint8 so positions 0..255 fit)
|
||||
int firstCell,
|
||||
const float *boundaries, // regular boundaries [(1<<bits)-1]
|
||||
const float *outlier_boundaries, // outlier boundaries [(1<<outlierBits)-1]
|
||||
int headDim,
|
||||
int numKVHeads,
|
||||
int bits,
|
||||
int numBoundaries,
|
||||
int outlierBits,
|
||||
int outlierCount,
|
||||
int numOutlierBoundaries,
|
||||
int kIsF32
|
||||
) {
|
||||
int batch = blockIdx.x;
|
||||
int head = blockIdx.y;
|
||||
int cell = firstCell + batch;
|
||||
|
||||
// Shared memory layout (laid out contiguously — launch passes total size):
|
||||
// s_k[headDim] - f32 K values
|
||||
// s_rot[headDim] - f32 rotated values
|
||||
// s_reduce[blockDim.x] - reduction scratch (float)
|
||||
// s_idx[headDim] - regular quantized indices (uint8)
|
||||
// s_is_outlier[headDim] - 1 = outlier, 0 = regular (uint8)
|
||||
// s_reg_pos[headDim] - regular position index map (int, up to headDim entries)
|
||||
// s_outl_pos[outlierCount] - outlier channel positions (int)
|
||||
// s_outl_val[outlierCount] - outlier rotated values (float)
|
||||
// s_outl_idx[outlierCount] - outlier quantized indices (uint8)
|
||||
|
||||
extern __shared__ char s_mem[];
|
||||
float *s_k = (float *)s_mem;
|
||||
float *s_rot = s_k + headDim;
|
||||
float *s_reduce = s_rot + headDim;
|
||||
uint8_t *s_idx = (uint8_t *)(s_reduce + blockDim.x);
|
||||
uint8_t *s_is_outlier = s_idx + headDim;
|
||||
int *s_reg_pos = (int *)(s_is_outlier + headDim);
|
||||
int *s_outl_pos = s_reg_pos + headDim;
|
||||
float *s_outl_val = (float *)(s_outl_pos + outlierCount);
|
||||
uint8_t *s_outl_idx = (uint8_t *)(s_outl_val + outlierCount);
|
||||
|
||||
// Step 1: Load K into s_k as f32.
|
||||
int base_k = batch * numKVHeads * headDim + head * headDim;
|
||||
for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
|
||||
if (kIsF32) {
|
||||
s_k[d] = ((const float *)k)[base_k + d];
|
||||
} else {
|
||||
s_k[d] = __half2float(__ushort_as_half(((const uint16_t *)k)[base_k + d]));
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 2: Rotate: s_rot[i] = sum_j rotation[i*headDim+j] * s_k[j] = (R^T @ k)[i].
|
||||
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
|
||||
float sum = 0.0f;
|
||||
for (int j = 0; j < headDim; j++) {
|
||||
sum += rotation[i * headDim + j] * s_k[j];
|
||||
}
|
||||
s_rot[i] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 3: Clear outlier mask.
|
||||
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
|
||||
s_is_outlier[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 4: Top-K outlier selection (serial on thread 0). O(K * headDim)
|
||||
// comparisons. At outlierCount=32, headDim=128 that's ~4k ops on one
|
||||
// thread — negligible next to the rotation matmul (16k ops) that ran
|
||||
// in parallel above. Also builds s_reg_pos as the ordered list of
|
||||
// regular (non-outlier) channel positions, and s_outl_pos for outliers.
|
||||
if (threadIdx.x == 0) {
|
||||
for (int r = 0; r < outlierCount; r++) {
|
||||
float best_val = -1.0f;
|
||||
int best_idx = 0;
|
||||
for (int i = 0; i < headDim; i++) {
|
||||
if (s_is_outlier[i]) continue;
|
||||
float a = fabsf(s_rot[i]);
|
||||
if (a > best_val) {
|
||||
best_val = a;
|
||||
best_idx = i;
|
||||
}
|
||||
}
|
||||
s_is_outlier[best_idx] = 1;
|
||||
s_outl_pos[r] = best_idx;
|
||||
s_outl_val[r] = s_rot[best_idx];
|
||||
}
|
||||
// Build regular position map after outlier selection is complete.
|
||||
int pos = 0;
|
||||
for (int i = 0; i < headDim; i++) {
|
||||
if (!s_is_outlier[i]) {
|
||||
s_reg_pos[pos++] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int regularCount = headDim - outlierCount;
|
||||
|
||||
// Step 5: Per-sub-block RMS scales via parallel reduction.
|
||||
// Each thread accumulates its dims' squared contributions into two
|
||||
// locals; we reduce regular and outlier sums back to back using
|
||||
// s_reduce as scratch.
|
||||
float local_sq_reg = 0.0f;
|
||||
float local_sq_out = 0.0f;
|
||||
for (int i = threadIdx.x; i < headDim; i += blockDim.x) {
|
||||
float v = s_rot[i];
|
||||
float sq = v * v;
|
||||
if (s_is_outlier[i]) {
|
||||
local_sq_out += sq;
|
||||
} else {
|
||||
local_sq_reg += sq;
|
||||
}
|
||||
}
|
||||
s_reduce[threadIdx.x] = local_sq_reg;
|
||||
__syncthreads();
|
||||
for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride) s_reduce[threadIdx.x] += s_reduce[threadIdx.x + stride];
|
||||
__syncthreads();
|
||||
}
|
||||
float regScale = 0.0f;
|
||||
if (threadIdx.x == 0) {
|
||||
float sum_sq = s_reduce[0];
|
||||
if (sum_sq > 1e-12f && regularCount > 0) {
|
||||
regScale = sqrtf(sum_sq / (float)regularCount);
|
||||
}
|
||||
scales_out[cell * numKVHeads + head] = regScale;
|
||||
s_reduce[0] = regScale;
|
||||
}
|
||||
__syncthreads();
|
||||
regScale = s_reduce[0];
|
||||
__syncthreads();
|
||||
|
||||
s_reduce[threadIdx.x] = local_sq_out;
|
||||
__syncthreads();
|
||||
for (int stride = blockDim.x >> 1; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride) s_reduce[threadIdx.x] += s_reduce[threadIdx.x + stride];
|
||||
__syncthreads();
|
||||
}
|
||||
float outScale = 0.0f;
|
||||
if (threadIdx.x == 0) {
|
||||
float sum_sq = s_reduce[0];
|
||||
if (sum_sq > 1e-12f && outlierCount > 0) {
|
||||
outScale = sqrtf(sum_sq / (float)outlierCount);
|
||||
}
|
||||
outlier_scales[cell * numKVHeads + head] = outScale;
|
||||
s_reduce[0] = outScale;
|
||||
}
|
||||
__syncthreads();
|
||||
outScale = s_reduce[0];
|
||||
|
||||
// Step 6: Quantize regular channels. s_idx[r] stores the code at the
|
||||
// CONTIGUOUS regular position r, not the original channel index, so the
|
||||
// dequant kernel can read them directly into the packed bit stream.
|
||||
for (int r = threadIdx.x; r < regularCount; r += blockDim.x) {
|
||||
int orig = s_reg_pos[r];
|
||||
float v = (regScale > 0.0f) ? (s_rot[orig] / regScale) : 0.0f;
|
||||
int idx = 0;
|
||||
for (int b = 0; b < numBoundaries; b++) {
|
||||
if (v >= boundaries[b]) idx++;
|
||||
}
|
||||
s_idx[r] = (uint8_t)idx;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 7: Pack regular bits into packed_out.
|
||||
// Per-head stride is padded up to a 4-byte multiple so atomicOr on
|
||||
// unsigned-int words stays aligned for every head, regardless of
|
||||
// how many regular channels (bit count) are in play. The Go-side
|
||||
// allocator uses the same padded value (regularPackedBytes()).
|
||||
const int regular_packed_bytes_raw = (regularCount * bits + 7) / 8;
|
||||
const int regular_packed_bytes = (regular_packed_bytes_raw + 3) & ~3;
|
||||
uint8_t *reg_out = packed_out + (cell * numKVHeads + head) * regular_packed_bytes;
|
||||
for (int p = threadIdx.x; p < regular_packed_bytes; p += blockDim.x) {
|
||||
reg_out[p] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
{
|
||||
uint8_t bitmask = (uint8_t)((1 << bits) - 1);
|
||||
for (int r = threadIdx.x; r < regularCount; r += blockDim.x) {
|
||||
int bit_offset = r * bits;
|
||||
int byte_idx = bit_offset >> 3;
|
||||
int shift = bit_offset & 7;
|
||||
uint8_t v = s_idx[r] & bitmask;
|
||||
atomicOr((unsigned int *)(reg_out + (byte_idx & ~3)),
|
||||
(unsigned int)(v << shift) << ((byte_idx & 3) * 8));
|
||||
if (shift + bits > 8) {
|
||||
int byte_idx2 = byte_idx + 1;
|
||||
atomicOr((unsigned int *)(reg_out + (byte_idx2 & ~3)),
|
||||
(unsigned int)(v >> (8 - shift)) << ((byte_idx2 & 3) * 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 8: Quantize outlier channels with their own codebook.
|
||||
for (int r = threadIdx.x; r < outlierCount; r += blockDim.x) {
|
||||
float v = (outScale > 0.0f) ? (s_outl_val[r] / outScale) : 0.0f;
|
||||
int idx = 0;
|
||||
for (int b = 0; b < numOutlierBoundaries; b++) {
|
||||
if (v >= outlier_boundaries[b]) idx++;
|
||||
}
|
||||
s_outl_idx[r] = (uint8_t)idx;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Step 9: Pack outlier bits. Same 4-byte alignment as regular packing.
|
||||
const int outlier_packed_bytes_raw = (outlierCount * outlierBits + 7) / 8;
|
||||
const int outlier_packed_bytes = (outlier_packed_bytes_raw + 3) & ~3;
|
||||
uint8_t *out_out = outlier_packed + (cell * numKVHeads + head) * outlier_packed_bytes;
|
||||
for (int p = threadIdx.x; p < outlier_packed_bytes; p += blockDim.x) {
|
||||
out_out[p] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
{
|
||||
uint8_t obmask = (uint8_t)((1 << outlierBits) - 1);
|
||||
for (int r = threadIdx.x; r < outlierCount; r += blockDim.x) {
|
||||
int bit_offset = r * outlierBits;
|
||||
int byte_idx = bit_offset >> 3;
|
||||
int shift = bit_offset & 7;
|
||||
uint8_t v = s_outl_idx[r] & obmask;
|
||||
atomicOr((unsigned int *)(out_out + (byte_idx & ~3)),
|
||||
(unsigned int)(v << shift) << ((byte_idx & 3) * 8));
|
||||
if (shift + outlierBits > 8) {
|
||||
int byte_idx2 = byte_idx + 1;
|
||||
atomicOr((unsigned int *)(out_out + (byte_idx2 & ~3)),
|
||||
(unsigned int)(v >> (8 - shift)) << ((byte_idx2 & 3) * 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 10: Write outlier channel indices. uint8_t covers 0..255 safely.
|
||||
uint8_t *idx_out = outlier_indices + (cell * numKVHeads + head) * outlierCount;
|
||||
for (int r = threadIdx.x; r < outlierCount; r += blockDim.x) {
|
||||
idx_out[r] = (uint8_t)s_outl_pos[r];
|
||||
}
|
||||
}
|
||||
|
||||
// ── ggml dispatch ─────────────────────────────────────────────────────────────
|
||||
|
||||
// Extern declaration for the V encode kernel (defined in tq-encode-v.cu).
|
||||
extern __global__ void tq_encode_v_kernel(
|
||||
const void *v, const float *rotation, uint8_t *packed_out, float *scales_out,
|
||||
int firstCell, const float *boundaries,
|
||||
int headDim, int numKVHeads, int bits, int numBoundaries, int vIsF32);
|
||||
|
||||
void ggml_cuda_tq_encode(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * k = dst->src[0];
|
||||
const struct ggml_tensor * rotation = dst->src[1];
|
||||
// src[2] unused (was cell_idx; now firstCell is in op_params[1])
|
||||
const struct ggml_tensor * scales = dst->src[3];
|
||||
const struct ggml_tensor * boundaries = dst->src[4];
|
||||
|
||||
const int headDim = (int)k->ne[0];
|
||||
const int numKVHeads = (int)k->ne[1];
|
||||
const int batchSize = (int)k->ne[2];
|
||||
const int bits = (int)((const int32_t *)dst->op_params)[0];
|
||||
const int firstCell = (int)((const int32_t *)dst->op_params)[1];
|
||||
const int outlierBits = (int)((const int32_t *)dst->op_params)[2];
|
||||
const int outlierCount = (int)((const int32_t *)dst->op_params)[3];
|
||||
const int numBoundaries = (1 << bits) - 1;
|
||||
const int kIsF32 = (k->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
|
||||
dim3 grid(batchSize, numKVHeads);
|
||||
int block_size = (headDim < TQ_ENCODE_BLOCK_SIZE) ? headDim : TQ_ENCODE_BLOCK_SIZE;
|
||||
int bs = 1;
|
||||
while (bs < block_size) bs <<= 1;
|
||||
block_size = bs;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
|
||||
const struct ggml_tensor * outlier_packed = dst->src[5];
|
||||
const struct ggml_tensor * outlier_scales = dst->src[6];
|
||||
const struct ggml_tensor * outlier_indices = dst->src[7];
|
||||
const struct ggml_tensor * outlier_boundaries = dst->src[8];
|
||||
const int numOutlierBoundaries = (1 << outlierBits) - 1;
|
||||
|
||||
// Shared memory layout for outlier kernel:
|
||||
// s_k[headDim] f32
|
||||
// s_rot[headDim] f32
|
||||
// s_reduce[block_size] f32
|
||||
// s_idx[headDim] u8
|
||||
// s_is_outlier[headDim] u8
|
||||
// s_reg_pos[headDim] i32 (over-sized to headDim; only regularCount used)
|
||||
// s_outl_pos[outlierCount] i32
|
||||
// s_outl_val[outlierCount] f32
|
||||
// s_outl_idx[outlierCount] u8
|
||||
size_t smem = (size_t)headDim * 2 * sizeof(float) // s_k + s_rot
|
||||
+ (size_t)block_size * sizeof(float) // s_reduce
|
||||
+ (size_t)headDim * 2 * sizeof(uint8_t) // s_idx + s_is_outlier
|
||||
+ (size_t)headDim * sizeof(int) // s_reg_pos
|
||||
+ (size_t)outlierCount * (sizeof(int) + sizeof(float) + sizeof(uint8_t));
|
||||
|
||||
tq_encode_kernel_outlier<<<grid, block_size, smem, stream>>>(
|
||||
k->data,
|
||||
(const float *)rotation->data,
|
||||
(uint8_t *)dst->data,
|
||||
(float *)scales->data,
|
||||
(uint8_t *)outlier_packed->data,
|
||||
(float *)outlier_scales->data,
|
||||
(uint8_t *)outlier_indices->data,
|
||||
firstCell,
|
||||
(const float *)boundaries->data,
|
||||
(const float *)outlier_boundaries->data,
|
||||
headDim, numKVHeads, bits, numBoundaries,
|
||||
outlierBits, outlierCount, numOutlierBoundaries, kIsF32
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t smem = (size_t)headDim * 2 * sizeof(float)
|
||||
+ (size_t)block_size * sizeof(float)
|
||||
+ (size_t)headDim * sizeof(uint8_t);
|
||||
|
||||
tq_encode_kernel<<<grid, block_size, smem, stream>>>(
|
||||
k->data,
|
||||
(const float *)rotation->data,
|
||||
(uint8_t *)dst->data,
|
||||
(float *)scales->data,
|
||||
firstCell,
|
||||
(const float *)boundaries->data,
|
||||
headDim, numKVHeads, bits, numBoundaries, kIsF32
|
||||
);
|
||||
}
|
||||
|
||||
// Combined K+V encode: two back-to-back kernel launches in a single GGML op.
|
||||
void ggml_cuda_tq_encode_kv(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
|
||||
// src layout: [0]=K, [1]=rotation, [2]=V, [3]=K_scales, [4]=K_bounds,
|
||||
// [5]=V_packed, [6]=V_scales, [7]=V_bounds
|
||||
const struct ggml_tensor * k = dst->src[0];
|
||||
const struct ggml_tensor * rotation = dst->src[1];
|
||||
const struct ggml_tensor * v = dst->src[2];
|
||||
const struct ggml_tensor * k_scales = dst->src[3];
|
||||
const struct ggml_tensor * k_bounds = dst->src[4];
|
||||
const struct ggml_tensor * v_packed = dst->src[5];
|
||||
const struct ggml_tensor * v_scales = dst->src[6];
|
||||
const struct ggml_tensor * v_bounds = dst->src[7];
|
||||
|
||||
int32_t k_bits, v_bits, firstCell;
|
||||
memcpy(&k_bits, (const int32_t *)dst->op_params + 0, sizeof(int32_t));
|
||||
memcpy(&v_bits, (const int32_t *)dst->op_params + 1, sizeof(int32_t));
|
||||
memcpy(&firstCell, (const int32_t *)dst->op_params + 2, sizeof(int32_t));
|
||||
|
||||
const int headDim = (int)k->ne[0];
|
||||
const int numKVHeads = (int)k->ne[1];
|
||||
const int batchSize = (int)k->ne[2];
|
||||
const int kIsF32 = (k->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
const int vIsF32 = (v->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
|
||||
dim3 grid(batchSize, numKVHeads);
|
||||
int block_size = (headDim < TQ_ENCODE_BLOCK_SIZE) ? headDim : TQ_ENCODE_BLOCK_SIZE;
|
||||
int bs = 1;
|
||||
while (bs < block_size) bs <<= 1;
|
||||
block_size = bs;
|
||||
|
||||
size_t smem = (size_t)headDim * 2 * sizeof(float)
|
||||
+ (size_t)block_size * sizeof(float)
|
||||
+ (size_t)headDim * sizeof(uint8_t);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// K encode
|
||||
tq_encode_kernel<<<grid, block_size, smem, stream>>>(
|
||||
k->data,
|
||||
(const float *)rotation->data,
|
||||
(uint8_t *)dst->data,
|
||||
(float *)k_scales->data,
|
||||
firstCell,
|
||||
(const float *)k_bounds->data,
|
||||
headDim, numKVHeads, k_bits, (1 << k_bits) - 1, kIsF32
|
||||
);
|
||||
|
||||
// V encode
|
||||
const float * rotation_ptr = rotation ? (const float *)rotation->data : nullptr;
|
||||
tq_encode_v_kernel<<<grid, block_size, smem, stream>>>(
|
||||
v->data,
|
||||
rotation_ptr,
|
||||
(uint8_t *)v_packed->data,
|
||||
(float *)v_scales->data,
|
||||
firstCell,
|
||||
(const float *)v_bounds->data,
|
||||
headDim, numKVHeads, v_bits, (1 << v_bits) - 1, vIsF32
|
||||
);
|
||||
}
|
||||
6
ml/backend/ggml/ggml/src/ggml-cuda/tq-encode.cuh
vendored
Normal file
6
ml/backend/ggml/ggml/src/ggml-cuda/tq-encode.cuh
vendored
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_tq_encode(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst);
|
||||
void ggml_cuda_tq_encode_kv(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst);
|
||||
552
ml/backend/ggml/ggml/src/ggml-cuda/tq-fattn-vec.cuh
vendored
Normal file
552
ml/backend/ggml/ggml/src/ggml-cuda/tq-fattn-vec.cuh
vendored
Normal file
|
|
@ -0,0 +1,552 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
// TurboQuant inline decode: extract a N-bit Lloyd-Max index from a packed byte row
|
||||
// and return codebook[idx] * rms_scale. Handles cross-byte boundaries for bits=2,3.
|
||||
static __device__ __forceinline__ float tq_decode_elem(
|
||||
const uint8_t * packed_row, const float * codebook, float rms_scale, int elem, int bits)
|
||||
{
|
||||
const int bit_pos = elem * bits;
|
||||
const int byte_idx = bit_pos >> 3;
|
||||
const int shift = bit_pos & 7;
|
||||
const int mask_val = (1 << bits) - 1;
|
||||
int idx = ((int)(packed_row[byte_idx] >> shift)) & mask_val;
|
||||
if (shift + bits > 8) {
|
||||
idx |= ((int)(packed_row[byte_idx + 1] << (8 - shift))) & mask_val;
|
||||
}
|
||||
return codebook[idx] * rms_scale;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// Hardcoded warp-shuffle TQ decode: combine packed bytes into a single
|
||||
// integer, shift to align, then extract N indices with compile-time
|
||||
// offsets. Eliminates per-element multiply, byte_idx computation, and
|
||||
// boundary-crossing branches.
|
||||
//
|
||||
// cb_lane: codebook[threadIdx.x % (1<<bits)], loaded once per kernel.
|
||||
// shfl_w: shuffle width (power of 2, >= 1<<bits).
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
#if __CUDA_ARCH__ >= 600 || !defined(__CUDA_ARCH__)
|
||||
// 3-bit, 4 elements. start_elem is a multiple of 4.
|
||||
// 12 bits needed from 2 bytes; bit offset within the first byte is 0 or 4.
|
||||
static __device__ __forceinline__ void tq_decode_4_3bit(
|
||||
const uint8_t * __restrict__ packed_row,
|
||||
float cb_lane, float rms, int start_elem, int shfl_w,
|
||||
float * __restrict__ out)
|
||||
{
|
||||
const int byte_off = (start_elem * 3) >> 3;
|
||||
const int bit_off = (start_elem * 3) & 7; // 0 or 4
|
||||
const uint32_t w = (uint32_t)packed_row[byte_off] | ((uint32_t)packed_row[byte_off + 1] << 8);
|
||||
const uint32_t s = w >> bit_off; // align element 0 to bit 0
|
||||
out[0] = __shfl_sync(0xFFFFFFFF, cb_lane, (s >> 0) & 7, shfl_w) * rms;
|
||||
out[1] = __shfl_sync(0xFFFFFFFF, cb_lane, (s >> 3) & 7, shfl_w) * rms;
|
||||
out[2] = __shfl_sync(0xFFFFFFFF, cb_lane, (s >> 6) & 7, shfl_w) * rms;
|
||||
out[3] = __shfl_sync(0xFFFFFFFF, cb_lane, (s >> 9) & 7, shfl_w) * rms;
|
||||
}
|
||||
|
||||
// 3-bit, 8 elements. start_elem is a multiple of 8 (Volta+ path).
|
||||
// 24 bits from 3 bytes; bit offset is always 0.
|
||||
static __device__ __forceinline__ void tq_decode_8_3bit(
|
||||
const uint8_t * __restrict__ packed_row,
|
||||
float cb_lane, float rms, int start_elem, int shfl_w,
|
||||
float * __restrict__ out)
|
||||
{
|
||||
const int byte_off = (start_elem * 3) >> 3;
|
||||
const uint32_t w = (uint32_t)packed_row[byte_off]
|
||||
| ((uint32_t)packed_row[byte_off + 1] << 8)
|
||||
| ((uint32_t)packed_row[byte_off + 2] << 16);
|
||||
out[0] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 0) & 7, shfl_w) * rms;
|
||||
out[1] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 3) & 7, shfl_w) * rms;
|
||||
out[2] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 6) & 7, shfl_w) * rms;
|
||||
out[3] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 9) & 7, shfl_w) * rms;
|
||||
out[4] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 12) & 7, shfl_w) * rms;
|
||||
out[5] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 15) & 7, shfl_w) * rms;
|
||||
out[6] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 18) & 7, shfl_w) * rms;
|
||||
out[7] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 21) & 7, shfl_w) * rms;
|
||||
}
|
||||
|
||||
// 2-bit, 4 elements. start_elem is a multiple of 4.
|
||||
// 8 bits = 1 byte, always byte-aligned.
|
||||
static __device__ __forceinline__ void tq_decode_4_2bit(
|
||||
const uint8_t * __restrict__ packed_row,
|
||||
float cb_lane, float rms, int start_elem, int shfl_w,
|
||||
float * __restrict__ out)
|
||||
{
|
||||
const uint32_t b = packed_row[start_elem >> 2];
|
||||
out[0] = __shfl_sync(0xFFFFFFFF, cb_lane, (b >> 0) & 3, shfl_w) * rms;
|
||||
out[1] = __shfl_sync(0xFFFFFFFF, cb_lane, (b >> 2) & 3, shfl_w) * rms;
|
||||
out[2] = __shfl_sync(0xFFFFFFFF, cb_lane, (b >> 4) & 3, shfl_w) * rms;
|
||||
out[3] = __shfl_sync(0xFFFFFFFF, cb_lane, (b >> 6) & 3, shfl_w) * rms;
|
||||
}
|
||||
|
||||
// 2-bit, 8 elements. start_elem is a multiple of 8.
|
||||
// 16 bits = 2 bytes.
|
||||
static __device__ __forceinline__ void tq_decode_8_2bit(
|
||||
const uint8_t * __restrict__ packed_row,
|
||||
float cb_lane, float rms, int start_elem, int shfl_w,
|
||||
float * __restrict__ out)
|
||||
{
|
||||
const int byte_off = start_elem >> 2;
|
||||
const uint32_t w = (uint32_t)packed_row[byte_off] | ((uint32_t)packed_row[byte_off + 1] << 8);
|
||||
out[0] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 0) & 3, shfl_w) * rms;
|
||||
out[1] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 2) & 3, shfl_w) * rms;
|
||||
out[2] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 4) & 3, shfl_w) * rms;
|
||||
out[3] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 6) & 3, shfl_w) * rms;
|
||||
out[4] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 8) & 3, shfl_w) * rms;
|
||||
out[5] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 10) & 3, shfl_w) * rms;
|
||||
out[6] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 12) & 3, shfl_w) * rms;
|
||||
out[7] = __shfl_sync(0xFFFFFFFF, cb_lane, (w >> 14) & 3, shfl_w) * rms;
|
||||
}
|
||||
#else
|
||||
// Stubs for sm < 600 (no __shfl_sync). Never executed — the launcher
|
||||
// asserts compute capability >= 6.0. These only exist so that the
|
||||
// __device__ dispatch template below can compile for Maxwell targets.
|
||||
static __device__ __forceinline__ void tq_decode_4_3bit(
|
||||
const uint8_t * __restrict__, float, float, int, int,
|
||||
float * __restrict__ out)
|
||||
{ out[0] = out[1] = out[2] = out[3] = 0.0f; }
|
||||
static __device__ __forceinline__ void tq_decode_8_3bit(
|
||||
const uint8_t * __restrict__, float, float, int, int,
|
||||
float * __restrict__ out)
|
||||
{ out[0] = out[1] = out[2] = out[3] = out[4] = out[5] = out[6] = out[7] = 0.0f; }
|
||||
static __device__ __forceinline__ void tq_decode_4_2bit(
|
||||
const uint8_t * __restrict__, float, float, int, int,
|
||||
float * __restrict__ out)
|
||||
{ out[0] = out[1] = out[2] = out[3] = 0.0f; }
|
||||
static __device__ __forceinline__ void tq_decode_8_2bit(
|
||||
const uint8_t * __restrict__, float, float, int, int,
|
||||
float * __restrict__ out)
|
||||
{ out[0] = out[1] = out[2] = out[3] = out[4] = out[5] = out[6] = out[7] = 0.0f; }
|
||||
#endif
|
||||
|
||||
// Dispatch: decode N elements (N=4 on Pascal, N=8 on Volta+).
|
||||
template<int N>
|
||||
static __device__ __forceinline__ void tq_decode_N_shfl(
|
||||
const uint8_t * __restrict__ packed_row,
|
||||
float cb_lane, float rms,
|
||||
int start_elem, int bits, int shfl_w,
|
||||
float * __restrict__ out)
|
||||
{
|
||||
if constexpr (N == 4) {
|
||||
if (bits == 3) { tq_decode_4_3bit(packed_row, cb_lane, rms, start_elem, shfl_w, out); }
|
||||
else { tq_decode_4_2bit(packed_row, cb_lane, rms, start_elem, shfl_w, out); }
|
||||
} else {
|
||||
if (bits == 3) { tq_decode_8_3bit(packed_row, cb_lane, rms, start_elem, shfl_w, out); }
|
||||
else { tq_decode_8_2bit(packed_row, cb_lane, rms, start_elem, shfl_w, out); }
|
||||
}
|
||||
}
|
||||
|
||||
// Compute Q·K dot product where K is TQ-compressed.
|
||||
// Uses warp-shuffle for codebook lookup (1-cycle register-to-register).
|
||||
template<int D, int nthreads, int cpy_ne>
|
||||
static __device__ __forceinline__ float tq_vec_dot_KQ(
|
||||
const uint8_t * packed_row,
|
||||
float cb_lane,
|
||||
float rms_scale,
|
||||
const float2 * Q_f,
|
||||
int bits)
|
||||
{
|
||||
float sum = 0.0f;
|
||||
const int tid_kq = threadIdx.x % nthreads;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
|
||||
// Decode 2*cpy_ne K elements at once.
|
||||
float k_dec[2*cpy_ne];
|
||||
tq_decode_N_shfl<2*cpy_ne>(packed_row, cb_lane, rms_scale,
|
||||
2*(k_KQ_0 + tid_kq*cpy_ne), bits, nthreads, k_dec);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
|
||||
const float2 q_pair = Q_f[k_KQ_0/nthreads + k_KQ_1];
|
||||
sum += q_pair.x * k_dec[2*k_KQ_1] + q_pair.y * k_dec[2*k_KQ_1 + 1];
|
||||
}
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
// TurboQuant fused flash-attention kernel.
|
||||
//
|
||||
// Q: [D, nTokensQ, nHeadsQ, nSeq] f32 (after SDPA Permute(0,2,1,3))
|
||||
// K_packed: [packedBytes*nKVHeads, capacity] i8 (encode result; base = full buffer)
|
||||
// V: [D, nCells, nKVHeads, nSeq] f16 when !V_PACKED (after FA-branch Permute(0,2,1,3))
|
||||
// [v_packedBytes*nKVHeads, capacity] i8 when V_PACKED (packed buffer, no permute)
|
||||
// mask: [nCells, nTokensQ] f16 or NULL
|
||||
// scales: [nKVHeads, capacity] f32 (K scales; base = full buffer)
|
||||
// codebook: [1<<bits] f32 (K codebook)
|
||||
//
|
||||
// V_PACKED == false: existing K-only fused path (V is f16, src[6] == NULL)
|
||||
// V_PACKED == true: new K+V fused path (V is packed i8, src[6] = v_scales, src[7] = v_codebook)
|
||||
//
|
||||
// Phase 1: D == 128, bits runtime param, gridDim.y == 1 (no multi-block).
|
||||
template<int D, int ncols, bool use_logit_softcap, bool V_PACKED>
|
||||
__launch_bounds__(128, 2)
|
||||
static __global__ void tq_flash_attn_ext_vec(
|
||||
const char * __restrict__ Q,
|
||||
const uint8_t * __restrict__ K_packed,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
const float * __restrict__ scales,
|
||||
const float * __restrict__ codebook,
|
||||
float scale,
|
||||
float logit_softcap,
|
||||
int bits,
|
||||
int firstCell,
|
||||
int nCells,
|
||||
int nKVHeads,
|
||||
int packedBytes,
|
||||
// Q geometry
|
||||
int32_t ne00,
|
||||
uint3 ne01, // init_fastdiv_values(nTokensQ); .z == nTokensQ
|
||||
int32_t ne02, // nHeadsQ
|
||||
int32_t ne03, // nSeq
|
||||
int32_t nb01, // Q stride: bytes between consecutive tokens
|
||||
int32_t nb02, // Q stride: bytes between consecutive heads
|
||||
int64_t nb03, // Q stride: bytes between consecutive sequences
|
||||
// V geometry (after permute: [D, nCells, nKVHeads]) — only used when !V_PACKED
|
||||
int32_t nb21, // bytes between V cells (= D*nKVHeads*sizeof(half))
|
||||
int32_t nb22, // bytes between V heads (= D*sizeof(half))
|
||||
int64_t nb23, // bytes between V seqs
|
||||
// mask geometry
|
||||
int32_t ne31, // nCells (mask row width)
|
||||
int32_t nb31, // mask stride: bytes between token rows
|
||||
// V packed params (only used when V_PACKED == true)
|
||||
const float * __restrict__ v_scales, // [nKVHeads, capacity] f32
|
||||
const float * __restrict__ v_codebook, // [1<<v_bits] f32
|
||||
int v_bits,
|
||||
int v_packedBytes
|
||||
)
|
||||
{
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
// Skip logit_softcap variants for unsupported D values (mirrors original kernel guard).
|
||||
if (use_logit_softcap && D != 128 && D != 256) {
|
||||
GGML_UNUSED_VARS(Q, K_packed, V, mask, dst, scales, codebook,
|
||||
scale, logit_softcap, bits, firstCell, nCells, nKVHeads, packedBytes,
|
||||
ne00, ne01, ne02, ne03, nb01, nb02, nb03,
|
||||
nb21, nb22, nb23, ne31, nb31,
|
||||
v_scales, v_codebook, v_bits, v_packedBytes);
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); // 16 on Volta+
|
||||
constexpr int cpy_ne = cpy_nb / 4; // 4
|
||||
|
||||
constexpr int nthreads = 128;
|
||||
constexpr int nthreads_KQ = nthreads / cpy_nb; // 8
|
||||
constexpr int nthreads_V = nthreads / cpy_nb; // 8
|
||||
|
||||
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_KQ");
|
||||
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 64");
|
||||
|
||||
constexpr int V_rows_per_thread = 2 * cpy_ne; // 8
|
||||
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; // 4
|
||||
|
||||
// dequantize_V is only needed for the f16 V path
|
||||
[[maybe_unused]] constexpr dequantize_V_t dequantize_V =
|
||||
V_PACKED ? (dequantize_V_t)nullptr
|
||||
: get_dequantize_V<GGML_TYPE_F16, float, V_rows_per_thread>();
|
||||
|
||||
const int ic0 = blockIdx.x * ncols;
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z % ne02;
|
||||
const int gqa_ratio = ne02 / nKVHeads;
|
||||
const int head_kv = head / gqa_ratio;
|
||||
|
||||
// Advance base pointers.
|
||||
Q += (int64_t)nb03*sequence + nb02*head + nb01*ic0;
|
||||
K_packed += (int64_t)firstCell * nKVHeads * packedBytes + (int64_t)head_kv * packedBytes;
|
||||
scales += (int64_t)firstCell * nKVHeads + head_kv;
|
||||
|
||||
// V pointer setup: f16 path uses stride-based addressing; packed path uses cell-index addressing.
|
||||
const uint8_t * V_packed_base = nullptr;
|
||||
const float * v_scales_base = nullptr;
|
||||
if constexpr (V_PACKED) {
|
||||
V_packed_base = (const uint8_t *)V
|
||||
+ (int64_t)firstCell * nKVHeads * v_packedBytes
|
||||
+ (int64_t)head_kv * v_packedBytes;
|
||||
v_scales_base = v_scales
|
||||
+ (int64_t)firstCell * nKVHeads + head_kv;
|
||||
} else {
|
||||
V += (int64_t)nb23*sequence + (int64_t)nb22*head_kv;
|
||||
}
|
||||
|
||||
const half * maskh = mask ? (const half *)(mask + (int64_t)nb31*ic0) : nullptr;
|
||||
|
||||
// Load one codebook entry per lane for warp-shuffle lookups.
|
||||
// For 3-bit (8 entries): lanes 0-7 hold codebook[0-7], lanes 8-15 repeat, etc.
|
||||
// __shfl_sync with width=8 handles the wrap-around.
|
||||
const float k_cb_lane = codebook[threadIdx.x & ((1 << bits) - 1)];
|
||||
[[maybe_unused]] const float v_cb_lane = V_PACKED
|
||||
? v_codebook[threadIdx.x & ((1 << v_bits) - 1)]
|
||||
: 0.0f;
|
||||
|
||||
constexpr int nwarps = nthreads / WARP_SIZE;
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
__builtin_assume(tid < nthreads);
|
||||
|
||||
constexpr int ne_KQ = ncols * D;
|
||||
constexpr int ne_combine = nwarps * V_cols_per_iter * D;
|
||||
|
||||
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
||||
|
||||
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
||||
|
||||
float KQ_max[ncols];
|
||||
float KQ_sum[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_max[j] = -FLT_MAX/2.0f;
|
||||
KQ_sum[j] = 0.0f;
|
||||
}
|
||||
|
||||
// Load Q into registers (float2 per half-pair, scale applied).
|
||||
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const float2 * Q_j = (const float2 *)(Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
|
||||
const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
|
||||
if (ncols == 1 || ic0 + j < (int)ne01.z) {
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
|
||||
}
|
||||
}
|
||||
// Apply attention scale.
|
||||
#pragma unroll
|
||||
for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
|
||||
Q_reg[j][k].x *= scale;
|
||||
Q_reg[j][k].y *= scale;
|
||||
}
|
||||
}
|
||||
|
||||
// Main KV loop — single block (gridDim.y == 1).
|
||||
for (int k_VKQ_0 = 0; k_VKQ_0 < nCells; k_VKQ_0 += nthreads,
|
||||
V += (V_PACKED ? 0 : (int64_t)nthreads * nb21),
|
||||
maskh += (maskh ? nthreads : 0)) {
|
||||
|
||||
float KQ_reg[ncols];
|
||||
float KQ_max_new[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_max_new[j] = KQ_max[j];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
|
||||
const int i_KQ = threadIdx.y*WARP_SIZE
|
||||
+ (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1)))
|
||||
+ i_KQ_0;
|
||||
const int cell_rel = k_VKQ_0 + i_KQ;
|
||||
const bool in_range = (cell_rel < nCells);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
// Always execute the dot product to keep all lanes convergent
|
||||
// for warp shuffles. rms_scale=0 for out-of-range cells.
|
||||
const uint8_t * packed_row = K_packed + (int64_t)cell_rel * nKVHeads * packedBytes;
|
||||
const float rms_scale = in_range ? scales[cell_rel * nKVHeads] : 0.0f;
|
||||
float sum = tq_vec_dot_KQ<D, nthreads_KQ, cpy_ne>(
|
||||
packed_row, k_cb_lane, rms_scale, Q_reg[j], bits);
|
||||
sum = warp_reduce_sum<nthreads_KQ>(sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum = logit_softcap * tanhf(sum);
|
||||
}
|
||||
|
||||
if (maskh && (ncols == 1 || ic0 + j < (int)ne01.z)) {
|
||||
sum += __half2float(maskh[j*ne31 + i_KQ]);
|
||||
}
|
||||
|
||||
if (!in_range) {
|
||||
sum = -FLT_MAX/2.0f;
|
||||
}
|
||||
|
||||
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);
|
||||
|
||||
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == (uint32_t)i_KQ_0) {
|
||||
KQ_reg[j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
#pragma unroll
|
||||
for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
|
||||
KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
|
||||
}
|
||||
const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
|
||||
KQ_max[j] = KQ_max_new[j];
|
||||
|
||||
KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);
|
||||
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
||||
KQ[j*nthreads + tid] = KQ_reg[j];
|
||||
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
||||
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef GGML_USE_HIP
|
||||
__syncwarp();
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
||||
const int k = threadIdx.y*WARP_SIZE + k0
|
||||
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
||||
|
||||
float KQ_k[ncols];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
KQ_k[j] = KQ[j*nthreads + k];
|
||||
}
|
||||
|
||||
if constexpr (V_PACKED) {
|
||||
// Decode V from packed buffer inline — warp-shuffle batch decode.
|
||||
// Always execute decodes to keep all lanes convergent for shuffles.
|
||||
// v_rms=0 for out-of-range cells produces zero contributions.
|
||||
const int cell_rel = k_VKQ_0 + k;
|
||||
const uint8_t * v_row = V_packed_base
|
||||
+ (int64_t)cell_rel * nKVHeads * v_packedBytes;
|
||||
const float v_rms = (cell_rel < nCells)
|
||||
? v_scales_base[cell_rel * nKVHeads] : 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
const int base_elem = 2*i_VKQ_0
|
||||
+ (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)
|
||||
* V_rows_per_thread;
|
||||
|
||||
// Batch-decode V elements using shuffle codebook.
|
||||
float v_dec[V_rows_per_thread];
|
||||
tq_decode_N_shfl<V_rows_per_thread>(v_row, v_cb_lane, v_rms,
|
||||
base_elem, v_bits, nthreads_V, v_dec);
|
||||
|
||||
#pragma unroll
|
||||
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += v_dec[2*i_VKQ_1] * KQ_k[j];
|
||||
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += v_dec[2*i_VKQ_1+1] * KQ_k[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original f16 V path (unchanged).
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
float2 tmp[V_rows_per_thread/2];
|
||||
dequantize_V(V + k*nb21, tmp,
|
||||
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
|
||||
#pragma unroll
|
||||
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x * KQ_k[j];
|
||||
VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y * KQ_k[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end KV loop
|
||||
|
||||
// --- Reduce across warps and write output ---
|
||||
|
||||
__shared__ float KQ_max_shared[ncols][WARP_SIZE];
|
||||
__shared__ float KQ_sum_shared[ncols][WARP_SIZE];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.y == 0) {
|
||||
KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
|
||||
KQ_sum_shared[j][threadIdx.x] = 0.0f;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (threadIdx.x == 0) {
|
||||
KQ_max_shared[j][threadIdx.y] = KQ_max[j];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||
if (ncols > 1 && ic0 + j_VKQ >= (int)ne01.z) {
|
||||
break;
|
||||
}
|
||||
|
||||
float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];
|
||||
kqmax_new = warp_reduce_max(kqmax_new);
|
||||
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
||||
KQ_max[j_VKQ] = kqmax_new;
|
||||
|
||||
float2 * VKQ_tmp = (float2 *)KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
||||
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
||||
VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale;
|
||||
VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
|
||||
const int i_VKQ = i_VKQ_0
|
||||
+ (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
|
||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ,
|
||||
&VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
||||
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4,
|
||||
&VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
||||
}
|
||||
|
||||
KQ_sum[j_VKQ] *= kqmax_scale;
|
||||
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
||||
if (threadIdx.x == 0) {
|
||||
KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (nthreads <= D || tid < D) {
|
||||
KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x];
|
||||
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += nthreads) {
|
||||
float dst_val = 0;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < nwarps; ++w) {
|
||||
#pragma unroll
|
||||
for (int v = 0; v < V_cols_per_iter; ++v) {
|
||||
dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
|
||||
}
|
||||
}
|
||||
dst_val /= KQ_sum[j_VKQ];
|
||||
// Output layout: [D, nHeadsQ, nTokensQ, nSeq] — matches ggml_flash_attn_ext layout.
|
||||
dst[(((int64_t)sequence*(int)ne01.z + ic0 + j_VKQ)*ne02 + head)*D + i0 + tid] = dst_val;
|
||||
}
|
||||
}
|
||||
|
||||
if (j_VKQ < ncols-1) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K_packed, V, mask, dst, scales, codebook,
|
||||
scale, logit_softcap, bits, firstCell, nCells, nKVHeads, packedBytes,
|
||||
ne00, ne01, ne02, ne03, nb01, nb02, nb03,
|
||||
nb21, nb22, nb23, ne31, nb31,
|
||||
v_scales, v_codebook, v_bits, v_packedBytes);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
139
ml/backend/ggml/ggml/src/ggml-cuda/tq-fattn.cu
vendored
Normal file
139
ml/backend/ggml/ggml/src/ggml-cuda/tq-fattn.cu
vendored
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
#include "tq-fattn.cuh"
|
||||
#include "tq-fattn-vec.cuh"
|
||||
|
||||
// Launch the TQ fused flash-attention kernel for a given (D, ncols, use_logit_softcap, V_PACKED).
|
||||
template<int D, int ncols, bool use_logit_softcap, bool V_PACKED>
|
||||
static void tq_fattn_vec_launch(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
|
||||
float scale, float logit_softcap,
|
||||
int bits, int firstCell, int nCells, int nKVHeads, int packedBytes,
|
||||
int v_bits, int v_packedBytes,
|
||||
const float * v_scales_ptr, const float * v_codebook_ptr)
|
||||
{
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K_p = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * scales = dst->src[4];
|
||||
const ggml_tensor * codebook = dst->src[5];
|
||||
|
||||
GGML_ASSERT(Q->ne[0] == D);
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
if constexpr (!V_PACKED) {
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||
} else {
|
||||
GGML_ASSERT(V->type == GGML_TYPE_I8);
|
||||
}
|
||||
|
||||
const int nTokensQ = (int)Q->ne[1];
|
||||
const int nHeadsQ = (int)Q->ne[2];
|
||||
const int nSeq = (int)Q->ne[3];
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
|
||||
const uint3 ne01 = init_fastdiv_values((uint64_t)nTokensQ);
|
||||
|
||||
const int ntiles_x = (nTokensQ + ncols - 1) / ncols;
|
||||
dim3 blocks(ntiles_x, 1, nHeadsQ * nSeq);
|
||||
dim3 threads(WARP_SIZE, 4);
|
||||
|
||||
// V strides: only used by !V_PACKED path; pass V strides for the f16 case.
|
||||
// For the V_PACKED case these are passed but ignored by the kernel.
|
||||
tq_flash_attn_ext_vec<D, ncols, use_logit_softcap, V_PACKED><<<blocks, threads, 0, ctx.stream()>>>(
|
||||
(const char *)Q->data,
|
||||
(const uint8_t *)K_p->data,
|
||||
(const char *)V->data,
|
||||
mask ? (const char *)mask->data : nullptr,
|
||||
(float *)dst->data,
|
||||
(const float *)scales->data,
|
||||
(const float *)codebook->data,
|
||||
scale, logit_softcap, bits, firstCell, nCells, nKVHeads, packedBytes,
|
||||
(int32_t)Q->ne[0],
|
||||
ne01,
|
||||
(int32_t)Q->ne[2],
|
||||
(int32_t)Q->ne[3],
|
||||
(int32_t)Q->nb[1],
|
||||
(int32_t)Q->nb[2],
|
||||
(int64_t)Q->nb[3],
|
||||
(int32_t)V->nb[1],
|
||||
(int32_t)V->nb[2],
|
||||
(int64_t)V->nb[3],
|
||||
mask ? (int32_t)mask->ne[0] : 0, // nCells (mask row width)
|
||||
mask ? (int32_t)mask->nb[1] : 0, // bytes between token rows
|
||||
v_scales_ptr, v_codebook_ptr, v_bits, v_packedBytes
|
||||
);
|
||||
}
|
||||
|
||||
void ggml_cuda_tq_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_cuda_info().devices[ctx.device].cc >= 600 &&
|
||||
"TurboQuant fused flash attention requires compute capability 6.0+");
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * v_scales_t = dst->src[6]; // NULL for K-only fused
|
||||
const ggml_tensor * v_codebook_t = dst->src[7];
|
||||
|
||||
float scale;
|
||||
float logit_softcap;
|
||||
int32_t bits;
|
||||
int32_t firstCell;
|
||||
int32_t v_bits;
|
||||
memcpy(&scale, (const float *)dst->op_params + 0, sizeof(float));
|
||||
memcpy(&logit_softcap, (const float *)dst->op_params + 1, sizeof(float));
|
||||
memcpy(&bits, (const int32_t *)dst->op_params + 2, sizeof(int32_t));
|
||||
memcpy(&firstCell, (const int32_t *)dst->op_params + 3, sizeof(int32_t));
|
||||
memcpy(&v_bits, (const int32_t *)dst->op_params + 4, sizeof(int32_t));
|
||||
|
||||
const int D = (int)Q->ne[0];
|
||||
const int nTokensQ = (int)Q->ne[1];
|
||||
// nCells: when V_PACKED, V is the raw buffer so ne[1] is capacity not nCells.
|
||||
// Compute nCells from the mask if available, else from V geometry.
|
||||
const bool v_packed = (v_scales_t != nullptr);
|
||||
|
||||
int nCells;
|
||||
if (v_packed) {
|
||||
// V is packed i8 buffer: nCells is not directly in V->ne[].
|
||||
// The mask always carries nCells; require mask when V is packed.
|
||||
GGML_ASSERT(dst->src[3] != nullptr && "TQ K+V fused requires mask to determine nCells");
|
||||
nCells = (int)dst->src[3]->ne[0];
|
||||
} else {
|
||||
nCells = (int)V->ne[1]; // after SDPA permute: [D, nCells, nKVHeads]
|
||||
}
|
||||
|
||||
// packedBytes for K:
|
||||
const int packedBytes = (D * bits + 7) / 8;
|
||||
// packedBytes for V (computed before nKVHeads so we can derive nKVHeads from packed V dims):
|
||||
const int v_packedBytes = v_packed ? ((D * v_bits + 7) / 8) : 0;
|
||||
// nKVHeads: for K-only fused, V is a permuted f16 tensor with ne[2]=nKVHeads.
|
||||
// For K+V fused, V is the raw packed i8 buffer with ne[0]=v_packedBytes*nKVHeads.
|
||||
const int nKVHeads = v_packed ? (int)V->ne[0] / v_packedBytes : (int)V->ne[2];
|
||||
|
||||
const float * v_scales_ptr = v_scales_t ? (const float *)v_scales_t->data : nullptr;
|
||||
const float * v_codebook_ptr = v_codebook_t ? (const float *)v_codebook_t->data : nullptr;
|
||||
|
||||
GGML_ASSERT(D == 128); // Phase 1: head_dim=128 only
|
||||
|
||||
if (logit_softcap != 0.0f) { scale /= logit_softcap; }
|
||||
|
||||
const int ncols = (nTokensQ == 1) ? 1 : 2;
|
||||
|
||||
#define DISPATCH(NCOLS, SOFTCAP) \
|
||||
if (v_packed) { \
|
||||
tq_fattn_vec_launch<128, NCOLS, SOFTCAP, true>(ctx, dst, scale, logit_softcap, \
|
||||
bits, firstCell, nCells, nKVHeads, packedBytes, \
|
||||
v_bits, v_packedBytes, v_scales_ptr, v_codebook_ptr); \
|
||||
} else { \
|
||||
tq_fattn_vec_launch<128, NCOLS, SOFTCAP, false>(ctx, dst, scale, logit_softcap, \
|
||||
bits, firstCell, nCells, nKVHeads, packedBytes, \
|
||||
0, 0, nullptr, nullptr); \
|
||||
}
|
||||
|
||||
if (ncols == 1) {
|
||||
if (logit_softcap == 0.0f) { DISPATCH(1, false); }
|
||||
else { DISPATCH(1, true); }
|
||||
} else {
|
||||
if (logit_softcap == 0.0f) { DISPATCH(2, false); }
|
||||
else { DISPATCH(2, true); }
|
||||
}
|
||||
#undef DISPATCH
|
||||
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
4
ml/backend/ggml/ggml/src/ggml-cuda/tq-fattn.cuh
vendored
Normal file
4
ml/backend/ggml/ggml/src/ggml-cuda/tq-fattn.cuh
vendored
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
#pragma once
|
||||
#include "ggml-cuda/common.cuh"
|
||||
|
||||
void ggml_cuda_tq_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -1704,3 +1704,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggm
|
|||
|
||||
return res;
|
||||
}
|
||||
|
||||
// TurboQuant pipeline getters — fixed-name kernels, no variants
|
||||
|
||||
static struct ggml_metal_pipeline_with_params tq_get_pipeline(ggml_metal_library_t lib, const char * name) {
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, name, name, nullptr);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequant (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_dequant"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequant_outlier(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_dequant_outlier"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_v"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_outlier"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16_d256"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed_d256"); }
|
||||
|
|
|
|||
|
|
@ -188,6 +188,17 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_att
|
|||
int32_t dv,
|
||||
int32_t nwg);
|
||||
|
||||
// TurboQuant pipeline getters
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequant (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequant_outlier(ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib);
|
||||
|
||||
// MTLResidencySet wrapper
|
||||
|
||||
typedef void * ggml_metal_rset_t;
|
||||
|
|
|
|||
|
|
@ -1181,6 +1181,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_TQ_ENCODE:
|
||||
case GGML_OP_TQ_ENCODE_V:
|
||||
case GGML_OP_TQ_ENCODE_KV:
|
||||
case GGML_OP_TQ_DEQUANT:
|
||||
case GGML_OP_TQ_DEQUANT_KV:
|
||||
case GGML_OP_TQ_FLASH_ATTN_EXT:
|
||||
return has_simdgroup_reduction;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -942,4 +942,82 @@ typedef struct {
|
|||
int64_t np;
|
||||
} ggml_metal_kargs_opt_step_sgd;
|
||||
|
||||
// TurboQuant (TQ) — dynamic KV-cache compression at inference time
|
||||
|
||||
typedef struct {
|
||||
int32_t headDim;
|
||||
int32_t numKVHeads;
|
||||
int32_t bits;
|
||||
int32_t firstCell;
|
||||
int32_t packed_bytes; // (headDim * bits + 7) / 8
|
||||
int32_t codebook_len;
|
||||
} ggml_metal_kargs_tq_dequant;
|
||||
|
||||
typedef struct {
|
||||
int32_t headDim;
|
||||
int32_t numKVHeads;
|
||||
int32_t bits;
|
||||
int32_t firstCell;
|
||||
int32_t reg_packed_bytes; // padded: (reg_count * bits + 7) / 8, aligned to 4
|
||||
int32_t outlier_bits;
|
||||
int32_t outlier_count;
|
||||
int32_t out_packed_bytes; // padded: (outlierCount * outlierBits + 7) / 8, aligned to 4
|
||||
} ggml_metal_kargs_tq_dequant_outlier;
|
||||
|
||||
typedef struct {
|
||||
int32_t headDim;
|
||||
int32_t numKVHeads;
|
||||
int32_t k_bits;
|
||||
int32_t v_bits;
|
||||
int32_t firstCell;
|
||||
int32_t k_packed_bytes;
|
||||
int32_t v_packed_bytes;
|
||||
int32_t k_codebook_len;
|
||||
int32_t v_codebook_len;
|
||||
} ggml_metal_kargs_tq_dequant_kv;
|
||||
|
||||
typedef struct {
|
||||
int32_t headDim;
|
||||
int32_t numKVHeads;
|
||||
int32_t bits;
|
||||
int32_t firstCell;
|
||||
int32_t kIsF32; // 1 = f32 input, 0 = f16
|
||||
int32_t hasRotation; // 1 = apply rotation matrix, 0 = skip
|
||||
} ggml_metal_kargs_tq_encode;
|
||||
|
||||
typedef struct {
|
||||
int32_t headDim;
|
||||
int32_t numKVHeads;
|
||||
int32_t bits;
|
||||
int32_t firstCell;
|
||||
int32_t kIsF32;
|
||||
int32_t outlierBits;
|
||||
int32_t outlierCount;
|
||||
} ggml_metal_kargs_tq_encode_outlier;
|
||||
|
||||
typedef struct {
|
||||
int32_t ncols; // 1 or 2 (nTokensQ == 1 ? 1 : 2)
|
||||
int32_t nTokensQ;
|
||||
int32_t nHeadsQ;
|
||||
int32_t nSeq;
|
||||
int32_t nCells;
|
||||
int32_t nKVHeads;
|
||||
int32_t bits; // K bits
|
||||
int32_t firstCell;
|
||||
int32_t packedBytes; // K packed bytes per head
|
||||
int32_t v_bits; // 0 when V is f16
|
||||
int32_t v_packedBytes;
|
||||
int32_t hasMask; // 1 = mask buffer valid
|
||||
int32_t ne31; // mask row width (= nCells)
|
||||
float scale;
|
||||
float logit_softcap;
|
||||
uint64_t nb01; // Q stride: bytes between consecutive tokens
|
||||
uint64_t nb02; // Q stride: bytes between consecutive heads
|
||||
uint64_t nb03; // Q stride: bytes between consecutive sequences
|
||||
uint64_t nb21; // V stride: bytes between cells (f16 path only)
|
||||
uint64_t nb22; // V stride: bytes between heads (f16 path only)
|
||||
uint64_t nb23; // V stride: bytes between seqs (f16 path only)
|
||||
uint64_t nb31; // mask stride: bytes between token rows
|
||||
} ggml_metal_kargs_tq_fattn_vec;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
||||
|
|
|
|||
|
|
@ -452,6 +452,30 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TQ_ENCODE:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tq_encode(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TQ_ENCODE_V:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tq_encode_v(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TQ_ENCODE_KV:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tq_encode_kv(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TQ_DEQUANT:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tq_dequant(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TQ_DEQUANT_KV:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tq_dequant_kv(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TQ_FLASH_ATTN_EXT:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tq_flash_attn_ext(ctx, idx);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
||||
|
|
@ -4154,3 +4178,465 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
return 1;
|
||||
}
|
||||
|
||||
// ── TurboQuant Metal ops ───────────────────────────────────────────────────
|
||||
|
||||
int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const int headDim = (int)op->ne[0];
|
||||
const int numKVHeads = (int)op->ne[1];
|
||||
const int nCells = (int)op->ne[2];
|
||||
const int bits = (int)((const int32_t *)op->op_params)[0];
|
||||
const int firstCell = (int)((const int32_t *)op->op_params)[1];
|
||||
const int outlierBits = (int)((const int32_t *)op->op_params)[2];
|
||||
const int outlierCount = (int)((const int32_t *)op->op_params)[3];
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
// Outlier kernel uses a 128-thread TG: threadgroup barriers + atomics on
|
||||
// s_mask require all threads. Non-outlier kernel uses a single simdgroup
|
||||
// (32 threads): it only reads tiisg and has no barriers, so a larger TG
|
||||
// just replicates work across idle simdgroups.
|
||||
const int outlier_block_size = 128;
|
||||
const int nonoutlier_block_size = 32;
|
||||
|
||||
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
|
||||
const int regular_count = headDim - outlierCount;
|
||||
const int reg_packed_raw = (regular_count * bits + 7) / 8;
|
||||
const int reg_packed_bytes = (reg_packed_raw + 3) & ~3;
|
||||
const int out_packed_raw = (outlierCount * outlierBits + 7) / 8;
|
||||
const int out_packed_bytes = (out_packed_raw + 3) & ~3;
|
||||
|
||||
ggml_metal_kargs_tq_dequant_outlier args = {
|
||||
/*.headDim =*/ headDim,
|
||||
/*.numKVHeads =*/ numKVHeads,
|
||||
/*.bits =*/ bits,
|
||||
/*.firstCell =*/ firstCell,
|
||||
/*.reg_packed_bytes=*/ reg_packed_bytes,
|
||||
/*.outlier_bits =*/ outlierBits,
|
||||
/*.outlier_count =*/ outlierCount,
|
||||
/*.out_packed_bytes=*/ out_packed_bytes,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_tq_dequant_outlier(lib);
|
||||
|
||||
const int mask_words = (headDim + 31) >> 5;
|
||||
const size_t smem = GGML_PAD((size_t)headDim * sizeof(int8_t) + (size_t)mask_words * sizeof(uint32_t), 16);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // reg packed
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); // reg scales
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); // reg codebook
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); // outlier packed
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // outlier scales
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6); // outlier indices
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); // outlier codebook
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, outlier_block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int packed_bytes = (headDim * bits + 7) / 8;
|
||||
const int codebook_len = (int)op->src[2]->ne[0];
|
||||
|
||||
ggml_metal_kargs_tq_dequant args = {
|
||||
/*.headDim =*/ headDim,
|
||||
/*.numKVHeads =*/ numKVHeads,
|
||||
/*.bits =*/ bits,
|
||||
/*.firstCell =*/ firstCell,
|
||||
/*.packed_bytes =*/ packed_bytes,
|
||||
/*.codebook_len =*/ codebook_len,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_tq_dequant(lib);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // packed
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); // scales
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); // codebook
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, nonoutlier_block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_tq_dequant_kv(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const int headDim = (int)op->ne[0];
|
||||
const int numKVHeads = (int)op->ne[1];
|
||||
const int nCells = (int)op->ne[2];
|
||||
|
||||
int32_t k_bits, v_bits, firstCell;
|
||||
memcpy(&k_bits, (const int32_t *)op->op_params + 0, sizeof(int32_t));
|
||||
memcpy(&v_bits, (const int32_t *)op->op_params + 1, sizeof(int32_t));
|
||||
memcpy(&firstCell, (const int32_t *)op->op_params + 2, sizeof(int32_t));
|
||||
|
||||
const int k_packed_bytes = (headDim * k_bits + 7) / 8;
|
||||
const int v_packed_bytes = (headDim * v_bits + 7) / 8;
|
||||
const int k_codebook_len = (int)op->src[2]->ne[0];
|
||||
const int v_codebook_len = (int)op->src[5]->ne[0];
|
||||
|
||||
// kernel_tq_dequant is single-simdgroup (uses only tiisg, no barriers,
|
||||
// no atomics) — 32-thread TGs eliminate 4× redundant work vs 128-thread.
|
||||
const int block_size = 32;
|
||||
const size_t plane_size = (size_t)headDim * numKVHeads * nCells;
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
ggml_metal_buffer_id bid_dst_v = bid_dst;
|
||||
bid_dst_v.offs += plane_size * sizeof(uint16_t);
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_tq_dequant(lib);
|
||||
|
||||
// K dequant → first plane
|
||||
{
|
||||
ggml_metal_kargs_tq_dequant args_k = {
|
||||
/*.headDim =*/ headDim,
|
||||
/*.numKVHeads =*/ numKVHeads,
|
||||
/*.bits =*/ k_bits,
|
||||
/*.firstCell =*/ firstCell,
|
||||
/*.packed_bytes =*/ k_packed_bytes,
|
||||
/*.codebook_len =*/ k_codebook_len,
|
||||
};
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args_k, sizeof(args_k), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
|
||||
}
|
||||
|
||||
// V dequant → second plane
|
||||
{
|
||||
ggml_metal_kargs_tq_dequant args_v = {
|
||||
/*.headDim =*/ headDim,
|
||||
/*.numKVHeads =*/ numKVHeads,
|
||||
/*.bits =*/ v_bits,
|
||||
/*.firstCell =*/ firstCell,
|
||||
/*.packed_bytes =*/ v_packed_bytes,
|
||||
/*.codebook_len =*/ v_codebook_len,
|
||||
};
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args_v, sizeof(args_v), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst_v, 4);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int tq_encode_block_size(int headDim) {
|
||||
int bs = 1;
|
||||
int target = std::min(headDim, 128);
|
||||
while (bs < target) bs <<= 1;
|
||||
return bs;
|
||||
}
|
||||
|
||||
int ggml_metal_op_tq_encode(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const int headDim = (int)op->src[0]->ne[0];
|
||||
const int numKVHeads = (int)op->src[0]->ne[1];
|
||||
const int batchSize = (int)op->src[0]->ne[2];
|
||||
const int bits = (int)((const int32_t *)op->op_params)[0];
|
||||
const int firstCell = (int)((const int32_t *)op->op_params)[1];
|
||||
const int outlierBits = (int)((const int32_t *)op->op_params)[2];
|
||||
const int outlierCount = (int)((const int32_t *)op->op_params)[3];
|
||||
const int kIsF32 = (op->src[0]->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
|
||||
const int block_size = tq_encode_block_size(headDim);
|
||||
|
||||
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
|
||||
ggml_metal_kargs_tq_encode_outlier args = {
|
||||
/*.headDim =*/ headDim,
|
||||
/*.numKVHeads =*/ numKVHeads,
|
||||
/*.bits =*/ bits,
|
||||
/*.firstCell =*/ firstCell,
|
||||
/*.kIsF32 =*/ kIsF32,
|
||||
/*.outlierBits =*/ outlierBits,
|
||||
/*.outlierCount =*/ outlierCount,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_tq_encode_outlier(lib);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // k
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); // rotation
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); // packed_out (dst)
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); // scales
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // boundaries
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6); // outlier_packed
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); // outlier_scales
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[7]), 8); // outlier_indices
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[8]), 9); // outlier_boundaries
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, batchSize, numKVHeads, 1, block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int hasRotation = (op->src[1] != nullptr) ? 1 : 0;
|
||||
|
||||
ggml_metal_kargs_tq_encode args = {
|
||||
/*.headDim =*/ headDim,
|
||||
/*.numKVHeads =*/ numKVHeads,
|
||||
/*.bits =*/ bits,
|
||||
/*.firstCell =*/ firstCell,
|
||||
/*.kIsF32 =*/ kIsF32,
|
||||
/*.hasRotation =*/ hasRotation,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_tq_encode(lib);
|
||||
|
||||
ggml_metal_buffer_id bid_rot = hasRotation
|
||||
? ggml_metal_get_buffer_id(op->src[1])
|
||||
: ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // k
|
||||
ggml_metal_encoder_set_buffer (enc, bid_rot, 2); // rotation (or dummy)
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); // packed_out (dst)
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); // scales
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // boundaries
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, batchSize, numKVHeads, 1, block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_tq_encode_v(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const int headDim = (int)op->src[0]->ne[0];
|
||||
const int numKVHeads = (int)op->src[0]->ne[1];
|
||||
const int batchSize = (int)op->src[0]->ne[2];
|
||||
const int bits = (int)((const int32_t *)op->op_params)[0];
|
||||
const int firstCell = (int)((const int32_t *)op->op_params)[1];
|
||||
const int vIsF32 = (op->src[0]->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
const int hasRotation = (op->src[1] != nullptr) ? 1 : 0;
|
||||
|
||||
const int block_size = tq_encode_block_size(headDim);
|
||||
|
||||
ggml_metal_kargs_tq_encode args = {
|
||||
/*.headDim =*/ headDim,
|
||||
/*.numKVHeads =*/ numKVHeads,
|
||||
/*.bits =*/ bits,
|
||||
/*.firstCell =*/ firstCell,
|
||||
/*.kIsF32 =*/ vIsF32,
|
||||
/*.hasRotation =*/ hasRotation,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_tq_encode_v(lib);
|
||||
|
||||
ggml_metal_buffer_id bid_rot = hasRotation
|
||||
? ggml_metal_get_buffer_id(op->src[1])
|
||||
: ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // v
|
||||
ggml_metal_encoder_set_buffer (enc, bid_rot, 2); // rotation (or dummy)
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); // packed_out (dst)
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); // scales
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // boundaries
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, batchSize, numKVHeads, 1, block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_tq_encode_kv(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
// src layout: [0]=K, [1]=rotation, [2]=V, [3]=K_scales, [4]=K_bounds,
|
||||
// [5]=V_packed, [6]=V_scales, [7]=V_bounds
|
||||
const int headDim = (int)op->src[0]->ne[0];
|
||||
const int numKVHeads = (int)op->src[0]->ne[1];
|
||||
const int batchSize = (int)op->src[0]->ne[2];
|
||||
const int kIsF32 = (op->src[0]->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
const int vIsF32 = (op->src[2]->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
const int hasRotation = (op->src[1] != nullptr) ? 1 : 0;
|
||||
|
||||
int32_t k_bits, v_bits, firstCell;
|
||||
memcpy(&k_bits, (const int32_t *)op->op_params + 0, sizeof(int32_t));
|
||||
memcpy(&v_bits, (const int32_t *)op->op_params + 1, sizeof(int32_t));
|
||||
memcpy(&firstCell, (const int32_t *)op->op_params + 2, sizeof(int32_t));
|
||||
|
||||
const int block_size = tq_encode_block_size(headDim);
|
||||
|
||||
ggml_metal_buffer_id bid_rot = hasRotation
|
||||
? ggml_metal_get_buffer_id(op->src[1])
|
||||
: ggml_metal_get_buffer_id(op);
|
||||
|
||||
// K encode
|
||||
{
|
||||
ggml_metal_kargs_tq_encode args_k = {
|
||||
/*.headDim =*/ headDim,
|
||||
/*.numKVHeads =*/ numKVHeads,
|
||||
/*.bits =*/ k_bits,
|
||||
/*.firstCell =*/ firstCell,
|
||||
/*.kIsF32 =*/ kIsF32,
|
||||
/*.hasRotation =*/ hasRotation,
|
||||
};
|
||||
auto pipeline_k = ggml_metal_library_get_pipeline_tq_encode(lib);
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_k);
|
||||
ggml_metal_encoder_set_bytes (enc, &args_k, sizeof(args_k), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); // K
|
||||
ggml_metal_encoder_set_buffer (enc, bid_rot, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); // dst = K packed
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); // K scales
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // K bounds
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, batchSize, numKVHeads, 1, block_size, 1, 1);
|
||||
}
|
||||
|
||||
// V encode
|
||||
{
|
||||
ggml_metal_kargs_tq_encode args_v = {
|
||||
/*.headDim =*/ headDim,
|
||||
/*.numKVHeads =*/ numKVHeads,
|
||||
/*.bits =*/ v_bits,
|
||||
/*.firstCell =*/ firstCell,
|
||||
/*.kIsF32 =*/ vIsF32,
|
||||
/*.hasRotation =*/ hasRotation,
|
||||
};
|
||||
auto pipeline_v = ggml_metal_library_get_pipeline_tq_encode_v(lib);
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline_v);
|
||||
ggml_metal_encoder_set_bytes (enc, &args_v, sizeof(args_v), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 1); // V
|
||||
ggml_metal_encoder_set_buffer (enc, bid_rot, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 3); // V packed
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 4); // V scales
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[7]), 5); // V bounds
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, batchSize, numKVHeads, 1, block_size, 1, 1);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_tq_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const ggml_tensor * Q = op->src[0];
|
||||
const ggml_tensor * K_p = op->src[1];
|
||||
const ggml_tensor * V = op->src[2];
|
||||
const ggml_tensor * mask = op->src[3];
|
||||
// src[4]=K_scales, src[5]=K_codebook, src[6]=V_scales (NULL for K-only), src[7]=V_codebook
|
||||
const bool v_packed = (op->src[6] != nullptr);
|
||||
|
||||
float scale;
|
||||
float logit_softcap;
|
||||
int32_t bits;
|
||||
int32_t firstCell;
|
||||
int32_t v_bits;
|
||||
memcpy(&scale, (const float *)op->op_params + 0, sizeof(float));
|
||||
memcpy(&logit_softcap, (const float *)op->op_params + 1, sizeof(float));
|
||||
memcpy(&bits, (const int32_t *)op->op_params + 2, sizeof(int32_t));
|
||||
memcpy(&firstCell, (const int32_t *)op->op_params + 3, sizeof(int32_t));
|
||||
memcpy(&v_bits, (const int32_t *)op->op_params + 4, sizeof(int32_t));
|
||||
|
||||
if (logit_softcap != 0.0f) { scale /= logit_softcap; }
|
||||
|
||||
const int D = (int)Q->ne[0];
|
||||
const int nTokensQ = (int)Q->ne[1];
|
||||
const int nHeadsQ = (int)Q->ne[2];
|
||||
const int nSeq = (int)Q->ne[3];
|
||||
|
||||
const int packedBytes = (D * bits + 7) / 8;
|
||||
const int v_packedBytes = v_packed ? ((D * v_bits + 7) / 8) : 0;
|
||||
const int nKVHeads = v_packed ? ((int)V->ne[0] / v_packedBytes) : (int)V->ne[2];
|
||||
|
||||
int nCells;
|
||||
if (v_packed) {
|
||||
GGML_ASSERT(mask != nullptr);
|
||||
nCells = (int)mask->ne[0];
|
||||
} else {
|
||||
nCells = (int)V->ne[1];
|
||||
}
|
||||
|
||||
const int ncols = (nTokensQ == 1) ? 1 : 2;
|
||||
const int ntiles_x = (nTokensQ + ncols - 1) / ncols;
|
||||
const int hasMask = mask ? 1 : 0;
|
||||
const int ne31 = mask ? (int)mask->ne[0] : 0;
|
||||
|
||||
ggml_metal_kargs_tq_fattn_vec args = {
|
||||
/*.ncols =*/ ncols,
|
||||
/*.nTokensQ =*/ nTokensQ,
|
||||
/*.nHeadsQ =*/ nHeadsQ,
|
||||
/*.nSeq =*/ nSeq,
|
||||
/*.nCells =*/ nCells,
|
||||
/*.nKVHeads =*/ nKVHeads,
|
||||
/*.bits =*/ bits,
|
||||
/*.firstCell =*/ firstCell,
|
||||
/*.packedBytes =*/ packedBytes,
|
||||
/*.v_bits =*/ v_bits,
|
||||
/*.v_packedBytes=*/ v_packedBytes,
|
||||
/*.hasMask =*/ hasMask,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.scale =*/ scale,
|
||||
/*.logit_softcap=*/ logit_softcap,
|
||||
/*.nb01 =*/ Q->nb[1],
|
||||
/*.nb02 =*/ Q->nb[2],
|
||||
/*.nb03 =*/ Q->nb[3],
|
||||
/*.nb21 =*/ V->nb[1],
|
||||
/*.nb22 =*/ V->nb[2],
|
||||
/*.nb23 =*/ V->nb[3],
|
||||
/*.nb31 =*/ mask ? mask->nb[1] : 0,
|
||||
};
|
||||
|
||||
// Select D=128 vs D=256 pipeline. Gemma3 runs at headDim=256; everything
|
||||
// else supported so far is D=128.
|
||||
GGML_ASSERT(D == 128 || D == 256);
|
||||
auto pipeline = v_packed
|
||||
? (D == 256
|
||||
? ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(lib)
|
||||
: ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib))
|
||||
: (D == 256
|
||||
? ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256(lib)
|
||||
: ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib));
|
||||
|
||||
ggml_metal_buffer_id bid_mask = hasMask ? ggml_metal_get_buffer_id(mask) : ggml_metal_get_buffer_id(op);
|
||||
ggml_metal_buffer_id bid_v_scales = v_packed ? ggml_metal_get_buffer_id(op->src[6]) : ggml_metal_get_buffer_id(op);
|
||||
ggml_metal_buffer_id bid_v_codebook = v_packed ? ggml_metal_get_buffer_id(op->src[7]) : ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(Q), 1); // Q
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(K_p), 2); // K packed
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(V), 3); // V (f16 or packed i8)
|
||||
ggml_metal_encoder_set_buffer (enc, bid_mask, 4); // mask
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); // K scales
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6); // K codebook
|
||||
ggml_metal_encoder_set_buffer (enc, bid_v_scales, 7); // V scales (or dummy)
|
||||
ggml_metal_encoder_set_buffer (enc, bid_v_codebook, 8); // V codebook (or dummy)
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 9); // dst
|
||||
|
||||
// Block: (32, 4, 1) = 128 threads; Grid: (ntiles_x, 1, nHeadsQ*nSeq)
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ntiles_x, 1, nHeadsQ * nSeq, 32, 4, 1);
|
||||
return 1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,6 +89,14 @@ int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
// TurboQuant ops
|
||||
int ggml_metal_op_tq_encode (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_tq_encode_v (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_tq_encode_kv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_tq_dequant (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_tq_dequant_kv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_tq_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
1828
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
vendored
1828
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
vendored
File diff suppressed because it is too large
Load diff
239
ml/backend/ggml/ggml/src/ggml.c
vendored
239
ml/backend/ggml/ggml/src/ggml.c
vendored
|
|
@ -1048,9 +1048,15 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"OPT_STEP_SGD",
|
||||
|
||||
"GLU",
|
||||
"TQ_ENCODE",
|
||||
"TQ_DEQUANT",
|
||||
"TQ_DEQUANT_KV",
|
||||
"TQ_FLASH_ATTN_EXT",
|
||||
"TQ_ENCODE_V",
|
||||
"TQ_ENCODE_KV",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
|
@ -1157,9 +1163,14 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"sgd(x)",
|
||||
|
||||
"glu(x)",
|
||||
"tq_encode(k,rot,idx)->packed",
|
||||
"tq_dequant(packed,scales)->f16",
|
||||
"tq_flash_attn_ext(q,k_packed,v)->f32",
|
||||
"tq_encode_v(v)->packed",
|
||||
"tq_encode_kv(k,v)->packed",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
|
||||
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
|
@ -7603,3 +7614,227 @@ bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, cons
|
|||
if (p0->strict_cpu != p1->strict_cpu ) return false;
|
||||
return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_tq_encode(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * packed,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * rotation,
|
||||
int32_t firstCell,
|
||||
struct ggml_tensor * boundaries,
|
||||
int32_t bits) {
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, packed);
|
||||
result->op = GGML_OP_TQ_ENCODE;
|
||||
result->src[0] = k;
|
||||
result->src[1] = rotation;
|
||||
result->src[2] = NULL; // cell_idx removed; kernel computes firstCell + batch
|
||||
result->src[3] = scales;
|
||||
result->src[4] = boundaries;
|
||||
ggml_set_op_params_i32(result, 0, bits);
|
||||
ggml_set_op_params_i32(result, 1, firstCell);
|
||||
ggml_set_op_params_i32(result, 2, 0); // outlier_bits (0 = uniform)
|
||||
ggml_set_op_params_i32(result, 3, 0); // outlier_count (0 = uniform)
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_tq_dequant(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * encode_result,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * codebook,
|
||||
int headDim, int numKVHeads, int nCells, int firstCell, int bits) {
|
||||
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F16,
|
||||
headDim, numKVHeads, nCells);
|
||||
result->op = GGML_OP_TQ_DEQUANT;
|
||||
result->src[0] = encode_result;
|
||||
result->src[1] = scales;
|
||||
result->src[2] = codebook;
|
||||
ggml_set_op_params_i32(result, 0, (int32_t)bits);
|
||||
ggml_set_op_params_i32(result, 1, (int32_t)firstCell);
|
||||
ggml_set_op_params_i32(result, 2, 0); // outlier_bits (0 = uniform)
|
||||
ggml_set_op_params_i32(result, 3, 0); // outlier_count (0 = uniform)
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_tq_encode_outlier: extends ggml_tq_encode with an outlier sub-block.
|
||||
// Same op (GGML_OP_TQ_ENCODE) with outlier_count > 0 in op_params[3]; the
|
||||
// CUDA backend dispatches to the outlier-aware kernel when it sees a non-zero
|
||||
// outlier_count in op_params. The regular packed buffer is the dst (view);
|
||||
// the outlier packed, scales, and indices are written as side effects via
|
||||
// src[5..8], same pattern as the regular scales src[3].
|
||||
struct ggml_tensor * ggml_tq_encode_outlier(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * packed,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * rotation,
|
||||
int32_t firstCell,
|
||||
struct ggml_tensor * boundaries,
|
||||
int32_t bits,
|
||||
struct ggml_tensor * outlier_packed,
|
||||
struct ggml_tensor * outlier_scales,
|
||||
struct ggml_tensor * outlier_indices,
|
||||
struct ggml_tensor * outlier_boundaries,
|
||||
int32_t outlier_bits,
|
||||
int32_t outlier_count) {
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, packed);
|
||||
result->op = GGML_OP_TQ_ENCODE;
|
||||
result->src[0] = k;
|
||||
result->src[1] = rotation;
|
||||
result->src[2] = NULL;
|
||||
result->src[3] = scales;
|
||||
result->src[4] = boundaries;
|
||||
result->src[5] = outlier_packed;
|
||||
result->src[6] = outlier_scales;
|
||||
result->src[7] = outlier_indices;
|
||||
result->src[8] = outlier_boundaries;
|
||||
ggml_set_op_params_i32(result, 0, bits);
|
||||
ggml_set_op_params_i32(result, 1, firstCell);
|
||||
ggml_set_op_params_i32(result, 2, outlier_bits);
|
||||
ggml_set_op_params_i32(result, 3, outlier_count);
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_tq_dequant_outlier: extends ggml_tq_dequant with an outlier overwrite
|
||||
// pass. Reconstructs [headDim, numKVHeads, nCells] f16 by decoding the
|
||||
// regular sub-block for all 128 positions, then overwriting the outlier
|
||||
// channel positions from the outlier sub-block.
|
||||
struct ggml_tensor * ggml_tq_dequant_outlier(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * encode_result,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * codebook,
|
||||
int headDim, int numKVHeads, int nCells, int firstCell, int bits,
|
||||
struct ggml_tensor * outlier_packed,
|
||||
struct ggml_tensor * outlier_scales,
|
||||
struct ggml_tensor * outlier_indices,
|
||||
struct ggml_tensor * outlier_codebook,
|
||||
int32_t outlier_bits,
|
||||
int32_t outlier_count) {
|
||||
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F16,
|
||||
headDim, numKVHeads, nCells);
|
||||
result->op = GGML_OP_TQ_DEQUANT;
|
||||
result->src[0] = encode_result;
|
||||
result->src[1] = scales;
|
||||
result->src[2] = codebook;
|
||||
result->src[3] = outlier_packed;
|
||||
result->src[4] = outlier_scales;
|
||||
result->src[5] = outlier_indices;
|
||||
result->src[6] = outlier_codebook;
|
||||
ggml_set_op_params_i32(result, 0, (int32_t)bits);
|
||||
ggml_set_op_params_i32(result, 1, (int32_t)firstCell);
|
||||
ggml_set_op_params_i32(result, 2, outlier_bits);
|
||||
ggml_set_op_params_i32(result, 3, outlier_count);
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_tq_dequant_kv(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k_encode_result,
|
||||
struct ggml_tensor * k_scales,
|
||||
struct ggml_tensor * k_codebook,
|
||||
struct ggml_tensor * v_encode_result,
|
||||
struct ggml_tensor * v_scales,
|
||||
struct ggml_tensor * v_codebook,
|
||||
struct ggml_tensor * v_rotation,
|
||||
int headDim, int numKVHeads, int nCells, int firstCell,
|
||||
int k_bits, int v_bits) {
|
||||
// Output: [headDim, numKVHeads, nCells, 2] f16 — last dim separates K (0) and V (1).
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F16,
|
||||
headDim, numKVHeads, nCells, 2);
|
||||
result->op = GGML_OP_TQ_DEQUANT_KV;
|
||||
result->src[0] = k_encode_result;
|
||||
result->src[1] = k_scales;
|
||||
result->src[2] = k_codebook;
|
||||
result->src[3] = v_encode_result;
|
||||
result->src[4] = v_scales;
|
||||
result->src[5] = v_codebook;
|
||||
result->src[6] = v_rotation; // NULL = no rotation fusion
|
||||
ggml_set_op_params_i32(result, 0, (int32_t)k_bits);
|
||||
ggml_set_op_params_i32(result, 1, (int32_t)v_bits);
|
||||
ggml_set_op_params_i32(result, 2, (int32_t)firstCell);
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_tq_flash_attn_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * k_packed,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * mask,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * codebook,
|
||||
float scale, float logit_softcap,
|
||||
int32_t bits, int32_t firstCell,
|
||||
struct ggml_tensor * v_scales,
|
||||
struct ggml_tensor * v_codebook,
|
||||
int32_t v_bits) {
|
||||
// Output: [D, nHeadsQ, nTokensQ, nSeq] f32 — same shape as standard flash_attn_ext.
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32,
|
||||
q->ne[0], q->ne[2], q->ne[1], q->ne[3]);
|
||||
result->op = GGML_OP_TQ_FLASH_ATTN_EXT;
|
||||
result->src[0] = q;
|
||||
result->src[1] = k_packed;
|
||||
result->src[2] = v;
|
||||
result->src[3] = mask;
|
||||
result->src[4] = scales;
|
||||
result->src[5] = codebook;
|
||||
result->src[6] = v_scales; // NULL → V is f16 (K-only fused); non-NULL → V is packed
|
||||
result->src[7] = v_codebook; // NULL when V is f16
|
||||
ggml_set_op_params_f32(result, 0, scale);
|
||||
ggml_set_op_params_f32(result, 1, logit_softcap);
|
||||
ggml_set_op_params_i32(result, 2, bits);
|
||||
ggml_set_op_params_i32(result, 3, firstCell);
|
||||
ggml_set_op_params_i32(result, 4, v_bits);
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_tq_encode_v(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * packed,
|
||||
struct ggml_tensor * scales,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * rotation,
|
||||
int32_t firstCell,
|
||||
struct ggml_tensor * boundaries,
|
||||
int32_t bits) {
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, packed);
|
||||
result->op = GGML_OP_TQ_ENCODE_V;
|
||||
result->src[0] = v;
|
||||
result->src[1] = rotation; // NULL = no rotation, non-NULL = R^T matrix
|
||||
result->src[2] = NULL;
|
||||
result->src[3] = scales;
|
||||
result->src[4] = boundaries;
|
||||
ggml_set_op_params_i32(result, 0, bits);
|
||||
ggml_set_op_params_i32(result, 1, firstCell);
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_tq_encode_kv(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k_packed,
|
||||
struct ggml_tensor * k_scales,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * rotation,
|
||||
struct ggml_tensor * k_boundaries,
|
||||
struct ggml_tensor * v_packed,
|
||||
struct ggml_tensor * v_scales,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * v_boundaries,
|
||||
int32_t firstCell, int32_t k_bits, int32_t v_bits) {
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, k_packed);
|
||||
result->op = GGML_OP_TQ_ENCODE_KV;
|
||||
result->src[0] = k;
|
||||
result->src[1] = rotation;
|
||||
result->src[2] = v;
|
||||
result->src[3] = k_scales;
|
||||
result->src[4] = k_boundaries;
|
||||
result->src[5] = v_packed;
|
||||
result->src[6] = v_scales;
|
||||
result->src[7] = v_boundaries;
|
||||
ggml_set_op_params_i32(result, 0, k_bits);
|
||||
ggml_set_op_params_i32(result, 1, v_bits);
|
||||
ggml_set_op_params_i32(result, 2, firstCell);
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
70
ml/backend/ggml/tq_device_scan.go
Normal file
70
ml/backend/ggml/tq_device_scan.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package ggml
|
||||
|
||||
import "fmt"
|
||||
|
||||
// TurboQuant compute-capability thresholds.
|
||||
//
|
||||
// NVIDIA path: Pascal (cc 6.0) is the floor because the TQ codebook lookup
|
||||
// relies on __shfl_sync, introduced on Kepler and universally stable from
|
||||
// Pascal onwards; earlier archs (Maxwell and below) trip a compile-time
|
||||
// assert inside tq-dequant.cu.
|
||||
//
|
||||
// AMD path: ggml-cuda encodes the gfx arch into props.compute_major as
|
||||
// (cc - OFFSET_AMD) / 0x100 (see ml/backend/ggml/ggml/src/ggml-cuda/
|
||||
// ggml-cuda.cu), so Vega/GCN/CDNA land at major=9 (0x9XX) and RDNA1+ lands
|
||||
// at major>=16 (0x1010 and up). Every RDNA generation is wave32; every
|
||||
// pre-RDNA AMD generation is wave64. That single boundary drives the gate:
|
||||
// TQ's 32-lane __shfl_sync becomes __shfl(_, _, 32) under the HIP shim
|
||||
// (vendors/hip.h), and on a 64-lane warp that sub-partitions into two
|
||||
// independent 32-lane groups — lanes 32-63 never receive codebook data
|
||||
// from the CUDA-tuned kernel and return garbage.
|
||||
const (
|
||||
tqMinNvidiaComputeMajor = 6 // Pascal
|
||||
tqMinAmdComputeMajor = 16 // RDNA1 gfx1010
|
||||
)
|
||||
|
||||
// tqDeviceAccepted returns whether TurboQuant kernels can safely run on a
|
||||
// device identified by its backend library name and compute-capability
|
||||
// major. The check is intentionally narrow: only wave32 GPUs are admitted.
|
||||
// The TQ codebook lookup issues __shfl_sync(mask, val, lane, 32), which the
|
||||
// HIP vendor shim at ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h rewrites
|
||||
// as __shfl(val, lane, 32). Width is preserved, but on a wave64 warp the
|
||||
// width=32 partitions the 64 physical lanes into two independent 32-lane
|
||||
// sub-groups — and the CUDA-tuned kernel only seeds codebook data in the
|
||||
// first 32, so lanes 32-63 shuffle against uninitialized values and return
|
||||
// garbage. Rejecting all wave64 AMD (Vega/GCN/CDNA) cleanly sidesteps that.
|
||||
//
|
||||
// When rejected, the returned skipReason is non-empty and names the
|
||||
// limitation in operator-visible terms so the accompanying slog.Warn can
|
||||
// be read by somebody who doesn't have the ggml source tree open.
|
||||
func tqDeviceAccepted(library string, ccMajor int) (accepted bool, skipReason string) {
|
||||
switch library {
|
||||
case "CUDA":
|
||||
if ccMajor >= tqMinNvidiaComputeMajor {
|
||||
return true, ""
|
||||
}
|
||||
return false, fmt.Sprintf(
|
||||
"TurboQuant requires NVIDIA Pascal (cc 6.0+) or AMD RDNA1+ (gfx1010+, wave32); got CUDA cc major=%d",
|
||||
ccMajor,
|
||||
)
|
||||
case "ROCm":
|
||||
if ccMajor >= tqMinAmdComputeMajor {
|
||||
return true, ""
|
||||
}
|
||||
return false, fmt.Sprintf(
|
||||
"TurboQuant on ROCm requires RDNA1+ (gfx1010+); wave64 AMD GPUs (Vega/GCN/CDNA) are not supported "+
|
||||
"because TQ's 32-lane __shfl_sync sub-partitions the 64-lane warp and returns garbage from "+
|
||||
"lanes 32-63 (gfx major=%d)",
|
||||
ccMajor,
|
||||
)
|
||||
case "Metal":
|
||||
// Apple Silicon always has 32-wide SIMD groups — same as CUDA warp width.
|
||||
// TQ's __shfl_sync(mask, val, lane, 32) maps to simd_shuffle(val, lane) 1:1.
|
||||
return true, ""
|
||||
default:
|
||||
return false, fmt.Sprintf(
|
||||
"TurboQuant requires a CUDA, ROCm, or Metal backend library; got library=%q",
|
||||
library,
|
||||
)
|
||||
}
|
||||
}
|
||||
88
ml/backend/ggml/tq_device_scan_test.go
Normal file
88
ml/backend/ggml/tq_device_scan_test.go
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
package ggml
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestTQDeviceAccepted pins the library-aware compute-capability gate that
|
||||
// decides whether a given GPU can run TurboQuant kernels. The TQ codebook
|
||||
// lookup uses __shfl_sync with a 32-lane mask; the HIP vendor shim at
|
||||
// ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h silently lowers that to
|
||||
// __shfl with width=warpSize, which reads unused lanes (and thus returns
|
||||
// garbage) on wave64 GPUs. The gate's job is to keep wave64 AMD devices
|
||||
// (Vega/GCN/CDNA) out of the TQ path while still accepting all wave32
|
||||
// RDNA (gfx1010+) and unchanged NVIDIA Pascal+.
|
||||
//
|
||||
// This test runs in pure Go (no cgo, no GPU) so it executes in CI on every
|
||||
// platform. It's the primary regression gate for the classifier rule; the
|
||||
// scanTQDevices cgo plumbing and the slog.Warn copy rewrites are covered by
|
||||
// code review and runtime verification on real hardware.
|
||||
func TestTQDeviceAccepted(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
library string
|
||||
ccMajor int
|
||||
wantAccept bool
|
||||
wantReasonHas string // substring that must appear in skipReason when rejected
|
||||
}{
|
||||
// NVIDIA path: gate unchanged from the original TurboQuant PR (Pascal+).
|
||||
{"nvidia_pascal_p40", "CUDA", 6, true, ""},
|
||||
{"nvidia_turing", "CUDA", 7, true, ""},
|
||||
{"nvidia_ampere", "CUDA", 8, true, ""},
|
||||
{"nvidia_hopper", "CUDA", 9, true, ""},
|
||||
{"nvidia_maxwell", "CUDA", 5, false, "CUDA"},
|
||||
{"nvidia_kepler", "CUDA", 3, false, "CUDA"},
|
||||
{"nvidia_bogus_zero", "CUDA", 0, false, "CUDA"},
|
||||
|
||||
// AMD wave64 — must be rejected. props.compute_major = (cc - OFFSET_AMD) / 0x100,
|
||||
// so Vega/GCN/CDNA all land at major=9.
|
||||
{"amd_vega_gfx900", "ROCm", 9, false, "ROCm"},
|
||||
{"amd_vega20_gfx906", "ROCm", 9, false, "ROCm"},
|
||||
{"amd_cdna1_mi100", "ROCm", 9, false, "wave64"},
|
||||
{"amd_cdna2_mi210", "ROCm", 9, false, "Vega"},
|
||||
{"amd_cdna3_mi300", "ROCm", 9, false, "CDNA"},
|
||||
{"amd_gcn4_polaris", "ROCm", 8, false, "ROCm"},
|
||||
|
||||
// AMD wave32 RDNA — must be accepted. RDNA1 gfx1010 is the minimum.
|
||||
{"amd_rdna1_gfx1010", "ROCm", 16, true, ""},
|
||||
{"amd_rdna2_gfx1030", "ROCm", 16, true, ""},
|
||||
{"amd_rdna3_gfx1100", "ROCm", 17, true, ""},
|
||||
{"amd_rdna3_5_gfx1150", "ROCm", 17, true, ""},
|
||||
{"amd_rdna4_gfx1200", "ROCm", 18, true, ""},
|
||||
|
||||
// Metal is accepted — Apple Silicon SIMD groups are 32-wide, matching
|
||||
// the TQ kernels' __shfl_sync(mask, val, lane, 32) width.
|
||||
{"metal", "Metal", 7, true, ""},
|
||||
|
||||
// Non-CUDA/ROCm/Metal backends — reject with an informative reason.
|
||||
{"vulkan", "Vulkan", 7, false, "Vulkan"},
|
||||
{"sycl", "SYCL", 7, false, "SYCL"},
|
||||
{"empty_library", "", 7, false, "library"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gotAccept, gotReason := tqDeviceAccepted(tc.library, tc.ccMajor)
|
||||
if gotAccept != tc.wantAccept {
|
||||
t.Fatalf("tqDeviceAccepted(%q, %d) accept = %v, want %v (reason=%q)",
|
||||
tc.library, tc.ccMajor, gotAccept, tc.wantAccept, gotReason)
|
||||
}
|
||||
if tc.wantAccept {
|
||||
if gotReason != "" {
|
||||
t.Errorf("tqDeviceAccepted(%q, %d) accepted but reason=%q; expected empty",
|
||||
tc.library, tc.ccMajor, gotReason)
|
||||
}
|
||||
return
|
||||
}
|
||||
if gotReason == "" {
|
||||
t.Fatalf("tqDeviceAccepted(%q, %d) rejected with empty reason; operators need a diagnosable message",
|
||||
tc.library, tc.ccMajor)
|
||||
}
|
||||
if tc.wantReasonHas != "" && !strings.Contains(gotReason, tc.wantReasonHas) {
|
||||
t.Errorf("tqDeviceAccepted(%q, %d) reason = %q, want substring %q",
|
||||
tc.library, tc.ccMajor, gotReason, tc.wantReasonHas)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
184
ml/backend/ggml/tq_outlier_encode_test.go
Normal file
184
ml/backend/ggml/tq_outlier_encode_test.go
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
package ggml
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/turboquant"
|
||||
)
|
||||
|
||||
// TestOutlierEncodeDequantGPUCPUEquivalence runs tq_encode_kernel_outlier
|
||||
// + tq_dequant_multihead_kernel_outlier on the GPU for a synthetic K
|
||||
// batch and compares the decoded output against the CPU reference
|
||||
// (EncodeKeyPerHeadOutlier + DequantKeyPerHeadOutlier in
|
||||
// turboquant/encode.go). Catches kernel algorithmic drift and tests the
|
||||
// exact multi-KV-head layout that broke llama3.2:3b / qwen2.5:7b under
|
||||
// the __shfl_sync divergence bug.
|
||||
//
|
||||
// CURRENT LIMITATION: this test skips in CI / on plain `go test` runs
|
||||
// because setup()'s synthetic GGUF has no tensors, so no GPU buffer
|
||||
// types end up in b.schedBufts, so scanTQDevices() finds no TQ-capable
|
||||
// GPU even on a machine with a P40. The test as written is correct and
|
||||
// runnable under a test harness that loads a real model-backed backend
|
||||
// (e.g. the tq_outlier_encode_test.go:setup function could be replaced
|
||||
// with a helper that loads a tiny real .gguf with GPU-assigned layers).
|
||||
// Until that harness lands, the CPU reference tests in
|
||||
// turboquant/encode_test.go (TestOutlierSplitVsUniformHeavyTailed,
|
||||
// TestOutlierPerHeadRoundTrip) cover algorithmic correctness and the
|
||||
// full tqbench matrix covers real-model runtime verification.
|
||||
//
|
||||
// The test is kept in place because:
|
||||
// 1. It's the correct scaffolding for a future GPU unit-test harness.
|
||||
// 2. It documents the exact CPU↔GPU equivalence contract the kernels
|
||||
// must satisfy.
|
||||
// 3. Once schedBufts is populated (either by a better harness or by
|
||||
// a future ggml change), the test becomes a regression gate
|
||||
// automatically — no further wiring needed.
|
||||
func TestOutlierEncodeDequantGPUCPUEquivalence(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
headDim int
|
||||
numKVHeads int
|
||||
bits int
|
||||
outlierBits int
|
||||
outlierCount int
|
||||
preset turboquant.Preset
|
||||
}{
|
||||
{"d128_h8_tq3k", 128, 8, 3, 4, 32, turboquant.PresetTQ3K},
|
||||
{"d128_h4_tq3k", 128, 4, 3, 4, 32, turboquant.PresetTQ3K},
|
||||
{"d128_h1_tq3k", 128, 1, 3, 4, 32, turboquant.PresetTQ3K},
|
||||
{"d256_h1_tq3k", 256, 1, 3, 4, 32, turboquant.PresetTQ3K},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := setup(t)
|
||||
b := ctx.(*Context).b
|
||||
|
||||
mgrAny := b.NewTQCompressedKManager(
|
||||
tc.headDim, tc.numKVHeads, tc.bits,
|
||||
tc.preset.RotationSeed,
|
||||
0, // vBits (K-only)
|
||||
tc.outlierBits, tc.outlierCount,
|
||||
)
|
||||
if mgrAny == nil {
|
||||
t.Skip("no TQ-capable GPU available (need NVIDIA Pascal cc 6.0+ or AMD RDNA1+)")
|
||||
}
|
||||
mgr, ok := mgrAny.(*ggmlTQCompressedK)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected TQ manager type %T", mgrAny)
|
||||
}
|
||||
|
||||
const nCells = 4
|
||||
capacity := nCells
|
||||
mgr.EnsureLayer(0, capacity)
|
||||
|
||||
// Build a deterministic synthetic K batch: Gaussian noise
|
||||
// via the same splitmix64 as the CPU tests.
|
||||
kData := make([]float32, tc.headDim*tc.numKVHeads*nCells)
|
||||
var rngState uint64 = 0xface_feed_cafe_0000 | uint64(tc.headDim)
|
||||
rng := &rngState
|
||||
for i := range kData {
|
||||
kData[i] = float32(testGaussian(rng))
|
||||
}
|
||||
|
||||
// GPU path: create a K tensor, run EncodeK + DequantK, read
|
||||
// back the dequanted output.
|
||||
kTensor := ctx.FromFloats(kData, tc.headDim, tc.numKVHeads, nCells)
|
||||
encodeResult := mgr.EncodeK(ctx, 0, kTensor, 0)
|
||||
if encodeResult == nil {
|
||||
t.Fatalf("EncodeK returned nil")
|
||||
}
|
||||
ctx.Forward(encodeResult)
|
||||
|
||||
dequant := mgr.DequantK(ctx, 0, encodeResult, 0, nCells)
|
||||
if dequant == nil {
|
||||
t.Fatalf("DequantK returned nil")
|
||||
}
|
||||
ctx.Forward(dequant).Compute(dequant)
|
||||
|
||||
gpuOut := dequant.Floats()
|
||||
if len(gpuOut) != tc.headDim*tc.numKVHeads*nCells {
|
||||
t.Fatalf("gpu output len = %d, want %d",
|
||||
len(gpuOut), tc.headDim*tc.numKVHeads*nCells)
|
||||
}
|
||||
|
||||
// CPU reference: encode + dequant each (cell, head) slab
|
||||
// with the same preset, compute the expected rotated-space
|
||||
// output, and compare elementwise.
|
||||
//
|
||||
// Note: the CPU reference uses float64 blockScale while the
|
||||
// GPU kernel uses float32 reductions, so scales differ by
|
||||
// round-off at the 1e-6 level. The tolerance accommodates
|
||||
// that plus quantizer boundary ambiguity (when a rotated
|
||||
// value sits exactly on a Lloyd-Max boundary, float32 vs
|
||||
// float64 accumulation can pick adjacent codebook slots).
|
||||
const tol float32 = 5e-2 // loose; tight enough to catch algo bugs
|
||||
var maxErr float32
|
||||
var mismatches int
|
||||
for c := range nCells {
|
||||
for h := range tc.numKVHeads {
|
||||
slab := make([]float32, tc.headDim)
|
||||
for d := range tc.headDim {
|
||||
slab[d] = kData[(c*tc.numKVHeads+h)*tc.headDim+d]
|
||||
}
|
||||
|
||||
cpuEnc, err := turboquant.EncodeKeyPerHeadOutlier(slab, tc.preset)
|
||||
if err != nil {
|
||||
t.Fatalf("cpu encode: %v", err)
|
||||
}
|
||||
cpuOut := turboquant.DequantKeyPerHeadOutlier(cpuEnc, tc.preset, tc.headDim)
|
||||
|
||||
for d := range tc.headDim {
|
||||
gpuVal := gpuOut[(c*tc.numKVHeads+h)*tc.headDim+d]
|
||||
cpuVal := cpuOut[d]
|
||||
diff := gpuVal - cpuVal
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
if diff > maxErr {
|
||||
maxErr = diff
|
||||
}
|
||||
if diff > tol {
|
||||
mismatches++
|
||||
if mismatches <= 5 {
|
||||
t.Logf("mismatch c=%d h=%d d=%d gpu=%f cpu=%f diff=%f",
|
||||
c, h, d, gpuVal, cpuVal, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Logf("%s: max elementwise err = %f (%d mismatches > %.3f)", tc.name, maxErr, mismatches, tol)
|
||||
if mismatches > 0 {
|
||||
t.Fatalf("%s: %d elements differ beyond tolerance %.3f", tc.name, mismatches, tol)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testGaussian is a local Box-Muller generator that doesn't depend on
|
||||
// the unexported helpers in the turboquant package. Uses a tiny
|
||||
// splitmix64 inlined to avoid adding a test dep.
|
||||
func testGaussian(state *uint64) float64 {
|
||||
u1 := testUniform(state)
|
||||
u2 := testUniform(state)
|
||||
// Guard against log(0).
|
||||
if u1 < 1e-12 {
|
||||
u1 = 1e-12
|
||||
}
|
||||
return math.Sqrt(-2*math.Log(u1)) * math.Cos(2*math.Pi*u2)
|
||||
}
|
||||
|
||||
func testUniform(state *uint64) float64 {
|
||||
*state += 0x9e3779b97f4a7c15
|
||||
z := *state
|
||||
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9
|
||||
z = (z ^ (z >> 27)) * 0x94d049bb133111eb
|
||||
z = z ^ (z >> 31)
|
||||
return float64(z>>11) / float64(1<<53)
|
||||
}
|
||||
|
||||
// Compile-time check that ml.Tensor is what we expect.
|
||||
var _ ml.Tensor = (*Tensor)(nil)
|
||||
574
ml/backend/ggml/turboquant_compressed.go
Normal file
574
ml/backend/ggml/turboquant_compressed.go
Normal file
|
|
@ -0,0 +1,574 @@
|
|||
package ggml
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/turboquant"
|
||||
)
|
||||
|
||||
// ggmlTQCompressedK implements ml.TQCompressedKManager using ggml tensors and
|
||||
// GGML_OP_TQ_ENCODE / GGML_OP_TQ_DEQUANT graph ops. All buffers live in GPU
|
||||
// memory; no CPU round-trips occur during the forward pass.
|
||||
type ggmlTQCompressedK struct {
|
||||
backend *Backend
|
||||
headDim int
|
||||
numKVHeads int
|
||||
bits int
|
||||
|
||||
// Outlier-split config (post-rotation top-K channel split). When
|
||||
// outlierCount > 0, EnsureLayer allocates additional tensors for an
|
||||
// outlier sub-block encoded at outlierBits, and the encode/dequant
|
||||
// kernels follow the outlier-aware path. When 0, uses pure uniform
|
||||
// per-channel Lloyd-Max at `bits`.
|
||||
outlierBits int
|
||||
outlierCount int
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
// Per-layer ggml tensors, allocated lazily via EnsureLayer.
|
||||
layerCtxs map[int]ml.Context
|
||||
packedTensors map[int]*Tensor // regular sub-block: [regularPackedBytes*numKVHeads, capacity] i8
|
||||
scalesTensors map[int]*Tensor // regular scales: [numKVHeads, capacity] f32
|
||||
|
||||
// Outlier sub-block per-layer tensors (populated only when outlierCount > 0).
|
||||
outlierPackedTensors map[int]*Tensor // [outlierPackedBytes*numKVHeads, capacity] i8
|
||||
outlierScalesTensors map[int]*Tensor // [numKVHeads, capacity] f32
|
||||
outlierIndicesTensors map[int]*Tensor // [outlierCount*numKVHeads, capacity] i8 (channel idx)
|
||||
|
||||
// Rotation matrix R^T, shared across layers: [headDim, headDim] f32.
|
||||
rotCtx ml.Context
|
||||
rotTensor *Tensor // stores R^T row-major (used for K encode and Q rotate)
|
||||
|
||||
// Rotation matrix R (transpose of R^T), for undoing V rotation in SDPA.
|
||||
// mul_mat(rotInverseTensor, x) = R @ x (recovers original from R^T @ x).
|
||||
rotInverseTensor *Tensor // stores R row-major
|
||||
|
||||
// Codebook and boundaries tensors, shared across layers.
|
||||
sharedCtx ml.Context
|
||||
codebookTensor *Tensor // regular: [1<<bits] f32
|
||||
boundariesTensor *Tensor // regular: [(1<<bits)-1] f32
|
||||
|
||||
// Outlier codebook and boundaries (populated only when outlierCount > 0).
|
||||
outlierCodebookTensor *Tensor // [1<<outlierBits] f32
|
||||
outlierBoundariesTensor *Tensor // [(1<<outlierBits)-1] f32
|
||||
|
||||
vBits int
|
||||
|
||||
// Per-layer V tensors, allocated lazily via EnsureVLayer.
|
||||
vLayerCtxs map[int]ml.Context
|
||||
vPackedTensors map[int]*Tensor // [packedBytes*numKVHeads, capacity] i8
|
||||
vScalesTensors map[int]*Tensor // [numKVHeads, capacity] f32
|
||||
|
||||
// V codebook and boundaries (same bit width as K for tq2/tq3).
|
||||
vCodebookTensor *Tensor // [1<<vBits] f32
|
||||
vBoundariesTensor *Tensor // [(1<<vBits)-1] f32
|
||||
|
||||
// preferFusedAttention is true on Metal. The DequantKV → stock FA path
|
||||
// writes a full f16 intermediate buffer before attention, doubling KV
|
||||
// bandwidth vs reading packed data directly. On Metal at long context the
|
||||
// fused kernel (kernel_tq_fattn_vec_packed) is dramatically faster because
|
||||
// it reads packed K+V once and never materialises the f16 intermediate.
|
||||
// On CUDA, DequantKV + stock FA is faster because cuDNN/cuBLAS flash
|
||||
// attention is highly tuned and the intermediate buffer stays in L2.
|
||||
preferFusedAttention bool
|
||||
}
|
||||
|
||||
// hasOutliers reports whether outlier-split is active for this manager.
|
||||
func (m *ggmlTQCompressedK) hasOutliers() bool {
|
||||
return m.outlierCount > 0 && m.outlierBits > 0 && m.outlierCount < m.headDim
|
||||
}
|
||||
|
||||
// PreferFusedAttention reports whether the fused flash-attention path
|
||||
// (packed K+V decoded inline) should be tried before DequantKV + stock FA.
|
||||
// True on Metal: the DequantKV path writes a full f16 intermediate buffer that
|
||||
// doubles KV bandwidth at long context. False on CUDA/ROCm where DequantKV +
|
||||
// stock FA is faster due to large L2 caches and highly-tuned flash attention.
|
||||
func (m *ggmlTQCompressedK) PreferFusedAttention() bool {
|
||||
return m.preferFusedAttention
|
||||
}
|
||||
|
||||
// regularChannelCount is the number of non-outlier channels per head.
|
||||
func (m *ggmlTQCompressedK) regularChannelCount() int {
|
||||
if m.hasOutliers() {
|
||||
return m.headDim - m.outlierCount
|
||||
}
|
||||
return m.headDim
|
||||
}
|
||||
|
||||
// regularPackedBytes is the padded per-head byte count for the regular
|
||||
// sub-block. The encode kernel uses atomicOr on 4-byte words to pack bits;
|
||||
// for that to stay aligned, each head's region must start on a 4-byte
|
||||
// boundary. Round the raw bit-count up to the next multiple of 4 so the
|
||||
// per-head stride is naturally aligned. The padding bytes are never read
|
||||
// during decode, and are zeroed by the encode kernel's init loop.
|
||||
func (m *ggmlTQCompressedK) regularPackedBytes() int {
|
||||
raw := (m.regularChannelCount()*m.bits + 7) / 8
|
||||
return (raw + 3) &^ 3
|
||||
}
|
||||
|
||||
// outlierPackedBytes is the padded per-head byte count for the outlier
|
||||
// sub-block. Same 4-byte alignment as regularPackedBytes() for the same
|
||||
// reason: atomicOr-on-word in the encode kernel.
|
||||
func (m *ggmlTQCompressedK) outlierPackedBytes() int {
|
||||
if !m.hasOutliers() {
|
||||
return 0
|
||||
}
|
||||
raw := (m.outlierCount*m.outlierBits + 7) / 8
|
||||
return (raw + 3) &^ 3
|
||||
}
|
||||
|
||||
func (b *Backend) NewTQCompressedKManager(headDim, numKVHeads, bits int, rotationSeed uint64, vBits, outlierBits, outlierCount int) ml.TQCompressedKManager {
|
||||
// TurboQuant ops run on CUDA (NVIDIA Pascal+), ROCm/HIP (AMD RDNA1+,
|
||||
// gfx1010+), or Metal (Apple Silicon). The gate is wave32: the kernels
|
||||
// hard-code a 32-lane shuffle for codebook lookup. On wave64 AMD (Vega/
|
||||
// GCN/CDNA) the HIP shim's __shfl(…, 32) sub-partitions the 64-lane warp
|
||||
// and the upper 32 lanes return garbage — those are rejected. Metal SIMD
|
||||
// groups are always 32-wide on Apple Silicon, so Metal is unconditionally
|
||||
// admitted. Scan the scheduler buffer types, pick the first TQ-capable
|
||||
// GPU, and warn clearly if there's no suitable device.
|
||||
scan := b.scanTQDevices()
|
||||
if !scan.selectedOK {
|
||||
if len(scan.Skipped) > 0 {
|
||||
slog.Warn("turboquant: no TQ-capable GPU found; falling back to f16 KV cache. "+
|
||||
"TurboQuant requires NVIDIA Pascal (cc 6.0+), AMD RDNA1+ (gfx1010+, wave32), or Apple Silicon (Metal).",
|
||||
"skipped_gpus", scan.Skipped)
|
||||
} else {
|
||||
slog.Warn("turboquant: no GPU backend available, falling back to f16 KV cache")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if len(scan.Skipped) > 0 {
|
||||
slog.Warn("turboquant: skipping unsupported GPU(s); TQ tensors will be placed on the "+
|
||||
"first wave32 device (NVIDIA Pascal+, AMD RDNA1+, or Apple Silicon). To silence "+
|
||||
"this warning, hide the unsupported cards with CUDA_VISIBLE_DEVICES / HIP_VISIBLE_DEVICES.",
|
||||
"selected", scan.SelectedName+" (cc "+scan.SelectedCC+")",
|
||||
"skipped", scan.Skipped)
|
||||
}
|
||||
if len(scan.Accepted) > 1 {
|
||||
slog.Warn("turboquant: multi-GPU detected; TQ compressed buffers live on the "+
|
||||
"primary GPU only. Layers scheduled to other GPUs will incur per-step "+
|
||||
"cross-GPU transfers. On SWA models like gemma3/gemma4 this is per "+
|
||||
"TQ-wrapped global sub-cache — the SWA sub-cache stays on its native "+
|
||||
"GPU and is unaffected. To avoid: set num_gpu so the model fits on one "+
|
||||
"GPU, or use the tq2k/tq3k (K-only) presets which let V stay on its "+
|
||||
"native GPU.",
|
||||
"selected", scan.SelectedName+" (cc "+scan.SelectedCC+")",
|
||||
"tq_capable_gpus", scan.Accepted)
|
||||
}
|
||||
|
||||
// Codebook and boundaries (same for all layers). Use headDim for the
|
||||
// codebook dim parameter regardless of whether outlier split is active:
|
||||
// the CPU path does the same (scalarCodebook(dim=headDim, bits)), and
|
||||
// after per-sub-block RMS normalization the input distribution is
|
||||
// approximately unit-variance Gaussian either way. Using sub-block dims
|
||||
// here produced slightly different boundaries that caused observable
|
||||
// quality regressions on multi-head configurations.
|
||||
codebook := turboquant.ExportCodebook(headDim, bits)
|
||||
boundaries := turboquant.ExportBoundaries(headDim, bits)
|
||||
|
||||
// All shared tensors (codebook, rotation) must be GPU-resident: TQ ops are
|
||||
// CUDA-only. Using newTQContext() ensures GPU buffer type is used regardless
|
||||
// of which model layers are on CPU vs GPU.
|
||||
// Size shared context for up to 6 tensors (codebook+bounds ×2 regular, ×2
|
||||
// outlier, ×2 V); newTQContext takes a hint count.
|
||||
sharedCtx := b.newTQContext(8)
|
||||
codebookT := sharedCtx.FromFloats(codebook, len(codebook)).(*Tensor)
|
||||
boundariesT := sharedCtx.FromFloats(boundaries, len(boundaries)).(*Tensor)
|
||||
|
||||
// Outlier codebook/boundaries at a different bit width. Use headDim as
|
||||
// the codebook dim (same reason as regular codebook above).
|
||||
var outlierCodebookT, outlierBoundariesT *Tensor
|
||||
if outlierCount > 0 && outlierBits > 0 && outlierCount < headDim {
|
||||
oCodebook := turboquant.ExportCodebook(headDim, outlierBits)
|
||||
oBoundaries := turboquant.ExportBoundaries(headDim, outlierBits)
|
||||
outlierCodebookT = sharedCtx.FromFloats(oCodebook, len(oCodebook)).(*Tensor)
|
||||
outlierBoundariesT = sharedCtx.FromFloats(oBoundaries, len(oBoundaries)).(*Tensor)
|
||||
}
|
||||
|
||||
// V codebook and boundaries
|
||||
vCodebook := turboquant.ExportCodebook(headDim, vBits)
|
||||
vBoundaries := turboquant.ExportBoundaries(headDim, vBits)
|
||||
vCodebookT := sharedCtx.FromFloats(vCodebook, len(vCodebook)).(*Tensor)
|
||||
vBoundariesT := sharedCtx.FromFloats(vBoundaries, len(vBoundaries)).(*Tensor)
|
||||
|
||||
// Rotation matrix R^T: rotData[i*headDim+j] = R[j][i]. Built via
|
||||
// Householder QR on a random Gaussian matrix per TurboQuant paper (arXiv
|
||||
// 2504.19874) Algorithm 1.
|
||||
rot := turboquant.BuildRotation(headDim, rotationSeed)
|
||||
rotData := make([]float32, headDim*headDim)
|
||||
for i := range headDim {
|
||||
for j := range headDim {
|
||||
rotData[i*headDim+j] = rot.Matrix[j*headDim+i]
|
||||
}
|
||||
}
|
||||
rotInverseData := make([]float32, headDim*headDim)
|
||||
copy(rotInverseData, rot.Matrix)
|
||||
rotCtx := b.newTQContext(2)
|
||||
rotTensor := rotCtx.FromFloats(rotData, headDim, headDim).(*Tensor)
|
||||
rotInverseTensor := rotCtx.FromFloats(rotInverseData, headDim, headDim).(*Tensor)
|
||||
|
||||
m := &ggmlTQCompressedK{
|
||||
backend: b,
|
||||
headDim: headDim,
|
||||
numKVHeads: numKVHeads,
|
||||
bits: bits,
|
||||
outlierBits: outlierBits,
|
||||
outlierCount: outlierCount,
|
||||
layerCtxs: make(map[int]ml.Context),
|
||||
packedTensors: make(map[int]*Tensor),
|
||||
scalesTensors: make(map[int]*Tensor),
|
||||
outlierPackedTensors: make(map[int]*Tensor),
|
||||
outlierScalesTensors: make(map[int]*Tensor),
|
||||
outlierIndicesTensors: make(map[int]*Tensor),
|
||||
rotCtx: rotCtx,
|
||||
rotTensor: rotTensor,
|
||||
rotInverseTensor: rotInverseTensor,
|
||||
sharedCtx: sharedCtx,
|
||||
codebookTensor: codebookT,
|
||||
boundariesTensor: boundariesT,
|
||||
outlierCodebookTensor: outlierCodebookT,
|
||||
outlierBoundariesTensor: outlierBoundariesT,
|
||||
vBits: vBits,
|
||||
vLayerCtxs: make(map[int]ml.Context),
|
||||
vPackedTensors: make(map[int]*Tensor),
|
||||
vScalesTensors: make(map[int]*Tensor),
|
||||
vCodebookTensor: vCodebookT,
|
||||
vBoundariesTensor: vBoundariesT,
|
||||
preferFusedAttention: scan.SelectedLibrary == "Metal",
|
||||
}
|
||||
if m.hasOutliers() {
|
||||
slog.Info("turboquant: outlier split enabled",
|
||||
"outlier_bits", outlierBits, "outlier_count", outlierCount,
|
||||
"regular_bits", bits, "regular_channels", m.regularChannelCount(),
|
||||
"effective_bits", float32(outlierCount*outlierBits+m.regularChannelCount()*bits)/float32(headDim))
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// EnsureLayer allocates per-layer packed and scales tensors on first use.
|
||||
// When outlier split is active, also allocates the outlier sub-block tensors.
|
||||
func (m *ggmlTQCompressedK) EnsureLayer(layer, capacity int) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.packedTensors[layer]; ok {
|
||||
return
|
||||
}
|
||||
packedBytes := m.regularPackedBytes()
|
||||
|
||||
// Size the per-layer context hint by how many tensors we're allocating.
|
||||
// 2 (regular) + 3 (outlier) = 5 tensors when outlier split is on.
|
||||
ctxHint := 2
|
||||
if m.hasOutliers() {
|
||||
ctxHint = 5
|
||||
}
|
||||
// TQ tensors must always be GPU-resident; use newTQContext, not Layer(layer),
|
||||
// which would allocate CPU memory for layers assigned to CPU.
|
||||
ctx := m.backend.newTQContext(ctxHint)
|
||||
// Opt this layer's TQ persistent buffers into the scheduler's per-layer
|
||||
// Cache accounting so the (packed K, scales, optional outlier sub-block)
|
||||
// tensors below flow into btDeviceMemory.Cache[layer] via newTensor, not
|
||||
// the anonymous Graph bucket. Without this, scanTQDevices' scheduler can't
|
||||
// see TQ's real KV footprint and may mis-plan context-fit decisions.
|
||||
ctx.layer = layer
|
||||
// packed: interleaved as (cell*numKVHeads+head)*packedBytes — matches encode kernel layout.
|
||||
packed := ctx.Zeros(ml.DTypeI8, packedBytes*m.numKVHeads, capacity).(*Tensor)
|
||||
// scales: scales[cell*numKVHeads+head] — cell-major.
|
||||
scales := ctx.Zeros(ml.DTypeF32, m.numKVHeads, capacity).(*Tensor)
|
||||
|
||||
m.layerCtxs[layer] = ctx
|
||||
m.packedTensors[layer] = packed
|
||||
m.scalesTensors[layer] = scales
|
||||
|
||||
if m.hasOutliers() {
|
||||
oPackedBytes := m.outlierPackedBytes()
|
||||
m.outlierPackedTensors[layer] = ctx.Zeros(ml.DTypeI8, oPackedBytes*m.numKVHeads, capacity).(*Tensor)
|
||||
m.outlierScalesTensors[layer] = ctx.Zeros(ml.DTypeF32, m.numKVHeads, capacity).(*Tensor)
|
||||
m.outlierIndicesTensors[layer] = ctx.Zeros(ml.DTypeI8, m.outlierCount*m.numKVHeads, capacity).(*Tensor)
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeK creates a GGML_OP_TQ_ENCODE graph node.
|
||||
// EnsureLayer must have been called for this layer before EncodeK;
|
||||
// the forward pass is single-threaded per cache so the map reads below
|
||||
// race only with concurrent EnsureLayer calls, which the contract forbids.
|
||||
// firstCell is the index of the first cache slot being written
|
||||
// (cells are sequential: firstCell+0, firstCell+1, ...).
|
||||
// Returns a view of the packed buffer (use as encodeResult in DequantK).
|
||||
func (m *ggmlTQCompressedK) EncodeK(ctx ml.Context, layer int, key ml.Tensor, firstCell int) ml.Tensor {
|
||||
packed := m.packedTensors[layer]
|
||||
if packed == nil {
|
||||
return nil
|
||||
}
|
||||
scales := m.scalesTensors[layer]
|
||||
if m.hasOutliers() {
|
||||
if oPacked := m.outlierPackedTensors[layer]; oPacked != nil {
|
||||
return packed.TQEncodeOutlier(ctx, scales, key, m.rotTensor, firstCell, m.boundariesTensor, m.bits,
|
||||
oPacked, m.outlierScalesTensors[layer], m.outlierIndicesTensors[layer], m.outlierBoundariesTensor,
|
||||
m.outlierBits, m.outlierCount)
|
||||
}
|
||||
}
|
||||
return packed.TQEncode(ctx, scales, key, m.rotTensor, firstCell, m.boundariesTensor, m.bits)
|
||||
}
|
||||
|
||||
// DequantK creates a GGML_OP_TQ_DEQUANT graph node. encodeResult is the
|
||||
// view returned by EncodeK (establishes encode→dequant ordering). See
|
||||
// EncodeK for the single-threaded access contract.
|
||||
func (m *ggmlTQCompressedK) DequantK(ctx ml.Context, layer int, encodeResult ml.Tensor, firstCell, nCells int) ml.Tensor {
|
||||
scales := m.scalesTensors[layer]
|
||||
if scales == nil || encodeResult == nil || nCells <= 0 {
|
||||
return nil
|
||||
}
|
||||
if m.hasOutliers() {
|
||||
if oPacked := m.outlierPackedTensors[layer]; oPacked != nil {
|
||||
return encodeResult.(*Tensor).TQDequantOutlier(ctx, scales, m.codebookTensor,
|
||||
m.headDim, m.numKVHeads, nCells, firstCell, m.bits,
|
||||
oPacked, m.outlierScalesTensors[layer], m.outlierIndicesTensors[layer], m.outlierCodebookTensor,
|
||||
m.outlierBits, m.outlierCount)
|
||||
}
|
||||
}
|
||||
return encodeResult.(*Tensor).TQDequant(ctx, scales, m.codebookTensor,
|
||||
m.headDim, m.numKVHeads, nCells, firstCell, m.bits)
|
||||
}
|
||||
|
||||
// fusedKernelSupports reports whether the fused TQ flash-attention kernel
|
||||
// should be used.
|
||||
//
|
||||
// The fused path is the default for supported configurations (headDim=128,
|
||||
// bits=2 or 3). It decodes packed K and V bits inline during flash attention
|
||||
// as a fallback for configurations where DequantKV is unsupported. The
|
||||
// inline-decode path is slower than DequantKV + stock FA on all measured
|
||||
// hardware — DequantKV is always preferred when available.
|
||||
func (m *ggmlTQCompressedK) fusedKernelSupports() bool {
|
||||
// D=128 on all backends; D=256 only on Metal (kernel_tq_fattn_vec_*{,_d256}).
|
||||
// CUDA still has only the D=128 kernel, so gemma3 (D=256) stays off the
|
||||
// fused path on CUDA.
|
||||
switch m.headDim {
|
||||
case 128:
|
||||
case 256:
|
||||
if !m.preferFusedAttention {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
if m.bits != 2 && m.bits != 3 {
|
||||
return false
|
||||
}
|
||||
// Outlier split changes the packed layout: the fused inline-decode FA
|
||||
// kernel reads the packed buffer directly and doesn't know about outlier
|
||||
// sub-blocks. Route to path 5 (separate dequant + stock FA) when outliers
|
||||
// are active. Extending the fused kernel to handle outliers is deliberately
|
||||
// NOT done — the fused inline-decode path is already documented as 17.6x
|
||||
// slower than separate dequant + stock FA (feedback_cuda_kernel_optimization.md),
|
||||
// so adding more ALU work (outlier scan / popcount / dual codebook shuffles)
|
||||
// to that inner loop moves it further from the correct architecture.
|
||||
if m.hasOutliers() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GetAsTQTensor wraps the packed K buffer for the given layer as a tqTensor
|
||||
// so that ScaledDotProductAttention can dispatch to the fused kernel.
|
||||
// Returns (nil, false) when the fused path is not supported.
|
||||
func (m *ggmlTQCompressedK) GetAsTQTensor(ctx ml.Context, layer int, encodeResult ml.Tensor, firstCell, nCells int) (ml.Tensor, bool) {
|
||||
if !m.fusedKernelSupports() {
|
||||
return nil, false
|
||||
}
|
||||
scales := m.scalesTensors[layer]
|
||||
if scales == nil || encodeResult == nil || nCells <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
return &tqTensor{
|
||||
Tensor: encodeResult.(*Tensor),
|
||||
scales: scales,
|
||||
codebook: m.codebookTensor,
|
||||
bits: m.bits,
|
||||
headDim: m.headDim,
|
||||
nKVHeads: m.numKVHeads,
|
||||
nCells: nCells,
|
||||
firstCell: firstCell,
|
||||
}, true
|
||||
}
|
||||
|
||||
// GetAsTQTensorKV wraps both packed K and packed V buffers as a tqTensor for
|
||||
// the fully fused K+V TQ flash-attention path. Returns (nil, false) when
|
||||
// fused is not supported or V compression is not yet active for this layer.
|
||||
func (m *ggmlTQCompressedK) GetAsTQTensorKV(ctx ml.Context, layer int, kEncodeResult, vEncodeResult ml.Tensor, firstCell, nCells int) (ml.Tensor, bool) {
|
||||
if !m.fusedKernelSupports() {
|
||||
return nil, false
|
||||
}
|
||||
kScales := m.scalesTensors[layer]
|
||||
vScales := m.vScalesTensors[layer]
|
||||
if kScales == nil || kEncodeResult == nil || nCells <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
if vScales == nil || vEncodeResult == nil {
|
||||
return nil, false
|
||||
}
|
||||
return &tqTensor{
|
||||
Tensor: kEncodeResult.(*Tensor),
|
||||
scales: kScales,
|
||||
codebook: m.codebookTensor,
|
||||
bits: m.bits,
|
||||
headDim: m.headDim,
|
||||
nKVHeads: m.numKVHeads,
|
||||
nCells: nCells,
|
||||
firstCell: firstCell,
|
||||
vPacked: vEncodeResult.(*Tensor),
|
||||
vScales: vScales,
|
||||
vCodebook: m.vCodebookTensor,
|
||||
vBits: m.vBits,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (m *ggmlTQCompressedK) RotationMatrix(_ ml.Context, _ int) ml.Tensor {
|
||||
return m.rotTensor
|
||||
}
|
||||
|
||||
// RotationMatrixR returns R (not R^T) for use as the V rotation undo matrix.
|
||||
// mul_mat(R, R^T @ v) = v (recovers original from rotated V).
|
||||
func (m *ggmlTQCompressedK) RotationMatrixR() ml.Tensor {
|
||||
return m.rotInverseTensor
|
||||
}
|
||||
|
||||
// EnsureVLayer allocates per-layer V packed and scales tensors on first use.
|
||||
func (m *ggmlTQCompressedK) EnsureVLayer(layer, capacity int) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.vPackedTensors[layer]; ok {
|
||||
return
|
||||
}
|
||||
// 4-byte alignment — matches regularPackedBytes() so the scheduler and the
|
||||
// Go-side allocator agree on padded bytes per head. The encode/dequant
|
||||
// kernels read the raw bits; padding is never touched.
|
||||
raw := (m.headDim*m.vBits + 7) / 8
|
||||
packedBytes := (raw + 3) &^ 3
|
||||
|
||||
ctx := m.backend.newTQContext(2)
|
||||
// Same per-layer Cache accounting as EnsureLayer — see that comment for why.
|
||||
ctx.layer = layer
|
||||
packed := ctx.Zeros(ml.DTypeI8, packedBytes*m.numKVHeads, capacity).(*Tensor)
|
||||
scales := ctx.Zeros(ml.DTypeF32, m.numKVHeads, capacity).(*Tensor)
|
||||
|
||||
m.vLayerCtxs[layer] = ctx
|
||||
m.vPackedTensors[layer] = packed
|
||||
m.vScalesTensors[layer] = scales
|
||||
}
|
||||
|
||||
// EncodeV creates a GGML_OP_TQ_ENCODE_V graph node.
|
||||
// EnsureVLayer must have been called for this layer before EncodeV.
|
||||
func (m *ggmlTQCompressedK) EncodeV(ctx ml.Context, layer int, value ml.Tensor, firstCell int) ml.Tensor {
|
||||
packed := m.vPackedTensors[layer]
|
||||
if packed == nil {
|
||||
return nil
|
||||
}
|
||||
// Pass the K rotation matrix so V outlier energy spreads evenly before
|
||||
// quantization. SDPA's post-attention R @ output step undoes the rotation.
|
||||
return packed.TQEncodeV(ctx, m.vScalesTensors[layer], value, m.rotTensor, firstCell, m.vBoundariesTensor, m.vBits)
|
||||
}
|
||||
|
||||
// EncodeKV creates a single GGML_OP_TQ_ENCODE_KV graph node encoding both
|
||||
// K and V, halving scheduler overhead vs separate EncodeK + EncodeV.
|
||||
// Returns (kEncodeResult, vEncodeResult) — both reference the same op for
|
||||
// graph dependency tracking.
|
||||
//
|
||||
// When outlier-split is active, the combined encode kernel is not used
|
||||
// because it only understands the uniform packed layout. Falls back to
|
||||
// separate EncodeK (outlier-aware) + EncodeV (uniform) calls.
|
||||
func (m *ggmlTQCompressedK) EncodeKV(ctx ml.Context, layer int, key, value ml.Tensor, firstCell int) (ml.Tensor, ml.Tensor) {
|
||||
if m.hasOutliers() {
|
||||
return m.EncodeK(ctx, layer, key, firstCell), m.EncodeV(ctx, layer, value, firstCell)
|
||||
}
|
||||
kPacked := m.packedTensors[layer]
|
||||
vPacked := m.vPackedTensors[layer]
|
||||
if kPacked == nil || vPacked == nil {
|
||||
return nil, nil
|
||||
}
|
||||
kResult := kPacked.TQEncodeKV(ctx,
|
||||
m.scalesTensors[layer], key, m.rotTensor, m.boundariesTensor,
|
||||
vPacked, m.vScalesTensors[layer], value, m.vBoundariesTensor,
|
||||
firstCell, m.bits, m.vBits)
|
||||
|
||||
// kResult is the EncodeKV op output (K packed view); the scheduler uses it
|
||||
// to order DequantKV after EncodeKV. V packed buffer was written as a side
|
||||
// effect by the combined kernel. Return the V packed tensor directly —
|
||||
// DequantKV reads its data pointer (which now contains the encoded V).
|
||||
// Graph ordering is still correct: DequantKV depends on kResult (src[0]),
|
||||
// and both kernels run on the same CUDA stream.
|
||||
return kResult, vPacked
|
||||
}
|
||||
|
||||
// DequantV creates a GGML_OP_TQ_DEQUANT graph node for V.
|
||||
// encodeResult is the view returned by EncodeV (establishes encode→dequant ordering).
|
||||
func (m *ggmlTQCompressedK) DequantV(ctx ml.Context, layer int, encodeResult ml.Tensor, firstCell, nCells int) ml.Tensor {
|
||||
scales := m.vScalesTensors[layer]
|
||||
if scales == nil || encodeResult == nil || nCells <= 0 {
|
||||
return nil
|
||||
}
|
||||
return encodeResult.(*Tensor).TQDequant(ctx, scales, m.vCodebookTensor,
|
||||
m.headDim, m.numKVHeads, nCells, firstCell, m.vBits)
|
||||
}
|
||||
|
||||
// DequantKV creates a single GGML_OP_TQ_DEQUANT_KV graph node that dequants
|
||||
// both K and V in one op, halving scheduler overhead vs separate DequantK+DequantV.
|
||||
// Returns (kTensor, vTensor) as views into the combined output.
|
||||
//
|
||||
// When outlier-split is active, the combined kernel cannot be used because
|
||||
// its K reader assumes the uniform packed layout. Returns (nil, nil) to
|
||||
// force Get() to fall through to the separate DequantK + DequantV path.
|
||||
func (m *ggmlTQCompressedK) DequantKV(ctx ml.Context, layer int, kEncodeResult, vEncodeResult ml.Tensor, firstCell, nCells int) (ml.Tensor, ml.Tensor) {
|
||||
if m.hasOutliers() {
|
||||
return nil, nil
|
||||
}
|
||||
kScales := m.scalesTensors[layer]
|
||||
vScales := m.vScalesTensors[layer]
|
||||
if kScales == nil || kEncodeResult == nil || nCells <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if vScales == nil || vEncodeResult == nil {
|
||||
return nil, nil
|
||||
}
|
||||
combined := TQDequantKV(ctx, m.backend,
|
||||
kEncodeResult.(*Tensor), kScales, m.codebookTensor,
|
||||
vEncodeResult.(*Tensor), vScales, m.vCodebookTensor,
|
||||
m.rotInverseTensor, // R matrix for fused V rotation undo
|
||||
m.headDim, m.numKVHeads, nCells, firstCell, m.bits, m.vBits)
|
||||
|
||||
// Split the [headDim, numKVHeads, nCells, 2] output into K and V views.
|
||||
planeBytes := m.headDim * m.numKVHeads * nCells * 2 // f16 = 2 bytes
|
||||
kView := combined.View(ctx, 0, m.headDim, combined.Stride(1), m.numKVHeads, combined.Stride(2), nCells)
|
||||
vView := combined.View(ctx, planeBytes, m.headDim, combined.Stride(1), m.numKVHeads, combined.Stride(2), nCells)
|
||||
|
||||
return kView, vView
|
||||
}
|
||||
|
||||
func (m *ggmlTQCompressedK) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, ctx := range m.layerCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range m.vLayerCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
if m.rotCtx != nil {
|
||||
m.rotCtx.Close()
|
||||
}
|
||||
if m.sharedCtx != nil {
|
||||
m.sharedCtx.Close()
|
||||
}
|
||||
m.packedTensors = nil
|
||||
m.scalesTensors = nil
|
||||
m.layerCtxs = nil
|
||||
m.vPackedTensors = nil
|
||||
m.vScalesTensors = nil
|
||||
m.vLayerCtxs = nil
|
||||
m.rotCtx = nil
|
||||
m.sharedCtx = nil
|
||||
}
|
||||
96
ml/backend/ggml/turboquant_fattn.go
Normal file
96
ml/backend/ggml/turboquant_fattn.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
package ggml
|
||||
|
||||
// #include "ggml/include/ggml.h"
|
||||
import "C"
|
||||
|
||||
import ml "github.com/ollama/ollama/ml"
|
||||
|
||||
// tqTensor wraps a packed-K buffer with the metadata needed for the fused TQ
|
||||
// flash attention kernel. The Tensor field holds the encode result (a view of
|
||||
// the persistent packed-K buffer).
|
||||
//
|
||||
// When vPacked is non-nil, the K+V fused kernel is used: V is decoded inline
|
||||
// from vPacked, bypassing the separate TQ_DEQUANT op for V.
|
||||
type tqTensor struct {
|
||||
*Tensor // packed K view ([packedBytes*nKVHeads, capacity] i8; encode result)
|
||||
scales *Tensor // K scales [nKVHeads, capacity] f32
|
||||
codebook *Tensor // K codebook [1<<bits] f32
|
||||
bits int
|
||||
headDim int
|
||||
nKVHeads int
|
||||
nCells int
|
||||
firstCell int
|
||||
// V packed fields (nil = V is f16 from inner cache; non-nil = K+V fused)
|
||||
vPacked *Tensor // packed V view [v_packedBytes*nKVHeads, capacity] i8
|
||||
vScales *Tensor // V scales [nKVHeads, capacity] f32
|
||||
vCodebook *Tensor // V codebook [1<<vBits] f32
|
||||
vBits int
|
||||
}
|
||||
|
||||
// Permute propagates the tqTensor wrapper through the key permutation that
|
||||
// ScaledDotProductAttention applies before the flash-attention dispatch.
|
||||
// The packed-K layout is custom (not standard ggml strides), so we preserve
|
||||
// the wrapper metadata and let the CUDA kernel ignore the permuted strides.
|
||||
func (t *tqTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
return &tqTensor{
|
||||
Tensor: t.Tensor.Permute(ctx, shape...).(*Tensor),
|
||||
scales: t.scales,
|
||||
codebook: t.codebook,
|
||||
bits: t.bits,
|
||||
headDim: t.headDim,
|
||||
nKVHeads: t.nKVHeads,
|
||||
nCells: t.nCells,
|
||||
firstCell: t.firstCell,
|
||||
vPacked: t.vPacked,
|
||||
vScales: t.vScales,
|
||||
vCodebook: t.vCodebook,
|
||||
vBits: t.vBits,
|
||||
}
|
||||
}
|
||||
|
||||
// TQFlashAttention creates a GGML_OP_TQ_FLASH_ATTN_EXT graph node.
|
||||
// query: permuted+rotated Q [D, nTokensQ, nHeadsQ, nSeq] f32
|
||||
// tqk: TQ packed-K wrapper (may also carry V packed fields for K+V fused)
|
||||
// value: permuted f16 V for K-only fused, OR packed i8 V for K+V fused
|
||||
func (b *Backend) tqFlashAttention(
|
||||
ctx ml.Context,
|
||||
query *Tensor,
|
||||
tqk *tqTensor,
|
||||
value *Tensor,
|
||||
mask ml.Tensor,
|
||||
scale float64,
|
||||
logitSoftcap float64,
|
||||
) ml.Tensor {
|
||||
var maskT *C.struct_ggml_tensor
|
||||
if mask != nil {
|
||||
maskT = mask.(*Tensor).t
|
||||
}
|
||||
|
||||
// K+V fused: pass V packed tensors to the C API.
|
||||
// K-only fused: pass NULL for v_scales (backward compat).
|
||||
var vScalesT, vCodebookT *C.struct_ggml_tensor
|
||||
vBits := C.int32_t(0)
|
||||
if tqk.vPacked != nil {
|
||||
vScalesT = tqk.vScales.t
|
||||
vCodebookT = tqk.vCodebook.t
|
||||
vBits = C.int32_t(tqk.vBits)
|
||||
}
|
||||
|
||||
t := C.ggml_tq_flash_attn_ext(
|
||||
ctx.(*Context).ctx,
|
||||
query.t,
|
||||
tqk.Tensor.t,
|
||||
value.t,
|
||||
maskT,
|
||||
tqk.scales.t,
|
||||
tqk.codebook.t,
|
||||
C.float(scale),
|
||||
C.float(logitSoftcap),
|
||||
C.int32_t(tqk.bits),
|
||||
C.int32_t(tqk.firstCell),
|
||||
vScalesT,
|
||||
vCodebookT,
|
||||
vBits,
|
||||
)
|
||||
return &Tensor{b: b, t: t}
|
||||
}
|
||||
|
|
@ -38,6 +38,7 @@ type Model interface {
|
|||
|
||||
Backend() ml.Backend
|
||||
Config() config
|
||||
SetCache(kvcache.Cache)
|
||||
}
|
||||
|
||||
// Validator is an optional interface that models can implement to perform
|
||||
|
|
@ -104,6 +105,12 @@ func (m *Base) Config() config {
|
|||
return m.config
|
||||
}
|
||||
|
||||
// SetCache replaces the model's cache. Used by TurboQuant to wrap the
|
||||
// Causal cache with compression.
|
||||
func (m *Base) SetCache(cache kvcache.Cache) {
|
||||
m.config.Cache = cache
|
||||
}
|
||||
|
||||
var models = make(map[string]func(fs.Config) (Model, error))
|
||||
|
||||
// Register registers a model constructor for the given architecture
|
||||
|
|
|
|||
|
|
@ -249,8 +249,8 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
if cc, ok := wc.UnderlyingCache().(kvcache.CausalConfigurable); ok {
|
||||
cc.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -198,8 +198,8 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
if cc, ok := wc.UnderlyingCache().(kvcache.CausalConfigurable); ok {
|
||||
cc.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,33 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
|||
|
||||
cache := model.Config().Cache
|
||||
if cache != nil {
|
||||
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
|
||||
dtype := kvCacheTypeFromStr(kvCacheType)
|
||||
if preset, ok := kvcache.PresetFromDType(dtype); ok {
|
||||
wrapped, active := kvcache.WrapWithTurboQuant(cache, preset)
|
||||
if active {
|
||||
cache = wrapped
|
||||
// Force f16 at Init. TurboQuantCache manages its preset
|
||||
// internally and always passes f16 to its inner *Causal; any
|
||||
// sibling sub-caches in a WrapperCache (e.g. the SWA side of
|
||||
// gemma3/gemma4) need f16 too since they can't allocate the
|
||||
// ggml-unknown TQ dtypes.
|
||||
dtype = ml.DTypeF16
|
||||
// For the top-level *Causal case, WrapWithTurboQuant returns
|
||||
// a new *TurboQuantCache and the model must be re-pointed at
|
||||
// it. For the *WrapperCache case, the same pointer is returned
|
||||
// (mutated in place) and no SetCache is needed.
|
||||
if tqc, ok := wrapped.(*kvcache.TurboQuantCache); ok {
|
||||
model.SetCache(tqc)
|
||||
}
|
||||
slog.Info("using turboquant kv cache", "preset", preset.Name)
|
||||
} else {
|
||||
// Non-wrappable cache (recurrent, SWA-only, etc.); fall back
|
||||
// to f16 so Init doesn't receive an unmapped dtype and panic.
|
||||
dtype = ml.DTypeF16
|
||||
slog.Warn("turboquant requested but cache is not wrappable, falling back to f16")
|
||||
}
|
||||
}
|
||||
cache.Init(model.Backend(), dtype, numSlots, int(numCtx), batchSize)
|
||||
}
|
||||
|
||||
return &InputCache{
|
||||
|
|
@ -64,6 +90,14 @@ func kvCacheTypeFromStr(s string) ml.DType {
|
|||
return ml.DTypeQ80
|
||||
case "q4_0":
|
||||
return ml.DTypeQ40
|
||||
case "tq2":
|
||||
return ml.DTypeTQ2
|
||||
case "tq3":
|
||||
return ml.DTypeTQ3
|
||||
case "tq3k":
|
||||
return ml.DTypeTQ3K
|
||||
case "tq2k":
|
||||
return ml.DTypeTQ2K
|
||||
default:
|
||||
return ml.DTypeF16
|
||||
}
|
||||
|
|
|
|||
178
turboquant/block.go
Normal file
178
turboquant/block.go
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Block struct {
|
||||
Version uint8
|
||||
PresetID uint8
|
||||
Role uint8
|
||||
Objective uint8
|
||||
OriginalDim uint16
|
||||
PaddedDim uint16
|
||||
BlockDim uint16
|
||||
RegularBits uint8
|
||||
RotationSeed uint64
|
||||
CodebookID uint16
|
||||
QJLRows uint16
|
||||
AuxLayoutID uint8
|
||||
ChannelIndices []uint16
|
||||
Scale float32
|
||||
RegularIndices []byte
|
||||
Residual ResidualSketch
|
||||
}
|
||||
|
||||
func (b Block) MarshalBinary() ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
fields := []any{
|
||||
b.Version,
|
||||
b.PresetID,
|
||||
b.Role,
|
||||
b.Objective,
|
||||
b.OriginalDim,
|
||||
b.PaddedDim,
|
||||
b.BlockDim,
|
||||
b.RegularBits,
|
||||
b.RotationSeed,
|
||||
b.CodebookID,
|
||||
b.QJLRows,
|
||||
b.AuxLayoutID,
|
||||
uint16(len(b.ChannelIndices)),
|
||||
b.Scale,
|
||||
uint32(len(b.RegularIndices)),
|
||||
b.Residual.Seed,
|
||||
b.Residual.Scale,
|
||||
b.Residual.SketchDim,
|
||||
uint32(len(b.Residual.Signs)),
|
||||
}
|
||||
for _, field := range fields {
|
||||
if err := binary.Write(&buf, binary.LittleEndian, field); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for _, idx := range b.ChannelIndices {
|
||||
if err := binary.Write(&buf, binary.LittleEndian, idx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
buf.Write(b.RegularIndices)
|
||||
buf.Write(b.Residual.Signs)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (b *Block) UnmarshalBinary(data []byte) error {
|
||||
r := bytes.NewReader(data)
|
||||
var channelCount uint16
|
||||
var regularLen, residualLen uint32
|
||||
fields := []any{
|
||||
&b.Version,
|
||||
&b.PresetID,
|
||||
&b.Role,
|
||||
&b.Objective,
|
||||
&b.OriginalDim,
|
||||
&b.PaddedDim,
|
||||
&b.BlockDim,
|
||||
&b.RegularBits,
|
||||
&b.RotationSeed,
|
||||
&b.CodebookID,
|
||||
&b.QJLRows,
|
||||
&b.AuxLayoutID,
|
||||
&channelCount,
|
||||
&b.Scale,
|
||||
®ularLen,
|
||||
&b.Residual.Seed,
|
||||
&b.Residual.Scale,
|
||||
&b.Residual.SketchDim,
|
||||
&residualLen,
|
||||
}
|
||||
for _, field := range fields {
|
||||
if err := binary.Read(r, binary.LittleEndian, field); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if b.Version != BlockVersion {
|
||||
return fmt.Errorf("unsupported block version %d", b.Version)
|
||||
}
|
||||
if b.Objective != uint8(objectiveMSE) && b.Objective != uint8(objectiveProduct) {
|
||||
return fmt.Errorf("unsupported block objective %d", b.Objective)
|
||||
}
|
||||
if b.OriginalDim == 0 || b.OriginalDim > b.PaddedDim || b.BlockDim != b.PaddedDim {
|
||||
return fmt.Errorf("invalid block dims: original=%d padded=%d block=%d", b.OriginalDim, b.PaddedDim, b.BlockDim)
|
||||
}
|
||||
|
||||
if channelCount > 0 {
|
||||
b.ChannelIndices = make([]uint16, channelCount)
|
||||
for i := range b.ChannelIndices {
|
||||
if err := binary.Read(r, binary.LittleEndian, &b.ChannelIndices[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.RegularIndices = make([]byte, regularLen)
|
||||
b.Residual.Signs = make([]byte, residualLen)
|
||||
for _, dst := range [][]byte{b.RegularIndices, b.Residual.Signs} {
|
||||
if _, err := io.ReadFull(r, dst); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.Len() != 0 {
|
||||
return fmt.Errorf("unexpected trailing bytes in turboquant block: %d", r.Len())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func packBits(values []uint8, bitsPerValue int) []byte {
|
||||
if bitsPerValue <= 0 || len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
totalBits := len(values) * bitsPerValue
|
||||
out := make([]byte, (totalBits+7)/8)
|
||||
mask := uint8((1 << bitsPerValue) - 1)
|
||||
bitPos := 0
|
||||
for _, value := range values {
|
||||
packed := value & mask
|
||||
bytePos := bitPos / 8
|
||||
shift := bitPos % 8
|
||||
out[bytePos] |= packed << shift
|
||||
if shift+bitsPerValue > 8 {
|
||||
out[bytePos+1] |= packed >> (8 - shift)
|
||||
}
|
||||
bitPos += bitsPerValue
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func unpackBits(data []byte, bitsPerValue, count int) []uint8 {
|
||||
out := make([]uint8, count)
|
||||
if bitsPerValue <= 0 {
|
||||
return out
|
||||
}
|
||||
mask := uint8((1 << bitsPerValue) - 1)
|
||||
bitPos := 0
|
||||
for i := range count {
|
||||
bytePos := bitPos / 8
|
||||
shift := bitPos % 8
|
||||
if bytePos >= len(data) {
|
||||
return out
|
||||
}
|
||||
value := data[bytePos] >> shift
|
||||
if shift+bitsPerValue > 8 && bytePos+1 < len(data) {
|
||||
value |= data[bytePos+1] << (8 - shift)
|
||||
}
|
||||
out[i] = value & mask
|
||||
bitPos += bitsPerValue
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func expectedPackedBytes(count, bitsPerValue int) int {
|
||||
if count <= 0 || bitsPerValue <= 0 {
|
||||
return 0
|
||||
}
|
||||
return ((count * bitsPerValue) + 7) / 8
|
||||
}
|
||||
185
turboquant/block_test.go
Normal file
185
turboquant/block_test.go
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
package turboquant
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBlockMarshalRoundTrip(t *testing.T) {
|
||||
block := Block{
|
||||
Version: BlockVersion,
|
||||
PresetID: PresetTQ3.ID,
|
||||
Role: uint8(roleKey),
|
||||
Objective: uint8(objectiveProduct),
|
||||
OriginalDim: 8,
|
||||
PaddedDim: 8,
|
||||
BlockDim: 8,
|
||||
RegularBits: 3,
|
||||
RotationSeed: 77,
|
||||
CodebookID: 3,
|
||||
QJLRows: 4,
|
||||
AuxLayoutID: 1,
|
||||
Scale: 1,
|
||||
RegularIndices: []byte{1, 2, 3},
|
||||
Residual: ResidualSketch{
|
||||
Seed: 88,
|
||||
Scale: 0.5,
|
||||
SketchDim: 4,
|
||||
Signs: []byte{0x0f},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := block.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var decoded Block
|
||||
if err := decoded.UnmarshalBinary(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if decoded.Version != block.Version ||
|
||||
decoded.PresetID != block.PresetID ||
|
||||
decoded.Role != block.Role ||
|
||||
decoded.Objective != block.Objective ||
|
||||
decoded.OriginalDim != block.OriginalDim ||
|
||||
decoded.PaddedDim != block.PaddedDim ||
|
||||
decoded.BlockDim != block.BlockDim ||
|
||||
decoded.RegularBits != block.RegularBits ||
|
||||
decoded.RotationSeed != block.RotationSeed ||
|
||||
decoded.CodebookID != block.CodebookID ||
|
||||
decoded.QJLRows != block.QJLRows ||
|
||||
decoded.AuxLayoutID != block.AuxLayoutID ||
|
||||
string(decoded.RegularIndices) != string(block.RegularIndices) ||
|
||||
decoded.Residual.Seed != block.Residual.Seed ||
|
||||
decoded.Residual.Scale != block.Residual.Scale ||
|
||||
decoded.Residual.SketchDim != block.Residual.SketchDim ||
|
||||
string(decoded.Residual.Signs) != string(block.Residual.Signs) {
|
||||
t.Fatalf("decoded block mismatch: %+v", decoded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlockUnmarshalRejectsBadVersion(t *testing.T) {
|
||||
block := Block{Version: BlockVersion, PresetID: PresetTQ2.ID, Role: uint8(roleValue), Objective: uint8(objectiveMSE), OriginalDim: 4, PaddedDim: 4, BlockDim: 4, RegularBits: 2}
|
||||
data, err := block.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data[0] = 99
|
||||
|
||||
var decoded Block
|
||||
if err := decoded.UnmarshalBinary(data); err == nil {
|
||||
t.Fatal("expected unsupported block version error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackBitsRoundTripMixedWidths(t *testing.T) {
|
||||
values2 := []uint8{1, 3, 0, 2}
|
||||
roundTrip2 := unpackBits(packBits(values2, 2), 2, len(values2))
|
||||
for i := range values2 {
|
||||
if roundTrip2[i] != values2[i] {
|
||||
t.Fatalf("2-bit round trip mismatch at %d: got %d want %d", i, roundTrip2[i], values2[i])
|
||||
}
|
||||
}
|
||||
|
||||
values3 := []uint8{3, 7, 1, 5}
|
||||
roundTrip3 := unpackBits(packBits(values3, 3), 3, len(values3))
|
||||
for i := range values3 {
|
||||
if roundTrip3[i] != values3[i] {
|
||||
t.Fatalf("3-bit round trip mismatch at %d: got %d want %d", i, roundTrip3[i], values3[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeOutlierSplitLayout verifies that vectors larger than OutlierCount
|
||||
// are encoded as two blocks (outlier + regular) with the expected bit widths and
|
||||
// ChannelIndices, and that vectors at or below OutlierCount stay as a single block.
|
||||
//
|
||||
// Uses an explicit outlier-enabled preset because PresetTQ3's shipped default
|
||||
// is OutlierCount=0 (outlier split is opt-in infrastructure; see PresetTQ3
|
||||
// comment in turboquant.go).
|
||||
func TestEncodeOutlierSplitLayout(t *testing.T) {
|
||||
preset := testOutlierPreset(PresetTQ3, 32)
|
||||
|
||||
// dim=70 > OutlierCount=32: two blocks expected.
|
||||
encoded, err := EncodeVector(pseudoRandomVector(70, 0x55), preset)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(encoded.Blocks) != 2 {
|
||||
t.Fatalf("block count = %d, want 2", len(encoded.Blocks))
|
||||
}
|
||||
|
||||
outlierBlock := encoded.Blocks[0]
|
||||
regularBlock := encoded.Blocks[1]
|
||||
|
||||
if int(outlierBlock.OriginalDim) != preset.OutlierCount {
|
||||
t.Errorf("outlier block dim = %d, want %d", outlierBlock.OriginalDim, preset.OutlierCount)
|
||||
}
|
||||
if outlierBlock.RegularBits != uint8(preset.OutlierBits) {
|
||||
t.Errorf("outlier bits = %d, want %d", outlierBlock.RegularBits, preset.OutlierBits)
|
||||
}
|
||||
if len(outlierBlock.ChannelIndices) != preset.OutlierCount {
|
||||
t.Errorf("outlier ChannelIndices len = %d, want %d", len(outlierBlock.ChannelIndices), preset.OutlierCount)
|
||||
}
|
||||
|
||||
wantRegularDim := 70 - preset.OutlierCount
|
||||
if int(regularBlock.OriginalDim) != wantRegularDim {
|
||||
t.Errorf("regular block dim = %d, want %d", regularBlock.OriginalDim, wantRegularDim)
|
||||
}
|
||||
if regularBlock.RegularBits != uint8(preset.ValueBits) {
|
||||
t.Errorf("regular bits = %d, want %d", regularBlock.RegularBits, preset.ValueBits)
|
||||
}
|
||||
if len(regularBlock.ChannelIndices) != wantRegularDim {
|
||||
t.Errorf("regular ChannelIndices len = %d, want %d", len(regularBlock.ChannelIndices), wantRegularDim)
|
||||
}
|
||||
|
||||
// ChannelIndices across both blocks must cover all 70 channels exactly once.
|
||||
seen := make([]int, 70)
|
||||
for _, idx := range outlierBlock.ChannelIndices {
|
||||
seen[idx]++
|
||||
}
|
||||
for _, idx := range regularBlock.ChannelIndices {
|
||||
seen[idx]++
|
||||
}
|
||||
for i, count := range seen {
|
||||
if count != 1 {
|
||||
t.Errorf("channel %d appears %d times across blocks", i, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBlockMarshalWithChannelIndices verifies that ChannelIndices round-trips correctly.
|
||||
func TestBlockMarshalWithChannelIndices(t *testing.T) {
|
||||
block := Block{
|
||||
Version: BlockVersion,
|
||||
PresetID: PresetTQ2.ID,
|
||||
Role: uint8(roleKey),
|
||||
Objective: uint8(objectiveMSE),
|
||||
OriginalDim: 4,
|
||||
PaddedDim: 4,
|
||||
BlockDim: 4,
|
||||
RegularBits: 2,
|
||||
RotationSeed: 42,
|
||||
CodebookID: 2,
|
||||
QJLRows: 0,
|
||||
AuxLayoutID: 1,
|
||||
ChannelIndices: []uint16{0, 3, 7, 12},
|
||||
Scale: 0.5,
|
||||
RegularIndices: []byte{0b10110001},
|
||||
}
|
||||
data, err := block.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got Block
|
||||
if err := got.UnmarshalBinary(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got.ChannelIndices) != len(block.ChannelIndices) {
|
||||
t.Fatalf("ChannelIndices len: got %d want %d", len(got.ChannelIndices), len(block.ChannelIndices))
|
||||
}
|
||||
for i := range block.ChannelIndices {
|
||||
if got.ChannelIndices[i] != block.ChannelIndices[i] {
|
||||
t.Errorf("ChannelIndices[%d]: got %d want %d", i, got.ChannelIndices[i], block.ChannelIndices[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
216
turboquant/codebook.go
Normal file
216
turboquant/codebook.go
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type codebookCacheKey struct {
|
||||
dim int
|
||||
bits int
|
||||
}
|
||||
|
||||
type scalarCodebookCacheValue struct {
|
||||
codebook []float32
|
||||
boundaries []float32
|
||||
}
|
||||
|
||||
var scalarCodebookCache sync.Map
|
||||
|
||||
// ExportCodebook returns the Lloyd-Max codebook centroids for the given
|
||||
// dim and bits. Used by the CUDA dequant kernel (loaded into GPU constant memory).
|
||||
func ExportCodebook(dim, bits int) []float32 {
|
||||
cb, _ := scalarCodebook(dim, bits)
|
||||
return cb
|
||||
}
|
||||
|
||||
// ExportBoundaries returns the Lloyd-Max decision boundaries for the given
|
||||
// dim and bits. Used by the CUDA encode kernel for binary-search quantization.
|
||||
// Boundaries are the midpoints between adjacent centroids; len = (1<<bits) - 1.
|
||||
func ExportBoundaries(dim, bits int) []float32 {
|
||||
_, boundaries := scalarCodebook(dim, bits)
|
||||
return boundaries
|
||||
}
|
||||
|
||||
func scalarCodebook(dim int, bits int) ([]float32, []float32) {
|
||||
key := codebookCacheKey{dim: dim, bits: bits}
|
||||
if cached, ok := scalarCodebookCache.Load(key); ok {
|
||||
value := cached.(scalarCodebookCacheValue)
|
||||
return append([]float32(nil), value.codebook...), append([]float32(nil), value.boundaries...)
|
||||
}
|
||||
|
||||
codebook := buildLloydMaxCodebook(dim, bits)
|
||||
value := scalarCodebookCacheValue{
|
||||
codebook: codebook,
|
||||
boundaries: codebookBoundaries(codebook),
|
||||
}
|
||||
actual, _ := scalarCodebookCache.LoadOrStore(key, value)
|
||||
cached := actual.(scalarCodebookCacheValue)
|
||||
return append([]float32(nil), cached.codebook...), append([]float32(nil), cached.boundaries...)
|
||||
}
|
||||
|
||||
func buildLloydMaxCodebook(dim int, bits int) []float32 {
|
||||
levels := 1 << bits
|
||||
if levels <= 1 {
|
||||
return []float32{0}
|
||||
}
|
||||
|
||||
samples := unitVectorCoordSamples(dim, bits, 65536)
|
||||
slices.Sort(samples)
|
||||
|
||||
centroids := make([]float64, levels)
|
||||
for level := range levels {
|
||||
begin := level * len(samples) / levels
|
||||
end := (level + 1) * len(samples) / levels
|
||||
if end <= begin {
|
||||
end = begin + 1
|
||||
}
|
||||
centroids[level] = meanFloat64(samples[begin:end])
|
||||
}
|
||||
|
||||
for range 48 {
|
||||
slices.Sort(centroids)
|
||||
bounds := make([]float64, levels-1)
|
||||
for i := range bounds {
|
||||
bounds[i] = (centroids[i] + centroids[i+1]) / 2
|
||||
}
|
||||
|
||||
sums := make([]float64, levels)
|
||||
counts := make([]int, levels)
|
||||
for _, sample := range samples {
|
||||
idx := quantizeScalarFloat64(sample, bounds)
|
||||
sums[idx] += sample
|
||||
counts[idx]++
|
||||
}
|
||||
|
||||
maxDelta := 0.0
|
||||
for i := range centroids {
|
||||
var next float64
|
||||
if counts[i] > 0 {
|
||||
next = sums[i] / float64(counts[i])
|
||||
} else if i == 0 {
|
||||
next = bounds[0] - 0.25
|
||||
} else if i == len(centroids)-1 {
|
||||
next = bounds[len(bounds)-1] + 0.25
|
||||
} else {
|
||||
next = (bounds[i-1] + bounds[i]) / 2
|
||||
}
|
||||
maxDelta = math.Max(maxDelta, math.Abs(next-centroids[i]))
|
||||
centroids[i] = next
|
||||
}
|
||||
if maxDelta < 1e-6 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
slices.Sort(centroids)
|
||||
codebook := make([]float32, len(centroids))
|
||||
for i := range centroids {
|
||||
codebook[i] = float32(centroids[i])
|
||||
}
|
||||
return codebook
|
||||
}
|
||||
|
||||
// unitVectorCoordSamples returns count samples from the exact marginal
|
||||
// distribution of a single coordinate of a uniformly random unit vector in R^d:
|
||||
//
|
||||
// z_1 / ‖z‖ · √d, z ~ N(0, I_d)
|
||||
//
|
||||
// For large d this converges to N(0,1); for smaller d the heavier tails of the
|
||||
// Beta((d-3)/2,(d-3)/2) coordinate distribution are preserved. Using this
|
||||
// distribution — rather than pure N(0,1) — produces the optimal Lloyd-Max
|
||||
// codebook for the actual coordinate distribution that arises after RMS
|
||||
// normalization and random rotation (Paper §3.1, Eq. 4, Lemma 1).
|
||||
func unitVectorCoordSamples(dim int, bits int, count int) []float64 {
|
||||
rng := splitmix64(uint64(bits+1)<<48 ^ uint64(dim+1)<<16 ^ 0x4d595df4d0f33173)
|
||||
out := make([]float64, count)
|
||||
if dim <= 1 {
|
||||
for i := range out {
|
||||
out[i] = gaussianFloat64(&rng)
|
||||
}
|
||||
return out
|
||||
}
|
||||
sqrtDim := math.Sqrt(float64(dim))
|
||||
for i := range out {
|
||||
z0 := gaussianFloat64(&rng)
|
||||
sumSq := z0 * z0
|
||||
for k := 1; k < dim; k++ {
|
||||
g := gaussianFloat64(&rng)
|
||||
sumSq += g * g
|
||||
}
|
||||
norm := math.Sqrt(sumSq)
|
||||
if norm < 1e-15 {
|
||||
out[i] = 0
|
||||
} else {
|
||||
out[i] = z0 / norm * sqrtDim
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func meanFloat64(values []float64) float64 {
|
||||
if len(values) == 0 {
|
||||
return 0
|
||||
}
|
||||
total := 0.0
|
||||
for _, value := range values {
|
||||
total += value
|
||||
}
|
||||
return total / float64(len(values))
|
||||
}
|
||||
|
||||
func codebookBoundaries(codebook []float32) []float32 {
|
||||
if len(codebook) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]float32, len(codebook)-1)
|
||||
for i := range out {
|
||||
out[i] = (codebook[i] + codebook[i+1]) / 2
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func quantizeScalarByBoundary(v float32, codebook []float32, boundaries []float32) uint8 {
|
||||
if len(codebook) == 0 {
|
||||
return 0
|
||||
}
|
||||
if len(boundaries) != len(codebook)-1 {
|
||||
return quantizeScalarNearest(v, codebook)
|
||||
}
|
||||
|
||||
idx := 0
|
||||
for idx < len(boundaries) && v >= boundaries[idx] {
|
||||
idx++
|
||||
}
|
||||
return uint8(idx)
|
||||
}
|
||||
|
||||
func quantizeScalarFloat64(v float64, boundaries []float64) int {
|
||||
idx := 0
|
||||
for idx < len(boundaries) && v >= boundaries[idx] {
|
||||
idx++
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func quantizeScalarNearest(v float32, codebook []float32) uint8 {
|
||||
best := 0
|
||||
bestDist := float32(math.MaxFloat32)
|
||||
for i, centroid := range codebook {
|
||||
d := abs32(v - centroid)
|
||||
if d < bestDist {
|
||||
bestDist = d
|
||||
best = i
|
||||
}
|
||||
}
|
||||
return uint8(best)
|
||||
}
|
||||
|
||||
func dequantizeScalar(idx uint8, codebook []float32) float32 {
|
||||
if int(idx) >= len(codebook) {
|
||||
return 0
|
||||
}
|
||||
return codebook[idx]
|
||||
}
|
||||
198
turboquant/codebook_test.go
Normal file
198
turboquant/codebook_test.go
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestUnitVectorCoordSamplesVariance checks that unitVectorCoordSamples
|
||||
// produces samples with variance ≈ 1.0 for the normalized coordinate
|
||||
// distribution (z_1/‖z‖ · √d). The variance of this distribution is exactly
|
||||
// d · Var[z_1/‖z‖] = d · (1/d) = 1, independent of d. For small d the
|
||||
// distribution has heavier tails than N(0,1) but still variance=1.
|
||||
func TestUnitVectorCoordSamplesVariance(t *testing.T) {
|
||||
for _, dim := range []int{4, 8, 16, 32, 64, 128} {
|
||||
samples := unitVectorCoordSamples(dim, 2, 32768)
|
||||
var sum, sumSq float64
|
||||
for _, s := range samples {
|
||||
sum += s
|
||||
sumSq += s * s
|
||||
}
|
||||
n := float64(len(samples))
|
||||
mean := sum / n
|
||||
variance := sumSq/n - mean*mean
|
||||
// Mean should be ~0, variance ~1 for all d.
|
||||
if math.Abs(mean) > 0.05 {
|
||||
t.Errorf("dim=%d: mean=%.4f, want ~0", dim, mean)
|
||||
}
|
||||
if math.Abs(variance-1.0) > 0.05 {
|
||||
t.Errorf("dim=%d: variance=%.4f, want ~1.0", dim, variance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnitVectorCoordSamplesKurtosis checks that small-d samples have lower
|
||||
// kurtosis than N(0,1) (kurtosis=3), confirming the bounded-support Beta tails.
|
||||
// The kurtosis of the coordinate distribution is 3d/(d+2), which equals
|
||||
// 2.0 at d=4 and approaches 3.0 as d→∞. So for small d the distribution is
|
||||
// platykurtic (kurtosis < 3, bounded support), unlike the unbounded Gaussian.
|
||||
func TestUnitVectorCoordSamplesKurtosis(t *testing.T) {
|
||||
// For d=4: kurtosis = 3×4/(4+2) = 2.0 (platykurtic, well below Gaussian's 3).
|
||||
// For d=128: kurtosis = 3×128/130 ≈ 2.95 (close to Gaussian's 3).
|
||||
samplesSmall := unitVectorCoordSamples(4, 2, 65536)
|
||||
var s2, s4 float64
|
||||
for _, s := range samplesSmall {
|
||||
s2 += s * s
|
||||
s4 += s * s * s * s
|
||||
}
|
||||
n := float64(len(samplesSmall))
|
||||
varSmall := s2 / n
|
||||
kurtSmall := (s4 / n) / (varSmall * varSmall)
|
||||
if kurtSmall >= 3.0 {
|
||||
t.Errorf("dim=4: kurtosis=%.3f, want < 3.0 (platykurtic for bounded distribution)", kurtSmall)
|
||||
}
|
||||
|
||||
samplesLarge := unitVectorCoordSamples(128, 2, 65536)
|
||||
var l2, l4 float64
|
||||
for _, s := range samplesLarge {
|
||||
l2 += s * s
|
||||
l4 += s * s * s * s
|
||||
}
|
||||
varLarge := l2 / n
|
||||
kurtLarge := (l4 / n) / (varLarge * varLarge)
|
||||
// For d=128 kurtosis should be close to Gaussian's 3.
|
||||
if math.Abs(kurtLarge-3.0) > 0.3 {
|
||||
t.Errorf("dim=128: kurtosis=%.3f, want ~3.0 (close to Gaussian)", kurtLarge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScalarCodebookDeterministic(t *testing.T) {
|
||||
for _, bits := range []int{2, 3} {
|
||||
codebookA, boundsA := scalarCodebook(128, bits)
|
||||
codebookB, boundsB := scalarCodebook(128, bits)
|
||||
if len(codebookA) != 1<<bits {
|
||||
t.Fatalf("bits=%d codebook len=%d", bits, len(codebookA))
|
||||
}
|
||||
for i := range codebookA {
|
||||
if codebookA[i] != codebookB[i] {
|
||||
t.Fatalf("bits=%d centroid mismatch at %d", bits, i)
|
||||
}
|
||||
}
|
||||
for i := range boundsA {
|
||||
if boundsA[i] != boundsB[i] {
|
||||
t.Fatalf("bits=%d boundary mismatch at %d", bits, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodebookBoundariesMonotonic(t *testing.T) {
|
||||
for _, bits := range []int{2, 3} {
|
||||
_, bounds := scalarCodebook(128, bits)
|
||||
for i := 1; i < len(bounds); i++ {
|
||||
if bounds[i] <= bounds[i-1] {
|
||||
t.Fatalf("bits=%d boundaries are not monotonic", bits)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuantizeScalarByBoundaryDeterministic(t *testing.T) {
|
||||
codebook, bounds := scalarCodebook(128, 3)
|
||||
mid := bounds[2]
|
||||
|
||||
left := quantizeScalarByBoundary(mid-1e-6, codebook, bounds)
|
||||
right := quantizeScalarByBoundary(mid+1e-6, codebook, bounds)
|
||||
atBoundary := quantizeScalarByBoundary(mid, codebook, bounds)
|
||||
|
||||
if left != 2 {
|
||||
t.Fatalf("left boundary bucket = %d, want 2", left)
|
||||
}
|
||||
if right != 3 {
|
||||
t.Fatalf("right boundary bucket = %d, want 3", right)
|
||||
}
|
||||
if atBoundary != 3 {
|
||||
t.Fatalf("exact boundary bucket = %d, want 3", atBoundary)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaperTheoremOneMSEBounds verifies that the Lloyd-Max codebook achieves
|
||||
// the total-distortion bounds from Theorem 1 of arXiv:2504.19874
|
||||
// for unit-norm vectors at bit widths 1–4.
|
||||
//
|
||||
// Paper Theorem 1 bounds (total distortion ||v - v_hat||^2 for unit-norm vectors):
|
||||
//
|
||||
// b=1 → 0.36, b=2 → 0.117, b=3 → 0.03, b=4 → 0.009
|
||||
func TestPaperTheoremOneMSEBounds(t *testing.T) {
|
||||
paperBounds := map[int]float64{
|
||||
1: 0.36,
|
||||
2: 0.117,
|
||||
3: 0.03,
|
||||
4: 0.009,
|
||||
}
|
||||
|
||||
const dim = 128
|
||||
const nTrials = 200
|
||||
rot := BuildRotation(dim, 0x42c0ffee)
|
||||
|
||||
for bits := 1; bits <= 4; bits++ {
|
||||
bits := bits // capture loop variable
|
||||
t.Run(fmt.Sprintf("bits=%d", bits), func(t *testing.T) {
|
||||
bound := paperBounds[bits]
|
||||
codebook, boundaries := scalarCodebook(dim, bits)
|
||||
|
||||
var totalMSE float64
|
||||
for trial := range nTrials {
|
||||
vec := pseudoRandomVector(dim, uint64(trial)*0x9e3779b97f4a7c15+1)
|
||||
|
||||
// Normalize to unit norm.
|
||||
var norm float64
|
||||
for _, v := range vec {
|
||||
norm += float64(v) * float64(v)
|
||||
}
|
||||
norm = math.Sqrt(norm)
|
||||
if norm < 1e-10 {
|
||||
continue
|
||||
}
|
||||
for i := range vec {
|
||||
vec[i] /= float32(norm)
|
||||
}
|
||||
|
||||
// Apply rotation (matches internal encode behavior).
|
||||
rotated := ApplyRotation(vec, rot)
|
||||
|
||||
// Compute RMS scale (same as blockScale in encode.go).
|
||||
var sumSq float64
|
||||
for _, v := range rotated {
|
||||
sumSq += float64(v) * float64(v)
|
||||
}
|
||||
scale := float32(math.Sqrt(sumSq / float64(dim)))
|
||||
|
||||
// Quantize each element and accumulate per-element MSE.
|
||||
var mse float64
|
||||
for _, v := range rotated {
|
||||
normalized := float32(0)
|
||||
if scale > 0 {
|
||||
normalized = v / scale
|
||||
}
|
||||
idx := quantizeScalarByBoundary(normalized, codebook, boundaries)
|
||||
recon := codebook[idx] * scale
|
||||
diff := float64(v - recon)
|
||||
mse += diff * diff
|
||||
}
|
||||
totalMSE += mse
|
||||
}
|
||||
|
||||
avgMSE := totalMSE / float64(nTrials)
|
||||
// Allow 50% headroom over the paper bound to account for finite
|
||||
// sample size and finite dimension effects.
|
||||
if avgMSE > bound*1.5 {
|
||||
t.Errorf("bits=%d: avg MSE %.6f exceeds 1.5× paper bound %.6f",
|
||||
bits, avgMSE, bound*1.5)
|
||||
}
|
||||
t.Logf("bits=%d: avg MSE=%.6f, paper bound=%.6f, ratio=%.2f",
|
||||
bits, avgMSE, bound, avgMSE/bound)
|
||||
})
|
||||
}
|
||||
}
|
||||
113
turboquant/decode.go
Normal file
113
turboquant/decode.go
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// DecodeVector fully dequantizes an encoded vector back to float32 in
|
||||
// original space. This applies inverse rotation and, for product-mode
|
||||
// blocks, adds the reconstructed QJL residual.
|
||||
func DecodeVector(data []byte) ([]float32, Preset, error) {
|
||||
ev, err := UnmarshalEncodedVector(data)
|
||||
if err != nil {
|
||||
return nil, Preset{}, err
|
||||
}
|
||||
|
||||
decoded := make([]float32, ev.Dim)
|
||||
offset := 0
|
||||
for _, block := range ev.Blocks {
|
||||
blockDim := int(block.OriginalDim)
|
||||
codebook, _ := scalarCodebook(blockDim, int(block.RegularBits))
|
||||
indices := unpackBits(block.RegularIndices, int(block.RegularBits), blockDim)
|
||||
rotated := make([]float32, blockDim)
|
||||
for i, idx := range indices {
|
||||
rotated[i] = dequantizeScalar(idx, codebook) * block.Scale
|
||||
}
|
||||
if vectorObjective(block.Objective) == objectiveProduct {
|
||||
residual := reconstructResidual(blockDim, block.Residual)
|
||||
for i := range rotated {
|
||||
rotated[i] += residual[i]
|
||||
}
|
||||
}
|
||||
original := ApplyInverseRotation(rotated, BuildRotation(blockDim, block.RotationSeed))
|
||||
|
||||
if len(block.ChannelIndices) == blockDim {
|
||||
// Scatter to original channel positions.
|
||||
for i, chIdx := range block.ChannelIndices {
|
||||
decoded[chIdx] = original[i]
|
||||
}
|
||||
} else {
|
||||
// Legacy single-block: fill contiguous range.
|
||||
copy(decoded[offset:], original)
|
||||
offset += blockDim
|
||||
}
|
||||
}
|
||||
|
||||
return decoded, ev.Preset, nil
|
||||
}
|
||||
|
||||
// UnmarshalEncodedVector deserializes an EncodedVector from its binary form.
|
||||
func UnmarshalEncodedVector(data []byte) (EncodedVector, error) {
|
||||
r := bytes.NewReader(data)
|
||||
var version uint8
|
||||
var presetID uint8
|
||||
var dim uint32
|
||||
var blockCount uint32
|
||||
for _, field := range []any{&version, &presetID, &dim, &blockCount} {
|
||||
if err := binary.Read(r, binary.LittleEndian, field); err != nil {
|
||||
return EncodedVector{}, err
|
||||
}
|
||||
}
|
||||
if version != BlockVersion {
|
||||
return EncodedVector{}, fmt.Errorf("unsupported encoded vector version %d", version)
|
||||
}
|
||||
|
||||
preset, err := PresetByID(presetID)
|
||||
if err != nil {
|
||||
return EncodedVector{}, err
|
||||
}
|
||||
|
||||
blocks := make([]Block, 0, blockCount)
|
||||
totalDim := 0
|
||||
for range int(blockCount) {
|
||||
var blockLen uint32
|
||||
if err := binary.Read(r, binary.LittleEndian, &blockLen); err != nil {
|
||||
return EncodedVector{}, err
|
||||
}
|
||||
blockData := make([]byte, blockLen)
|
||||
if _, err := io.ReadFull(r, blockData); err != nil {
|
||||
return EncodedVector{}, err
|
||||
}
|
||||
var block Block
|
||||
if err := block.UnmarshalBinary(blockData); err != nil {
|
||||
return EncodedVector{}, err
|
||||
}
|
||||
if block.PresetID != preset.ID {
|
||||
return EncodedVector{}, fmt.Errorf("block preset id %d does not match encoded preset %d", block.PresetID, preset.ID)
|
||||
}
|
||||
if len(block.RegularIndices) != expectedPackedBytes(int(block.OriginalDim), int(block.RegularBits)) {
|
||||
return EncodedVector{}, fmt.Errorf("invalid primary index length %d for dim %d and bits %d", len(block.RegularIndices), block.OriginalDim, block.RegularBits)
|
||||
}
|
||||
if block.Residual.SketchDim != block.QJLRows {
|
||||
return EncodedVector{}, fmt.Errorf("residual sketch dim %d does not match qjl rows %d", block.Residual.SketchDim, block.QJLRows)
|
||||
}
|
||||
totalDim += int(block.OriginalDim)
|
||||
blocks = append(blocks, block)
|
||||
}
|
||||
if r.Len() != 0 {
|
||||
return EncodedVector{}, fmt.Errorf("unexpected trailing bytes in encoded vector: %d", r.Len())
|
||||
}
|
||||
if totalDim != int(dim) {
|
||||
return EncodedVector{}, fmt.Errorf("encoded vector dim mismatch: header=%d blocks=%d", dim, totalDim)
|
||||
}
|
||||
|
||||
return EncodedVector{
|
||||
Version: version,
|
||||
Preset: preset,
|
||||
Dim: int(dim),
|
||||
Blocks: blocks,
|
||||
}, nil
|
||||
}
|
||||
127
turboquant/decode_test.go
Normal file
127
turboquant/decode_test.go
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUnmarshalEncodedVectorRejectsBadHeader(t *testing.T) {
|
||||
if _, err := UnmarshalEncodedVector([]byte{1, 2, 3}); err == nil {
|
||||
t.Fatal("expected malformed header error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEncodedVectorRejectsWrongVersion(t *testing.T) {
|
||||
encoded, err := EncodeVector([]float32{1, 2, 3, 4}, PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := encoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data[0] = 99
|
||||
|
||||
if _, err := UnmarshalEncodedVector(data); err == nil {
|
||||
t.Fatal("expected unsupported version error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEncodedVectorRejectsBadPresetID(t *testing.T) {
|
||||
encoded, err := EncodeVector([]float32{1, 2, 3, 4}, PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := encoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data[1] = 99
|
||||
|
||||
if _, err := UnmarshalEncodedVector(data); err == nil {
|
||||
t.Fatal("expected bad preset id error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEncodedVectorRejectsTruncatedBlockPayload(t *testing.T) {
|
||||
encoded, err := EncodeVector(pseudoRandomVector(16, 0x99), PresetTQ2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := encoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
truncated := data[:len(data)-1]
|
||||
if _, err := UnmarshalEncodedVector(truncated); err == nil {
|
||||
t.Fatal("expected truncated block payload error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeVectorRejectsInvalidIndexLengths(t *testing.T) {
|
||||
encoded, err := EncodeVector(pseudoRandomVector(16, 0x77), PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := encoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ev, err := UnmarshalEncodedVector(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ev.Blocks[0].RegularIndices = append(ev.Blocks[0].RegularIndices, 0)
|
||||
corrupt, err := marshalTestEncodedVector(ev)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, _, err := DecodeVector(corrupt); err == nil {
|
||||
t.Fatal("expected invalid primary index length error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeVectorPreservesOriginalLength(t *testing.T) {
|
||||
values := pseudoRandomVector(130, 0x66)
|
||||
encoded, err := EncodeVector(values, PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := encoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decoded, _, err := DecodeVector(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(decoded) != len(values) {
|
||||
t.Fatalf("decoded len = %d, want %d", len(decoded), len(values))
|
||||
}
|
||||
}
|
||||
|
||||
func marshalTestEncodedVector(ev EncodedVector) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
for _, field := range []any{ev.Version, ev.Preset.ID, uint32(ev.Dim), uint32(len(ev.Blocks))} {
|
||||
if err := binary.Write(&buf, binary.LittleEndian, field); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for _, block := range ev.Blocks {
|
||||
blockData, err := block.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint32(len(blockData))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf.Write(blockData)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
386
turboquant/encode.go
Normal file
386
turboquant/encode.go
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
// Rotation seed modifiers for outlier and regular sub-blocks.
|
||||
// Using ASCII mnemonics: "OUTL1ER\0" and "REGULAR\0".
|
||||
const (
|
||||
outlierSeedXOR = uint64(0x4f55544c31455200)
|
||||
regularSeedXOR = uint64(0x524547554c415200)
|
||||
)
|
||||
|
||||
type EncodedVector struct {
|
||||
Version uint8
|
||||
Preset Preset
|
||||
Dim int
|
||||
Blocks []Block
|
||||
}
|
||||
|
||||
func EncodeVector(values []float32, preset Preset) (EncodedVector, error) {
|
||||
return encodeVector(values, preset, roleGeneric, objectiveMSE, preset.ValueBits)
|
||||
}
|
||||
|
||||
func EncodeKeyVector(values []float32, preset Preset) (EncodedVector, error) {
|
||||
return encodeVector(values, preset, roleKey, objectiveProduct, preset.KeyPrimaryBits)
|
||||
}
|
||||
|
||||
func EncodeValueVector(values []float32, preset Preset) (EncodedVector, error) {
|
||||
return encodeVector(values, preset, roleValue, objectiveMSE, preset.ValueBits)
|
||||
}
|
||||
|
||||
func encodeVector(values []float32, preset Preset, role vectorRole, objective vectorObjective, bits int) (EncodedVector, error) {
|
||||
dim := len(values)
|
||||
if dim <= 0 {
|
||||
return EncodedVector{}, fmt.Errorf("invalid turboquant vector dim %d", dim)
|
||||
}
|
||||
if bits <= 0 || bits >= 8 {
|
||||
return EncodedVector{}, fmt.Errorf("invalid turboquant bit width %d", bits)
|
||||
}
|
||||
|
||||
if preset.HasOutlierSplit() && dim > preset.OutlierCount {
|
||||
return encodeVectorWithOutliers(values, preset, role, objective, bits)
|
||||
}
|
||||
|
||||
block, err := encodeSubBlock(values, nil, preset, role, objective, bits, preset.RotationSeed)
|
||||
if err != nil {
|
||||
return EncodedVector{}, err
|
||||
}
|
||||
return EncodedVector{
|
||||
Version: BlockVersion,
|
||||
Preset: preset,
|
||||
Dim: dim,
|
||||
Blocks: []Block{block},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encodeVectorWithOutliers splits the vector into outlier and regular channel
|
||||
// sub-blocks, each encoded independently with its own rotation. The outlier
|
||||
// block is stored first in Blocks.
|
||||
func encodeVectorWithOutliers(values []float32, preset Preset, role vectorRole, objective vectorObjective, regularBits int) (EncodedVector, error) {
|
||||
split := SplitOutlierChannels(values, preset.OutlierCount)
|
||||
|
||||
outlierBits := preset.OutlierBits
|
||||
if outlierBits <= 0 || outlierBits >= 8 {
|
||||
return EncodedVector{}, fmt.Errorf("invalid outlier bit width %d", outlierBits)
|
||||
}
|
||||
if regularBits <= 0 || regularBits >= 8 {
|
||||
return EncodedVector{}, fmt.Errorf("invalid regular bit width %d", regularBits)
|
||||
}
|
||||
|
||||
outlierSeed := preset.RotationSeed ^ outlierSeedXOR
|
||||
regularSeed := preset.RotationSeed ^ regularSeedXOR
|
||||
|
||||
outlierBlock, err := encodeSubBlock(split.OutlierValues, split.OutlierIndices, preset, role, objective, outlierBits, outlierSeed)
|
||||
if err != nil {
|
||||
return EncodedVector{}, fmt.Errorf("outlier block: %w", err)
|
||||
}
|
||||
|
||||
// Regular block always uses MSE (no QJL sketch) regardless of the key/value
|
||||
// role. QJL is only applied to the outlier block, which concentrates the
|
||||
// residual correction budget on the highest-magnitude channels.
|
||||
regularBlock, err := encodeSubBlock(split.RegularValues, split.RegularIndices, preset, role, objectiveMSE, regularBits, regularSeed)
|
||||
if err != nil {
|
||||
return EncodedVector{}, fmt.Errorf("regular block: %w", err)
|
||||
}
|
||||
|
||||
return EncodedVector{
|
||||
Version: BlockVersion,
|
||||
Preset: preset,
|
||||
Dim: len(values),
|
||||
Blocks: []Block{outlierBlock, regularBlock},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encodeSubBlock encodes a sub-vector (identified by channelIndices into the
|
||||
// original full-dim vector) as a single Block. If channelIndices is nil, the
|
||||
// block covers all channels (single-block legacy path).
|
||||
func encodeSubBlock(values []float32, channelIndices []uint16, preset Preset, role vectorRole, objective vectorObjective, bits int, rotationSeed uint64) (Block, error) {
|
||||
dim := len(values)
|
||||
if dim <= 0 {
|
||||
return Block{}, fmt.Errorf("empty sub-block")
|
||||
}
|
||||
|
||||
codebook, boundaries := scalarCodebook(dim, bits)
|
||||
rotation := BuildRotation(dim, rotationSeed)
|
||||
rotated := ApplyRotation(values, rotation)
|
||||
scale := blockScale(rotated)
|
||||
primaryCodes := make([]uint8, dim)
|
||||
reconRotated := make([]float32, dim)
|
||||
if scale == 0 {
|
||||
for i := range primaryCodes {
|
||||
primaryCodes[i] = quantizeScalarByBoundary(0, codebook, boundaries)
|
||||
reconRotated[i] = 0
|
||||
}
|
||||
} else {
|
||||
for i, value := range rotated {
|
||||
normalized := value / scale
|
||||
idx := quantizeScalarByBoundary(normalized, codebook, boundaries)
|
||||
primaryCodes[i] = idx
|
||||
reconRotated[i] = dequantizeScalar(idx, codebook) * scale
|
||||
}
|
||||
}
|
||||
|
||||
qjlRows := 0
|
||||
if objective == objectiveProduct {
|
||||
qjlRows = preset.KeyQJLRows(dim)
|
||||
}
|
||||
|
||||
block := Block{
|
||||
Version: BlockVersion,
|
||||
PresetID: preset.ID,
|
||||
Role: uint8(role),
|
||||
Objective: uint8(objective),
|
||||
OriginalDim: uint16(dim),
|
||||
PaddedDim: uint16(dim),
|
||||
BlockDim: uint16(dim),
|
||||
RegularBits: uint8(bits),
|
||||
RotationSeed: rotationSeed,
|
||||
CodebookID: uint16(bits),
|
||||
QJLRows: uint16(qjlRows),
|
||||
AuxLayoutID: 1,
|
||||
ChannelIndices: channelIndices,
|
||||
Scale: scale,
|
||||
RegularIndices: packBits(primaryCodes, bits),
|
||||
Residual: encodeResidual(rotated, reconRotated, qjlRows, rotationSeed^0x9e3779b97f4a7c15),
|
||||
}
|
||||
return block, nil
|
||||
}
|
||||
|
||||
// EncodeKeyPerHead quantizes a single attention head's key vector using
|
||||
// rotation + Lloyd-Max without outlier split or QJL residual. Returns
|
||||
// packed N-bit indices and an RMS scale. The quantized values are in
|
||||
// rotated space — the caller must rotate Q to match at attention time.
|
||||
//
|
||||
// This produces a GPU-friendly flat representation: just packed bits + scale.
|
||||
func EncodeKeyPerHead(values []float32, preset Preset) (packedIndices []byte, scale float32, err error) {
|
||||
dim := len(values)
|
||||
if dim <= 0 {
|
||||
return nil, 0, fmt.Errorf("empty head vector")
|
||||
}
|
||||
bits := preset.KeyPrimaryBits
|
||||
if bits <= 0 || bits >= 8 {
|
||||
return nil, 0, fmt.Errorf("invalid bit width %d", bits)
|
||||
}
|
||||
|
||||
codebook, boundaries := scalarCodebook(dim, bits)
|
||||
rotation := BuildRotation(dim, preset.RotationSeed)
|
||||
rotated := ApplyRotation(values, rotation)
|
||||
scale = blockScale(rotated)
|
||||
|
||||
codes := make([]uint8, dim)
|
||||
if scale > 0 {
|
||||
for i, v := range rotated {
|
||||
codes[i] = quantizeScalarByBoundary(v/scale, codebook, boundaries)
|
||||
}
|
||||
}
|
||||
return packBits(codes, bits), scale, nil
|
||||
}
|
||||
|
||||
// DequantKeyPerHead reconstructs f32 values from packed indices + scale in
|
||||
// rotated space (no inverse rotation). Used for CPU-side testing/fallback.
|
||||
func DequantKeyPerHead(packedIndices []byte, scale float32, headDim, bits int) []float32 {
|
||||
codebook, _ := scalarCodebook(headDim, bits)
|
||||
indices := unpackBits(packedIndices, bits, headDim)
|
||||
out := make([]float32, headDim)
|
||||
for i, idx := range indices {
|
||||
out[i] = dequantizeScalar(idx, codebook) * scale
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// OutlierPerHead is the CPU reference for TurboQuant paper Algorithm 1
|
||||
// Sec 4.3's outlier-split K encoding. Mirrors the GPU kernel
|
||||
// tq_encode_kernel_outlier / tq_dequant_multihead_kernel_outlier in
|
||||
// ml/backend/ggml/ggml/src/ggml-cuda/tq-*.cu:
|
||||
//
|
||||
// 1. Rotate K by Householder QR of a random Gaussian matrix (shared
|
||||
// rotation, same as EncodeKeyPerHead).
|
||||
// 2. Select the top-K channels by absolute rotated magnitude as
|
||||
// outliers.
|
||||
// 3. Compute independent RMS scales for regular and outlier sub-blocks.
|
||||
// 4. Quantize each sub-block with its own Lloyd-Max codebook at the
|
||||
// preset's primary bits (regular) and outlier bits (outlier).
|
||||
// 5. Return packed regular + packed outlier streams, per-sub-block
|
||||
// scales, and the channel index list.
|
||||
//
|
||||
// Unlike EncodeKeyVector / encodeVectorWithOutliers, the split happens
|
||||
// in ROTATED space (single rotation matmul) because that is what the
|
||||
// GPU can do cheaply and what the paper's symmetric-rotation formulation
|
||||
// assumes. The Go reference exists to validate the GPU kernel
|
||||
// bit-exactly, not to reproduce the block-protocol outlier path.
|
||||
type OutlierPerHead struct {
|
||||
RegularPacked []byte
|
||||
RegularScale float32
|
||||
OutlierPacked []byte
|
||||
OutlierScale float32
|
||||
OutlierIndices []int
|
||||
}
|
||||
|
||||
func EncodeKeyPerHeadOutlier(values []float32, preset Preset) (OutlierPerHead, error) {
|
||||
dim := len(values)
|
||||
if dim <= 0 {
|
||||
return OutlierPerHead{}, fmt.Errorf("empty head vector")
|
||||
}
|
||||
bits := preset.KeyPrimaryBits
|
||||
if bits <= 0 || bits >= 8 {
|
||||
return OutlierPerHead{}, fmt.Errorf("invalid regular bit width %d", bits)
|
||||
}
|
||||
outlierBits := preset.OutlierBits
|
||||
outlierCount := preset.OutlierCount
|
||||
if outlierBits <= 0 || outlierBits >= 8 || outlierCount <= 0 || outlierCount >= dim {
|
||||
return OutlierPerHead{}, fmt.Errorf("invalid outlier split %d@%dbits over dim=%d", outlierCount, outlierBits, dim)
|
||||
}
|
||||
regularCount := dim - outlierCount
|
||||
|
||||
rotation := BuildRotation(dim, preset.RotationSeed)
|
||||
rotated := ApplyRotation(values, rotation)
|
||||
|
||||
// Step 2: top-K outlier select by abs(rotated). Mirrors the serial
|
||||
// thread-0 selection in tq_encode_kernel_outlier step 4. We use the
|
||||
// same "mark-and-scan" algorithm so the outlier set order matches
|
||||
// the GPU kernel exactly for a given input.
|
||||
isOutlier := make([]bool, dim)
|
||||
outlierPos := make([]int, 0, outlierCount)
|
||||
outlierVal := make([]float32, 0, outlierCount)
|
||||
for range outlierCount {
|
||||
bestVal := float32(-1.0)
|
||||
bestIdx := 0
|
||||
for i := range dim {
|
||||
if isOutlier[i] {
|
||||
continue
|
||||
}
|
||||
a := rotated[i]
|
||||
if a < 0 {
|
||||
a = -a
|
||||
}
|
||||
if a > bestVal {
|
||||
bestVal = a
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
isOutlier[bestIdx] = true
|
||||
outlierPos = append(outlierPos, bestIdx)
|
||||
outlierVal = append(outlierVal, rotated[bestIdx])
|
||||
}
|
||||
|
||||
// Build the regular channel position list in ascending order. The
|
||||
// GPU kernel walks i=0..dim-1 and appends non-outlier positions to
|
||||
// s_reg_pos in order; the CPU reference does the same so packed
|
||||
// regular slot r maps to the same original channel on both sides.
|
||||
regularRotated := make([]float32, 0, regularCount)
|
||||
for i := range dim {
|
||||
if !isOutlier[i] {
|
||||
regularRotated = append(regularRotated, rotated[i])
|
||||
}
|
||||
}
|
||||
|
||||
regularScale := blockScale(regularRotated)
|
||||
outlierScale := blockScale(outlierVal)
|
||||
|
||||
regularCodebook, regularBoundaries := scalarCodebook(dim, bits)
|
||||
outlierCodebook, outlierBoundaries := scalarCodebook(dim, outlierBits)
|
||||
|
||||
regularCodes := make([]uint8, regularCount)
|
||||
if regularScale > 0 {
|
||||
for r, v := range regularRotated {
|
||||
regularCodes[r] = quantizeScalarByBoundary(v/regularScale, regularCodebook, regularBoundaries)
|
||||
}
|
||||
}
|
||||
outlierCodes := make([]uint8, outlierCount)
|
||||
if outlierScale > 0 {
|
||||
for r, v := range outlierVal {
|
||||
outlierCodes[r] = quantizeScalarByBoundary(v/outlierScale, outlierCodebook, outlierBoundaries)
|
||||
}
|
||||
}
|
||||
|
||||
return OutlierPerHead{
|
||||
RegularPacked: packBits(regularCodes, bits),
|
||||
RegularScale: regularScale,
|
||||
OutlierPacked: packBits(outlierCodes, outlierBits),
|
||||
OutlierScale: outlierScale,
|
||||
OutlierIndices: outlierPos,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DequantKeyPerHeadOutlier is the CPU reference for the outlier-split
|
||||
// dequant. Returns the reconstructed rotated-space vector that the
|
||||
// caller can compare against ApplyRotation(original, rotation).
|
||||
func DequantKeyPerHeadOutlier(enc OutlierPerHead, preset Preset, headDim int) []float32 {
|
||||
outlierCount := len(enc.OutlierIndices)
|
||||
regularCount := headDim - outlierCount
|
||||
bits := preset.KeyPrimaryBits
|
||||
outlierBits := preset.OutlierBits
|
||||
|
||||
regularCodebook, _ := scalarCodebook(headDim, bits)
|
||||
outlierCodebook, _ := scalarCodebook(headDim, outlierBits)
|
||||
|
||||
regularIdx := unpackBits(enc.RegularPacked, bits, regularCount)
|
||||
outlierIdx := unpackBits(enc.OutlierPacked, outlierBits, outlierCount)
|
||||
|
||||
// Build position-to-slot lookups. outlier_slot_for[i] is the index
|
||||
// in OutlierIndices that equals i (or -1 if i is regular).
|
||||
outlierSlotFor := make([]int, headDim)
|
||||
for i := range outlierSlotFor {
|
||||
outlierSlotFor[i] = -1
|
||||
}
|
||||
for slot, pos := range enc.OutlierIndices {
|
||||
outlierSlotFor[pos] = slot
|
||||
}
|
||||
|
||||
out := make([]float32, headDim)
|
||||
regPos := 0
|
||||
for i := range headDim {
|
||||
if outlierSlotFor[i] >= 0 {
|
||||
out[i] = dequantizeScalar(outlierIdx[outlierSlotFor[i]], outlierCodebook) * enc.OutlierScale
|
||||
} else {
|
||||
out[i] = dequantizeScalar(regularIdx[regPos], regularCodebook) * enc.RegularScale
|
||||
regPos++
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func blockScale(values []float32) float32 {
|
||||
if len(values) == 0 {
|
||||
return 0
|
||||
}
|
||||
var sumSquares float64
|
||||
for _, value := range values {
|
||||
sumSquares += float64(value * value)
|
||||
}
|
||||
if sumSquares < 1e-12 {
|
||||
return 0
|
||||
}
|
||||
return float32(math.Sqrt(sumSquares / float64(len(values))))
|
||||
}
|
||||
|
||||
func (e EncodedVector) MarshalBinary() ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
header := []any{
|
||||
e.Version,
|
||||
e.Preset.ID,
|
||||
uint32(e.Dim),
|
||||
uint32(len(e.Blocks)),
|
||||
}
|
||||
for _, field := range header {
|
||||
if err := binary.Write(&buf, binary.LittleEndian, field); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for _, block := range e.Blocks {
|
||||
blockData, err := block.MarshalBinary()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&buf, binary.LittleEndian, uint32(len(blockData))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf.Write(blockData)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
964
turboquant/encode_test.go
Normal file
964
turboquant/encode_test.go
Normal file
|
|
@ -0,0 +1,964 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// testOutlierPreset returns a copy of base with outlier split forced on
|
||||
// for testing. The shipped PresetTQ3 / PresetTQ3K default to
|
||||
// OutlierCount=0 (see comments on those presets in turboquant.go): the
|
||||
// paper-grounded outlier-split path is available in the kernels and Go
|
||||
// helpers but disabled in the default presets because on the models
|
||||
// this fork ships against it hurts decode throughput and PPL without
|
||||
// the Phase 2A asymmetric-quantization follow-up. Tests that exercise
|
||||
// the outlier-split code path explicitly opt in via this helper.
|
||||
func testOutlierPreset(base Preset, count int) Preset {
|
||||
out := base
|
||||
out.OutlierCount = count
|
||||
return out
|
||||
}
|
||||
|
||||
// TestOutlierSplitBoundaries pins the single-block vs two-block boundary.
|
||||
// dim == OutlierCount must produce a single block (no split).
|
||||
// dim == OutlierCount+1 is the smallest two-block encoding.
|
||||
// TestMemoryFormulaMatchesMarshalSize verifies that the two-block byte-count
|
||||
// formula used in fs/ggml/ggml.go GraphSize matches the actual MarshalBinary
|
||||
// output size. This catches formula drift whenever Block layout changes.
|
||||
func TestMemoryFormulaMatchesMarshalSize(t *testing.T) {
|
||||
// The GraphSize formula replicated below assumes the two-block
|
||||
// outlier-split layout. All shipped tq* presets default to uniform
|
||||
// (OutlierCount=0) so force outlier variants here; the formula itself
|
||||
// is what's being pinned, not the default preset state.
|
||||
tq2Split := testOutlierPreset(PresetTQ2, 32)
|
||||
tq3Split := testOutlierPreset(PresetTQ3, 32)
|
||||
cases := []struct {
|
||||
preset Preset
|
||||
dim int
|
||||
}{
|
||||
{tq2Split, 128},
|
||||
{tq3Split, 128},
|
||||
{tq3Split, 64},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
vec := make([]float32, tc.dim)
|
||||
|
||||
keyEncoded, err := EncodeKeyVector(vec, tc.preset)
|
||||
if err != nil {
|
||||
t.Fatalf("%s dim=%d key: %v", tc.preset.Name, tc.dim, err)
|
||||
}
|
||||
keyData, err := keyEncoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatalf("%s dim=%d key marshal: %v", tc.preset.Name, tc.dim, err)
|
||||
}
|
||||
|
||||
valueEncoded, err := EncodeValueVector(vec, tc.preset)
|
||||
if err != nil {
|
||||
t.Fatalf("%s dim=%d value: %v", tc.preset.Name, tc.dim, err)
|
||||
}
|
||||
valueData, err := valueEncoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatalf("%s dim=%d value marshal: %v", tc.preset.Name, tc.dim, err)
|
||||
}
|
||||
|
||||
// Replicate the fs/ggml/ggml.go GraphSize formula.
|
||||
const outlierCount = uint64(32)
|
||||
outlierBits := uint64(tc.preset.OutlierBits)
|
||||
regularKeyBits := uint64(tc.preset.KeyPrimaryBits)
|
||||
regularValueBits := uint64(tc.preset.ValueBits)
|
||||
dim := uint64(tc.dim)
|
||||
outlierData := (outlierCount*outlierBits + 7) / 8
|
||||
qjlData := (outlierCount + 7) / 8
|
||||
wantKey := 122 + 2*dim + outlierData + ((dim-outlierCount)*regularKeyBits+7)/8 + qjlData
|
||||
wantValue := 122 + 2*dim + outlierData + ((dim-outlierCount)*regularValueBits+7)/8
|
||||
|
||||
if uint64(len(keyData)) != wantKey {
|
||||
t.Errorf("%s dim=%d: key MarshalBinary=%d bytes, formula=%d",
|
||||
tc.preset.Name, tc.dim, len(keyData), wantKey)
|
||||
}
|
||||
if uint64(len(valueData)) != wantValue {
|
||||
t.Errorf("%s dim=%d: value MarshalBinary=%d bytes, formula=%d",
|
||||
tc.preset.Name, tc.dim, len(valueData), wantValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutlierSplitBoundaries(t *testing.T) {
|
||||
for _, preset := range []Preset{testOutlierPreset(PresetTQ2, 32), testOutlierPreset(PresetTQ3, 32)} {
|
||||
atBoundary := pseudoRandomVector(preset.OutlierCount, 0xbabe)
|
||||
encoded, err := EncodeKeyVector(atBoundary, preset)
|
||||
if err != nil {
|
||||
t.Fatalf("%s dim=OutlierCount: %v", preset.Name, err)
|
||||
}
|
||||
if len(encoded.Blocks) != 1 {
|
||||
t.Errorf("%s dim=%d: got %d blocks, want 1 (no split at exact boundary)",
|
||||
preset.Name, preset.OutlierCount, len(encoded.Blocks))
|
||||
}
|
||||
if len(encoded.Blocks[0].ChannelIndices) != 0 {
|
||||
t.Errorf("%s dim=%d: single-block should have no ChannelIndices", preset.Name, preset.OutlierCount)
|
||||
}
|
||||
|
||||
minSplit := pseudoRandomVector(preset.OutlierCount+1, 0xbabe)
|
||||
encoded2, err := EncodeKeyVector(minSplit, preset)
|
||||
if err != nil {
|
||||
t.Fatalf("%s dim=OutlierCount+1: %v", preset.Name, err)
|
||||
}
|
||||
if len(encoded2.Blocks) != 2 {
|
||||
t.Errorf("%s dim=%d: got %d blocks, want 2 (minimum outlier split)",
|
||||
preset.Name, preset.OutlierCount+1, len(encoded2.Blocks))
|
||||
}
|
||||
// Regular block has dim=1; verify it round-trips cleanly.
|
||||
data, err := encoded2.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatalf("%s dim=OutlierCount+1 marshal: %v", preset.Name, err)
|
||||
}
|
||||
decoded, _, err := DecodeVector(data)
|
||||
if err != nil {
|
||||
t.Fatalf("%s dim=OutlierCount+1 decode: %v", preset.Name, err)
|
||||
}
|
||||
if len(decoded) != preset.OutlierCount+1 {
|
||||
t.Errorf("%s dim=OutlierCount+1: decoded len=%d want %d",
|
||||
preset.Name, len(decoded), preset.OutlierCount+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDecodeRoundTripAcrossShapes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
values []float32
|
||||
preset Preset
|
||||
maxMSE float32
|
||||
}{
|
||||
{name: "small", values: []float32{0.25, -1.5, 3.25, 0.75, -0.5, 2.0, 1.0}, preset: PresetTQ3, maxMSE: 1.5},
|
||||
{name: "non-power-of-two", values: pseudoRandomVector(70, 2), preset: PresetTQ3, maxMSE: 3.0},
|
||||
{name: "multi-head-like", values: pseudoRandomVector(128, 3), preset: PresetTQ2, maxMSE: 5.0},
|
||||
{name: "constant", values: filledVector(31, 1.5), preset: PresetTQ2, maxMSE: 1.0},
|
||||
{name: "zero", values: filledVector(64, 0), preset: PresetTQ3, maxMSE: 0.01},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
encoded, err := EncodeVector(tc.values, tc.preset)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := encoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decoded, preset, err := DecodeVector(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if preset.Name != tc.preset.Name {
|
||||
t.Fatalf("preset = %q, want %q", preset.Name, tc.preset.Name)
|
||||
}
|
||||
if len(decoded) != len(tc.values) {
|
||||
t.Fatalf("decoded len = %d, want %d", len(decoded), len(tc.values))
|
||||
}
|
||||
|
||||
stats := Compare(tc.values, decoded)
|
||||
if stats.MSE > tc.maxMSE {
|
||||
t.Fatalf("MSE = %v, want <= %v", stats.MSE, tc.maxMSE)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeVectorDeterministicBytes(t *testing.T) {
|
||||
values := pseudoRandomVector(96, 0x4242)
|
||||
|
||||
encodedA, err := EncodeVector(values, PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
encodedB, err := EncodeVector(values, PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dataA, err := encodedA.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dataB, err := encodedB.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(dataA) != string(dataB) {
|
||||
t.Fatal("expected byte-identical encoding output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeKeyAndValueUseDifferentObjectives(t *testing.T) {
|
||||
values := pseudoRandomVector(32, 0x77)
|
||||
keyEncoded, err := EncodeKeyVector(values, PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
valueEncoded, err := EncodeValueVector(values, PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if keyEncoded.Blocks[0].Objective != uint8(objectiveProduct) {
|
||||
t.Fatalf("key objective = %d, want %d", keyEncoded.Blocks[0].Objective, objectiveProduct)
|
||||
}
|
||||
if valueEncoded.Blocks[0].Objective != uint8(objectiveMSE) {
|
||||
t.Fatalf("value objective = %d, want %d", valueEncoded.Blocks[0].Objective, objectiveMSE)
|
||||
}
|
||||
if keyEncoded.Blocks[0].QJLRows == 0 {
|
||||
t.Fatal("expected product-mode key rows to carry a residual sketch")
|
||||
}
|
||||
if valueEncoded.Blocks[0].QJLRows != 0 {
|
||||
t.Fatal("expected MSE value rows to omit a residual sketch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPresetNames(t *testing.T) {
|
||||
for _, name := range []string{"tq2", "tq3", "tq3k", "tq2k"} {
|
||||
preset, err := PresetByName(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if preset.Name != name {
|
||||
t.Fatalf("preset %q resolved to %q", name, preset.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestQJLDimMatchesPaperSpec verifies that the QJL sketch uses d random
|
||||
// projections (one per dimension), matching the paper's specification in
|
||||
// arXiv 2504.19874. With QJLRowsDivisor=1 this ensures the estimator variance
|
||||
// matches the paper's theoretical analysis and the bit accounting is exact:
|
||||
// tq2 = 2.5 bits/elem avg, tq3 = 3.5 bits/elem avg.
|
||||
func TestQJLDimMatchesPaperSpec(t *testing.T) {
|
||||
cases := []struct {
|
||||
preset Preset
|
||||
dim int
|
||||
}{
|
||||
{PresetTQ2, 64},
|
||||
{PresetTQ2, 128},
|
||||
{PresetTQ3, 128},
|
||||
{PresetTQ3, 256},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := tc.preset.KeyQJLRows(tc.dim)
|
||||
if got != tc.dim {
|
||||
t.Errorf("%s: KeyQJLRows(%d) = %d, want %d (paper spec: d projections per d-dim vector)",
|
||||
tc.preset.Name, tc.dim, got, tc.dim)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaperMSEDistortionBound verifies that the MSE quantizer operates within
|
||||
// the information-theoretic bounds from Theorem 3 of arXiv 2504.19874:
|
||||
//
|
||||
// lower: D_mse >= 1/4^b
|
||||
// upper: D_mse <= (√3π/2) / 4^b ≈ 2.72 / 4^b
|
||||
//
|
||||
// Tested on 1000 random unit vectors at dim=128 (the primary validated head_dim).
|
||||
func TestPaperMSEDistortionBound(t *testing.T) {
|
||||
const dim = 128
|
||||
const trials = 1000
|
||||
|
||||
cases := []struct {
|
||||
preset Preset
|
||||
bits int
|
||||
}{
|
||||
{PresetTQ2, PresetTQ2.ValueBits},
|
||||
{PresetTQ3, PresetTQ3.ValueBits},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.preset.Name, func(t *testing.T) {
|
||||
lowerBound := 1.0 / math.Pow(4, float64(tc.bits))
|
||||
paperUpper := 2.72 / math.Pow(4, float64(tc.bits))
|
||||
|
||||
var totalDistortion float64
|
||||
rng := splitmix64(0x1234567890abcdef)
|
||||
for range trials {
|
||||
// Random unit vector drawn from the uniform distribution on S^{d-1}.
|
||||
vec := make([]float32, dim)
|
||||
var norm2 float64
|
||||
for j := range vec {
|
||||
v := gaussianFloat64(&rng)
|
||||
vec[j] = float32(v)
|
||||
norm2 += v * v
|
||||
}
|
||||
norm := math.Sqrt(norm2)
|
||||
for j := range vec {
|
||||
vec[j] /= float32(norm)
|
||||
}
|
||||
|
||||
encoded, err := EncodeValueVector(vec, tc.preset)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := encoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decoded, _, err := DecodeVector(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// D_mse = ||x - x_hat||^2 / ||x||^2 = ||x - x_hat||^2 (||x||=1)
|
||||
var distortion float64
|
||||
for j := range vec {
|
||||
d := float64(vec[j] - decoded[j])
|
||||
distortion += d * d
|
||||
}
|
||||
totalDistortion += distortion
|
||||
}
|
||||
avgDistortion := totalDistortion / float64(trials)
|
||||
|
||||
t.Logf("%s D_mse = %.6f, paper bounds [%.6f, %.6f]",
|
||||
tc.preset.Name, avgDistortion, lowerBound, paperUpper)
|
||||
|
||||
// Allow 1.5× headroom over the paper's upper bound to account for
|
||||
// finite-d effects and the Cartesian (non-PolarQuant) encoding path.
|
||||
if avgDistortion > paperUpper*1.5 {
|
||||
t.Fatalf("D_mse = %.6f exceeds paper upper bound × 1.5 (%.6f)",
|
||||
avgDistortion, paperUpper*1.5)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaperProductUnbiasedness verifies that the product-objective estimator
|
||||
// is near-unbiased: E[score(q, encoded_k) - dot(q, k)] ≈ 0. This is the
|
||||
// central claim of Q_prod in arXiv 2504.19874.
|
||||
func TestPaperProductUnbiasedness(t *testing.T) {
|
||||
const dim = 128
|
||||
const trials = 500
|
||||
// Allow 5% signed relative bias averaged over 500 trials.
|
||||
const maxRelBias = 0.05
|
||||
|
||||
rng := splitmix64(0xdeadbeefcafe1234)
|
||||
var signedBias, rmsTrue float64
|
||||
for range trials {
|
||||
query := make([]float32, dim)
|
||||
key := make([]float32, dim)
|
||||
for j := range query {
|
||||
query[j] = float32(gaussianFloat64(&rng))
|
||||
key[j] = float32(gaussianFloat64(&rng))
|
||||
}
|
||||
|
||||
encoded, err := EncodeKeyVector(key, PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := encoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decoded, _, err := DecodeVector(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var estimated float32
|
||||
for j := range query {
|
||||
estimated += query[j] * decoded[j]
|
||||
}
|
||||
|
||||
var trueDot float32
|
||||
for j := range query {
|
||||
trueDot += query[j] * key[j]
|
||||
}
|
||||
signedBias += float64(estimated - trueDot)
|
||||
rmsTrue += float64(trueDot * trueDot)
|
||||
}
|
||||
|
||||
avgSignedBias := signedBias / float64(trials)
|
||||
rmsTrue = math.Sqrt(rmsTrue / float64(trials))
|
||||
relativeBias := math.Abs(avgSignedBias) / rmsTrue
|
||||
|
||||
t.Logf("avg signed bias = %.6f, rms true dot = %.6f, relative bias = %.4f",
|
||||
avgSignedBias, rmsTrue, relativeBias)
|
||||
|
||||
if relativeBias > maxRelBias {
|
||||
t.Fatalf("relative bias = %.4f, want <= %.4f (product estimator should be near-unbiased)",
|
||||
relativeBias, maxRelBias)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOutlierSplitMSEImproves verifies that encoding with the outlier-split
|
||||
// strategy achieves lower MSE than uniform quantization at the same average bit
|
||||
// rate. This is the core quality claim of §4.3 of arXiv 2504.19874.
|
||||
func TestOutlierSplitMSEImproves(t *testing.T) {
|
||||
const dim = 128
|
||||
const trials = 200
|
||||
rng := splitmix64(0xfeedbabe12345678)
|
||||
|
||||
for _, preset := range []Preset{testOutlierPreset(PresetTQ2, 32), testOutlierPreset(PresetTQ3, 32)} {
|
||||
t.Run(preset.Name, func(t *testing.T) {
|
||||
var splitMSE, uniformMSE float64
|
||||
for range trials {
|
||||
vec := make([]float32, dim)
|
||||
for j := range vec {
|
||||
vec[j] = float32(gaussianFloat64(&rng))
|
||||
}
|
||||
|
||||
// Outlier-split encoding (2-block, current path).
|
||||
splitEncoded, err := EncodeValueVector(vec, preset)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
splitData, err := splitEncoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
splitDecoded, _, err := DecodeVector(splitData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for j := range vec {
|
||||
d := float64(vec[j] - splitDecoded[j])
|
||||
splitMSE += d * d
|
||||
}
|
||||
|
||||
// Uniform encoding: both sub-block sizes at regular bits, same total bits.
|
||||
// Encode the whole vector at regular bits to match the average bit rate.
|
||||
uniformBits := preset.ValueBits
|
||||
unifBlock, err := encodeSubBlock(vec, nil, preset, roleValue, objectiveMSE, uniformBits, preset.RotationSeed)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
codebook, _ := scalarCodebook(dim, uniformBits)
|
||||
rot := BuildRotation(dim, preset.RotationSeed)
|
||||
uIndices := unpackBits(unifBlock.RegularIndices, uniformBits, dim)
|
||||
unifRotated := make([]float32, dim)
|
||||
for j, idx := range uIndices {
|
||||
unifRotated[j] = dequantizeScalar(idx, codebook) * unifBlock.Scale
|
||||
}
|
||||
unifDecoded := ApplyInverseRotation(unifRotated, rot)
|
||||
for j := range vec {
|
||||
d := float64(vec[j] - unifDecoded[j])
|
||||
uniformMSE += d * d
|
||||
}
|
||||
}
|
||||
splitMSE /= float64(trials * dim)
|
||||
uniformMSE /= float64(trials * dim)
|
||||
t.Logf("%s: split MSE=%.6f uniform MSE=%.6f", preset.Name, splitMSE, uniformMSE)
|
||||
if splitMSE >= uniformMSE {
|
||||
t.Errorf("outlier split MSE (%.6f) not lower than uniform MSE (%.6f)", splitMSE, uniformMSE)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOutlierSplitProductUnbiasedness verifies that the multi-block product
|
||||
// estimator remains near-unbiased after the outlier split is applied.
|
||||
func TestOutlierSplitProductUnbiasedness(t *testing.T) {
|
||||
const dim = 128
|
||||
const trials = 300
|
||||
const maxRelBias = 0.07
|
||||
|
||||
outlierPreset := testOutlierPreset(PresetTQ3, 32)
|
||||
rng := splitmix64(0xabcdef0123456789)
|
||||
var signedBias, rmsTrue float64
|
||||
for range trials {
|
||||
query := make([]float32, dim)
|
||||
key := make([]float32, dim)
|
||||
for j := range query {
|
||||
query[j] = float32(gaussianFloat64(&rng))
|
||||
key[j] = float32(gaussianFloat64(&rng))
|
||||
}
|
||||
|
||||
encoded, err := EncodeKeyVector(key, outlierPreset)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := encoded.MarshalBinary()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decoded, _, err := DecodeVector(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var estimated float32
|
||||
for j := range query {
|
||||
estimated += query[j] * decoded[j]
|
||||
}
|
||||
|
||||
var trueDot float32
|
||||
for j := range query {
|
||||
trueDot += query[j] * key[j]
|
||||
}
|
||||
signedBias += float64(estimated - trueDot)
|
||||
rmsTrue += float64(trueDot * trueDot)
|
||||
}
|
||||
|
||||
avgSignedBias := signedBias / float64(trials)
|
||||
rmsTrue = math.Sqrt(rmsTrue / float64(trials))
|
||||
relativeBias := math.Abs(avgSignedBias) / rmsTrue
|
||||
|
||||
t.Logf("outlier-split avg signed bias = %.6f, rms true dot = %.6f, relative bias = %.4f",
|
||||
avgSignedBias, rmsTrue, relativeBias)
|
||||
if relativeBias > maxRelBias {
|
||||
t.Fatalf("relative bias = %.4f, want <= %.4f (multi-block estimator should be near-unbiased)",
|
||||
relativeBias, maxRelBias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistortionThresholds(t *testing.T) {
|
||||
tq2Mean, err := meanMSEForPreset(PresetTQ2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tq3Mean, err := meanMSEForPreset(PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tq2Mean > 5.5 {
|
||||
t.Fatalf("tq2 mean MSE = %v, want <= 5.5", tq2Mean)
|
||||
}
|
||||
if tq3Mean > 3.5 {
|
||||
t.Fatalf("tq3 mean MSE = %v, want <= 3.5", tq3Mean)
|
||||
}
|
||||
if tq3Mean > tq2Mean {
|
||||
t.Fatalf("tq3 mean MSE = %v, want <= tq2 mean MSE %v", tq3Mean, tq2Mean)
|
||||
}
|
||||
}
|
||||
|
||||
// ── per-head uniform encoding ──────────────────────────────────────────────
|
||||
|
||||
func TestEncodeKeyPerHeadRoundTrip(t *testing.T) {
|
||||
const dim = 128
|
||||
const trials = 200
|
||||
|
||||
for _, preset := range []Preset{PresetTQ2, PresetTQ3} {
|
||||
t.Run(preset.Name, func(t *testing.T) {
|
||||
rng := splitmix64(0xbeefcafe)
|
||||
var totalMSE float64
|
||||
var totalAbsErr, totalAbsDot float64
|
||||
|
||||
for range trials {
|
||||
values := make([]float32, dim)
|
||||
query := make([]float32, dim)
|
||||
for j := range values {
|
||||
values[j] = float32(gaussianFloat64(&rng))
|
||||
query[j] = float32(gaussianFloat64(&rng))
|
||||
}
|
||||
|
||||
packed, scale, err := EncodeKeyPerHead(values, preset)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectedBytes := (dim*preset.KeyPrimaryBits + 7) / 8
|
||||
if len(packed) != expectedBytes {
|
||||
t.Fatalf("packed len = %d, want %d", len(packed), expectedBytes)
|
||||
}
|
||||
|
||||
dequantRot := DequantKeyPerHead(packed, scale, dim, preset.KeyPrimaryBits)
|
||||
|
||||
// Verify MSE in rotated space (reconstruction quality).
|
||||
rotation := BuildRotation(dim, preset.RotationSeed)
|
||||
valuesRot := ApplyRotation(values, rotation)
|
||||
var mse float64
|
||||
for j := range valuesRot {
|
||||
d := float64(valuesRot[j] - dequantRot[j])
|
||||
mse += d * d
|
||||
}
|
||||
mse /= float64(dim)
|
||||
totalMSE += mse
|
||||
|
||||
// Verify dot product in rotated space matches true dot product.
|
||||
queryRot := ApplyRotation(query, rotation)
|
||||
var estDot, trueDot float32
|
||||
for j := range queryRot {
|
||||
estDot += queryRot[j] * dequantRot[j]
|
||||
}
|
||||
for j := range query {
|
||||
trueDot += query[j] * values[j]
|
||||
}
|
||||
totalAbsErr += math.Abs(float64(estDot - trueDot))
|
||||
totalAbsDot += math.Abs(float64(trueDot))
|
||||
|
||||
_ = scale
|
||||
}
|
||||
|
||||
avgMSE := totalMSE / trials
|
||||
avgRelErr := totalAbsErr / (totalAbsDot + 1e-8)
|
||||
|
||||
t.Logf("avg MSE = %.6f, avg relative dot error = %.4f", avgMSE, avgRelErr)
|
||||
|
||||
// Uniform encoding (no outlier split, no QJL) has higher error than
|
||||
// full TQ. These thresholds validate the codec works, not paper quality.
|
||||
maxRelErr := 0.45
|
||||
if preset.KeyPrimaryBits >= 3 {
|
||||
maxRelErr = 0.25
|
||||
}
|
||||
if avgRelErr > maxRelErr {
|
||||
t.Fatalf("avg relative dot error = %.4f, want <= %.4f", avgRelErr, maxRelErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeKeyPerHeadRoundTripLargeDims checks CPU reference round-trip
|
||||
// quality at the head dims used by models outside the llama/qwen D=128 norm.
|
||||
// These are regression targets for the non-128 code path: gemma3 global uses
|
||||
// D=256, gemma4 global uses D=512. If any sub-test passes but the CUDA path
|
||||
// produces different output at the same D, the bug is kernel-side; if a
|
||||
// sub-test fails, the core TQ math (rotation + codebook) has a D-dependent
|
||||
// bug.
|
||||
func TestEncodeKeyPerHeadRoundTripLargeDims(t *testing.T) {
|
||||
const trials = 200
|
||||
|
||||
for _, dim := range []int{256, 512} {
|
||||
for _, preset := range []Preset{PresetTQ2, PresetTQ3} {
|
||||
t.Run(fmt.Sprintf("dim%d_%s", dim, preset.Name), func(t *testing.T) {
|
||||
rng := splitmix64(0xbeefcafe ^ uint64(dim))
|
||||
var totalMSE, totalAbsErr, totalAbsDot float64
|
||||
|
||||
for range trials {
|
||||
values := make([]float32, dim)
|
||||
query := make([]float32, dim)
|
||||
for j := range values {
|
||||
values[j] = float32(gaussianFloat64(&rng))
|
||||
query[j] = float32(gaussianFloat64(&rng))
|
||||
}
|
||||
|
||||
packed, scale, err := EncodeKeyPerHead(values, preset)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dequantRot := DequantKeyPerHead(packed, scale, dim, preset.KeyPrimaryBits)
|
||||
|
||||
rotation := BuildRotation(dim, preset.RotationSeed)
|
||||
valuesRot := ApplyRotation(values, rotation)
|
||||
var mse float64
|
||||
for j := range valuesRot {
|
||||
d := float64(valuesRot[j] - dequantRot[j])
|
||||
mse += d * d
|
||||
}
|
||||
mse /= float64(dim)
|
||||
totalMSE += mse
|
||||
|
||||
queryRot := ApplyRotation(query, rotation)
|
||||
var estDot, trueDot float32
|
||||
for j := range queryRot {
|
||||
estDot += queryRot[j] * dequantRot[j]
|
||||
}
|
||||
for j := range query {
|
||||
trueDot += query[j] * values[j]
|
||||
}
|
||||
totalAbsErr += math.Abs(float64(estDot - trueDot))
|
||||
totalAbsDot += math.Abs(float64(trueDot))
|
||||
}
|
||||
|
||||
avgMSE := totalMSE / trials
|
||||
avgRelErr := totalAbsErr / (totalAbsDot + 1e-8)
|
||||
t.Logf("D=%d %s: avg MSE = %.6f, avg rel dot error = %.4f", dim, preset.Name, avgMSE, avgRelErr)
|
||||
|
||||
maxRelErr := 0.45
|
||||
if preset.KeyPrimaryBits >= 3 {
|
||||
maxRelErr = 0.25
|
||||
}
|
||||
if avgRelErr > maxRelErr {
|
||||
t.Fatalf("D=%d %s: avg relative dot error = %.4f, want <= %.4f", dim, preset.Name, avgRelErr, maxRelErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeKeyPerHeadWithDCOffset checks whether TQ handles K vectors with
|
||||
// a DC bias component. Models like qwen2/qwen2.5 have learned Q/K/V bias
|
||||
// tensors; their K vectors have a non-zero mean that TQ's RMS-based scale
|
||||
// normalization doesn't center out. This test adds a DC offset to random
|
||||
// Gaussian samples and measures round-trip quality vs the non-biased case.
|
||||
func TestEncodeKeyPerHeadWithDCOffset(t *testing.T) {
|
||||
const dim = 128
|
||||
const trials = 200
|
||||
|
||||
offsets := []float32{0.0, 0.3, 1.0, 3.0}
|
||||
for _, offset := range offsets {
|
||||
for _, preset := range []Preset{PresetTQ3} {
|
||||
t.Run(fmt.Sprintf("offset%.1f_%s", offset, preset.Name), func(t *testing.T) {
|
||||
rng := splitmix64(0xbeefcafe)
|
||||
var totalAbsErr, totalAbsDot float64
|
||||
|
||||
for range trials {
|
||||
values := make([]float32, dim)
|
||||
query := make([]float32, dim)
|
||||
for j := range values {
|
||||
values[j] = float32(gaussianFloat64(&rng)) + offset
|
||||
query[j] = float32(gaussianFloat64(&rng))
|
||||
}
|
||||
|
||||
packed, scale, err := EncodeKeyPerHead(values, preset)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dequantRot := DequantKeyPerHead(packed, scale, dim, preset.KeyPrimaryBits)
|
||||
|
||||
rotation := BuildRotation(dim, preset.RotationSeed)
|
||||
queryRot := ApplyRotation(query, rotation)
|
||||
var estDot, trueDot float32
|
||||
for j := range queryRot {
|
||||
estDot += queryRot[j] * dequantRot[j]
|
||||
}
|
||||
for j := range query {
|
||||
trueDot += query[j] * values[j]
|
||||
}
|
||||
totalAbsErr += math.Abs(float64(estDot - trueDot))
|
||||
totalAbsDot += math.Abs(float64(trueDot))
|
||||
}
|
||||
|
||||
avgRelErr := totalAbsErr / (totalAbsDot + 1e-8)
|
||||
t.Logf("offset=%.1f %s: avg rel dot error = %.4f", offset, preset.Name, avgRelErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeKeyPerHeadZeroVector(t *testing.T) {
|
||||
values := make([]float32, 64)
|
||||
packed, scale, err := EncodeKeyPerHead(values, PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if scale != 0 {
|
||||
t.Fatalf("expected zero scale for zero vector, got %v", scale)
|
||||
}
|
||||
dequant := DequantKeyPerHead(packed, scale, 64, PresetTQ3.KeyPrimaryBits)
|
||||
for i, v := range dequant {
|
||||
if v != 0 {
|
||||
t.Fatalf("dequant[%d] = %v, want 0", i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeKeyPerHeadDimSizes(t *testing.T) {
|
||||
for _, dim := range []int{32, 64, 96, 128, 256} {
|
||||
t.Run(fmt.Sprintf("dim%d", dim), func(t *testing.T) {
|
||||
values := pseudoRandomVector(dim, uint64(dim))
|
||||
packed, scale, err := EncodeKeyPerHead(values, PresetTQ3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if scale <= 0 {
|
||||
t.Fatalf("expected positive scale for dim=%d", dim)
|
||||
}
|
||||
expectedBytes := (dim*PresetTQ3.KeyPrimaryBits + 7) / 8
|
||||
if len(packed) != expectedBytes {
|
||||
t.Fatalf("packed len = %d, want %d for dim=%d", len(packed), expectedBytes, dim)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeKeyPerHeadDeterministic(t *testing.T) {
|
||||
for _, preset := range []Preset{PresetTQ2, PresetTQ3} {
|
||||
t.Run(preset.Name, func(t *testing.T) {
|
||||
vec := pseudoRandomVector(128, 0xcafe)
|
||||
packed1, scale1, err := EncodeKeyPerHead(vec, preset)
|
||||
if err != nil {
|
||||
t.Fatalf("first EncodeKeyPerHead: %v", err)
|
||||
}
|
||||
packed2, scale2, err := EncodeKeyPerHead(vec, preset)
|
||||
if err != nil {
|
||||
t.Fatalf("second EncodeKeyPerHead: %v", err)
|
||||
}
|
||||
if scale1 != scale2 {
|
||||
t.Fatalf("scale not deterministic: %f vs %f", scale1, scale2)
|
||||
}
|
||||
for i := range packed1 {
|
||||
if packed1[i] != packed2[i] {
|
||||
t.Fatalf("packed byte %d differs: %d vs %d", i, packed1[i], packed2[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// studentT3Float64 samples from Student-t with df=3 via Z / sqrt(ChiSq3/3).
|
||||
// Student-t df=3 has finite mean/variance but heavy tails (kurtosis → ∞),
|
||||
// which stresses quantizers that assume Gaussian inputs.
|
||||
func studentT3Float64(rng *splitmix64) float64 {
|
||||
z := gaussianFloat64(rng)
|
||||
// ChiSq(3) = sum of 3 squared standard normals.
|
||||
var chisq float64
|
||||
for range 3 {
|
||||
n := gaussianFloat64(rng)
|
||||
chisq += n * n
|
||||
}
|
||||
return z / math.Sqrt(chisq/3.0)
|
||||
}
|
||||
|
||||
// TestOutlierSplitVsUniformHeavyTailed verifies that the CPU reference
|
||||
// outlier-split encoder beats the uniform encoder on heavy-tailed K
|
||||
// vectors (Student-t df=3). This is the statistical regression test
|
||||
// for the outlier-split algorithm: if someone breaks the split logic
|
||||
// (e.g. misattributes channels to sub-blocks, uses the wrong scale),
|
||||
// this test starts failing because outlier no longer helps.
|
||||
//
|
||||
// Acceptance: on Student-t df=3 inputs, outlier-split relative dot
|
||||
// error must be <= 1.5x the uniform relative dot error. In practice
|
||||
// outlier-split is substantially BETTER on heavy tails — the 1.5x
|
||||
// bound is a floor to catch "outlier split broken" regressions, not
|
||||
// the expected performance.
|
||||
func TestOutlierSplitVsUniformHeavyTailed(t *testing.T) {
|
||||
const dim = 128
|
||||
const trials = 200
|
||||
|
||||
presets := []Preset{
|
||||
testOutlierPreset(PresetTQ3, 32),
|
||||
testOutlierPreset(PresetTQ3K, 32),
|
||||
}
|
||||
for _, preset := range presets {
|
||||
t.Run(preset.Name, func(t *testing.T) {
|
||||
rng := splitmix64(0xc0de_babe_dead_beef)
|
||||
var uniformAbsErr, uniformAbsDot float64
|
||||
var outlierAbsErr, outlierAbsDot float64
|
||||
|
||||
for trial := range trials {
|
||||
values := make([]float32, dim)
|
||||
query := make([]float32, dim)
|
||||
for j := range values {
|
||||
values[j] = float32(studentT3Float64(&rng))
|
||||
query[j] = float32(gaussianFloat64(&rng))
|
||||
}
|
||||
|
||||
rotation := BuildRotation(dim, preset.RotationSeed)
|
||||
valuesRot := ApplyRotation(values, rotation)
|
||||
queryRot := ApplyRotation(query, rotation)
|
||||
|
||||
var trueDot float32
|
||||
for j := range query {
|
||||
trueDot += query[j] * values[j]
|
||||
}
|
||||
|
||||
// Uniform path.
|
||||
uPacked, uScale, err := EncodeKeyPerHead(values, preset)
|
||||
if err != nil {
|
||||
t.Fatalf("uniform encode: %v", err)
|
||||
}
|
||||
uDeq := DequantKeyPerHead(uPacked, uScale, dim, preset.KeyPrimaryBits)
|
||||
var uEstDot float32
|
||||
for j := range uDeq {
|
||||
uEstDot += queryRot[j] * uDeq[j]
|
||||
}
|
||||
uniformAbsErr += math.Abs(float64(uEstDot - trueDot))
|
||||
uniformAbsDot += math.Abs(float64(trueDot))
|
||||
|
||||
// Outlier-split path.
|
||||
oEnc, err := EncodeKeyPerHeadOutlier(values, preset)
|
||||
if err != nil {
|
||||
t.Fatalf("outlier encode: %v", err)
|
||||
}
|
||||
oDeq := DequantKeyPerHeadOutlier(oEnc, preset, dim)
|
||||
var oEstDot float32
|
||||
for j := range oDeq {
|
||||
oEstDot += queryRot[j] * oDeq[j]
|
||||
}
|
||||
outlierAbsErr += math.Abs(float64(oEstDot - trueDot))
|
||||
outlierAbsDot += math.Abs(float64(trueDot))
|
||||
|
||||
// Sanity: the outlier-split decode should still be close
|
||||
// enough to the rotated input that a round-trip MSE is
|
||||
// sensible. Catches frame-shifted reconstructions where
|
||||
// the dot happens to land right but the vector is wrong.
|
||||
var mse float64
|
||||
for j := range oDeq {
|
||||
d := float64(oDeq[j] - valuesRot[j])
|
||||
mse += d * d
|
||||
}
|
||||
mse /= float64(dim)
|
||||
if mse > 4.0 {
|
||||
t.Fatalf("trial %d: outlier-split rotated-space MSE = %.4f, reconstruction is garbage", trial, mse)
|
||||
}
|
||||
}
|
||||
|
||||
uniformRelErr := uniformAbsErr / (uniformAbsDot + 1e-8)
|
||||
outlierRelErr := outlierAbsErr / (outlierAbsDot + 1e-8)
|
||||
t.Logf("%s Student-t df=3: uniform rel-dot-err=%.4f, outlier rel-dot-err=%.4f, ratio=%.3f",
|
||||
preset.Name, uniformRelErr, outlierRelErr, outlierRelErr/uniformRelErr)
|
||||
|
||||
// Outlier-split must not be worse than 1.5x uniform. In
|
||||
// practice it should be substantially better.
|
||||
if outlierRelErr > 1.5*uniformRelErr {
|
||||
t.Fatalf("%s: outlier rel-dot-err %.4f > 1.5x uniform %.4f — outlier split broken",
|
||||
preset.Name, outlierRelErr, uniformRelErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOutlierPerHeadRoundTrip pins the CPU reference bit-exact round-trip
|
||||
// at the shapes actually used by GPU kernels (headDim=128, 256). Encode
|
||||
// → dequant → compare against ApplyRotation(values, rotation). This
|
||||
// catches algorithmic drift (wrong codebook selection, off-by-one in
|
||||
// slot mapping, index packing errors) even when the GPU path is broken.
|
||||
func TestOutlierPerHeadRoundTrip(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
dim int
|
||||
preset Preset
|
||||
}{
|
||||
{"d128_tq3", 128, testOutlierPreset(PresetTQ3, 32)},
|
||||
{"d128_tq3k", 128, testOutlierPreset(PresetTQ3K, 32)},
|
||||
{"d256_tq3k", 256, testOutlierPreset(PresetTQ3K, 32)}, // gemma3:1b global-layer headDim
|
||||
{"d128_tq2k", 128, testOutlierPreset(PresetTQ2K, 32)},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const trials = 50
|
||||
rng := splitmix64(0xfeed_face_0000_0000 | uint64(tc.dim))
|
||||
var totalMSE float64
|
||||
|
||||
for range trials {
|
||||
values := make([]float32, tc.dim)
|
||||
for j := range values {
|
||||
values[j] = float32(gaussianFloat64(&rng))
|
||||
}
|
||||
|
||||
enc, err := EncodeKeyPerHeadOutlier(values, tc.preset)
|
||||
if err != nil {
|
||||
t.Fatalf("encode: %v", err)
|
||||
}
|
||||
if len(enc.OutlierIndices) != tc.preset.OutlierCount {
|
||||
t.Fatalf("outlier count = %d, want %d", len(enc.OutlierIndices), tc.preset.OutlierCount)
|
||||
}
|
||||
expRegularBytes := ((tc.dim-tc.preset.OutlierCount)*tc.preset.KeyPrimaryBits + 7) / 8
|
||||
if len(enc.RegularPacked) != expRegularBytes {
|
||||
t.Fatalf("regular packed len = %d, want %d", len(enc.RegularPacked), expRegularBytes)
|
||||
}
|
||||
expOutlierBytes := (tc.preset.OutlierCount*tc.preset.OutlierBits + 7) / 8
|
||||
if len(enc.OutlierPacked) != expOutlierBytes {
|
||||
t.Fatalf("outlier packed len = %d, want %d", len(enc.OutlierPacked), expOutlierBytes)
|
||||
}
|
||||
|
||||
dec := DequantKeyPerHeadOutlier(enc, tc.preset, tc.dim)
|
||||
rotation := BuildRotation(tc.dim, tc.preset.RotationSeed)
|
||||
valuesRot := ApplyRotation(values, rotation)
|
||||
|
||||
var mse float64
|
||||
for j := range valuesRot {
|
||||
d := float64(valuesRot[j] - dec[j])
|
||||
mse += d * d
|
||||
}
|
||||
mse /= float64(tc.dim)
|
||||
totalMSE += mse
|
||||
}
|
||||
|
||||
avgMSE := totalMSE / trials
|
||||
t.Logf("%s: avg rotated-space round-trip MSE = %.4f", tc.name, avgMSE)
|
||||
// Lloyd-Max 3-bit on unit-variance Gaussian has MSE ~0.04 per
|
||||
// channel in rotated space. 0.15 is a loose bound that still
|
||||
// catches format bugs.
|
||||
if avgMSE > 0.15 {
|
||||
t.Fatalf("%s: round-trip MSE %.4f too large — encoder/dequant mismatch", tc.name, avgMSE)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
88
turboquant/outlier.go
Normal file
88
turboquant/outlier.go
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// OutlierSplit partitions a vector's channels into outlier and regular sets.
|
||||
// OutlierIndices and RegularIndices are sorted in ascending order.
|
||||
// OutlierValues and RegularValues contain the corresponding channel values.
|
||||
type OutlierSplit struct {
|
||||
OutlierIndices []uint16
|
||||
OutlierValues []float32
|
||||
RegularIndices []uint16
|
||||
RegularValues []float32
|
||||
}
|
||||
|
||||
// SplitOutlierChannels identifies the top-outlierCount channels by absolute
|
||||
// magnitude and returns them separately from the remaining regular channels.
|
||||
// Tie-breaking is by index ascending for determinism.
|
||||
// If outlierCount <= 0, all channels are returned as regular (empty outlier).
|
||||
// If outlierCount >= len(values), all channels are returned as outlier.
|
||||
func SplitOutlierChannels(values []float32, outlierCount int) OutlierSplit {
|
||||
dim := len(values)
|
||||
if outlierCount <= 0 || dim == 0 {
|
||||
indices := make([]uint16, dim)
|
||||
vals := make([]float32, dim)
|
||||
for i := range values {
|
||||
indices[i] = uint16(i)
|
||||
vals[i] = values[i]
|
||||
}
|
||||
return OutlierSplit{RegularIndices: indices, RegularValues: vals}
|
||||
}
|
||||
if outlierCount >= dim {
|
||||
indices := make([]uint16, dim)
|
||||
vals := make([]float32, dim)
|
||||
for i := range values {
|
||||
indices[i] = uint16(i)
|
||||
vals[i] = values[i]
|
||||
}
|
||||
return OutlierSplit{OutlierIndices: indices, OutlierValues: vals}
|
||||
}
|
||||
|
||||
// Sort channel indices by abs value descending, tie-break by index ascending.
|
||||
order := make([]int, dim)
|
||||
for i := range order {
|
||||
order[i] = i
|
||||
}
|
||||
slices.SortStableFunc(order, func(a, b int) int {
|
||||
va := abs32(values[a])
|
||||
vb := abs32(values[b])
|
||||
if va != vb {
|
||||
return cmp.Compare(vb, va)
|
||||
}
|
||||
return cmp.Compare(a, b)
|
||||
})
|
||||
|
||||
outlierSet := make([]int, outlierCount)
|
||||
regularSet := make([]int, dim-outlierCount)
|
||||
copy(outlierSet, order[:outlierCount])
|
||||
copy(regularSet, order[outlierCount:])
|
||||
|
||||
// Each set is sorted by index ascending so ChannelIndices is deterministic.
|
||||
slices.Sort(outlierSet)
|
||||
slices.Sort(regularSet)
|
||||
|
||||
outlierIndices := make([]uint16, outlierCount)
|
||||
outlierValues := make([]float32, outlierCount)
|
||||
for i, idx := range outlierSet {
|
||||
outlierIndices[i] = uint16(idx)
|
||||
outlierValues[i] = values[idx]
|
||||
}
|
||||
|
||||
regularCount := dim - outlierCount
|
||||
regularIndices := make([]uint16, regularCount)
|
||||
regularValues := make([]float32, regularCount)
|
||||
for i, idx := range regularSet {
|
||||
regularIndices[i] = uint16(idx)
|
||||
regularValues[i] = values[idx]
|
||||
}
|
||||
|
||||
return OutlierSplit{
|
||||
OutlierIndices: outlierIndices,
|
||||
OutlierValues: outlierValues,
|
||||
RegularIndices: regularIndices,
|
||||
RegularValues: regularValues,
|
||||
}
|
||||
}
|
||||
133
turboquant/outlier_test.go
Normal file
133
turboquant/outlier_test.go
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitOutlierChannelsBasic(t *testing.T) {
|
||||
values := []float32{0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6}
|
||||
split := SplitOutlierChannels(values, 3)
|
||||
|
||||
// Top-3 by abs magnitude: indices 1(0.9), 3(0.8), 5(0.7)
|
||||
wantOutlierIdx := []uint16{1, 3, 5}
|
||||
wantRegularIdx := []uint16{0, 2, 4, 6, 7}
|
||||
|
||||
if len(split.OutlierIndices) != len(wantOutlierIdx) {
|
||||
t.Fatalf("outlier count: got %d want %d", len(split.OutlierIndices), len(wantOutlierIdx))
|
||||
}
|
||||
for i, idx := range wantOutlierIdx {
|
||||
if split.OutlierIndices[i] != idx {
|
||||
t.Errorf("outlier[%d]: got %d want %d", i, split.OutlierIndices[i], idx)
|
||||
}
|
||||
if split.OutlierValues[i] != values[idx] {
|
||||
t.Errorf("outlierVal[%d]: got %v want %v", i, split.OutlierValues[i], values[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if len(split.RegularIndices) != len(wantRegularIdx) {
|
||||
t.Fatalf("regular count: got %d want %d", len(split.RegularIndices), len(wantRegularIdx))
|
||||
}
|
||||
for i, idx := range wantRegularIdx {
|
||||
if split.RegularIndices[i] != idx {
|
||||
t.Errorf("regular[%d]: got %d want %d", i, split.RegularIndices[i], idx)
|
||||
}
|
||||
if split.RegularValues[i] != values[idx] {
|
||||
t.Errorf("regularVal[%d]: got %v want %v", i, split.RegularValues[i], values[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOutlierChannelsCoversAll(t *testing.T) {
|
||||
// Every channel index must appear exactly once across outlier + regular.
|
||||
values := []float32{3, 1, 4, 1, 5, 9, 2, 6}
|
||||
for outlierCount := 0; outlierCount <= len(values); outlierCount++ {
|
||||
split := SplitOutlierChannels(values, outlierCount)
|
||||
seen := make(map[uint16]int)
|
||||
for _, idx := range split.OutlierIndices {
|
||||
seen[idx]++
|
||||
}
|
||||
for _, idx := range split.RegularIndices {
|
||||
seen[idx]++
|
||||
}
|
||||
for i := range values {
|
||||
if seen[uint16(i)] != 1 {
|
||||
t.Errorf("outlierCount=%d: channel %d appears %d times", outlierCount, i, seen[uint16(i)])
|
||||
}
|
||||
}
|
||||
if len(split.OutlierIndices)+len(split.RegularIndices) != len(values) {
|
||||
t.Errorf("outlierCount=%d: total channels %d != %d", outlierCount, len(split.OutlierIndices)+len(split.RegularIndices), len(values))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOutlierChannelsZeroOutliers(t *testing.T) {
|
||||
values := []float32{1, 2, 3, 4}
|
||||
split := SplitOutlierChannels(values, 0)
|
||||
if len(split.OutlierIndices) != 0 {
|
||||
t.Errorf("expected 0 outliers, got %d", len(split.OutlierIndices))
|
||||
}
|
||||
if len(split.RegularIndices) != 4 {
|
||||
t.Errorf("expected 4 regular, got %d", len(split.RegularIndices))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOutlierChannelsAllOutliers(t *testing.T) {
|
||||
values := []float32{1, 2, 3, 4}
|
||||
split := SplitOutlierChannels(values, len(values))
|
||||
if len(split.OutlierIndices) != 4 {
|
||||
t.Errorf("expected 4 outliers, got %d", len(split.OutlierIndices))
|
||||
}
|
||||
if len(split.RegularIndices) != 0 {
|
||||
t.Errorf("expected 0 regular, got %d", len(split.RegularIndices))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOutlierChannelsNegativeValues(t *testing.T) {
|
||||
// Negative values: magnitude matters, not sign.
|
||||
values := []float32{-5, 1, -3, 2}
|
||||
split := SplitOutlierChannels(values, 2)
|
||||
// Top-2 by abs: index 0 (5.0), index 2 (3.0)
|
||||
if split.OutlierIndices[0] != 0 || split.OutlierIndices[1] != 2 {
|
||||
t.Errorf("expected outlier indices [0,2], got %v", split.OutlierIndices)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOutlierChannelsTieBreakerByIndex(t *testing.T) {
|
||||
// Equal magnitudes: lower index wins as outlier (tie-break index ascending).
|
||||
values := []float32{1, 1, 1, 1}
|
||||
split := SplitOutlierChannels(values, 2)
|
||||
// Should pick indices 0 and 1 as outliers.
|
||||
if split.OutlierIndices[0] != 0 || split.OutlierIndices[1] != 1 {
|
||||
t.Errorf("expected outlier indices [0,1], got %v", split.OutlierIndices)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOutlierChannelsIndicesSorted(t *testing.T) {
|
||||
// Output indices must be sorted ascending in both slices.
|
||||
values := pseudoRandomVector(16, 42)
|
||||
split := SplitOutlierChannels(values, 5)
|
||||
for i := 1; i < len(split.OutlierIndices); i++ {
|
||||
if split.OutlierIndices[i] <= split.OutlierIndices[i-1] {
|
||||
t.Errorf("outlier indices not sorted at %d: %d <= %d", i, split.OutlierIndices[i], split.OutlierIndices[i-1])
|
||||
}
|
||||
}
|
||||
for i := 1; i < len(split.RegularIndices); i++ {
|
||||
if split.RegularIndices[i] <= split.RegularIndices[i-1] {
|
||||
t.Errorf("regular indices not sorted at %d: %d <= %d", i, split.RegularIndices[i], split.RegularIndices[i-1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitOutlierChannelsDeterministic(t *testing.T) {
|
||||
values := pseudoRandomVector(32, 99)
|
||||
a := SplitOutlierChannels(values, 8)
|
||||
b := SplitOutlierChannels(values, 8)
|
||||
if len(a.OutlierIndices) != len(b.OutlierIndices) {
|
||||
t.Fatal("non-deterministic outlier count")
|
||||
}
|
||||
for i := range a.OutlierIndices {
|
||||
if a.OutlierIndices[i] != b.OutlierIndices[i] {
|
||||
t.Fatalf("non-deterministic at outlier index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
158
turboquant/residual_qjl.go
Normal file
158
turboquant/residual_qjl.go
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
package turboquant
|
||||
|
||||
import "math"
|
||||
|
||||
const qjlUnbiasScale = 1.2533141373155001 // sqrt(pi / 2)
|
||||
|
||||
type ResidualSketch struct {
|
||||
Seed uint64
|
||||
Scale float32 // residual L2 norm; retained name keeps the old struct shape stable
|
||||
SketchDim uint16
|
||||
Signs []byte
|
||||
}
|
||||
|
||||
func encodeResidual(rotated, approx []float32, sketchSpec any, seed uint64) ResidualSketch {
|
||||
var sketchRows int
|
||||
switch v := sketchSpec.(type) {
|
||||
case int:
|
||||
sketchRows = v
|
||||
case Preset:
|
||||
sketchRows = v.KeyQJLRows(len(rotated))
|
||||
default:
|
||||
return ResidualSketch{}
|
||||
}
|
||||
if sketchRows <= 0 || len(rotated) == 0 {
|
||||
return ResidualSketch{}
|
||||
}
|
||||
|
||||
residual := make([]float32, len(rotated))
|
||||
var l2 float64
|
||||
for i := range rotated {
|
||||
delta := rotated[i] - approx[i]
|
||||
residual[i] = delta
|
||||
l2 += float64(delta * delta)
|
||||
}
|
||||
|
||||
if l2 == 0 {
|
||||
return ResidualSketch{
|
||||
Seed: seed,
|
||||
SketchDim: uint16(sketchRows),
|
||||
Signs: make([]byte, expectedPackedBytes(sketchRows, 1)),
|
||||
}
|
||||
}
|
||||
|
||||
signBits := make([]uint8, sketchRows)
|
||||
for row := range sketchRows {
|
||||
if gaussianProjectionDot(residual, seed, row) >= 0 {
|
||||
signBits[row] = 1
|
||||
}
|
||||
}
|
||||
|
||||
return ResidualSketch{
|
||||
Seed: seed,
|
||||
Scale: float32(math.Sqrt(l2)),
|
||||
SketchDim: uint16(sketchRows),
|
||||
Signs: packBits(signBits, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func reconstructResidual(dim int, sketch ResidualSketch) []float32 {
|
||||
out := make([]float32, dim)
|
||||
if dim == 0 || sketch.SketchDim == 0 || sketch.Scale == 0 {
|
||||
return out
|
||||
}
|
||||
|
||||
signBits := unpackBits(sketch.Signs, 1, int(sketch.SketchDim))
|
||||
scale := float32(qjlUnbiasScale) * sketch.Scale / float32(sketch.SketchDim)
|
||||
for row, bit := range signBits {
|
||||
sign := float32(-1)
|
||||
if bit == 1 {
|
||||
sign = 1
|
||||
}
|
||||
// Residual reconstruction stays on float32 accumulation today; if a backend-specific half path is introduced, it needs an explicit FP32-accumulate audit before rollout.
|
||||
for col := range dim {
|
||||
out[col] += sign * gaussianProjectionEntry(sketch.Seed, row, col) * scale
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func residualDotCorrection(queryRot []float32, sketch ResidualSketch) float32 {
|
||||
if len(queryRot) == 0 || sketch.SketchDim == 0 || sketch.Scale == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
signBits := unpackBits(sketch.Signs, 1, int(sketch.SketchDim))
|
||||
var total float32
|
||||
for row, bit := range signBits {
|
||||
sign := float32(-1)
|
||||
if bit == 1 {
|
||||
sign = 1
|
||||
}
|
||||
total += sign * gaussianProjectionDot(queryRot, sketch.Seed, row)
|
||||
}
|
||||
|
||||
correction := float32(qjlUnbiasScale) * sketch.Scale * (total / float32(sketch.SketchDim))
|
||||
queryNorm := float32(math.Sqrt(float64(dotSelf(queryRot))))
|
||||
if queryNorm == 0 {
|
||||
return 0
|
||||
}
|
||||
maxCorrection := sketch.Scale * queryNorm
|
||||
if correction > maxCorrection {
|
||||
return maxCorrection
|
||||
}
|
||||
if correction < -maxCorrection {
|
||||
return -maxCorrection
|
||||
}
|
||||
if sketch.Scale < 1e-6 {
|
||||
return 0
|
||||
}
|
||||
return correction
|
||||
}
|
||||
|
||||
// PrecomputeCorrectionVec builds the vector w = (√π/2 · residualNorm / sketchDim) · Σ_j sign_j · G_j,
|
||||
// where G_j is row j of the random Gaussian projection matrix and sign_j is the stored QJL sign bit.
|
||||
// Scoring then reduces to dot(queryRotated, w), replacing the per-query O(dim²) Gaussian projection
|
||||
// loop with a single O(dim) dot product.
|
||||
//
|
||||
// The returned slice has length dim and is nil when the sketch carries no correction (Scale==0).
|
||||
func PrecomputeCorrectionVec(sketch ResidualSketch, dim int) []float32 {
|
||||
if sketch.SketchDim == 0 || sketch.Scale == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]float32, dim)
|
||||
signBits := unpackBits(sketch.Signs, 1, int(sketch.SketchDim))
|
||||
scale := float32(qjlUnbiasScale) * sketch.Scale / float32(sketch.SketchDim)
|
||||
for row, bit := range signBits {
|
||||
sign := float32(-1)
|
||||
if bit == 1 {
|
||||
sign = 1
|
||||
}
|
||||
sv := sign * scale
|
||||
for col := range dim {
|
||||
out[col] += sv * gaussianProjectionEntry(sketch.Seed, row, col)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func gaussianProjectionDot(values []float32, seed uint64, row int) float32 {
|
||||
var out float32
|
||||
for col, value := range values {
|
||||
out += value * gaussianProjectionEntry(seed, row, col)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func gaussianProjectionEntry(seed uint64, row, col int) float32 {
|
||||
local := splitmix64(seed ^ uint64(row+1)*0x9e3779b97f4a7c15 ^ uint64(col+1)*0xbf58476d1ce4e5b9)
|
||||
return float32(gaussianFloat64(&local))
|
||||
}
|
||||
|
||||
func dotSelf(values []float32) float32 {
|
||||
var out float32
|
||||
for _, value := range values {
|
||||
out += value * value
|
||||
}
|
||||
return out
|
||||
}
|
||||
58
turboquant/residual_qjl_test.go
Normal file
58
turboquant/residual_qjl_test.go
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
package turboquant
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestResidualSketchDeterministic(t *testing.T) {
|
||||
a := encodeResidual([]float32{1, 2, 3, 4}, []float32{0, 0, 0, 0}, PresetTQ2, 99)
|
||||
b := encodeResidual([]float32{1, 2, 3, 4}, []float32{0, 0, 0, 0}, PresetTQ2, 99)
|
||||
if a.Scale != b.Scale || a.Seed != b.Seed || string(a.Signs) != string(b.Signs) {
|
||||
t.Fatal("residual sketch is not deterministic")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconstructResidualLength(t *testing.T) {
|
||||
sketch := encodeResidual(
|
||||
[]float32{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
[]float32{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
PresetTQ3,
|
||||
123,
|
||||
)
|
||||
reconstructed := reconstructResidual(8, sketch)
|
||||
if len(reconstructed) != 8 {
|
||||
t.Fatalf("reconstructed length = %d, want 8", len(reconstructed))
|
||||
}
|
||||
if len(sketch.Signs) != expectedPackedBytes(int(sketch.SketchDim), 1) {
|
||||
t.Fatalf("sign sketch bytes = %d, want %d", len(sketch.Signs), expectedPackedBytes(int(sketch.SketchDim), 1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestZeroResidualProducesZeroReconstruction(t *testing.T) {
|
||||
sketch := encodeResidual(
|
||||
[]float32{1, 1, 1, 1},
|
||||
[]float32{1, 1, 1, 1},
|
||||
PresetTQ3,
|
||||
456,
|
||||
)
|
||||
reconstructed := reconstructResidual(4, sketch)
|
||||
for i, value := range reconstructed {
|
||||
if abs32(value) > 1e-6 {
|
||||
t.Fatalf("reconstructed[%d] = %v, want 0", i, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResidualDotCorrectionDeterministicAndFinite(t *testing.T) {
|
||||
rotated := []float32{1.5, -2.5, 0.5, 3.0, -1.25, 2.25, 0.75, -0.5}
|
||||
approx := []float32{1.0, -2.0, 0.25, 2.5, -1.0, 2.0, 0.5, -0.25}
|
||||
sketch := encodeResidual(rotated, approx, PresetTQ3, 789)
|
||||
query := []float32{0.2, -0.1, 0.3, 0.4, -0.2, 0.5, -0.6, 0.7}
|
||||
|
||||
a := residualDotCorrection(query, sketch)
|
||||
b := residualDotCorrection(query, sketch)
|
||||
if a != b {
|
||||
t.Fatalf("dot correction = %v and %v, want deterministic output", a, b)
|
||||
}
|
||||
if a != a {
|
||||
t.Fatal("dot correction returned NaN")
|
||||
}
|
||||
}
|
||||
211
turboquant/rotation.go
Normal file
211
turboquant/rotation.go
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Rotation struct {
|
||||
Dim int
|
||||
Seed uint64
|
||||
Matrix []float32 // row-major orthogonal matrix
|
||||
}
|
||||
|
||||
type rotationCacheKey struct {
|
||||
dim int
|
||||
seed uint64
|
||||
}
|
||||
|
||||
var rotationCache sync.Map
|
||||
|
||||
func BuildRotation(dim int, seed uint64) Rotation {
|
||||
key := rotationCacheKey{dim: dim, seed: seed}
|
||||
if cached, ok := rotationCache.Load(key); ok {
|
||||
return cached.(Rotation)
|
||||
}
|
||||
|
||||
rot := Rotation{
|
||||
Dim: dim,
|
||||
Seed: seed,
|
||||
Matrix: buildOrthogonalMatrix(dim, seed),
|
||||
}
|
||||
actual, _ := rotationCache.LoadOrStore(key, rot)
|
||||
return actual.(Rotation)
|
||||
}
|
||||
|
||||
func ApplyRotation(x []float32, rot Rotation) []float32 {
|
||||
if len(x) != rot.Dim {
|
||||
panic("turboquant: vector length does not match rotation dimension")
|
||||
}
|
||||
|
||||
out := make([]float32, rot.Dim)
|
||||
for row := range rot.Dim {
|
||||
base := row * rot.Dim
|
||||
var sum float32
|
||||
for col, value := range x {
|
||||
sum += rot.Matrix[base+col] * value
|
||||
}
|
||||
out[row] = sum
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ApplyInverseRotation(y []float32, rot Rotation) []float32 {
|
||||
if len(y) != rot.Dim {
|
||||
panic("turboquant: vector length does not match rotation dimension")
|
||||
}
|
||||
|
||||
out := make([]float32, rot.Dim)
|
||||
for row := range rot.Dim {
|
||||
yVal := y[row]
|
||||
base := row * rot.Dim
|
||||
// The inverse rotation accumulates in float32 on the reconstruction path; keep this FP32-accumulate behavior explicit while long-generation corruption audits remain active.
|
||||
for col := range rot.Dim {
|
||||
out[col] += rot.Matrix[base+col] * yVal
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// buildOrthogonalMatrix returns a dim×dim orthogonal matrix derived from the
|
||||
// given seed using Householder QR factorisation. The input is the same seeded
|
||||
// Gaussian matrix used by the previous Gram-Schmidt path, but the QR Q-factor
|
||||
// is numerically unconditionally orthogonal. This algorithm replaced the
|
||||
// classical Gram-Schmidt used through BlockVersion 3; the Householder path was
|
||||
// introduced at BlockVersion 4 (current: BlockVersion 6).
|
||||
//
|
||||
// Algorithm: apply dim-1 Householder reflectors H_1…H_{dim-1} from the left
|
||||
// to reduce A to upper triangular form R. Simultaneously accumulate
|
||||
// Q = H_1 * H_2 * … * H_{dim-1} by right-multiplying each reflector into Q,
|
||||
// starting from the identity. The resulting Q is the orthogonal factor in
|
||||
// A = QR and has orthonormal rows and columns.
|
||||
func buildOrthogonalMatrix(dim int, seed uint64) []float32 {
|
||||
if dim <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialise A with the same seeded Gaussian rows as before.
|
||||
a := make([][]float64, dim)
|
||||
for row := range dim {
|
||||
a[row] = make([]float64, dim)
|
||||
rng := splitmix64(seed ^ uint64(dim)<<32 ^ uint64(row+1)*0x9e3779b97f4a7c15)
|
||||
for col := range dim {
|
||||
a[row][col] = gaussianFloat64(&rng)
|
||||
}
|
||||
}
|
||||
|
||||
// Q starts as the identity; we accumulate Q = H_1 * … * H_{dim-1}.
|
||||
q := make([][]float64, dim)
|
||||
for i := range q {
|
||||
q[i] = make([]float64, dim)
|
||||
q[i][i] = 1.0
|
||||
}
|
||||
|
||||
v := make([]float64, dim) // scratch buffer for the Householder vector
|
||||
|
||||
for k := range dim - 1 {
|
||||
n := dim - k
|
||||
|
||||
// Build Householder vector for column k, rows k..dim-1.
|
||||
// Sign chosen to avoid cancellation: sigma = −sign(a[k][k]) * ||x||.
|
||||
for i := range n {
|
||||
v[i] = a[k+i][k]
|
||||
}
|
||||
sigma := vectorNorm64(v[:n])
|
||||
if v[0] >= 0 {
|
||||
sigma = -sigma
|
||||
}
|
||||
v[0] -= sigma
|
||||
vnorm2 := dotFloat64(v[:n], v[:n])
|
||||
if vnorm2 < 1e-28 {
|
||||
continue // column already zeroed — no reflector needed
|
||||
}
|
||||
beta := 2.0 / vnorm2
|
||||
|
||||
// Apply H_k to a[k:, k:] from the left.
|
||||
for j := k; j < dim; j++ {
|
||||
var dot float64
|
||||
for i := range n {
|
||||
dot += v[i] * a[k+i][j]
|
||||
}
|
||||
dot *= beta
|
||||
for i := range n {
|
||||
a[k+i][j] -= dot * v[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Apply H_k to q[:, k:] from the right: q ← q * H_k.
|
||||
for i := range dim {
|
||||
var dot float64
|
||||
for j := range n {
|
||||
dot += q[i][k+j] * v[j]
|
||||
}
|
||||
dot *= beta
|
||||
for j := range n {
|
||||
q[i][k+j] -= dot * v[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sign-normalise each row so the first non-negligible element is positive.
|
||||
// This convention matches the previous Gram-Schmidt path and makes the
|
||||
// output deterministic despite the reflector sign ambiguity.
|
||||
for row := range dim {
|
||||
for _, value := range q[row] {
|
||||
if math.Abs(value) <= 1e-12 {
|
||||
continue
|
||||
}
|
||||
if value < 0 {
|
||||
for col := range dim {
|
||||
q[row][col] = -q[row][col]
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]float32, dim*dim)
|
||||
for row := range dim {
|
||||
for col := range dim {
|
||||
out[row*dim+col] = float32(q[row][col])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func dotFloat64(a, b []float64) float64 {
|
||||
var out float64
|
||||
for i := range a {
|
||||
out += a[i] * b[i]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func vectorNorm64(values []float64) float64 {
|
||||
var sum float64
|
||||
for _, value := range values {
|
||||
sum += value * value
|
||||
}
|
||||
return math.Sqrt(sum)
|
||||
}
|
||||
|
||||
func gaussianFloat64(rng *splitmix64) float64 {
|
||||
u1 := unitUniform64(rng)
|
||||
u2 := unitUniform64(rng)
|
||||
return math.Sqrt(-2*math.Log(u1)) * math.Cos(2*math.Pi*u2)
|
||||
}
|
||||
|
||||
func unitUniform64(rng *splitmix64) float64 {
|
||||
const scale = 1.0 / (1 << 53)
|
||||
return (float64(rng.next()>>11) + 0.5) * scale
|
||||
}
|
||||
|
||||
type splitmix64 uint64
|
||||
|
||||
func (s *splitmix64) next() uint64 {
|
||||
*s += 0x9e3779b97f4a7c15
|
||||
z := uint64(*s)
|
||||
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9
|
||||
z = (z ^ (z >> 27)) * 0x94d049bb133111eb
|
||||
return z ^ (z >> 31)
|
||||
}
|
||||
189
turboquant/rotation_test.go
Normal file
189
turboquant/rotation_test.go
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// testDims covers small, medium, and the primary production dims.
|
||||
var testDims = []int{4, 8, 16, 32, 64, 128}
|
||||
|
||||
func TestBuildRotationDeterministic(t *testing.T) {
|
||||
for _, dim := range testDims {
|
||||
a := BuildRotation(dim, 123)
|
||||
b := BuildRotation(dim, 123)
|
||||
if a.Dim != b.Dim || a.Seed != b.Seed || len(a.Matrix) != len(b.Matrix) {
|
||||
t.Fatalf("rotation metadata mismatch for dim %d", dim)
|
||||
}
|
||||
for i := range a.Matrix {
|
||||
if a.Matrix[i] != b.Matrix[i] {
|
||||
t.Fatalf("rotation mismatch at dim=%d idx=%d", dim, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRotationDifferentSeedsDiffer(t *testing.T) {
|
||||
a := BuildRotation(16, 111)
|
||||
b := BuildRotation(16, 222)
|
||||
different := false
|
||||
for i := range a.Matrix {
|
||||
if a.Matrix[i] != b.Matrix[i] {
|
||||
different = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !different {
|
||||
t.Fatal("different seeds produced the same orthogonal matrix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyInverseRotation(t *testing.T) {
|
||||
for _, dim := range testDims {
|
||||
values := pseudoRandomVector(dim, uint64(dim)*17)
|
||||
rot := BuildRotation(dim, uint64(dim)*19)
|
||||
got := ApplyInverseRotation(ApplyRotation(values, rot), rot)
|
||||
for i := range values {
|
||||
if abs32(values[i]-got[i]) > 1e-4 {
|
||||
t.Fatalf("dim=%d idx=%d got=%v want=%v", dim, i, got[i], values[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotationPreservesNorm(t *testing.T) {
|
||||
for _, dim := range testDims {
|
||||
values := pseudoRandomVector(dim, uint64(dim)*23)
|
||||
rot := BuildRotation(dim, uint64(dim)*29)
|
||||
before := vectorNorm(values)
|
||||
after := vectorNorm(ApplyRotation(values, rot))
|
||||
if abs32(before-after) > 1e-3 {
|
||||
t.Fatalf("dim=%d norm drift=%v", dim, abs32(before-after))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAttentionScoreRotationInvariance verifies the exact mathematical
|
||||
// invariant: dot(Q,K) == dot(R@Q, R@K) for an orthogonal matrix R. Because R
|
||||
// is orthogonal, R^T R = I, so the inner product is preserved exactly (up to
|
||||
// floating-point rounding). This is the property that makes rotating K before
|
||||
// quantization safe when Q is also rotated at attention time.
|
||||
func TestAttentionScoreRotationInvariance(t *testing.T) {
|
||||
for _, dim := range []int{64, 128, 256} {
|
||||
t.Run(fmt.Sprintf("dim=%d", dim), func(t *testing.T) {
|
||||
seed := uint64(0x42c0ffee)
|
||||
q := pseudoRandomVector(dim, seed)
|
||||
k := pseudoRandomVector(dim, seed^0xbeef)
|
||||
|
||||
rot := BuildRotation(dim, seed+1)
|
||||
qRot := ApplyRotation(q, rot)
|
||||
kRot := ApplyRotation(k, rot)
|
||||
|
||||
var dotOrig float64
|
||||
for i := range q {
|
||||
dotOrig += float64(q[i]) * float64(k[i])
|
||||
}
|
||||
|
||||
var dotRot float64
|
||||
for i := range qRot {
|
||||
dotRot += float64(qRot[i]) * float64(kRot[i])
|
||||
}
|
||||
|
||||
relErr := math.Abs(dotOrig-dotRot) / (math.Abs(dotOrig) + 1e-10)
|
||||
if relErr > 1e-4 {
|
||||
t.Errorf("dim=%d: dot(Q,K)=%.6f dot(RQ,RK)=%.6f relErr=%.2e",
|
||||
dim, dotOrig, dotRot, relErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestQuantizedAttentionScorePreservation verifies the practical end-to-end
|
||||
// path: encode K per-head (which stores R@k), rotate Q (giving R@q), then
|
||||
// compute (R@q)·(R@k_quant) and compare against the true dot(Q,K). Single
|
||||
// trials can have high per-sample error from quantization noise, so we average
|
||||
// over many trials and check the mean relative error, matching the pattern used
|
||||
// by TestEncodeKeyPerHeadRoundTrip.
|
||||
func TestQuantizedAttentionScorePreservation(t *testing.T) {
|
||||
const dim = 128
|
||||
const trials = 100
|
||||
|
||||
for _, preset := range []Preset{PresetTQ2, PresetTQ3} {
|
||||
t.Run(preset.Name, func(t *testing.T) {
|
||||
rng := splitmix64(0x1111feed)
|
||||
rot := BuildRotation(dim, preset.RotationSeed)
|
||||
var totalAbsErr, totalAbsDot float64
|
||||
|
||||
for trial := range trials {
|
||||
q := make([]float32, dim)
|
||||
k := make([]float32, dim)
|
||||
for j := range q {
|
||||
q[j] = float32(gaussianFloat64(&rng))
|
||||
k[j] = float32(gaussianFloat64(&rng))
|
||||
}
|
||||
|
||||
// True attention score.
|
||||
var trueScore float64
|
||||
for i := range q {
|
||||
trueScore += float64(q[i]) * float64(k[i])
|
||||
}
|
||||
|
||||
// Quantized path: encode K in rotated space, rotate Q, compute score.
|
||||
packed, scale, err := EncodeKeyPerHead(k, preset)
|
||||
if err != nil {
|
||||
t.Fatalf("trial %d EncodeKeyPerHead: %v", trial, err)
|
||||
}
|
||||
kRecon := DequantKeyPerHead(packed, scale, dim, preset.KeyPrimaryBits)
|
||||
qRot := ApplyRotation(q, rot)
|
||||
|
||||
var quantScore float64
|
||||
for i := range qRot {
|
||||
quantScore += float64(qRot[i]) * float64(kRecon[i])
|
||||
}
|
||||
|
||||
totalAbsErr += math.Abs(trueScore - quantScore)
|
||||
totalAbsDot += math.Abs(trueScore)
|
||||
}
|
||||
|
||||
avgRelErr := totalAbsErr / (totalAbsDot + 1e-10)
|
||||
t.Logf("preset=%s avg relative dot error = %.4f over %d trials", preset.Name, avgRelErr, trials)
|
||||
|
||||
// tq3 (3-bit) is tighter than tq2 (2-bit); these thresholds match the
|
||||
// codec quality validated by TestEncodeKeyPerHeadRoundTrip.
|
||||
maxRelErr := 0.45
|
||||
if preset.KeyPrimaryBits >= 3 {
|
||||
maxRelErr = 0.25
|
||||
}
|
||||
if avgRelErr > maxRelErr {
|
||||
t.Errorf("preset=%s avg relative dot error = %.4f, want <= %.4f",
|
||||
preset.Name, avgRelErr, maxRelErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildRotationIsOrthogonal verifies that Q satisfies Q*Q^T = I (rows are
|
||||
// orthonormal). This is the core invariant required by the TurboQuant encoding
|
||||
// and is guaranteed unconditionally by the Householder QR algorithm.
|
||||
func TestBuildRotationIsOrthogonal(t *testing.T) {
|
||||
// Include dim=256 to exercise the algorithm well beyond typical head_dim.
|
||||
for _, dim := range append(testDims, 256) {
|
||||
rot := BuildRotation(dim, uint64(dim)*31)
|
||||
for i := range dim {
|
||||
for j := i; j < dim; j++ {
|
||||
var dot float32
|
||||
for k := range dim {
|
||||
dot += rot.Matrix[i*dim+k] * rot.Matrix[j*dim+k]
|
||||
}
|
||||
if i == j {
|
||||
if math.Abs(float64(dot-1)) > 5e-5 {
|
||||
t.Fatalf("dim=%d row %d: self-dot=%.6f, want 1.0", dim, i, dot)
|
||||
}
|
||||
} else if math.Abs(float64(dot)) > 5e-5 {
|
||||
t.Fatalf("dim=%d rows %d,%d: cross-dot=%.6f, want 0.0", dim, i, j, dot)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
36
turboquant/stats.go
Normal file
36
turboquant/stats.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
package turboquant
|
||||
|
||||
import "math"
|
||||
|
||||
type Stats struct {
|
||||
MSE float32
|
||||
RMSE float32
|
||||
MeanAbsErr float32
|
||||
MaxAbsErr float32
|
||||
}
|
||||
|
||||
func Compare(reference, approx []float32) Stats {
|
||||
var s Stats
|
||||
if len(reference) == 0 || len(reference) != len(approx) {
|
||||
return s
|
||||
}
|
||||
for i := range reference {
|
||||
err := reference[i] - approx[i]
|
||||
s.MSE += err * err
|
||||
s.MeanAbsErr += abs32(err)
|
||||
if abs32(err) > s.MaxAbsErr {
|
||||
s.MaxAbsErr = abs32(err)
|
||||
}
|
||||
}
|
||||
s.MSE /= float32(len(reference))
|
||||
s.RMSE = float32(math.Sqrt(float64(s.MSE)))
|
||||
s.MeanAbsErr /= float32(len(reference))
|
||||
return s
|
||||
}
|
||||
|
||||
func abs32(v float32) float32 {
|
||||
if v < 0 {
|
||||
return -v
|
||||
}
|
||||
return v
|
||||
}
|
||||
41
turboquant/stats_test.go
Normal file
41
turboquant/stats_test.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompareIdenticalVectors(t *testing.T) {
|
||||
values := []float32{1, -2, 3, -4}
|
||||
stats := Compare(values, values)
|
||||
if stats.MSE != 0 || stats.RMSE != 0 || stats.MeanAbsErr != 0 || stats.MaxAbsErr != 0 {
|
||||
t.Fatalf("unexpected non-zero stats: %+v", stats)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareKnownExample(t *testing.T) {
|
||||
reference := []float32{1, 2, 3}
|
||||
approx := []float32{2, 0, 3}
|
||||
stats := Compare(reference, approx)
|
||||
|
||||
if abs32(stats.MSE-float32(5.0/3.0)) > 1e-6 {
|
||||
t.Fatalf("MSE = %v, want %v", stats.MSE, float32(5.0/3.0))
|
||||
}
|
||||
if abs32(stats.MeanAbsErr-1) > 1e-6 {
|
||||
t.Fatalf("MeanAbsErr = %v, want 1", stats.MeanAbsErr)
|
||||
}
|
||||
if abs32(stats.MaxAbsErr-2) > 1e-6 {
|
||||
t.Fatalf("MaxAbsErr = %v, want 2", stats.MaxAbsErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareRMSEMatchesSqrtMSE(t *testing.T) {
|
||||
reference := []float32{1, 2, 3}
|
||||
approx := []float32{2, 0, 3}
|
||||
stats := Compare(reference, approx)
|
||||
|
||||
want := float32(math.Sqrt(float64(stats.MSE)))
|
||||
if abs32(stats.RMSE-want) > 1e-6 {
|
||||
t.Fatalf("RMSE = %v, want %v", stats.RMSE, want)
|
||||
}
|
||||
}
|
||||
66
turboquant/test_helpers_test.go
Normal file
66
turboquant/test_helpers_test.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package turboquant
|
||||
|
||||
import "math"
|
||||
|
||||
type namedVector struct {
|
||||
name string
|
||||
values []float32
|
||||
}
|
||||
|
||||
func deterministicCorpus() []namedVector {
|
||||
return []namedVector{
|
||||
{name: "ramp-16", values: []float32{-4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}},
|
||||
{name: "alternating-16", values: []float32{1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8}},
|
||||
{name: "sparse-17", values: []float32{0, 0, 0, 7, 0, 0, -5, 0, 0, 0, 9, 0, 0, 0, 0, -3, 0}},
|
||||
{name: "constant-31", values: filledVector(31, 1.5)},
|
||||
{name: "zero-64", values: filledVector(64, 0)},
|
||||
{name: "random-65", values: pseudoRandomVector(65, 0x1234abcd)},
|
||||
}
|
||||
}
|
||||
|
||||
func pseudoRandomVector(n int, seed uint64) []float32 {
|
||||
rng := splitmix64(seed)
|
||||
out := make([]float32, n)
|
||||
for i := range out {
|
||||
u := float32(rng.next()&0xffff) / 65535
|
||||
out[i] = (u * 10) - 5
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func filledVector(n int, value float32) []float32 {
|
||||
out := make([]float32, n)
|
||||
for i := range out {
|
||||
out[i] = value
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func vectorNorm(values []float32) float32 {
|
||||
var sum float64
|
||||
for _, v := range values {
|
||||
sum += float64(v) * float64(v)
|
||||
}
|
||||
return float32(math.Sqrt(sum))
|
||||
}
|
||||
|
||||
func meanMSEForPreset(preset Preset) (float32, error) {
|
||||
var total float32
|
||||
corpus := deterministicCorpus()
|
||||
for _, tc := range corpus {
|
||||
encoded, err := EncodeVector(tc.values, preset)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
data, err := encoded.MarshalBinary()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
decoded, _, err := DecodeVector(data)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total += Compare(tc.values, decoded).MSE
|
||||
}
|
||||
return total / float32(len(corpus)), nil
|
||||
}
|
||||
136
turboquant/turboquant.go
Normal file
136
turboquant/turboquant.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
package turboquant
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const BlockVersion = 6
|
||||
|
||||
type vectorRole uint8
|
||||
|
||||
const (
|
||||
roleGeneric vectorRole = iota
|
||||
roleKey
|
||||
roleValue
|
||||
)
|
||||
|
||||
type vectorObjective uint8
|
||||
|
||||
const (
|
||||
objectiveMSE vectorObjective = iota + 1
|
||||
objectiveProduct
|
||||
)
|
||||
|
||||
type Preset struct {
|
||||
ID uint8
|
||||
Name string
|
||||
RotationSeed uint64
|
||||
KeyPrimaryBits int
|
||||
ValueBits int
|
||||
QJLRowsDivisor int
|
||||
OutlierBits int
|
||||
OutlierCount int
|
||||
}
|
||||
|
||||
var (
|
||||
// All four tq* presets ship with OutlierCount=0 (pure uniform Lloyd-Max
|
||||
// after Householder QR rotation, i.e. the core of TurboQuant Algorithm 1
|
||||
// §3.1 without the optional outlier split from §4.3 or the QJL residual
|
||||
// sketch from Algorithm 2). The uniform defaults were chosen after
|
||||
// measuring that on the models this fork ships against (llama, gemma3,
|
||||
// qwen3-coder), outlier split hurts both decode throughput and PPL — the
|
||||
// paper's split targets heavy-tailed rotated K distributions, which these
|
||||
// models don't exhibit, and the extra metadata (92 vs 52 bytes/head/cell
|
||||
// at oc=32) translates to ~25% decode regression on 3B-class models at
|
||||
// short context. Keeping the defaults symmetric across tq2 / tq3 / tq2k /
|
||||
// tq3k means the digit in the preset name maps directly to effective
|
||||
// bits/elem: "tq3" is exactly 3 bits, not the 3.25 bits you'd get under
|
||||
// outlier split with oc=32.
|
||||
//
|
||||
// The outlier-split kernel path remains in the code (encode/dequant
|
||||
// dispatchers check op_params[2..3] and route to the outlier variant when
|
||||
// both are non-zero). A future dynamic-dispatch PR can enable it per
|
||||
// model / per env var — e.g. for the qwen2 family once Phase 2A
|
||||
// asymmetric quantization lands, or for models with larger headDim where
|
||||
// the metadata overhead amortizes better. See project_tq_backlog.md for
|
||||
// the planned dynamic-oc dispatch work.
|
||||
|
||||
// tq2: 2-bit K + 2-bit V, both rotated and Lloyd-Max quantized. Highest
|
||||
// compression tier — effective 2 bits/elem both sides.
|
||||
PresetTQ2 = newPreset(1, "tq2", 2, 2, 1, 0x25c0ffee, 3, 0)
|
||||
|
||||
// tq3: 3-bit K + 3-bit V, both rotated and Lloyd-Max quantized. Default
|
||||
// "balanced" tier — effective 3 bits/elem both sides.
|
||||
PresetTQ3 = newPreset(2, "tq3", 3, 3, 1, 0x35c0ffee, 4, 0)
|
||||
|
||||
// tq3k: 3-bit K only, V stays as f16. ~40% KV VRAM savings with near-f16
|
||||
// decode (no V dequant at all). ValueBits=0 signals K-only mode to the
|
||||
// kvcache layer.
|
||||
PresetTQ3K = newPreset(3, "tq3k", 3, 0, 1, 0x35c0ffee, 4, 0)
|
||||
|
||||
// tq2k: 2-bit K only, V stays as f16. Maximum K compression with f16 V;
|
||||
// smallest K footprint before PPL degrades too much. ValueBits=0 signals
|
||||
// K-only mode to the kvcache layer.
|
||||
PresetTQ2K = newPreset(4, "tq2k", 2, 0, 1, 0x25c0ffee, 3, 0)
|
||||
)
|
||||
|
||||
func newPreset(id uint8, name string, keyBits int, valueBits int, qjlRowsDivisor int, seed uint64, outlierBits int, outlierCount int) Preset {
|
||||
return Preset{
|
||||
ID: id,
|
||||
Name: name,
|
||||
RotationSeed: seed,
|
||||
KeyPrimaryBits: keyBits,
|
||||
ValueBits: valueBits,
|
||||
QJLRowsDivisor: qjlRowsDivisor,
|
||||
OutlierBits: outlierBits,
|
||||
OutlierCount: outlierCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (p Preset) HasOutlierSplit() bool {
|
||||
return p.OutlierBits > 0 && p.OutlierCount > 0
|
||||
}
|
||||
|
||||
func PresetByName(name string) (Preset, error) {
|
||||
switch name {
|
||||
case "tq2":
|
||||
return PresetTQ2, nil
|
||||
case "tq3":
|
||||
return PresetTQ3, nil
|
||||
case "tq3k":
|
||||
return PresetTQ3K, nil
|
||||
case "tq2k":
|
||||
return PresetTQ2K, nil
|
||||
default:
|
||||
return Preset{}, fmt.Errorf("unknown turboquant preset %q", name)
|
||||
}
|
||||
}
|
||||
|
||||
func PresetByID(id uint8) (Preset, error) {
|
||||
switch id {
|
||||
case PresetTQ2.ID:
|
||||
return PresetTQ2, nil
|
||||
case PresetTQ3.ID:
|
||||
return PresetTQ3, nil
|
||||
case PresetTQ3K.ID:
|
||||
return PresetTQ3K, nil
|
||||
case PresetTQ2K.ID:
|
||||
return PresetTQ2K, nil
|
||||
default:
|
||||
return Preset{}, fmt.Errorf("unknown turboquant preset id %d", id)
|
||||
}
|
||||
}
|
||||
|
||||
func (p Preset) KeyQJLRows(dim int) int {
|
||||
if dim <= 0 {
|
||||
return 0
|
||||
}
|
||||
if p.QJLRowsDivisor <= 0 {
|
||||
return 0
|
||||
}
|
||||
rows := dim / p.QJLRowsDivisor
|
||||
if rows < 1 {
|
||||
return 1
|
||||
}
|
||||
return rows
|
||||
}
|
||||
Loading…
Reference in a new issue