ml/backend/ggml: add D=256 TurboQuant flash-attention kernels on Metal

The fused inline-decode kernels were only instantiated at D=128. gemma3
runs at headDim=256 and was therefore forced onto the DequantKV + stock
FA path — which materialises a ~3.5 GB f16 intermediate K+V tensor at
32k context and OOMs on Apple Silicon with tq3.

Adds D=256 variants of the two fused kernels:

  * kernel_tq_fattn_vec_f16_d256
  * kernel_tq_fattn_vec_packed_d256

Thread layout is unchanged (32x4 = 128 threads); each thread now covers
32 D-positions in the Q and K loops, accumulates four V-passes, and
writes two output elements in the final phase. Shared memory grows
from 2048 to 4096 floats (16 KiB, well under the 32 KiB per-threadgroup
limit).

The dispatch picks D=128 or D=256 based on Q->ne[0]. Go-side eligibility
extends to headDim==256 only when preferFusedAttention is true (Metal);
CUDA continues to gate on headDim==128 until the CUDA kernel gains the
same instantiation.
This commit is contained in:
Michael Verrilli 2026-04-21 11:12:23 +00:00
parent 76e5fc2b75
commit 0c1f7f108d
No known key found for this signature in database
GPG key ID: E4F2103B6C63B961
8 changed files with 1843 additions and 24 deletions

View file

@ -38,12 +38,14 @@ type TurboQuantCache struct {
logPathOnce [5]sync.Once
// fusedFallbackEligible gates the inline-decode fused-FA fallback paths
// (Get paths 2 and 4). Those paths dispatch to a CUDA kernel that is
// template-instantiated only at D=128, so any model with a larger head
// dim (gemma3 D=256, gemma4 D=512) must skip them to avoid a kernel-side
// GGML_ASSERT. The DequantK + stock FA path (Get path 0/1/5) works at
// any head dim — this gate is specific to the inline-decode variants.
// Remove once the fused kernels gain D=256/512 template instantiations.
// (Get paths 2 and 4). The CUDA fused kernel is template-instantiated only
// at D=128, so models with a larger head dim (gemma4 D=512) must skip it
// to avoid a kernel-side GGML_ASSERT. The Metal fused kernel has both
// D=128 and D=256 variants (kernel_tq_fattn_vec_*{,_d256}), so gemma3
// D=256 is eligible on Metal but not on CUDA until the CUDA kernel gains
// a D=256 instantiation. The DequantK + stock FA path (Get paths 0/1/5)
// works at any head dim — this gate is specific to the inline-decode
// variants.
fusedFallbackEligible bool
// preferFusedAttn is true on Metal. At long context, DequantKV + stock FA
@ -479,17 +481,6 @@ func (c *TurboQuantCache) activateGPUEncode() {
}
c.compressedK = mgr
// The inline-decode fused-FA fallback paths (Get paths 2 and 4) dispatch
// to a CUDA kernel template instantiated only at D=128. Models with a
// larger head dim (e.g. gemma4 global layers at headDim=512) must skip
// those fallbacks to avoid a kernel-side GGML_ASSERT; path 5 (separate
// K+V dequant) handles them correctly.
c.fusedFallbackEligible = (c.headDim == 128)
if !c.fusedFallbackEligible {
slog.Info("turboquant: inline-decode fused-FA fallback paths disabled",
"reason", "headDim != 128", "headDim", c.headDim)
}
type fusedAttnPreferrer interface {
PreferFusedAttention() bool
}
@ -500,6 +491,23 @@ func (c *TurboQuantCache) activateGPUEncode() {
}
}
// The inline-decode fused-FA fallback paths (Get paths 2 and 4) dispatch
// to a kernel that is D-specialised. CUDA has only D=128 today; Metal has
// D=128 and D=256 (kernel_tq_fattn_vec_*{,_d256}). Models with an
// unsupported head dim (e.g. gemma4 D=512, or gemma3 D=256 on CUDA) must
// skip these fallbacks to avoid a kernel-side GGML_ASSERT; path 5
// (separate K+V dequant) handles them correctly.
c.fusedFallbackEligible = c.headDim == 128 ||
(c.headDim == 256 && c.preferFusedAttn)
if !c.fusedFallbackEligible {
reason := "headDim != 128"
if c.headDim == 256 {
reason = "headDim == 256 but backend lacks D=256 fused kernel"
}
slog.Info("turboquant: inline-decode fused-FA fallback paths disabled",
"reason", reason, "headDim", c.headDim)
}
// Cache the rotation matrices and the backend's rotation-setter hook on
// TurboQuantCache so Get() can arm them per-call without re-running a
// type assertion every layer. We do NOT set them at activate time — a

View file

@ -0,0 +1,667 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Verrilli <msv@pobox.com>
Date: Tue, 21 Apr 2026 21:02:21 +0000
Subject: [PATCH] ml/backend/ggml: add D=256 TurboQuant flash-attention kernels
on Metal
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The fused inline-decode kernels were only instantiated at D=128. gemma3
runs at headDim=256 and was therefore forced onto the DequantKV + stock
FA path — which materialises a ~3.5 GB f16 intermediate K+V tensor at
32k context and OOMs on Apple Silicon with tq3.
Adds D=256 variants of the two fused kernels:
* kernel_tq_fattn_vec_f16_d256
* kernel_tq_fattn_vec_packed_d256
Thread layout is unchanged (32x4 = 128 threads); each thread now covers
32 D-positions in the Q and K loops, accumulates four V-passes, and
writes two output elements in the final phase. Shared memory grows
from 2048 to 4096 floats (16 KiB, well under the 32 KiB per-threadgroup
limit).
The dispatch picks D=128 or D=256 based on Q->ne[0]. Go-side eligibility
extends to headDim==256 only when preferFusedAttention is true (Metal);
CUDA continues to gate on headDim==128 until the CUDA kernel gains the
same instantiation.
---
ggml/src/ggml-metal/ggml-metal-device.cpp | 6 +-
ggml/src/ggml-metal/ggml-metal-device.h | 6 +-
ggml/src/ggml-metal/ggml-metal-ops.cpp | 11 +-
ggml/src/ggml-metal/ggml-metal.metal | 562 ++++++++++++++++++++++
4 files changed, 579 insertions(+), 6 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index 2686d2d30..c99fafe00 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -1720,5 +1720,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_v"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_outlier"); }
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16_d256"); }
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed_d256"); }
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
index aadd82659..fb45cbbfd 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.h
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
@@ -194,8 +194,10 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib);
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib);
// MTLResidencySet wrapper
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index 7abc0eaac..b5ab1c14e 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -4602,9 +4602,16 @@ int ggml_metal_op_tq_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb31 =*/ mask ? mask->nb[1] : 0,
};
+ // Select D=128 vs D=256 pipeline. Gemma3 runs at headDim=256; everything
+ // else supported so far is D=128.
+ GGML_ASSERT(D == 128 || D == 256);
auto pipeline = v_packed
- ? ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib)
- : ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib);
+ ? (D == 256
+ ? ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(lib)
+ : ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib))
+ : (D == 256
+ ? ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256(lib)
+ : ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib));
ggml_metal_buffer_id bid_mask = hasMask ? ggml_metal_get_buffer_id(mask) : ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_v_scales = v_packed ? ggml_metal_get_buffer_id(op->src[6]) : ggml_metal_get_buffer_id(op);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index d0bc6bf9e..2718e8bb1 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -11184,6 +11184,293 @@ kernel void kernel_tq_fattn_vec_f16(
}
}
+// ─────────────────────────────────────────────────────────────────────────────
+// kernel_tq_fattn_vec_f16_d256
+// TQ fused flash-attention at head dim 256: K packed i8, V f16.
+// Thread layout identical to the D=128 variant (32×4 = 128 threads) but each
+// thread now produces 2 output elements (D/nthreads = 2) and covers twice as
+// many D positions in the Q/K/V loops.
+// Grid: (ntiles_x, 1, nHeadsQ*nSeq)
+// ─────────────────────────────────────────────────────────────────────────────
+kernel void kernel_tq_fattn_vec_f16_d256(
+ constant ggml_metal_kargs_tq_fattn_vec & args,
+ device const char * Q_data [[buffer(1)]],
+ device const uint8_t * K_packed [[buffer(2)]],
+ device const half * V_data [[buffer(3)]],
+ device const half * mask_data [[buffer(4)]],
+ device const float * K_scales [[buffer(5)]],
+ device const float * K_cb [[buffer(6)]],
+ device const float * dummy_vs [[buffer(7)]],
+ device const float * dummy_vc [[buffer(8)]],
+ device float * dst [[buffer(9)]],
+ uint3 tgpig [[threadgroup_position_in_grid]],
+ uint tiisg [[thread_index_in_simdgroup]],
+ uint sgitg [[simdgroup_index_in_threadgroup]])
+{
+ constexpr int D = 256;
+ constexpr int nthreads = 128;
+ constexpr int nthreads_KQ = 8;
+ constexpr int nthreads_V = 8;
+ constexpr int V_cols_per_iter = 4;
+ constexpr int nwarps = 4;
+
+ const int ic0 = (int)tgpig.x * args.ncols;
+ const int blk_z = (int)tgpig.z;
+ const int sequence = blk_z / args.nHeadsQ;
+ const int head = blk_z % args.nHeadsQ;
+ const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
+ const int head_kv = head / gqa_ratio;
+
+ const int tid = (int)sgitg * 32 + (int)tiisg;
+
+ device const float * Q = (device const float *)Q_data
+ + (long)sequence * (args.nb03 / sizeof(float))
+ + (long)head * (args.nb02 / sizeof(float))
+ + (long)ic0 * (args.nb01 / sizeof(float));
+
+ device const uint8_t * K_p = K_packed
+ + (long)args.firstCell * args.nKVHeads * args.packedBytes
+ + (long)head_kv * args.packedBytes;
+ device const float * K_sc = K_scales
+ + (long)args.firstCell * args.nKVHeads + head_kv;
+
+ device const half * V = V_data
+ + (long)sequence * (args.nb23 / sizeof(half))
+ + (long)head_kv * (args.nb22 / sizeof(half));
+
+ device const half * maskh = args.hasMask
+ ? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
+ : nullptr;
+
+ const int k_cb_mask = (1 << args.bits) - 1;
+ const float k_cb_lane = K_cb[tiisg & k_cb_mask];
+
+ const int tid_kq = (int)tiisg % nthreads_KQ;
+
+ // D=256: Q_reg holds 16 float2 per thread per query slot (D/(2*nthreads_KQ)).
+ float2 Q_reg[2][16];
+ for (int j = 0; j < args.ncols; j++) {
+ device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
+ for (int i = 0; i < 16; i++) {
+ const int elem = tid_kq * 16 + i; // float2 index within [0, 127]
+ Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
+ }
+ for (int i = 0; i < 16; i++) {
+ Q_reg[j][i].x *= args.scale;
+ Q_reg[j][i].y *= args.scale;
+ }
+ }
+
+ // D=256: VKQ holds 4 passes × 4 float2 = 16 float2 per query slot.
+ float2 VKQ[2][16];
+ for (int j = 0; j < 2; j++)
+ for (int i = 0; i < 16; i++)
+ VKQ[j][i] = float2(0.0f, 0.0f);
+
+ float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
+ float KQ_sum[2] = { 0.0f, 0.0f };
+
+ // D=256: KQ_tg sized nwarps*V_cols_per_iter*D = 4*4*256 = 4096 floats (16 KiB).
+ threadgroup float KQ_tg[4096];
+ threadgroup float KQ_max_tg[2][32];
+ threadgroup float KQ_sum_tg[2][32];
+
+ for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
+
+ float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
+
+ for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
+ const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
+ const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
+ const int cell_rel = k_VKQ_0 + i_KQ;
+ const bool in_range = (cell_rel < args.nCells);
+
+ for (int j = 0; j < args.ncols; j++) {
+ device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
+ const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
+
+ // D=256: 16 k-iterations × 2 elements each = 32 D-positions per thread.
+ float sum = 0.0f;
+ for (int k = 0; k < 16; k++) {
+ const int start_elem = tid_kq * 32 + k * 2; // float index [0..254]
+ float k_dec[2];
+ if (args.bits == 3) {
+ const int bit_pos0 = start_elem * 3;
+ const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
+ const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
+ int idx0 = (int)((w0 >> sh0) & 7);
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)idx0) * rms_scale;
+ const int bit_pos1 = (start_elem + 1) * 3;
+ const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
+ const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
+ int idx1 = (int)((w1 >> sh1) & 7);
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)idx1) * rms_scale;
+ } else {
+ const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
+ const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
+ }
+ sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
+ }
+ sum += simd_shuffle_xor(sum, 4);
+ sum += simd_shuffle_xor(sum, 2);
+ sum += simd_shuffle_xor(sum, 1);
+
+ if (args.logit_softcap != 0.0f) {
+ sum = args.logit_softcap * tanh(sum);
+ }
+
+ if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
+ sum += float(maskh[(long)j * args.ne31 + i_KQ]);
+ }
+
+ if (!in_range) sum = -FLT_MAX/2.0f;
+
+ KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
+
+ if (tid_kq == (uint)i_KQ_0) {
+ KQ_tg[j * nthreads + tid] = sum;
+ }
+ }
+ }
+
+ for (int j = 0; j < args.ncols; j++) {
+ KQ_max_new[j] = simd_max(KQ_max_new[j]);
+
+ const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
+ KQ_max[j] = KQ_max_new[j];
+
+ const float kq_val = KQ_tg[j * nthreads + tid];
+ const float kq_exp = exp(kq_val - KQ_max[j]);
+ KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
+ KQ_tg[j * nthreads + tid] = kq_exp;
+
+ for (int i = 0; i < 16; i++) {
+ VKQ[j][i].x *= KQ_max_scale;
+ VKQ[j][i].y *= KQ_max_scale;
+ }
+ }
+
+ // D=256: 4 passes × 64 V-elements each = full 256-element coverage.
+ // pass 0 → V[ 0.. 63] → VKQ[ 0.. 3]
+ // pass 1 → V[ 64..127] → VKQ[ 4.. 7]
+ // pass 2 → V[128..191] → VKQ[ 8..11]
+ // pass 3 → V[192..255] → VKQ[12..15]
+ for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
+ const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
+ const int cell_rel = k_VKQ_0 + k;
+
+ float KQ_k[2];
+ for (int j = 0; j < args.ncols; j++) {
+ KQ_k[j] = KQ_tg[j * nthreads + k];
+ }
+
+ device const half * V_cell = (cell_rel < args.nCells)
+ ? V + (long)cell_rel * (args.nb21 / sizeof(half))
+ : nullptr;
+
+ const int v_tid = (int)tiisg % nthreads_V;
+ for (int pass = 0; pass < 4; pass++) {
+ for (int i = 0; i < 8; i++) {
+ const int elem = pass * 64 + v_tid * 8 + i;
+ float v_val = (V_cell && elem < D) ? float(V_cell[elem]) : 0.0f;
+ const int vkq_idx = pass * 4 + i / 2;
+ for (int j = 0; j < args.ncols; j++) {
+ if (i % 2 == 0) VKQ[j][vkq_idx].x += v_val * KQ_k[j];
+ else VKQ[j][vkq_idx].y += v_val * KQ_k[j];
+ }
+ }
+ }
+ }
+ } // end KV loop
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (sgitg == 0) {
+ KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
+ KQ_sum_tg[j][tiisg] = 0.0f;
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (tiisg == 0) {
+ KQ_max_tg[j][sgitg] = KQ_max[j];
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
+
+ float kqmax_new = KQ_max_tg[j][tiisg];
+ kqmax_new = simd_max(kqmax_new);
+ const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
+ KQ_max[j] = kqmax_new;
+
+ for (int i = 0; i < 16; i++) {
+ VKQ[j][i].x *= kqmax_scale;
+ VKQ[j][i].y *= kqmax_scale;
+ }
+
+ // D=256: VKQ_tg layout per (sgitg,v-group) region of D/2 = 128 float2:
+ // VKQ[ 0.. 3] → VKQ_tg[v_tid*4 + 0..3] (pass 0, float2 slots 0..31)
+ // VKQ[ 4.. 7] → VKQ_tg[32 + v_tid*4 + 0..3] (pass 1, float2 slots 32..63)
+ // VKQ[ 8..11] → VKQ_tg[64 + v_tid*4 + 0..3] (pass 2, float2 slots 64..95)
+ // VKQ[12..15] → VKQ_tg[96 + v_tid*4 + 0..3] (pass 3, float2 slots 96..127)
+ const int v_tid = (int)tiisg % nthreads_V;
+ threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
+ + (long)sgitg * (V_cols_per_iter * D/2)
+ + (long)((int)tiisg / nthreads_V) * (D/2);
+ VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
+ VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
+ VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
+ VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
+ VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
+ VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
+ VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
+ VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
+ VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
+ VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
+ VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
+ VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
+ VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
+ VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
+ VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
+ VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
+
+ KQ_sum[j] *= kqmax_scale;
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
+ if (tiisg == 0) {
+ KQ_sum_tg[j][sgitg] = KQ_sum[j];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // D=256: each thread writes 2 output positions (tid, tid+128). Compute
+ // KQ_sum with ALL lanes participating in the simd_sum (required for
+ // convergence), then two writes per thread.
+ KQ_sum[j] = KQ_sum_tg[j][tiisg];
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
+
+ const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
+ for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
+ const int out_elem = out_offset + tid;
+ float dst_val = 0.0f;
+ for (int w = 0; w < nwarps; w++) {
+ for (int v = 0; v < V_cols_per_iter; v++) {
+ dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
+ }
+ }
+ dst_val /= KQ_sum[j];
+ dst[out_idx * D + out_elem] = dst_val;
+ }
+
+ if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+}
+
// ─────────────────────────────────────────────────────────────────────────────
// kernel_tq_fattn_vec_packed
// TQ fused flash-attention: K packed i8, V packed i8.
@@ -11451,3 +11738,278 @@ kernel void kernel_tq_fattn_vec_packed(
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
+
+// ─────────────────────────────────────────────────────────────────────────────
+// kernel_tq_fattn_vec_packed_d256
+// TQ fused flash-attention at head dim 256: K packed i8, V packed i8.
+// Mirrors kernel_tq_fattn_vec_packed with 4 V-passes and 2 outputs per thread.
+// ─────────────────────────────────────────────────────────────────────────────
+kernel void kernel_tq_fattn_vec_packed_d256(
+ constant ggml_metal_kargs_tq_fattn_vec & args,
+ device const char * Q_data [[buffer(1)]],
+ device const uint8_t * K_packed [[buffer(2)]],
+ device const uint8_t * V_packed [[buffer(3)]],
+ device const half * mask_data [[buffer(4)]],
+ device const float * K_scales [[buffer(5)]],
+ device const float * K_cb [[buffer(6)]],
+ device const float * V_scales [[buffer(7)]],
+ device const float * V_cb [[buffer(8)]],
+ device float * dst [[buffer(9)]],
+ uint3 tgpig [[threadgroup_position_in_grid]],
+ uint tiisg [[thread_index_in_simdgroup]],
+ uint sgitg [[simdgroup_index_in_threadgroup]])
+{
+ constexpr int D = 256;
+ constexpr int nthreads = 128;
+ constexpr int nthreads_KQ = 8;
+ constexpr int nthreads_V = 8;
+ constexpr int V_cols_per_iter = 4;
+ constexpr int nwarps = 4;
+
+ const int ic0 = (int)tgpig.x * args.ncols;
+ const int blk_z = (int)tgpig.z;
+ const int sequence = blk_z / args.nHeadsQ;
+ const int head = blk_z % args.nHeadsQ;
+ const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
+ const int head_kv = head / gqa_ratio;
+
+ const int tid = (int)sgitg * 32 + (int)tiisg;
+
+ device const float * Q = (device const float *)Q_data
+ + (long)sequence * (args.nb03 / sizeof(float))
+ + (long)head * (args.nb02 / sizeof(float))
+ + (long)ic0 * (args.nb01 / sizeof(float));
+
+ device const uint8_t * K_p = K_packed
+ + (long)args.firstCell * args.nKVHeads * args.packedBytes
+ + (long)head_kv * args.packedBytes;
+ device const float * K_sc = K_scales
+ + (long)args.firstCell * args.nKVHeads + head_kv;
+
+ device const uint8_t * V_p = V_packed
+ + (long)args.firstCell * args.nKVHeads * args.v_packedBytes
+ + (long)head_kv * args.v_packedBytes;
+ device const float * V_sc = V_scales
+ + (long)args.firstCell * args.nKVHeads + head_kv;
+
+ device const half * maskh = args.hasMask
+ ? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
+ : nullptr;
+
+ const int k_cb_mask = (1 << args.bits) - 1;
+ const float k_cb_lane = K_cb[tiisg & k_cb_mask];
+ const int v_cb_mask = (1 << args.v_bits) - 1;
+ const float v_cb_lane = V_cb[tiisg & v_cb_mask];
+
+ const int tid_kq = (int)tiisg % nthreads_KQ;
+
+ float2 Q_reg[2][16];
+ for (int j = 0; j < args.ncols; j++) {
+ device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
+ for (int i = 0; i < 16; i++) {
+ const int elem = tid_kq * 16 + i;
+ Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
+ }
+ for (int i = 0; i < 16; i++) {
+ Q_reg[j][i].x *= args.scale;
+ Q_reg[j][i].y *= args.scale;
+ }
+ }
+
+ float2 VKQ[2][16];
+ for (int j = 0; j < 2; j++)
+ for (int i = 0; i < 16; i++)
+ VKQ[j][i] = float2(0.0f, 0.0f);
+
+ float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
+ float KQ_sum[2] = { 0.0f, 0.0f };
+
+ threadgroup float KQ_tg[4096];
+ threadgroup float KQ_max_tg[2][32];
+ threadgroup float KQ_sum_tg[2][32];
+
+ for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
+
+ float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
+
+ for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
+ const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
+ const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
+ const int cell_rel = k_VKQ_0 + i_KQ;
+ const bool in_range = (cell_rel < args.nCells);
+
+ for (int j = 0; j < args.ncols; j++) {
+ device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
+ const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
+
+ float sum = 0.0f;
+ for (int k = 0; k < 16; k++) {
+ const int start_elem = tid_kq * 32 + k * 2;
+ float k_dec[2];
+ if (args.bits == 3) {
+ const int bit_pos0 = start_elem * 3;
+ const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
+ const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((w0 >> sh0) & 7)) * rms_scale;
+ const int bit_pos1 = (start_elem + 1) * 3;
+ const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
+ const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((w1 >> sh1) & 7)) * rms_scale;
+ } else {
+ const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
+ const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
+ }
+ sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
+ }
+ sum += simd_shuffle_xor(sum, 4);
+ sum += simd_shuffle_xor(sum, 2);
+ sum += simd_shuffle_xor(sum, 1);
+
+ if (args.logit_softcap != 0.0f) {
+ sum = args.logit_softcap * tanh(sum);
+ }
+ if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
+ sum += float(maskh[(long)j * args.ne31 + i_KQ]);
+ }
+ if (!in_range) sum = -FLT_MAX/2.0f;
+
+ KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
+
+ if (tid_kq == (uint)i_KQ_0) {
+ KQ_tg[j * nthreads + tid] = sum;
+ }
+ }
+ }
+
+ for (int j = 0; j < args.ncols; j++) {
+ KQ_max_new[j] = simd_max(KQ_max_new[j]);
+
+ const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
+ KQ_max[j] = KQ_max_new[j];
+
+ const float kq_val = KQ_tg[j * nthreads + tid];
+ const float kq_exp = exp(kq_val - KQ_max[j]);
+ KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
+ KQ_tg[j * nthreads + tid] = kq_exp;
+
+ for (int i = 0; i < 16; i++) {
+ VKQ[j][i].x *= KQ_max_scale;
+ VKQ[j][i].y *= KQ_max_scale;
+ }
+ }
+
+ // D=256: 4 passes cover V[0..63], V[64..127], V[128..191], V[192..255].
+ for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
+ const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
+ const int cell_rel = k_VKQ_0 + k;
+ const bool in_range_v = (cell_rel < args.nCells);
+
+ float KQ_k[2];
+ for (int j = 0; j < args.ncols; j++) {
+ KQ_k[j] = KQ_tg[j * nthreads + k];
+ }
+
+ device const uint8_t * v_row = in_range_v
+ ? V_p + (long)cell_rel * args.nKVHeads * args.v_packedBytes
+ : nullptr;
+ const float v_rms = in_range_v ? V_sc[cell_rel * args.nKVHeads] : 0.0f;
+
+ const int v_tid = (int)tiisg % nthreads_V;
+ for (int pass = 0; pass < 4; pass++) {
+ float v_dec[8];
+ const int start_elem = pass * 64 + v_tid * 8;
+ if (v_row && start_elem < D) {
+ tq_decode_8_shfl(v_row, v_cb_lane, v_rms, start_elem, args.v_bits, v_dec);
+ } else {
+ tq_decode_8_shfl(K_p, v_cb_lane, 0.0f, 0, args.v_bits, v_dec);
+ }
+ for (int i = 0; i < 8; i++) {
+ const int vkq_idx = pass * 4 + i / 2;
+ for (int j = 0; j < args.ncols; j++) {
+ if (i % 2 == 0) VKQ[j][vkq_idx].x += v_dec[i] * KQ_k[j];
+ else VKQ[j][vkq_idx].y += v_dec[i] * KQ_k[j];
+ }
+ }
+ }
+ }
+ } // end KV loop
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (sgitg == 0) {
+ KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
+ KQ_sum_tg[j][tiisg] = 0.0f;
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (tiisg == 0) {
+ KQ_max_tg[j][sgitg] = KQ_max[j];
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int j = 0; j < args.ncols; j++) {
+ if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
+
+ float kqmax_new = KQ_max_tg[j][tiisg];
+ kqmax_new = simd_max(kqmax_new);
+ const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
+ KQ_max[j] = kqmax_new;
+
+ for (int i = 0; i < 16; i++) {
+ VKQ[j][i].x *= kqmax_scale;
+ VKQ[j][i].y *= kqmax_scale;
+ }
+
+ const int v_tid = (int)tiisg % nthreads_V;
+ threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
+ + (long)sgitg * (V_cols_per_iter * D/2)
+ + (long)((int)tiisg / nthreads_V) * (D/2);
+ VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
+ VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
+ VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
+ VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
+ VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
+ VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
+ VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
+ VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
+ VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
+ VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
+ VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
+ VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
+ VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
+ VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
+ VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
+ VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
+
+ KQ_sum[j] *= kqmax_scale;
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
+ if (tiisg == 0) {
+ KQ_sum_tg[j][sgitg] = KQ_sum[j];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ KQ_sum[j] = KQ_sum_tg[j][tiisg];
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
+
+ const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
+ for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
+ const int out_elem = out_offset + tid;
+ float dst_val = 0.0f;
+ for (int w = 0; w < nwarps; w++) {
+ for (int v = 0; v < V_cols_per_iter; v++) {
+ dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
+ }
+ }
+ dst_val /= KQ_sum[j];
+ dst[out_idx * D + out_elem] = dst_val;
+ }
+
+ if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+}

View file

@ -1720,5 +1720,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_v"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_outlier"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16_d256"); }
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed_d256"); }

View file

@ -194,8 +194,10 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib);
// MTLResidencySet wrapper

View file

@ -14084,6 +14084,293 @@ kernel void kernel_tq_fattn_vec_f16(
}
}
// ─────────────────────────────────────────────────────────────────────────────
// kernel_tq_fattn_vec_f16_d256
// TQ fused flash-attention at head dim 256: K packed i8, V f16.
// Thread layout identical to the D=128 variant (32×4 = 128 threads) but each
// thread now produces 2 output elements (D/nthreads = 2) and covers twice as
// many D positions in the Q/K/V loops.
// Grid: (ntiles_x, 1, nHeadsQ*nSeq)
// ─────────────────────────────────────────────────────────────────────────────
kernel void kernel_tq_fattn_vec_f16_d256(
constant ggml_metal_kargs_tq_fattn_vec & args,
device const char * Q_data [[buffer(1)]],
device const uint8_t * K_packed [[buffer(2)]],
device const half * V_data [[buffer(3)]],
device const half * mask_data [[buffer(4)]],
device const float * K_scales [[buffer(5)]],
device const float * K_cb [[buffer(6)]],
device const float * dummy_vs [[buffer(7)]],
device const float * dummy_vc [[buffer(8)]],
device float * dst [[buffer(9)]],
uint3 tgpig [[threadgroup_position_in_grid]],
uint tiisg [[thread_index_in_simdgroup]],
uint sgitg [[simdgroup_index_in_threadgroup]])
{
constexpr int D = 256;
constexpr int nthreads = 128;
constexpr int nthreads_KQ = 8;
constexpr int nthreads_V = 8;
constexpr int V_cols_per_iter = 4;
constexpr int nwarps = 4;
const int ic0 = (int)tgpig.x * args.ncols;
const int blk_z = (int)tgpig.z;
const int sequence = blk_z / args.nHeadsQ;
const int head = blk_z % args.nHeadsQ;
const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
const int head_kv = head / gqa_ratio;
const int tid = (int)sgitg * 32 + (int)tiisg;
device const float * Q = (device const float *)Q_data
+ (long)sequence * (args.nb03 / sizeof(float))
+ (long)head * (args.nb02 / sizeof(float))
+ (long)ic0 * (args.nb01 / sizeof(float));
device const uint8_t * K_p = K_packed
+ (long)args.firstCell * args.nKVHeads * args.packedBytes
+ (long)head_kv * args.packedBytes;
device const float * K_sc = K_scales
+ (long)args.firstCell * args.nKVHeads + head_kv;
device const half * V = V_data
+ (long)sequence * (args.nb23 / sizeof(half))
+ (long)head_kv * (args.nb22 / sizeof(half));
device const half * maskh = args.hasMask
? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
: nullptr;
const int k_cb_mask = (1 << args.bits) - 1;
const float k_cb_lane = K_cb[tiisg & k_cb_mask];
const int tid_kq = (int)tiisg % nthreads_KQ;
// D=256: Q_reg holds 16 float2 per thread per query slot (D/(2*nthreads_KQ)).
float2 Q_reg[2][16];
for (int j = 0; j < args.ncols; j++) {
device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
for (int i = 0; i < 16; i++) {
const int elem = tid_kq * 16 + i; // float2 index within [0, 127]
Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
}
for (int i = 0; i < 16; i++) {
Q_reg[j][i].x *= args.scale;
Q_reg[j][i].y *= args.scale;
}
}
// D=256: VKQ holds 4 passes × 4 float2 = 16 float2 per query slot.
float2 VKQ[2][16];
for (int j = 0; j < 2; j++)
for (int i = 0; i < 16; i++)
VKQ[j][i] = float2(0.0f, 0.0f);
float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
float KQ_sum[2] = { 0.0f, 0.0f };
// D=256: KQ_tg sized nwarps*V_cols_per_iter*D = 4*4*256 = 4096 floats (16 KiB).
threadgroup float KQ_tg[4096];
threadgroup float KQ_max_tg[2][32];
threadgroup float KQ_sum_tg[2][32];
for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
const int cell_rel = k_VKQ_0 + i_KQ;
const bool in_range = (cell_rel < args.nCells);
for (int j = 0; j < args.ncols; j++) {
device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
// D=256: 16 k-iterations × 2 elements each = 32 D-positions per thread.
float sum = 0.0f;
for (int k = 0; k < 16; k++) {
const int start_elem = tid_kq * 32 + k * 2; // float index [0..254]
float k_dec[2];
if (args.bits == 3) {
const int bit_pos0 = start_elem * 3;
const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
int idx0 = (int)((w0 >> sh0) & 7);
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)idx0) * rms_scale;
const int bit_pos1 = (start_elem + 1) * 3;
const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
int idx1 = (int)((w1 >> sh1) & 7);
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)idx1) * rms_scale;
} else {
const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
}
sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
}
sum += simd_shuffle_xor(sum, 4);
sum += simd_shuffle_xor(sum, 2);
sum += simd_shuffle_xor(sum, 1);
if (args.logit_softcap != 0.0f) {
sum = args.logit_softcap * tanh(sum);
}
if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
sum += float(maskh[(long)j * args.ne31 + i_KQ]);
}
if (!in_range) sum = -FLT_MAX/2.0f;
KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
if (tid_kq == (uint)i_KQ_0) {
KQ_tg[j * nthreads + tid] = sum;
}
}
}
for (int j = 0; j < args.ncols; j++) {
KQ_max_new[j] = simd_max(KQ_max_new[j]);
const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
KQ_max[j] = KQ_max_new[j];
const float kq_val = KQ_tg[j * nthreads + tid];
const float kq_exp = exp(kq_val - KQ_max[j]);
KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
KQ_tg[j * nthreads + tid] = kq_exp;
for (int i = 0; i < 16; i++) {
VKQ[j][i].x *= KQ_max_scale;
VKQ[j][i].y *= KQ_max_scale;
}
}
// D=256: 4 passes × 64 V-elements each = full 256-element coverage.
// pass 0 → V[ 0.. 63] → VKQ[ 0.. 3]
// pass 1 → V[ 64..127] → VKQ[ 4.. 7]
// pass 2 → V[128..191] → VKQ[ 8..11]
// pass 3 → V[192..255] → VKQ[12..15]
for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
const int cell_rel = k_VKQ_0 + k;
float KQ_k[2];
for (int j = 0; j < args.ncols; j++) {
KQ_k[j] = KQ_tg[j * nthreads + k];
}
device const half * V_cell = (cell_rel < args.nCells)
? V + (long)cell_rel * (args.nb21 / sizeof(half))
: nullptr;
const int v_tid = (int)tiisg % nthreads_V;
for (int pass = 0; pass < 4; pass++) {
for (int i = 0; i < 8; i++) {
const int elem = pass * 64 + v_tid * 8 + i;
float v_val = (V_cell && elem < D) ? float(V_cell[elem]) : 0.0f;
const int vkq_idx = pass * 4 + i / 2;
for (int j = 0; j < args.ncols; j++) {
if (i % 2 == 0) VKQ[j][vkq_idx].x += v_val * KQ_k[j];
else VKQ[j][vkq_idx].y += v_val * KQ_k[j];
}
}
}
}
} // end KV loop
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (sgitg == 0) {
KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
KQ_sum_tg[j][tiisg] = 0.0f;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (tiisg == 0) {
KQ_max_tg[j][sgitg] = KQ_max[j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
float kqmax_new = KQ_max_tg[j][tiisg];
kqmax_new = simd_max(kqmax_new);
const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
KQ_max[j] = kqmax_new;
for (int i = 0; i < 16; i++) {
VKQ[j][i].x *= kqmax_scale;
VKQ[j][i].y *= kqmax_scale;
}
// D=256: VKQ_tg layout per (sgitg,v-group) region of D/2 = 128 float2:
// VKQ[ 0.. 3] → VKQ_tg[v_tid*4 + 0..3] (pass 0, float2 slots 0..31)
// VKQ[ 4.. 7] → VKQ_tg[32 + v_tid*4 + 0..3] (pass 1, float2 slots 32..63)
// VKQ[ 8..11] → VKQ_tg[64 + v_tid*4 + 0..3] (pass 2, float2 slots 64..95)
// VKQ[12..15] → VKQ_tg[96 + v_tid*4 + 0..3] (pass 3, float2 slots 96..127)
const int v_tid = (int)tiisg % nthreads_V;
threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
+ (long)sgitg * (V_cols_per_iter * D/2)
+ (long)((int)tiisg / nthreads_V) * (D/2);
VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
KQ_sum[j] *= kqmax_scale;
KQ_sum[j] = simd_sum(KQ_sum[j]);
if (tiisg == 0) {
KQ_sum_tg[j][sgitg] = KQ_sum[j];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// D=256: each thread writes 2 output positions (tid, tid+128). Compute
// KQ_sum with ALL lanes participating in the simd_sum (required for
// convergence), then two writes per thread.
KQ_sum[j] = KQ_sum_tg[j][tiisg];
KQ_sum[j] = simd_sum(KQ_sum[j]);
const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
const int out_elem = out_offset + tid;
float dst_val = 0.0f;
for (int w = 0; w < nwarps; w++) {
for (int v = 0; v < V_cols_per_iter; v++) {
dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
}
}
dst_val /= KQ_sum[j];
dst[out_idx * D + out_elem] = dst_val;
}
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// ─────────────────────────────────────────────────────────────────────────────
// kernel_tq_fattn_vec_packed
// TQ fused flash-attention: K packed i8, V packed i8.
@ -14351,3 +14638,278 @@ kernel void kernel_tq_fattn_vec_packed(
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// ─────────────────────────────────────────────────────────────────────────────
// kernel_tq_fattn_vec_packed_d256
// TQ fused flash-attention at head dim 256: K packed i8, V packed i8.
// Mirrors kernel_tq_fattn_vec_packed with 4 V-passes and 2 outputs per thread.
// ─────────────────────────────────────────────────────────────────────────────
kernel void kernel_tq_fattn_vec_packed_d256(
constant ggml_metal_kargs_tq_fattn_vec & args,
device const char * Q_data [[buffer(1)]],
device const uint8_t * K_packed [[buffer(2)]],
device const uint8_t * V_packed [[buffer(3)]],
device const half * mask_data [[buffer(4)]],
device const float * K_scales [[buffer(5)]],
device const float * K_cb [[buffer(6)]],
device const float * V_scales [[buffer(7)]],
device const float * V_cb [[buffer(8)]],
device float * dst [[buffer(9)]],
uint3 tgpig [[threadgroup_position_in_grid]],
uint tiisg [[thread_index_in_simdgroup]],
uint sgitg [[simdgroup_index_in_threadgroup]])
{
constexpr int D = 256;
constexpr int nthreads = 128;
constexpr int nthreads_KQ = 8;
constexpr int nthreads_V = 8;
constexpr int V_cols_per_iter = 4;
constexpr int nwarps = 4;
const int ic0 = (int)tgpig.x * args.ncols;
const int blk_z = (int)tgpig.z;
const int sequence = blk_z / args.nHeadsQ;
const int head = blk_z % args.nHeadsQ;
const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
const int head_kv = head / gqa_ratio;
const int tid = (int)sgitg * 32 + (int)tiisg;
device const float * Q = (device const float *)Q_data
+ (long)sequence * (args.nb03 / sizeof(float))
+ (long)head * (args.nb02 / sizeof(float))
+ (long)ic0 * (args.nb01 / sizeof(float));
device const uint8_t * K_p = K_packed
+ (long)args.firstCell * args.nKVHeads * args.packedBytes
+ (long)head_kv * args.packedBytes;
device const float * K_sc = K_scales
+ (long)args.firstCell * args.nKVHeads + head_kv;
device const uint8_t * V_p = V_packed
+ (long)args.firstCell * args.nKVHeads * args.v_packedBytes
+ (long)head_kv * args.v_packedBytes;
device const float * V_sc = V_scales
+ (long)args.firstCell * args.nKVHeads + head_kv;
device const half * maskh = args.hasMask
? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
: nullptr;
const int k_cb_mask = (1 << args.bits) - 1;
const float k_cb_lane = K_cb[tiisg & k_cb_mask];
const int v_cb_mask = (1 << args.v_bits) - 1;
const float v_cb_lane = V_cb[tiisg & v_cb_mask];
const int tid_kq = (int)tiisg % nthreads_KQ;
float2 Q_reg[2][16];
for (int j = 0; j < args.ncols; j++) {
device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
for (int i = 0; i < 16; i++) {
const int elem = tid_kq * 16 + i;
Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
}
for (int i = 0; i < 16; i++) {
Q_reg[j][i].x *= args.scale;
Q_reg[j][i].y *= args.scale;
}
}
float2 VKQ[2][16];
for (int j = 0; j < 2; j++)
for (int i = 0; i < 16; i++)
VKQ[j][i] = float2(0.0f, 0.0f);
float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
float KQ_sum[2] = { 0.0f, 0.0f };
threadgroup float KQ_tg[4096];
threadgroup float KQ_max_tg[2][32];
threadgroup float KQ_sum_tg[2][32];
for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
const int cell_rel = k_VKQ_0 + i_KQ;
const bool in_range = (cell_rel < args.nCells);
for (int j = 0; j < args.ncols; j++) {
device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
float sum = 0.0f;
for (int k = 0; k < 16; k++) {
const int start_elem = tid_kq * 32 + k * 2;
float k_dec[2];
if (args.bits == 3) {
const int bit_pos0 = start_elem * 3;
const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((w0 >> sh0) & 7)) * rms_scale;
const int bit_pos1 = (start_elem + 1) * 3;
const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((w1 >> sh1) & 7)) * rms_scale;
} else {
const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
}
sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
}
sum += simd_shuffle_xor(sum, 4);
sum += simd_shuffle_xor(sum, 2);
sum += simd_shuffle_xor(sum, 1);
if (args.logit_softcap != 0.0f) {
sum = args.logit_softcap * tanh(sum);
}
if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
sum += float(maskh[(long)j * args.ne31 + i_KQ]);
}
if (!in_range) sum = -FLT_MAX/2.0f;
KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
if (tid_kq == (uint)i_KQ_0) {
KQ_tg[j * nthreads + tid] = sum;
}
}
}
for (int j = 0; j < args.ncols; j++) {
KQ_max_new[j] = simd_max(KQ_max_new[j]);
const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
KQ_max[j] = KQ_max_new[j];
const float kq_val = KQ_tg[j * nthreads + tid];
const float kq_exp = exp(kq_val - KQ_max[j]);
KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
KQ_tg[j * nthreads + tid] = kq_exp;
for (int i = 0; i < 16; i++) {
VKQ[j][i].x *= KQ_max_scale;
VKQ[j][i].y *= KQ_max_scale;
}
}
// D=256: 4 passes cover V[0..63], V[64..127], V[128..191], V[192..255].
for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
const int cell_rel = k_VKQ_0 + k;
const bool in_range_v = (cell_rel < args.nCells);
float KQ_k[2];
for (int j = 0; j < args.ncols; j++) {
KQ_k[j] = KQ_tg[j * nthreads + k];
}
device const uint8_t * v_row = in_range_v
? V_p + (long)cell_rel * args.nKVHeads * args.v_packedBytes
: nullptr;
const float v_rms = in_range_v ? V_sc[cell_rel * args.nKVHeads] : 0.0f;
const int v_tid = (int)tiisg % nthreads_V;
for (int pass = 0; pass < 4; pass++) {
float v_dec[8];
const int start_elem = pass * 64 + v_tid * 8;
if (v_row && start_elem < D) {
tq_decode_8_shfl(v_row, v_cb_lane, v_rms, start_elem, args.v_bits, v_dec);
} else {
tq_decode_8_shfl(K_p, v_cb_lane, 0.0f, 0, args.v_bits, v_dec);
}
for (int i = 0; i < 8; i++) {
const int vkq_idx = pass * 4 + i / 2;
for (int j = 0; j < args.ncols; j++) {
if (i % 2 == 0) VKQ[j][vkq_idx].x += v_dec[i] * KQ_k[j];
else VKQ[j][vkq_idx].y += v_dec[i] * KQ_k[j];
}
}
}
}
} // end KV loop
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (sgitg == 0) {
KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
KQ_sum_tg[j][tiisg] = 0.0f;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (tiisg == 0) {
KQ_max_tg[j][sgitg] = KQ_max[j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
float kqmax_new = KQ_max_tg[j][tiisg];
kqmax_new = simd_max(kqmax_new);
const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
KQ_max[j] = kqmax_new;
for (int i = 0; i < 16; i++) {
VKQ[j][i].x *= kqmax_scale;
VKQ[j][i].y *= kqmax_scale;
}
const int v_tid = (int)tiisg % nthreads_V;
threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
+ (long)sgitg * (V_cols_per_iter * D/2)
+ (long)((int)tiisg / nthreads_V) * (D/2);
VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
KQ_sum[j] *= kqmax_scale;
KQ_sum[j] = simd_sum(KQ_sum[j]);
if (tiisg == 0) {
KQ_sum_tg[j][sgitg] = KQ_sum[j];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
KQ_sum[j] = KQ_sum_tg[j][tiisg];
KQ_sum[j] = simd_sum(KQ_sum[j]);
const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
const int out_elem = out_offset + tid;
float dst_val = 0.0f;
for (int w = 0; w < nwarps; w++) {
for (int v = 0; v < V_cols_per_iter; v++) {
dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
}
}
dst_val /= KQ_sum[j];
dst[out_idx * D + out_elem] = dst_val;
}
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
}
}

View file

@ -4602,9 +4602,16 @@ int ggml_metal_op_tq_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
/*.nb31 =*/ mask ? mask->nb[1] : 0,
};
// Select D=128 vs D=256 pipeline. Gemma3 runs at headDim=256; everything
// else supported so far is D=128.
GGML_ASSERT(D == 128 || D == 256);
auto pipeline = v_packed
? ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib)
: ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib);
? (D == 256
? ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(lib)
: ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib))
: (D == 256
? ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256(lib)
: ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib));
ggml_metal_buffer_id bid_mask = hasMask ? ggml_metal_get_buffer_id(mask) : ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_v_scales = v_packed ? ggml_metal_get_buffer_id(op->src[6]) : ggml_metal_get_buffer_id(op);

View file

@ -11184,6 +11184,293 @@ kernel void kernel_tq_fattn_vec_f16(
}
}
// ─────────────────────────────────────────────────────────────────────────────
// kernel_tq_fattn_vec_f16_d256
// TQ fused flash-attention at head dim 256: K packed i8, V f16.
// Thread layout identical to the D=128 variant (32×4 = 128 threads) but each
// thread now produces 2 output elements (D/nthreads = 2) and covers twice as
// many D positions in the Q/K/V loops.
// Grid: (ntiles_x, 1, nHeadsQ*nSeq)
// ─────────────────────────────────────────────────────────────────────────────
kernel void kernel_tq_fattn_vec_f16_d256(
constant ggml_metal_kargs_tq_fattn_vec & args,
device const char * Q_data [[buffer(1)]],
device const uint8_t * K_packed [[buffer(2)]],
device const half * V_data [[buffer(3)]],
device const half * mask_data [[buffer(4)]],
device const float * K_scales [[buffer(5)]],
device const float * K_cb [[buffer(6)]],
device const float * dummy_vs [[buffer(7)]],
device const float * dummy_vc [[buffer(8)]],
device float * dst [[buffer(9)]],
uint3 tgpig [[threadgroup_position_in_grid]],
uint tiisg [[thread_index_in_simdgroup]],
uint sgitg [[simdgroup_index_in_threadgroup]])
{
constexpr int D = 256;
constexpr int nthreads = 128;
constexpr int nthreads_KQ = 8;
constexpr int nthreads_V = 8;
constexpr int V_cols_per_iter = 4;
constexpr int nwarps = 4;
const int ic0 = (int)tgpig.x * args.ncols;
const int blk_z = (int)tgpig.z;
const int sequence = blk_z / args.nHeadsQ;
const int head = blk_z % args.nHeadsQ;
const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
const int head_kv = head / gqa_ratio;
const int tid = (int)sgitg * 32 + (int)tiisg;
device const float * Q = (device const float *)Q_data
+ (long)sequence * (args.nb03 / sizeof(float))
+ (long)head * (args.nb02 / sizeof(float))
+ (long)ic0 * (args.nb01 / sizeof(float));
device const uint8_t * K_p = K_packed
+ (long)args.firstCell * args.nKVHeads * args.packedBytes
+ (long)head_kv * args.packedBytes;
device const float * K_sc = K_scales
+ (long)args.firstCell * args.nKVHeads + head_kv;
device const half * V = V_data
+ (long)sequence * (args.nb23 / sizeof(half))
+ (long)head_kv * (args.nb22 / sizeof(half));
device const half * maskh = args.hasMask
? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
: nullptr;
const int k_cb_mask = (1 << args.bits) - 1;
const float k_cb_lane = K_cb[tiisg & k_cb_mask];
const int tid_kq = (int)tiisg % nthreads_KQ;
// D=256: Q_reg holds 16 float2 per thread per query slot (D/(2*nthreads_KQ)).
float2 Q_reg[2][16];
for (int j = 0; j < args.ncols; j++) {
device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
for (int i = 0; i < 16; i++) {
const int elem = tid_kq * 16 + i; // float2 index within [0, 127]
Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
}
for (int i = 0; i < 16; i++) {
Q_reg[j][i].x *= args.scale;
Q_reg[j][i].y *= args.scale;
}
}
// D=256: VKQ holds 4 passes × 4 float2 = 16 float2 per query slot.
float2 VKQ[2][16];
for (int j = 0; j < 2; j++)
for (int i = 0; i < 16; i++)
VKQ[j][i] = float2(0.0f, 0.0f);
float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
float KQ_sum[2] = { 0.0f, 0.0f };
// D=256: KQ_tg sized nwarps*V_cols_per_iter*D = 4*4*256 = 4096 floats (16 KiB).
threadgroup float KQ_tg[4096];
threadgroup float KQ_max_tg[2][32];
threadgroup float KQ_sum_tg[2][32];
for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
const int cell_rel = k_VKQ_0 + i_KQ;
const bool in_range = (cell_rel < args.nCells);
for (int j = 0; j < args.ncols; j++) {
device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
// D=256: 16 k-iterations × 2 elements each = 32 D-positions per thread.
float sum = 0.0f;
for (int k = 0; k < 16; k++) {
const int start_elem = tid_kq * 32 + k * 2; // float index [0..254]
float k_dec[2];
if (args.bits == 3) {
const int bit_pos0 = start_elem * 3;
const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
int idx0 = (int)((w0 >> sh0) & 7);
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)idx0) * rms_scale;
const int bit_pos1 = (start_elem + 1) * 3;
const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
int idx1 = (int)((w1 >> sh1) & 7);
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)idx1) * rms_scale;
} else {
const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
}
sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
}
sum += simd_shuffle_xor(sum, 4);
sum += simd_shuffle_xor(sum, 2);
sum += simd_shuffle_xor(sum, 1);
if (args.logit_softcap != 0.0f) {
sum = args.logit_softcap * tanh(sum);
}
if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
sum += float(maskh[(long)j * args.ne31 + i_KQ]);
}
if (!in_range) sum = -FLT_MAX/2.0f;
KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
if (tid_kq == (uint)i_KQ_0) {
KQ_tg[j * nthreads + tid] = sum;
}
}
}
for (int j = 0; j < args.ncols; j++) {
KQ_max_new[j] = simd_max(KQ_max_new[j]);
const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
KQ_max[j] = KQ_max_new[j];
const float kq_val = KQ_tg[j * nthreads + tid];
const float kq_exp = exp(kq_val - KQ_max[j]);
KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
KQ_tg[j * nthreads + tid] = kq_exp;
for (int i = 0; i < 16; i++) {
VKQ[j][i].x *= KQ_max_scale;
VKQ[j][i].y *= KQ_max_scale;
}
}
// D=256: 4 passes × 64 V-elements each = full 256-element coverage.
// pass 0 → V[ 0.. 63] → VKQ[ 0.. 3]
// pass 1 → V[ 64..127] → VKQ[ 4.. 7]
// pass 2 → V[128..191] → VKQ[ 8..11]
// pass 3 → V[192..255] → VKQ[12..15]
for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
const int cell_rel = k_VKQ_0 + k;
float KQ_k[2];
for (int j = 0; j < args.ncols; j++) {
KQ_k[j] = KQ_tg[j * nthreads + k];
}
device const half * V_cell = (cell_rel < args.nCells)
? V + (long)cell_rel * (args.nb21 / sizeof(half))
: nullptr;
const int v_tid = (int)tiisg % nthreads_V;
for (int pass = 0; pass < 4; pass++) {
for (int i = 0; i < 8; i++) {
const int elem = pass * 64 + v_tid * 8 + i;
float v_val = (V_cell && elem < D) ? float(V_cell[elem]) : 0.0f;
const int vkq_idx = pass * 4 + i / 2;
for (int j = 0; j < args.ncols; j++) {
if (i % 2 == 0) VKQ[j][vkq_idx].x += v_val * KQ_k[j];
else VKQ[j][vkq_idx].y += v_val * KQ_k[j];
}
}
}
}
} // end KV loop
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (sgitg == 0) {
KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
KQ_sum_tg[j][tiisg] = 0.0f;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (tiisg == 0) {
KQ_max_tg[j][sgitg] = KQ_max[j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
float kqmax_new = KQ_max_tg[j][tiisg];
kqmax_new = simd_max(kqmax_new);
const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
KQ_max[j] = kqmax_new;
for (int i = 0; i < 16; i++) {
VKQ[j][i].x *= kqmax_scale;
VKQ[j][i].y *= kqmax_scale;
}
// D=256: VKQ_tg layout per (sgitg,v-group) region of D/2 = 128 float2:
// VKQ[ 0.. 3] → VKQ_tg[v_tid*4 + 0..3] (pass 0, float2 slots 0..31)
// VKQ[ 4.. 7] → VKQ_tg[32 + v_tid*4 + 0..3] (pass 1, float2 slots 32..63)
// VKQ[ 8..11] → VKQ_tg[64 + v_tid*4 + 0..3] (pass 2, float2 slots 64..95)
// VKQ[12..15] → VKQ_tg[96 + v_tid*4 + 0..3] (pass 3, float2 slots 96..127)
const int v_tid = (int)tiisg % nthreads_V;
threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
+ (long)sgitg * (V_cols_per_iter * D/2)
+ (long)((int)tiisg / nthreads_V) * (D/2);
VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
KQ_sum[j] *= kqmax_scale;
KQ_sum[j] = simd_sum(KQ_sum[j]);
if (tiisg == 0) {
KQ_sum_tg[j][sgitg] = KQ_sum[j];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// D=256: each thread writes 2 output positions (tid, tid+128). Compute
// KQ_sum with ALL lanes participating in the simd_sum (required for
// convergence), then two writes per thread.
KQ_sum[j] = KQ_sum_tg[j][tiisg];
KQ_sum[j] = simd_sum(KQ_sum[j]);
const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
const int out_elem = out_offset + tid;
float dst_val = 0.0f;
for (int w = 0; w < nwarps; w++) {
for (int v = 0; v < V_cols_per_iter; v++) {
dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
}
}
dst_val /= KQ_sum[j];
dst[out_idx * D + out_elem] = dst_val;
}
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// ─────────────────────────────────────────────────────────────────────────────
// kernel_tq_fattn_vec_packed
// TQ fused flash-attention: K packed i8, V packed i8.
@ -11451,3 +11738,278 @@ kernel void kernel_tq_fattn_vec_packed(
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// ─────────────────────────────────────────────────────────────────────────────
// kernel_tq_fattn_vec_packed_d256
// TQ fused flash-attention at head dim 256: K packed i8, V packed i8.
// Mirrors kernel_tq_fattn_vec_packed with 4 V-passes and 2 outputs per thread.
// ─────────────────────────────────────────────────────────────────────────────
kernel void kernel_tq_fattn_vec_packed_d256(
constant ggml_metal_kargs_tq_fattn_vec & args,
device const char * Q_data [[buffer(1)]],
device const uint8_t * K_packed [[buffer(2)]],
device const uint8_t * V_packed [[buffer(3)]],
device const half * mask_data [[buffer(4)]],
device const float * K_scales [[buffer(5)]],
device const float * K_cb [[buffer(6)]],
device const float * V_scales [[buffer(7)]],
device const float * V_cb [[buffer(8)]],
device float * dst [[buffer(9)]],
uint3 tgpig [[threadgroup_position_in_grid]],
uint tiisg [[thread_index_in_simdgroup]],
uint sgitg [[simdgroup_index_in_threadgroup]])
{
constexpr int D = 256;
constexpr int nthreads = 128;
constexpr int nthreads_KQ = 8;
constexpr int nthreads_V = 8;
constexpr int V_cols_per_iter = 4;
constexpr int nwarps = 4;
const int ic0 = (int)tgpig.x * args.ncols;
const int blk_z = (int)tgpig.z;
const int sequence = blk_z / args.nHeadsQ;
const int head = blk_z % args.nHeadsQ;
const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
const int head_kv = head / gqa_ratio;
const int tid = (int)sgitg * 32 + (int)tiisg;
device const float * Q = (device const float *)Q_data
+ (long)sequence * (args.nb03 / sizeof(float))
+ (long)head * (args.nb02 / sizeof(float))
+ (long)ic0 * (args.nb01 / sizeof(float));
device const uint8_t * K_p = K_packed
+ (long)args.firstCell * args.nKVHeads * args.packedBytes
+ (long)head_kv * args.packedBytes;
device const float * K_sc = K_scales
+ (long)args.firstCell * args.nKVHeads + head_kv;
device const uint8_t * V_p = V_packed
+ (long)args.firstCell * args.nKVHeads * args.v_packedBytes
+ (long)head_kv * args.v_packedBytes;
device const float * V_sc = V_scales
+ (long)args.firstCell * args.nKVHeads + head_kv;
device const half * maskh = args.hasMask
? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
: nullptr;
const int k_cb_mask = (1 << args.bits) - 1;
const float k_cb_lane = K_cb[tiisg & k_cb_mask];
const int v_cb_mask = (1 << args.v_bits) - 1;
const float v_cb_lane = V_cb[tiisg & v_cb_mask];
const int tid_kq = (int)tiisg % nthreads_KQ;
float2 Q_reg[2][16];
for (int j = 0; j < args.ncols; j++) {
device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
for (int i = 0; i < 16; i++) {
const int elem = tid_kq * 16 + i;
Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
}
for (int i = 0; i < 16; i++) {
Q_reg[j][i].x *= args.scale;
Q_reg[j][i].y *= args.scale;
}
}
float2 VKQ[2][16];
for (int j = 0; j < 2; j++)
for (int i = 0; i < 16; i++)
VKQ[j][i] = float2(0.0f, 0.0f);
float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
float KQ_sum[2] = { 0.0f, 0.0f };
threadgroup float KQ_tg[4096];
threadgroup float KQ_max_tg[2][32];
threadgroup float KQ_sum_tg[2][32];
for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
const int cell_rel = k_VKQ_0 + i_KQ;
const bool in_range = (cell_rel < args.nCells);
for (int j = 0; j < args.ncols; j++) {
device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
float sum = 0.0f;
for (int k = 0; k < 16; k++) {
const int start_elem = tid_kq * 32 + k * 2;
float k_dec[2];
if (args.bits == 3) {
const int bit_pos0 = start_elem * 3;
const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((w0 >> sh0) & 7)) * rms_scale;
const int bit_pos1 = (start_elem + 1) * 3;
const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((w1 >> sh1) & 7)) * rms_scale;
} else {
const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
}
sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
}
sum += simd_shuffle_xor(sum, 4);
sum += simd_shuffle_xor(sum, 2);
sum += simd_shuffle_xor(sum, 1);
if (args.logit_softcap != 0.0f) {
sum = args.logit_softcap * tanh(sum);
}
if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
sum += float(maskh[(long)j * args.ne31 + i_KQ]);
}
if (!in_range) sum = -FLT_MAX/2.0f;
KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
if (tid_kq == (uint)i_KQ_0) {
KQ_tg[j * nthreads + tid] = sum;
}
}
}
for (int j = 0; j < args.ncols; j++) {
KQ_max_new[j] = simd_max(KQ_max_new[j]);
const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
KQ_max[j] = KQ_max_new[j];
const float kq_val = KQ_tg[j * nthreads + tid];
const float kq_exp = exp(kq_val - KQ_max[j]);
KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
KQ_tg[j * nthreads + tid] = kq_exp;
for (int i = 0; i < 16; i++) {
VKQ[j][i].x *= KQ_max_scale;
VKQ[j][i].y *= KQ_max_scale;
}
}
// D=256: 4 passes cover V[0..63], V[64..127], V[128..191], V[192..255].
for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
const int cell_rel = k_VKQ_0 + k;
const bool in_range_v = (cell_rel < args.nCells);
float KQ_k[2];
for (int j = 0; j < args.ncols; j++) {
KQ_k[j] = KQ_tg[j * nthreads + k];
}
device const uint8_t * v_row = in_range_v
? V_p + (long)cell_rel * args.nKVHeads * args.v_packedBytes
: nullptr;
const float v_rms = in_range_v ? V_sc[cell_rel * args.nKVHeads] : 0.0f;
const int v_tid = (int)tiisg % nthreads_V;
for (int pass = 0; pass < 4; pass++) {
float v_dec[8];
const int start_elem = pass * 64 + v_tid * 8;
if (v_row && start_elem < D) {
tq_decode_8_shfl(v_row, v_cb_lane, v_rms, start_elem, args.v_bits, v_dec);
} else {
tq_decode_8_shfl(K_p, v_cb_lane, 0.0f, 0, args.v_bits, v_dec);
}
for (int i = 0; i < 8; i++) {
const int vkq_idx = pass * 4 + i / 2;
for (int j = 0; j < args.ncols; j++) {
if (i % 2 == 0) VKQ[j][vkq_idx].x += v_dec[i] * KQ_k[j];
else VKQ[j][vkq_idx].y += v_dec[i] * KQ_k[j];
}
}
}
}
} // end KV loop
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (sgitg == 0) {
KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
KQ_sum_tg[j][tiisg] = 0.0f;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (tiisg == 0) {
KQ_max_tg[j][sgitg] = KQ_max[j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < args.ncols; j++) {
if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
float kqmax_new = KQ_max_tg[j][tiisg];
kqmax_new = simd_max(kqmax_new);
const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
KQ_max[j] = kqmax_new;
for (int i = 0; i < 16; i++) {
VKQ[j][i].x *= kqmax_scale;
VKQ[j][i].y *= kqmax_scale;
}
const int v_tid = (int)tiisg % nthreads_V;
threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
+ (long)sgitg * (V_cols_per_iter * D/2)
+ (long)((int)tiisg / nthreads_V) * (D/2);
VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
KQ_sum[j] *= kqmax_scale;
KQ_sum[j] = simd_sum(KQ_sum[j]);
if (tiisg == 0) {
KQ_sum_tg[j][sgitg] = KQ_sum[j];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
KQ_sum[j] = KQ_sum_tg[j][tiisg];
KQ_sum[j] = simd_sum(KQ_sum[j]);
const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
const int out_elem = out_offset + tid;
float dst_val = 0.0f;
for (int w = 0; w < nwarps; w++) {
for (int v = 0; v < V_cols_per_iter; v++) {
dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
}
}
dst_val /= KQ_sum[j];
dst[out_idx * D + out_elem] = dst_val;
}
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
}
}

View file

@ -342,7 +342,16 @@ func (m *ggmlTQCompressedK) DequantK(ctx ml.Context, layer int, encodeResult ml.
// inline-decode path is slower than DequantKV + stock FA on all measured
// hardware — DequantKV is always preferred when available.
func (m *ggmlTQCompressedK) fusedKernelSupports() bool {
if m.headDim != 128 {
// D=128 on all backends; D=256 only on Metal (kernel_tq_fattn_vec_*{,_d256}).
// CUDA still has only the D=128 kernel, so gemma3 (D=256) stays off the
// fused path on CUDA.
switch m.headDim {
case 128:
case 256:
if !m.preferFusedAttention {
return false
}
default:
return false
}
if m.bits != 2 && m.bits != 3 {