mirror of
https://github.com/ollama/ollama
synced 2026-04-23 08:45:14 +00:00
ml/backend/ggml: add D=256 TurboQuant flash-attention kernels on Metal
The fused inline-decode kernels were only instantiated at D=128. gemma3 runs at headDim=256 and was therefore forced onto the DequantKV + stock FA path — which materialises a ~3.5 GB f16 intermediate K+V tensor at 32k context and OOMs on Apple Silicon with tq3. Adds D=256 variants of the two fused kernels: * kernel_tq_fattn_vec_f16_d256 * kernel_tq_fattn_vec_packed_d256 Thread layout is unchanged (32x4 = 128 threads); each thread now covers 32 D-positions in the Q and K loops, accumulates four V-passes, and writes two output elements in the final phase. Shared memory grows from 2048 to 4096 floats (16 KiB, well under the 32 KiB per-threadgroup limit). The dispatch picks D=128 or D=256 based on Q->ne[0]. Go-side eligibility extends to headDim==256 only when preferFusedAttention is true (Metal); CUDA continues to gate on headDim==128 until the CUDA kernel gains the same instantiation.
This commit is contained in:
parent
76e5fc2b75
commit
0c1f7f108d
|
|
@ -38,12 +38,14 @@ type TurboQuantCache struct {
|
|||
logPathOnce [5]sync.Once
|
||||
|
||||
// fusedFallbackEligible gates the inline-decode fused-FA fallback paths
|
||||
// (Get paths 2 and 4). Those paths dispatch to a CUDA kernel that is
|
||||
// template-instantiated only at D=128, so any model with a larger head
|
||||
// dim (gemma3 D=256, gemma4 D=512) must skip them to avoid a kernel-side
|
||||
// GGML_ASSERT. The DequantK + stock FA path (Get path 0/1/5) works at
|
||||
// any head dim — this gate is specific to the inline-decode variants.
|
||||
// Remove once the fused kernels gain D=256/512 template instantiations.
|
||||
// (Get paths 2 and 4). The CUDA fused kernel is template-instantiated only
|
||||
// at D=128, so models with a larger head dim (gemma4 D=512) must skip it
|
||||
// to avoid a kernel-side GGML_ASSERT. The Metal fused kernel has both
|
||||
// D=128 and D=256 variants (kernel_tq_fattn_vec_*{,_d256}), so gemma3
|
||||
// D=256 is eligible on Metal but not on CUDA until the CUDA kernel gains
|
||||
// a D=256 instantiation. The DequantK + stock FA path (Get paths 0/1/5)
|
||||
// works at any head dim — this gate is specific to the inline-decode
|
||||
// variants.
|
||||
fusedFallbackEligible bool
|
||||
|
||||
// preferFusedAttn is true on Metal. At long context, DequantKV + stock FA
|
||||
|
|
@ -479,17 +481,6 @@ func (c *TurboQuantCache) activateGPUEncode() {
|
|||
}
|
||||
c.compressedK = mgr
|
||||
|
||||
// The inline-decode fused-FA fallback paths (Get paths 2 and 4) dispatch
|
||||
// to a CUDA kernel template instantiated only at D=128. Models with a
|
||||
// larger head dim (e.g. gemma4 global layers at headDim=512) must skip
|
||||
// those fallbacks to avoid a kernel-side GGML_ASSERT; path 5 (separate
|
||||
// K+V dequant) handles them correctly.
|
||||
c.fusedFallbackEligible = (c.headDim == 128)
|
||||
if !c.fusedFallbackEligible {
|
||||
slog.Info("turboquant: inline-decode fused-FA fallback paths disabled",
|
||||
"reason", "headDim != 128", "headDim", c.headDim)
|
||||
}
|
||||
|
||||
type fusedAttnPreferrer interface {
|
||||
PreferFusedAttention() bool
|
||||
}
|
||||
|
|
@ -500,6 +491,23 @@ func (c *TurboQuantCache) activateGPUEncode() {
|
|||
}
|
||||
}
|
||||
|
||||
// The inline-decode fused-FA fallback paths (Get paths 2 and 4) dispatch
|
||||
// to a kernel that is D-specialised. CUDA has only D=128 today; Metal has
|
||||
// D=128 and D=256 (kernel_tq_fattn_vec_*{,_d256}). Models with an
|
||||
// unsupported head dim (e.g. gemma4 D=512, or gemma3 D=256 on CUDA) must
|
||||
// skip these fallbacks to avoid a kernel-side GGML_ASSERT; path 5
|
||||
// (separate K+V dequant) handles them correctly.
|
||||
c.fusedFallbackEligible = c.headDim == 128 ||
|
||||
(c.headDim == 256 && c.preferFusedAttn)
|
||||
if !c.fusedFallbackEligible {
|
||||
reason := "headDim != 128"
|
||||
if c.headDim == 256 {
|
||||
reason = "headDim == 256 but backend lacks D=256 fused kernel"
|
||||
}
|
||||
slog.Info("turboquant: inline-decode fused-FA fallback paths disabled",
|
||||
"reason", reason, "headDim", c.headDim)
|
||||
}
|
||||
|
||||
// Cache the rotation matrices and the backend's rotation-setter hook on
|
||||
// TurboQuantCache so Get() can arm them per-call without re-running a
|
||||
// type assertion every layer. We do NOT set them at activate time — a
|
||||
|
|
|
|||
|
|
@ -0,0 +1,667 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Verrilli <msv@pobox.com>
|
||||
Date: Tue, 21 Apr 2026 21:02:21 +0000
|
||||
Subject: [PATCH] ml/backend/ggml: add D=256 TurboQuant flash-attention kernels
|
||||
on Metal
|
||||
MIME-Version: 1.0
|
||||
Content-Type: text/plain; charset=UTF-8
|
||||
Content-Transfer-Encoding: 8bit
|
||||
|
||||
The fused inline-decode kernels were only instantiated at D=128. gemma3
|
||||
runs at headDim=256 and was therefore forced onto the DequantKV + stock
|
||||
FA path — which materialises a ~3.5 GB f16 intermediate K+V tensor at
|
||||
32k context and OOMs on Apple Silicon with tq3.
|
||||
|
||||
Adds D=256 variants of the two fused kernels:
|
||||
|
||||
* kernel_tq_fattn_vec_f16_d256
|
||||
* kernel_tq_fattn_vec_packed_d256
|
||||
|
||||
Thread layout is unchanged (32x4 = 128 threads); each thread now covers
|
||||
32 D-positions in the Q and K loops, accumulates four V-passes, and
|
||||
writes two output elements in the final phase. Shared memory grows
|
||||
from 2048 to 4096 floats (16 KiB, well under the 32 KiB per-threadgroup
|
||||
limit).
|
||||
|
||||
The dispatch picks D=128 or D=256 based on Q->ne[0]. Go-side eligibility
|
||||
extends to headDim==256 only when preferFusedAttention is true (Metal);
|
||||
CUDA continues to gate on headDim==128 until the CUDA kernel gains the
|
||||
same instantiation.
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-device.cpp | 6 +-
|
||||
ggml/src/ggml-metal/ggml-metal-device.h | 6 +-
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 11 +-
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 562 ++++++++++++++++++++++
|
||||
4 files changed, 579 insertions(+), 6 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
index 2686d2d30..c99fafe00 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
|
||||
@@ -1720,5 +1720,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_v"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_outlier"); }
|
||||
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
|
||||
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16_d256"); }
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed_d256"); }
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
index aadd82659..fb45cbbfd 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
|
||||
@@ -194,8 +194,10 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib);
|
||||
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
|
||||
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib);
|
||||
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib);
|
||||
|
||||
// MTLResidencySet wrapper
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index 7abc0eaac..b5ab1c14e 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -4602,9 +4602,16 @@ int ggml_metal_op_tq_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb31 =*/ mask ? mask->nb[1] : 0,
|
||||
};
|
||||
|
||||
+ // Select D=128 vs D=256 pipeline. Gemma3 runs at headDim=256; everything
|
||||
+ // else supported so far is D=128.
|
||||
+ GGML_ASSERT(D == 128 || D == 256);
|
||||
auto pipeline = v_packed
|
||||
- ? ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib)
|
||||
- : ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib);
|
||||
+ ? (D == 256
|
||||
+ ? ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(lib)
|
||||
+ : ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib))
|
||||
+ : (D == 256
|
||||
+ ? ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256(lib)
|
||||
+ : ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib));
|
||||
|
||||
ggml_metal_buffer_id bid_mask = hasMask ? ggml_metal_get_buffer_id(mask) : ggml_metal_get_buffer_id(op);
|
||||
ggml_metal_buffer_id bid_v_scales = v_packed ? ggml_metal_get_buffer_id(op->src[6]) : ggml_metal_get_buffer_id(op);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index d0bc6bf9e..2718e8bb1 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -11184,6 +11184,293 @@ kernel void kernel_tq_fattn_vec_f16(
|
||||
}
|
||||
}
|
||||
|
||||
+// ─────────────────────────────────────────────────────────────────────────────
|
||||
+// kernel_tq_fattn_vec_f16_d256
|
||||
+// TQ fused flash-attention at head dim 256: K packed i8, V f16.
|
||||
+// Thread layout identical to the D=128 variant (32×4 = 128 threads) but each
|
||||
+// thread now produces 2 output elements (D/nthreads = 2) and covers twice as
|
||||
+// many D positions in the Q/K/V loops.
|
||||
+// Grid: (ntiles_x, 1, nHeadsQ*nSeq)
|
||||
+// ─────────────────────────────────────────────────────────────────────────────
|
||||
+kernel void kernel_tq_fattn_vec_f16_d256(
|
||||
+ constant ggml_metal_kargs_tq_fattn_vec & args,
|
||||
+ device const char * Q_data [[buffer(1)]],
|
||||
+ device const uint8_t * K_packed [[buffer(2)]],
|
||||
+ device const half * V_data [[buffer(3)]],
|
||||
+ device const half * mask_data [[buffer(4)]],
|
||||
+ device const float * K_scales [[buffer(5)]],
|
||||
+ device const float * K_cb [[buffer(6)]],
|
||||
+ device const float * dummy_vs [[buffer(7)]],
|
||||
+ device const float * dummy_vc [[buffer(8)]],
|
||||
+ device float * dst [[buffer(9)]],
|
||||
+ uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
+ uint tiisg [[thread_index_in_simdgroup]],
|
||||
+ uint sgitg [[simdgroup_index_in_threadgroup]])
|
||||
+{
|
||||
+ constexpr int D = 256;
|
||||
+ constexpr int nthreads = 128;
|
||||
+ constexpr int nthreads_KQ = 8;
|
||||
+ constexpr int nthreads_V = 8;
|
||||
+ constexpr int V_cols_per_iter = 4;
|
||||
+ constexpr int nwarps = 4;
|
||||
+
|
||||
+ const int ic0 = (int)tgpig.x * args.ncols;
|
||||
+ const int blk_z = (int)tgpig.z;
|
||||
+ const int sequence = blk_z / args.nHeadsQ;
|
||||
+ const int head = blk_z % args.nHeadsQ;
|
||||
+ const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
|
||||
+ const int head_kv = head / gqa_ratio;
|
||||
+
|
||||
+ const int tid = (int)sgitg * 32 + (int)tiisg;
|
||||
+
|
||||
+ device const float * Q = (device const float *)Q_data
|
||||
+ + (long)sequence * (args.nb03 / sizeof(float))
|
||||
+ + (long)head * (args.nb02 / sizeof(float))
|
||||
+ + (long)ic0 * (args.nb01 / sizeof(float));
|
||||
+
|
||||
+ device const uint8_t * K_p = K_packed
|
||||
+ + (long)args.firstCell * args.nKVHeads * args.packedBytes
|
||||
+ + (long)head_kv * args.packedBytes;
|
||||
+ device const float * K_sc = K_scales
|
||||
+ + (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
+
|
||||
+ device const half * V = V_data
|
||||
+ + (long)sequence * (args.nb23 / sizeof(half))
|
||||
+ + (long)head_kv * (args.nb22 / sizeof(half));
|
||||
+
|
||||
+ device const half * maskh = args.hasMask
|
||||
+ ? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
|
||||
+ : nullptr;
|
||||
+
|
||||
+ const int k_cb_mask = (1 << args.bits) - 1;
|
||||
+ const float k_cb_lane = K_cb[tiisg & k_cb_mask];
|
||||
+
|
||||
+ const int tid_kq = (int)tiisg % nthreads_KQ;
|
||||
+
|
||||
+ // D=256: Q_reg holds 16 float2 per thread per query slot (D/(2*nthreads_KQ)).
|
||||
+ float2 Q_reg[2][16];
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ const int elem = tid_kq * 16 + i; // float2 index within [0, 127]
|
||||
+ Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
|
||||
+ }
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ Q_reg[j][i].x *= args.scale;
|
||||
+ Q_reg[j][i].y *= args.scale;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ // D=256: VKQ holds 4 passes × 4 float2 = 16 float2 per query slot.
|
||||
+ float2 VKQ[2][16];
|
||||
+ for (int j = 0; j < 2; j++)
|
||||
+ for (int i = 0; i < 16; i++)
|
||||
+ VKQ[j][i] = float2(0.0f, 0.0f);
|
||||
+
|
||||
+ float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
|
||||
+ float KQ_sum[2] = { 0.0f, 0.0f };
|
||||
+
|
||||
+ // D=256: KQ_tg sized nwarps*V_cols_per_iter*D = 4*4*256 = 4096 floats (16 KiB).
|
||||
+ threadgroup float KQ_tg[4096];
|
||||
+ threadgroup float KQ_max_tg[2][32];
|
||||
+ threadgroup float KQ_sum_tg[2][32];
|
||||
+
|
||||
+ for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
|
||||
+
|
||||
+ float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
|
||||
+
|
||||
+ for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
|
||||
+ const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
|
||||
+ const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
|
||||
+ const int cell_rel = k_VKQ_0 + i_KQ;
|
||||
+ const bool in_range = (cell_rel < args.nCells);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
|
||||
+ const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
+
|
||||
+ // D=256: 16 k-iterations × 2 elements each = 32 D-positions per thread.
|
||||
+ float sum = 0.0f;
|
||||
+ for (int k = 0; k < 16; k++) {
|
||||
+ const int start_elem = tid_kq * 32 + k * 2; // float index [0..254]
|
||||
+ float k_dec[2];
|
||||
+ if (args.bits == 3) {
|
||||
+ const int bit_pos0 = start_elem * 3;
|
||||
+ const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
|
||||
+ const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
|
||||
+ int idx0 = (int)((w0 >> sh0) & 7);
|
||||
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)idx0) * rms_scale;
|
||||
+ const int bit_pos1 = (start_elem + 1) * 3;
|
||||
+ const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
|
||||
+ const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
|
||||
+ int idx1 = (int)((w1 >> sh1) & 7);
|
||||
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)idx1) * rms_scale;
|
||||
+ } else {
|
||||
+ const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
|
||||
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
|
||||
+ const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
|
||||
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
|
||||
+ }
|
||||
+ sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
|
||||
+ }
|
||||
+ sum += simd_shuffle_xor(sum, 4);
|
||||
+ sum += simd_shuffle_xor(sum, 2);
|
||||
+ sum += simd_shuffle_xor(sum, 1);
|
||||
+
|
||||
+ if (args.logit_softcap != 0.0f) {
|
||||
+ sum = args.logit_softcap * tanh(sum);
|
||||
+ }
|
||||
+
|
||||
+ if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
|
||||
+ sum += float(maskh[(long)j * args.ne31 + i_KQ]);
|
||||
+ }
|
||||
+
|
||||
+ if (!in_range) sum = -FLT_MAX/2.0f;
|
||||
+
|
||||
+ KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
|
||||
+
|
||||
+ if (tid_kq == (uint)i_KQ_0) {
|
||||
+ KQ_tg[j * nthreads + tid] = sum;
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ KQ_max_new[j] = simd_max(KQ_max_new[j]);
|
||||
+
|
||||
+ const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
|
||||
+ KQ_max[j] = KQ_max_new[j];
|
||||
+
|
||||
+ const float kq_val = KQ_tg[j * nthreads + tid];
|
||||
+ const float kq_exp = exp(kq_val - KQ_max[j]);
|
||||
+ KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
|
||||
+ KQ_tg[j * nthreads + tid] = kq_exp;
|
||||
+
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ VKQ[j][i].x *= KQ_max_scale;
|
||||
+ VKQ[j][i].y *= KQ_max_scale;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ // D=256: 4 passes × 64 V-elements each = full 256-element coverage.
|
||||
+ // pass 0 → V[ 0.. 63] → VKQ[ 0.. 3]
|
||||
+ // pass 1 → V[ 64..127] → VKQ[ 4.. 7]
|
||||
+ // pass 2 → V[128..191] → VKQ[ 8..11]
|
||||
+ // pass 3 → V[192..255] → VKQ[12..15]
|
||||
+ for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
|
||||
+ const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
|
||||
+ const int cell_rel = k_VKQ_0 + k;
|
||||
+
|
||||
+ float KQ_k[2];
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ KQ_k[j] = KQ_tg[j * nthreads + k];
|
||||
+ }
|
||||
+
|
||||
+ device const half * V_cell = (cell_rel < args.nCells)
|
||||
+ ? V + (long)cell_rel * (args.nb21 / sizeof(half))
|
||||
+ : nullptr;
|
||||
+
|
||||
+ const int v_tid = (int)tiisg % nthreads_V;
|
||||
+ for (int pass = 0; pass < 4; pass++) {
|
||||
+ for (int i = 0; i < 8; i++) {
|
||||
+ const int elem = pass * 64 + v_tid * 8 + i;
|
||||
+ float v_val = (V_cell && elem < D) ? float(V_cell[elem]) : 0.0f;
|
||||
+ const int vkq_idx = pass * 4 + i / 2;
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (i % 2 == 0) VKQ[j][vkq_idx].x += v_val * KQ_k[j];
|
||||
+ else VKQ[j][vkq_idx].y += v_val * KQ_k[j];
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ } // end KV loop
|
||||
+
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (sgitg == 0) {
|
||||
+ KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
|
||||
+ KQ_sum_tg[j][tiisg] = 0.0f;
|
||||
+ }
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (tiisg == 0) {
|
||||
+ KQ_max_tg[j][sgitg] = KQ_max[j];
|
||||
+ }
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
|
||||
+
|
||||
+ float kqmax_new = KQ_max_tg[j][tiisg];
|
||||
+ kqmax_new = simd_max(kqmax_new);
|
||||
+ const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
|
||||
+ KQ_max[j] = kqmax_new;
|
||||
+
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ VKQ[j][i].x *= kqmax_scale;
|
||||
+ VKQ[j][i].y *= kqmax_scale;
|
||||
+ }
|
||||
+
|
||||
+ // D=256: VKQ_tg layout per (sgitg,v-group) region of D/2 = 128 float2:
|
||||
+ // VKQ[ 0.. 3] → VKQ_tg[v_tid*4 + 0..3] (pass 0, float2 slots 0..31)
|
||||
+ // VKQ[ 4.. 7] → VKQ_tg[32 + v_tid*4 + 0..3] (pass 1, float2 slots 32..63)
|
||||
+ // VKQ[ 8..11] → VKQ_tg[64 + v_tid*4 + 0..3] (pass 2, float2 slots 64..95)
|
||||
+ // VKQ[12..15] → VKQ_tg[96 + v_tid*4 + 0..3] (pass 3, float2 slots 96..127)
|
||||
+ const int v_tid = (int)tiisg % nthreads_V;
|
||||
+ threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
|
||||
+ + (long)sgitg * (V_cols_per_iter * D/2)
|
||||
+ + (long)((int)tiisg / nthreads_V) * (D/2);
|
||||
+ VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
|
||||
+ VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
|
||||
+ VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
|
||||
+ VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
|
||||
+
|
||||
+ KQ_sum[j] *= kqmax_scale;
|
||||
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
+ if (tiisg == 0) {
|
||||
+ KQ_sum_tg[j][sgitg] = KQ_sum[j];
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ // D=256: each thread writes 2 output positions (tid, tid+128). Compute
|
||||
+ // KQ_sum with ALL lanes participating in the simd_sum (required for
|
||||
+ // convergence), then two writes per thread.
|
||||
+ KQ_sum[j] = KQ_sum_tg[j][tiisg];
|
||||
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
+
|
||||
+ const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
|
||||
+ for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
|
||||
+ const int out_elem = out_offset + tid;
|
||||
+ float dst_val = 0.0f;
|
||||
+ for (int w = 0; w < nwarps; w++) {
|
||||
+ for (int v = 0; v < V_cols_per_iter; v++) {
|
||||
+ dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
|
||||
+ }
|
||||
+ }
|
||||
+ dst_val /= KQ_sum[j];
|
||||
+ dst[out_idx * D + out_elem] = dst_val;
|
||||
+ }
|
||||
+
|
||||
+ if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// kernel_tq_fattn_vec_packed
|
||||
// TQ fused flash-attention: K packed i8, V packed i8.
|
||||
@@ -11451,3 +11738,278 @@ kernel void kernel_tq_fattn_vec_packed(
|
||||
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
+
|
||||
+// ─────────────────────────────────────────────────────────────────────────────
|
||||
+// kernel_tq_fattn_vec_packed_d256
|
||||
+// TQ fused flash-attention at head dim 256: K packed i8, V packed i8.
|
||||
+// Mirrors kernel_tq_fattn_vec_packed with 4 V-passes and 2 outputs per thread.
|
||||
+// ─────────────────────────────────────────────────────────────────────────────
|
||||
+kernel void kernel_tq_fattn_vec_packed_d256(
|
||||
+ constant ggml_metal_kargs_tq_fattn_vec & args,
|
||||
+ device const char * Q_data [[buffer(1)]],
|
||||
+ device const uint8_t * K_packed [[buffer(2)]],
|
||||
+ device const uint8_t * V_packed [[buffer(3)]],
|
||||
+ device const half * mask_data [[buffer(4)]],
|
||||
+ device const float * K_scales [[buffer(5)]],
|
||||
+ device const float * K_cb [[buffer(6)]],
|
||||
+ device const float * V_scales [[buffer(7)]],
|
||||
+ device const float * V_cb [[buffer(8)]],
|
||||
+ device float * dst [[buffer(9)]],
|
||||
+ uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
+ uint tiisg [[thread_index_in_simdgroup]],
|
||||
+ uint sgitg [[simdgroup_index_in_threadgroup]])
|
||||
+{
|
||||
+ constexpr int D = 256;
|
||||
+ constexpr int nthreads = 128;
|
||||
+ constexpr int nthreads_KQ = 8;
|
||||
+ constexpr int nthreads_V = 8;
|
||||
+ constexpr int V_cols_per_iter = 4;
|
||||
+ constexpr int nwarps = 4;
|
||||
+
|
||||
+ const int ic0 = (int)tgpig.x * args.ncols;
|
||||
+ const int blk_z = (int)tgpig.z;
|
||||
+ const int sequence = blk_z / args.nHeadsQ;
|
||||
+ const int head = blk_z % args.nHeadsQ;
|
||||
+ const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
|
||||
+ const int head_kv = head / gqa_ratio;
|
||||
+
|
||||
+ const int tid = (int)sgitg * 32 + (int)tiisg;
|
||||
+
|
||||
+ device const float * Q = (device const float *)Q_data
|
||||
+ + (long)sequence * (args.nb03 / sizeof(float))
|
||||
+ + (long)head * (args.nb02 / sizeof(float))
|
||||
+ + (long)ic0 * (args.nb01 / sizeof(float));
|
||||
+
|
||||
+ device const uint8_t * K_p = K_packed
|
||||
+ + (long)args.firstCell * args.nKVHeads * args.packedBytes
|
||||
+ + (long)head_kv * args.packedBytes;
|
||||
+ device const float * K_sc = K_scales
|
||||
+ + (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
+
|
||||
+ device const uint8_t * V_p = V_packed
|
||||
+ + (long)args.firstCell * args.nKVHeads * args.v_packedBytes
|
||||
+ + (long)head_kv * args.v_packedBytes;
|
||||
+ device const float * V_sc = V_scales
|
||||
+ + (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
+
|
||||
+ device const half * maskh = args.hasMask
|
||||
+ ? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
|
||||
+ : nullptr;
|
||||
+
|
||||
+ const int k_cb_mask = (1 << args.bits) - 1;
|
||||
+ const float k_cb_lane = K_cb[tiisg & k_cb_mask];
|
||||
+ const int v_cb_mask = (1 << args.v_bits) - 1;
|
||||
+ const float v_cb_lane = V_cb[tiisg & v_cb_mask];
|
||||
+
|
||||
+ const int tid_kq = (int)tiisg % nthreads_KQ;
|
||||
+
|
||||
+ float2 Q_reg[2][16];
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ const int elem = tid_kq * 16 + i;
|
||||
+ Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
|
||||
+ }
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ Q_reg[j][i].x *= args.scale;
|
||||
+ Q_reg[j][i].y *= args.scale;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ float2 VKQ[2][16];
|
||||
+ for (int j = 0; j < 2; j++)
|
||||
+ for (int i = 0; i < 16; i++)
|
||||
+ VKQ[j][i] = float2(0.0f, 0.0f);
|
||||
+
|
||||
+ float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
|
||||
+ float KQ_sum[2] = { 0.0f, 0.0f };
|
||||
+
|
||||
+ threadgroup float KQ_tg[4096];
|
||||
+ threadgroup float KQ_max_tg[2][32];
|
||||
+ threadgroup float KQ_sum_tg[2][32];
|
||||
+
|
||||
+ for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
|
||||
+
|
||||
+ float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
|
||||
+
|
||||
+ for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
|
||||
+ const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
|
||||
+ const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
|
||||
+ const int cell_rel = k_VKQ_0 + i_KQ;
|
||||
+ const bool in_range = (cell_rel < args.nCells);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
|
||||
+ const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
+
|
||||
+ float sum = 0.0f;
|
||||
+ for (int k = 0; k < 16; k++) {
|
||||
+ const int start_elem = tid_kq * 32 + k * 2;
|
||||
+ float k_dec[2];
|
||||
+ if (args.bits == 3) {
|
||||
+ const int bit_pos0 = start_elem * 3;
|
||||
+ const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
|
||||
+ const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
|
||||
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((w0 >> sh0) & 7)) * rms_scale;
|
||||
+ const int bit_pos1 = (start_elem + 1) * 3;
|
||||
+ const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
|
||||
+ const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
|
||||
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((w1 >> sh1) & 7)) * rms_scale;
|
||||
+ } else {
|
||||
+ const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
|
||||
+ k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
|
||||
+ const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
|
||||
+ k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
|
||||
+ }
|
||||
+ sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
|
||||
+ }
|
||||
+ sum += simd_shuffle_xor(sum, 4);
|
||||
+ sum += simd_shuffle_xor(sum, 2);
|
||||
+ sum += simd_shuffle_xor(sum, 1);
|
||||
+
|
||||
+ if (args.logit_softcap != 0.0f) {
|
||||
+ sum = args.logit_softcap * tanh(sum);
|
||||
+ }
|
||||
+ if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
|
||||
+ sum += float(maskh[(long)j * args.ne31 + i_KQ]);
|
||||
+ }
|
||||
+ if (!in_range) sum = -FLT_MAX/2.0f;
|
||||
+
|
||||
+ KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
|
||||
+
|
||||
+ if (tid_kq == (uint)i_KQ_0) {
|
||||
+ KQ_tg[j * nthreads + tid] = sum;
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ KQ_max_new[j] = simd_max(KQ_max_new[j]);
|
||||
+
|
||||
+ const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
|
||||
+ KQ_max[j] = KQ_max_new[j];
|
||||
+
|
||||
+ const float kq_val = KQ_tg[j * nthreads + tid];
|
||||
+ const float kq_exp = exp(kq_val - KQ_max[j]);
|
||||
+ KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
|
||||
+ KQ_tg[j * nthreads + tid] = kq_exp;
|
||||
+
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ VKQ[j][i].x *= KQ_max_scale;
|
||||
+ VKQ[j][i].y *= KQ_max_scale;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ // D=256: 4 passes cover V[0..63], V[64..127], V[128..191], V[192..255].
|
||||
+ for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
|
||||
+ const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
|
||||
+ const int cell_rel = k_VKQ_0 + k;
|
||||
+ const bool in_range_v = (cell_rel < args.nCells);
|
||||
+
|
||||
+ float KQ_k[2];
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ KQ_k[j] = KQ_tg[j * nthreads + k];
|
||||
+ }
|
||||
+
|
||||
+ device const uint8_t * v_row = in_range_v
|
||||
+ ? V_p + (long)cell_rel * args.nKVHeads * args.v_packedBytes
|
||||
+ : nullptr;
|
||||
+ const float v_rms = in_range_v ? V_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
+
|
||||
+ const int v_tid = (int)tiisg % nthreads_V;
|
||||
+ for (int pass = 0; pass < 4; pass++) {
|
||||
+ float v_dec[8];
|
||||
+ const int start_elem = pass * 64 + v_tid * 8;
|
||||
+ if (v_row && start_elem < D) {
|
||||
+ tq_decode_8_shfl(v_row, v_cb_lane, v_rms, start_elem, args.v_bits, v_dec);
|
||||
+ } else {
|
||||
+ tq_decode_8_shfl(K_p, v_cb_lane, 0.0f, 0, args.v_bits, v_dec);
|
||||
+ }
|
||||
+ for (int i = 0; i < 8; i++) {
|
||||
+ const int vkq_idx = pass * 4 + i / 2;
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (i % 2 == 0) VKQ[j][vkq_idx].x += v_dec[i] * KQ_k[j];
|
||||
+ else VKQ[j][vkq_idx].y += v_dec[i] * KQ_k[j];
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ } // end KV loop
|
||||
+
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (sgitg == 0) {
|
||||
+ KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
|
||||
+ KQ_sum_tg[j][tiisg] = 0.0f;
|
||||
+ }
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (tiisg == 0) {
|
||||
+ KQ_max_tg[j][sgitg] = KQ_max[j];
|
||||
+ }
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ for (int j = 0; j < args.ncols; j++) {
|
||||
+ if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
|
||||
+
|
||||
+ float kqmax_new = KQ_max_tg[j][tiisg];
|
||||
+ kqmax_new = simd_max(kqmax_new);
|
||||
+ const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
|
||||
+ KQ_max[j] = kqmax_new;
|
||||
+
|
||||
+ for (int i = 0; i < 16; i++) {
|
||||
+ VKQ[j][i].x *= kqmax_scale;
|
||||
+ VKQ[j][i].y *= kqmax_scale;
|
||||
+ }
|
||||
+
|
||||
+ const int v_tid = (int)tiisg % nthreads_V;
|
||||
+ threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
|
||||
+ + (long)sgitg * (V_cols_per_iter * D/2)
|
||||
+ + (long)((int)tiisg / nthreads_V) * (D/2);
|
||||
+ VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
|
||||
+ VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
|
||||
+ VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
|
||||
+ VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
|
||||
+ VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
|
||||
+ VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
|
||||
+ VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
|
||||
+
|
||||
+ KQ_sum[j] *= kqmax_scale;
|
||||
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
+ if (tiisg == 0) {
|
||||
+ KQ_sum_tg[j][sgitg] = KQ_sum[j];
|
||||
+ }
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ KQ_sum[j] = KQ_sum_tg[j][tiisg];
|
||||
+ KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
+
|
||||
+ const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
|
||||
+ for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
|
||||
+ const int out_elem = out_offset + tid;
|
||||
+ float dst_val = 0.0f;
|
||||
+ for (int w = 0; w < nwarps; w++) {
|
||||
+ for (int v = 0; v < V_cols_per_iter; v++) {
|
||||
+ dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
|
||||
+ }
|
||||
+ }
|
||||
+ dst_val /= KQ_sum[j];
|
||||
+ dst[out_idx * D + out_elem] = dst_val;
|
||||
+ }
|
||||
+
|
||||
+ if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+ }
|
||||
+}
|
||||
|
|
@ -1720,5 +1720,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
|
|||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_v"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_encode_outlier"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_f16_d256"); }
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib) { return tq_get_pipeline(lib, "kernel_tq_fattn_vec_packed_d256"); }
|
||||
|
|
|
|||
|
|
@ -194,8 +194,10 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_dequan
|
|||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_v (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_encode_outlier(ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed(ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16 (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256 (ggml_metal_library_t lib);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(ggml_metal_library_t lib);
|
||||
|
||||
// MTLResidencySet wrapper
|
||||
|
||||
|
|
|
|||
|
|
@ -14084,6 +14084,293 @@ kernel void kernel_tq_fattn_vec_f16(
|
|||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// kernel_tq_fattn_vec_f16_d256
|
||||
// TQ fused flash-attention at head dim 256: K packed i8, V f16.
|
||||
// Thread layout identical to the D=128 variant (32×4 = 128 threads) but each
|
||||
// thread now produces 2 output elements (D/nthreads = 2) and covers twice as
|
||||
// many D positions in the Q/K/V loops.
|
||||
// Grid: (ntiles_x, 1, nHeadsQ*nSeq)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
kernel void kernel_tq_fattn_vec_f16_d256(
|
||||
constant ggml_metal_kargs_tq_fattn_vec & args,
|
||||
device const char * Q_data [[buffer(1)]],
|
||||
device const uint8_t * K_packed [[buffer(2)]],
|
||||
device const half * V_data [[buffer(3)]],
|
||||
device const half * mask_data [[buffer(4)]],
|
||||
device const float * K_scales [[buffer(5)]],
|
||||
device const float * K_cb [[buffer(6)]],
|
||||
device const float * dummy_vs [[buffer(7)]],
|
||||
device const float * dummy_vc [[buffer(8)]],
|
||||
device float * dst [[buffer(9)]],
|
||||
uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
uint tiisg [[thread_index_in_simdgroup]],
|
||||
uint sgitg [[simdgroup_index_in_threadgroup]])
|
||||
{
|
||||
constexpr int D = 256;
|
||||
constexpr int nthreads = 128;
|
||||
constexpr int nthreads_KQ = 8;
|
||||
constexpr int nthreads_V = 8;
|
||||
constexpr int V_cols_per_iter = 4;
|
||||
constexpr int nwarps = 4;
|
||||
|
||||
const int ic0 = (int)tgpig.x * args.ncols;
|
||||
const int blk_z = (int)tgpig.z;
|
||||
const int sequence = blk_z / args.nHeadsQ;
|
||||
const int head = blk_z % args.nHeadsQ;
|
||||
const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
|
||||
const int head_kv = head / gqa_ratio;
|
||||
|
||||
const int tid = (int)sgitg * 32 + (int)tiisg;
|
||||
|
||||
device const float * Q = (device const float *)Q_data
|
||||
+ (long)sequence * (args.nb03 / sizeof(float))
|
||||
+ (long)head * (args.nb02 / sizeof(float))
|
||||
+ (long)ic0 * (args.nb01 / sizeof(float));
|
||||
|
||||
device const uint8_t * K_p = K_packed
|
||||
+ (long)args.firstCell * args.nKVHeads * args.packedBytes
|
||||
+ (long)head_kv * args.packedBytes;
|
||||
device const float * K_sc = K_scales
|
||||
+ (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
|
||||
device const half * V = V_data
|
||||
+ (long)sequence * (args.nb23 / sizeof(half))
|
||||
+ (long)head_kv * (args.nb22 / sizeof(half));
|
||||
|
||||
device const half * maskh = args.hasMask
|
||||
? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
|
||||
: nullptr;
|
||||
|
||||
const int k_cb_mask = (1 << args.bits) - 1;
|
||||
const float k_cb_lane = K_cb[tiisg & k_cb_mask];
|
||||
|
||||
const int tid_kq = (int)tiisg % nthreads_KQ;
|
||||
|
||||
// D=256: Q_reg holds 16 float2 per thread per query slot (D/(2*nthreads_KQ)).
|
||||
float2 Q_reg[2][16];
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
|
||||
for (int i = 0; i < 16; i++) {
|
||||
const int elem = tid_kq * 16 + i; // float2 index within [0, 127]
|
||||
Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
|
||||
}
|
||||
for (int i = 0; i < 16; i++) {
|
||||
Q_reg[j][i].x *= args.scale;
|
||||
Q_reg[j][i].y *= args.scale;
|
||||
}
|
||||
}
|
||||
|
||||
// D=256: VKQ holds 4 passes × 4 float2 = 16 float2 per query slot.
|
||||
float2 VKQ[2][16];
|
||||
for (int j = 0; j < 2; j++)
|
||||
for (int i = 0; i < 16; i++)
|
||||
VKQ[j][i] = float2(0.0f, 0.0f);
|
||||
|
||||
float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
|
||||
float KQ_sum[2] = { 0.0f, 0.0f };
|
||||
|
||||
// D=256: KQ_tg sized nwarps*V_cols_per_iter*D = 4*4*256 = 4096 floats (16 KiB).
|
||||
threadgroup float KQ_tg[4096];
|
||||
threadgroup float KQ_max_tg[2][32];
|
||||
threadgroup float KQ_sum_tg[2][32];
|
||||
|
||||
for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
|
||||
|
||||
float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
|
||||
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
|
||||
const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
|
||||
const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
|
||||
const int cell_rel = k_VKQ_0 + i_KQ;
|
||||
const bool in_range = (cell_rel < args.nCells);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
|
||||
const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
|
||||
// D=256: 16 k-iterations × 2 elements each = 32 D-positions per thread.
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < 16; k++) {
|
||||
const int start_elem = tid_kq * 32 + k * 2; // float index [0..254]
|
||||
float k_dec[2];
|
||||
if (args.bits == 3) {
|
||||
const int bit_pos0 = start_elem * 3;
|
||||
const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
|
||||
const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
|
||||
int idx0 = (int)((w0 >> sh0) & 7);
|
||||
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)idx0) * rms_scale;
|
||||
const int bit_pos1 = (start_elem + 1) * 3;
|
||||
const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
|
||||
const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
|
||||
int idx1 = (int)((w1 >> sh1) & 7);
|
||||
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)idx1) * rms_scale;
|
||||
} else {
|
||||
const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
|
||||
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
|
||||
const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
|
||||
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
|
||||
}
|
||||
sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
|
||||
}
|
||||
sum += simd_shuffle_xor(sum, 4);
|
||||
sum += simd_shuffle_xor(sum, 2);
|
||||
sum += simd_shuffle_xor(sum, 1);
|
||||
|
||||
if (args.logit_softcap != 0.0f) {
|
||||
sum = args.logit_softcap * tanh(sum);
|
||||
}
|
||||
|
||||
if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
|
||||
sum += float(maskh[(long)j * args.ne31 + i_KQ]);
|
||||
}
|
||||
|
||||
if (!in_range) sum = -FLT_MAX/2.0f;
|
||||
|
||||
KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
|
||||
|
||||
if (tid_kq == (uint)i_KQ_0) {
|
||||
KQ_tg[j * nthreads + tid] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
KQ_max_new[j] = simd_max(KQ_max_new[j]);
|
||||
|
||||
const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
|
||||
KQ_max[j] = KQ_max_new[j];
|
||||
|
||||
const float kq_val = KQ_tg[j * nthreads + tid];
|
||||
const float kq_exp = exp(kq_val - KQ_max[j]);
|
||||
KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
|
||||
KQ_tg[j * nthreads + tid] = kq_exp;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
VKQ[j][i].x *= KQ_max_scale;
|
||||
VKQ[j][i].y *= KQ_max_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// D=256: 4 passes × 64 V-elements each = full 256-element coverage.
|
||||
// pass 0 → V[ 0.. 63] → VKQ[ 0.. 3]
|
||||
// pass 1 → V[ 64..127] → VKQ[ 4.. 7]
|
||||
// pass 2 → V[128..191] → VKQ[ 8..11]
|
||||
// pass 3 → V[192..255] → VKQ[12..15]
|
||||
for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
|
||||
const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
|
||||
const int cell_rel = k_VKQ_0 + k;
|
||||
|
||||
float KQ_k[2];
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
KQ_k[j] = KQ_tg[j * nthreads + k];
|
||||
}
|
||||
|
||||
device const half * V_cell = (cell_rel < args.nCells)
|
||||
? V + (long)cell_rel * (args.nb21 / sizeof(half))
|
||||
: nullptr;
|
||||
|
||||
const int v_tid = (int)tiisg % nthreads_V;
|
||||
for (int pass = 0; pass < 4; pass++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int elem = pass * 64 + v_tid * 8 + i;
|
||||
float v_val = (V_cell && elem < D) ? float(V_cell[elem]) : 0.0f;
|
||||
const int vkq_idx = pass * 4 + i / 2;
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (i % 2 == 0) VKQ[j][vkq_idx].x += v_val * KQ_k[j];
|
||||
else VKQ[j][vkq_idx].y += v_val * KQ_k[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end KV loop
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (sgitg == 0) {
|
||||
KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
|
||||
KQ_sum_tg[j][tiisg] = 0.0f;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (tiisg == 0) {
|
||||
KQ_max_tg[j][sgitg] = KQ_max[j];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
|
||||
|
||||
float kqmax_new = KQ_max_tg[j][tiisg];
|
||||
kqmax_new = simd_max(kqmax_new);
|
||||
const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
|
||||
KQ_max[j] = kqmax_new;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
VKQ[j][i].x *= kqmax_scale;
|
||||
VKQ[j][i].y *= kqmax_scale;
|
||||
}
|
||||
|
||||
// D=256: VKQ_tg layout per (sgitg,v-group) region of D/2 = 128 float2:
|
||||
// VKQ[ 0.. 3] → VKQ_tg[v_tid*4 + 0..3] (pass 0, float2 slots 0..31)
|
||||
// VKQ[ 4.. 7] → VKQ_tg[32 + v_tid*4 + 0..3] (pass 1, float2 slots 32..63)
|
||||
// VKQ[ 8..11] → VKQ_tg[64 + v_tid*4 + 0..3] (pass 2, float2 slots 64..95)
|
||||
// VKQ[12..15] → VKQ_tg[96 + v_tid*4 + 0..3] (pass 3, float2 slots 96..127)
|
||||
const int v_tid = (int)tiisg % nthreads_V;
|
||||
threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
|
||||
+ (long)sgitg * (V_cols_per_iter * D/2)
|
||||
+ (long)((int)tiisg / nthreads_V) * (D/2);
|
||||
VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
|
||||
VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
|
||||
VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
|
||||
VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
|
||||
VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
|
||||
VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
|
||||
VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
|
||||
VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
|
||||
VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
|
||||
VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
|
||||
VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
|
||||
VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
|
||||
VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
|
||||
VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
|
||||
VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
|
||||
VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
|
||||
|
||||
KQ_sum[j] *= kqmax_scale;
|
||||
KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
if (tiisg == 0) {
|
||||
KQ_sum_tg[j][sgitg] = KQ_sum[j];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// D=256: each thread writes 2 output positions (tid, tid+128). Compute
|
||||
// KQ_sum with ALL lanes participating in the simd_sum (required for
|
||||
// convergence), then two writes per thread.
|
||||
KQ_sum[j] = KQ_sum_tg[j][tiisg];
|
||||
KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
|
||||
const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
|
||||
for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
|
||||
const int out_elem = out_offset + tid;
|
||||
float dst_val = 0.0f;
|
||||
for (int w = 0; w < nwarps; w++) {
|
||||
for (int v = 0; v < V_cols_per_iter; v++) {
|
||||
dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
|
||||
}
|
||||
}
|
||||
dst_val /= KQ_sum[j];
|
||||
dst[out_idx * D + out_elem] = dst_val;
|
||||
}
|
||||
|
||||
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// kernel_tq_fattn_vec_packed
|
||||
// TQ fused flash-attention: K packed i8, V packed i8.
|
||||
|
|
@ -14351,3 +14638,278 @@ kernel void kernel_tq_fattn_vec_packed(
|
|||
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// kernel_tq_fattn_vec_packed_d256
|
||||
// TQ fused flash-attention at head dim 256: K packed i8, V packed i8.
|
||||
// Mirrors kernel_tq_fattn_vec_packed with 4 V-passes and 2 outputs per thread.
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
kernel void kernel_tq_fattn_vec_packed_d256(
|
||||
constant ggml_metal_kargs_tq_fattn_vec & args,
|
||||
device const char * Q_data [[buffer(1)]],
|
||||
device const uint8_t * K_packed [[buffer(2)]],
|
||||
device const uint8_t * V_packed [[buffer(3)]],
|
||||
device const half * mask_data [[buffer(4)]],
|
||||
device const float * K_scales [[buffer(5)]],
|
||||
device const float * K_cb [[buffer(6)]],
|
||||
device const float * V_scales [[buffer(7)]],
|
||||
device const float * V_cb [[buffer(8)]],
|
||||
device float * dst [[buffer(9)]],
|
||||
uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
uint tiisg [[thread_index_in_simdgroup]],
|
||||
uint sgitg [[simdgroup_index_in_threadgroup]])
|
||||
{
|
||||
constexpr int D = 256;
|
||||
constexpr int nthreads = 128;
|
||||
constexpr int nthreads_KQ = 8;
|
||||
constexpr int nthreads_V = 8;
|
||||
constexpr int V_cols_per_iter = 4;
|
||||
constexpr int nwarps = 4;
|
||||
|
||||
const int ic0 = (int)tgpig.x * args.ncols;
|
||||
const int blk_z = (int)tgpig.z;
|
||||
const int sequence = blk_z / args.nHeadsQ;
|
||||
const int head = blk_z % args.nHeadsQ;
|
||||
const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
|
||||
const int head_kv = head / gqa_ratio;
|
||||
|
||||
const int tid = (int)sgitg * 32 + (int)tiisg;
|
||||
|
||||
device const float * Q = (device const float *)Q_data
|
||||
+ (long)sequence * (args.nb03 / sizeof(float))
|
||||
+ (long)head * (args.nb02 / sizeof(float))
|
||||
+ (long)ic0 * (args.nb01 / sizeof(float));
|
||||
|
||||
device const uint8_t * K_p = K_packed
|
||||
+ (long)args.firstCell * args.nKVHeads * args.packedBytes
|
||||
+ (long)head_kv * args.packedBytes;
|
||||
device const float * K_sc = K_scales
|
||||
+ (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
|
||||
device const uint8_t * V_p = V_packed
|
||||
+ (long)args.firstCell * args.nKVHeads * args.v_packedBytes
|
||||
+ (long)head_kv * args.v_packedBytes;
|
||||
device const float * V_sc = V_scales
|
||||
+ (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
|
||||
device const half * maskh = args.hasMask
|
||||
? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
|
||||
: nullptr;
|
||||
|
||||
const int k_cb_mask = (1 << args.bits) - 1;
|
||||
const float k_cb_lane = K_cb[tiisg & k_cb_mask];
|
||||
const int v_cb_mask = (1 << args.v_bits) - 1;
|
||||
const float v_cb_lane = V_cb[tiisg & v_cb_mask];
|
||||
|
||||
const int tid_kq = (int)tiisg % nthreads_KQ;
|
||||
|
||||
float2 Q_reg[2][16];
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
|
||||
for (int i = 0; i < 16; i++) {
|
||||
const int elem = tid_kq * 16 + i;
|
||||
Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
|
||||
}
|
||||
for (int i = 0; i < 16; i++) {
|
||||
Q_reg[j][i].x *= args.scale;
|
||||
Q_reg[j][i].y *= args.scale;
|
||||
}
|
||||
}
|
||||
|
||||
float2 VKQ[2][16];
|
||||
for (int j = 0; j < 2; j++)
|
||||
for (int i = 0; i < 16; i++)
|
||||
VKQ[j][i] = float2(0.0f, 0.0f);
|
||||
|
||||
float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
|
||||
float KQ_sum[2] = { 0.0f, 0.0f };
|
||||
|
||||
threadgroup float KQ_tg[4096];
|
||||
threadgroup float KQ_max_tg[2][32];
|
||||
threadgroup float KQ_sum_tg[2][32];
|
||||
|
||||
for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
|
||||
|
||||
float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
|
||||
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
|
||||
const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
|
||||
const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
|
||||
const int cell_rel = k_VKQ_0 + i_KQ;
|
||||
const bool in_range = (cell_rel < args.nCells);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
|
||||
const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < 16; k++) {
|
||||
const int start_elem = tid_kq * 32 + k * 2;
|
||||
float k_dec[2];
|
||||
if (args.bits == 3) {
|
||||
const int bit_pos0 = start_elem * 3;
|
||||
const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
|
||||
const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
|
||||
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((w0 >> sh0) & 7)) * rms_scale;
|
||||
const int bit_pos1 = (start_elem + 1) * 3;
|
||||
const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
|
||||
const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
|
||||
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((w1 >> sh1) & 7)) * rms_scale;
|
||||
} else {
|
||||
const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
|
||||
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
|
||||
const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
|
||||
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
|
||||
}
|
||||
sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
|
||||
}
|
||||
sum += simd_shuffle_xor(sum, 4);
|
||||
sum += simd_shuffle_xor(sum, 2);
|
||||
sum += simd_shuffle_xor(sum, 1);
|
||||
|
||||
if (args.logit_softcap != 0.0f) {
|
||||
sum = args.logit_softcap * tanh(sum);
|
||||
}
|
||||
if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
|
||||
sum += float(maskh[(long)j * args.ne31 + i_KQ]);
|
||||
}
|
||||
if (!in_range) sum = -FLT_MAX/2.0f;
|
||||
|
||||
KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
|
||||
|
||||
if (tid_kq == (uint)i_KQ_0) {
|
||||
KQ_tg[j * nthreads + tid] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
KQ_max_new[j] = simd_max(KQ_max_new[j]);
|
||||
|
||||
const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
|
||||
KQ_max[j] = KQ_max_new[j];
|
||||
|
||||
const float kq_val = KQ_tg[j * nthreads + tid];
|
||||
const float kq_exp = exp(kq_val - KQ_max[j]);
|
||||
KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
|
||||
KQ_tg[j * nthreads + tid] = kq_exp;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
VKQ[j][i].x *= KQ_max_scale;
|
||||
VKQ[j][i].y *= KQ_max_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// D=256: 4 passes cover V[0..63], V[64..127], V[128..191], V[192..255].
|
||||
for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
|
||||
const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
|
||||
const int cell_rel = k_VKQ_0 + k;
|
||||
const bool in_range_v = (cell_rel < args.nCells);
|
||||
|
||||
float KQ_k[2];
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
KQ_k[j] = KQ_tg[j * nthreads + k];
|
||||
}
|
||||
|
||||
device const uint8_t * v_row = in_range_v
|
||||
? V_p + (long)cell_rel * args.nKVHeads * args.v_packedBytes
|
||||
: nullptr;
|
||||
const float v_rms = in_range_v ? V_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
|
||||
const int v_tid = (int)tiisg % nthreads_V;
|
||||
for (int pass = 0; pass < 4; pass++) {
|
||||
float v_dec[8];
|
||||
const int start_elem = pass * 64 + v_tid * 8;
|
||||
if (v_row && start_elem < D) {
|
||||
tq_decode_8_shfl(v_row, v_cb_lane, v_rms, start_elem, args.v_bits, v_dec);
|
||||
} else {
|
||||
tq_decode_8_shfl(K_p, v_cb_lane, 0.0f, 0, args.v_bits, v_dec);
|
||||
}
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int vkq_idx = pass * 4 + i / 2;
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (i % 2 == 0) VKQ[j][vkq_idx].x += v_dec[i] * KQ_k[j];
|
||||
else VKQ[j][vkq_idx].y += v_dec[i] * KQ_k[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end KV loop
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (sgitg == 0) {
|
||||
KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
|
||||
KQ_sum_tg[j][tiisg] = 0.0f;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (tiisg == 0) {
|
||||
KQ_max_tg[j][sgitg] = KQ_max[j];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
|
||||
|
||||
float kqmax_new = KQ_max_tg[j][tiisg];
|
||||
kqmax_new = simd_max(kqmax_new);
|
||||
const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
|
||||
KQ_max[j] = kqmax_new;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
VKQ[j][i].x *= kqmax_scale;
|
||||
VKQ[j][i].y *= kqmax_scale;
|
||||
}
|
||||
|
||||
const int v_tid = (int)tiisg % nthreads_V;
|
||||
threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
|
||||
+ (long)sgitg * (V_cols_per_iter * D/2)
|
||||
+ (long)((int)tiisg / nthreads_V) * (D/2);
|
||||
VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
|
||||
VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
|
||||
VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
|
||||
VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
|
||||
VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
|
||||
VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
|
||||
VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
|
||||
VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
|
||||
VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
|
||||
VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
|
||||
VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
|
||||
VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
|
||||
VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
|
||||
VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
|
||||
VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
|
||||
VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
|
||||
|
||||
KQ_sum[j] *= kqmax_scale;
|
||||
KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
if (tiisg == 0) {
|
||||
KQ_sum_tg[j][sgitg] = KQ_sum[j];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
KQ_sum[j] = KQ_sum_tg[j][tiisg];
|
||||
KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
|
||||
const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
|
||||
for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
|
||||
const int out_elem = out_offset + tid;
|
||||
float dst_val = 0.0f;
|
||||
for (int w = 0; w < nwarps; w++) {
|
||||
for (int v = 0; v < V_cols_per_iter; v++) {
|
||||
dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
|
||||
}
|
||||
}
|
||||
dst_val /= KQ_sum[j];
|
||||
dst[out_idx * D + out_elem] = dst_val;
|
||||
}
|
||||
|
||||
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4602,9 +4602,16 @@ int ggml_metal_op_tq_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
/*.nb31 =*/ mask ? mask->nb[1] : 0,
|
||||
};
|
||||
|
||||
// Select D=128 vs D=256 pipeline. Gemma3 runs at headDim=256; everything
|
||||
// else supported so far is D=128.
|
||||
GGML_ASSERT(D == 128 || D == 256);
|
||||
auto pipeline = v_packed
|
||||
? ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib)
|
||||
: ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib);
|
||||
? (D == 256
|
||||
? ggml_metal_library_get_pipeline_tq_fattn_vec_packed_d256(lib)
|
||||
: ggml_metal_library_get_pipeline_tq_fattn_vec_packed(lib))
|
||||
: (D == 256
|
||||
? ggml_metal_library_get_pipeline_tq_fattn_vec_f16_d256(lib)
|
||||
: ggml_metal_library_get_pipeline_tq_fattn_vec_f16(lib));
|
||||
|
||||
ggml_metal_buffer_id bid_mask = hasMask ? ggml_metal_get_buffer_id(mask) : ggml_metal_get_buffer_id(op);
|
||||
ggml_metal_buffer_id bid_v_scales = v_packed ? ggml_metal_get_buffer_id(op->src[6]) : ggml_metal_get_buffer_id(op);
|
||||
|
|
|
|||
562
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
vendored
562
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
vendored
|
|
@ -11184,6 +11184,293 @@ kernel void kernel_tq_fattn_vec_f16(
|
|||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// kernel_tq_fattn_vec_f16_d256
|
||||
// TQ fused flash-attention at head dim 256: K packed i8, V f16.
|
||||
// Thread layout identical to the D=128 variant (32×4 = 128 threads) but each
|
||||
// thread now produces 2 output elements (D/nthreads = 2) and covers twice as
|
||||
// many D positions in the Q/K/V loops.
|
||||
// Grid: (ntiles_x, 1, nHeadsQ*nSeq)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
kernel void kernel_tq_fattn_vec_f16_d256(
|
||||
constant ggml_metal_kargs_tq_fattn_vec & args,
|
||||
device const char * Q_data [[buffer(1)]],
|
||||
device const uint8_t * K_packed [[buffer(2)]],
|
||||
device const half * V_data [[buffer(3)]],
|
||||
device const half * mask_data [[buffer(4)]],
|
||||
device const float * K_scales [[buffer(5)]],
|
||||
device const float * K_cb [[buffer(6)]],
|
||||
device const float * dummy_vs [[buffer(7)]],
|
||||
device const float * dummy_vc [[buffer(8)]],
|
||||
device float * dst [[buffer(9)]],
|
||||
uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
uint tiisg [[thread_index_in_simdgroup]],
|
||||
uint sgitg [[simdgroup_index_in_threadgroup]])
|
||||
{
|
||||
constexpr int D = 256;
|
||||
constexpr int nthreads = 128;
|
||||
constexpr int nthreads_KQ = 8;
|
||||
constexpr int nthreads_V = 8;
|
||||
constexpr int V_cols_per_iter = 4;
|
||||
constexpr int nwarps = 4;
|
||||
|
||||
const int ic0 = (int)tgpig.x * args.ncols;
|
||||
const int blk_z = (int)tgpig.z;
|
||||
const int sequence = blk_z / args.nHeadsQ;
|
||||
const int head = blk_z % args.nHeadsQ;
|
||||
const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
|
||||
const int head_kv = head / gqa_ratio;
|
||||
|
||||
const int tid = (int)sgitg * 32 + (int)tiisg;
|
||||
|
||||
device const float * Q = (device const float *)Q_data
|
||||
+ (long)sequence * (args.nb03 / sizeof(float))
|
||||
+ (long)head * (args.nb02 / sizeof(float))
|
||||
+ (long)ic0 * (args.nb01 / sizeof(float));
|
||||
|
||||
device const uint8_t * K_p = K_packed
|
||||
+ (long)args.firstCell * args.nKVHeads * args.packedBytes
|
||||
+ (long)head_kv * args.packedBytes;
|
||||
device const float * K_sc = K_scales
|
||||
+ (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
|
||||
device const half * V = V_data
|
||||
+ (long)sequence * (args.nb23 / sizeof(half))
|
||||
+ (long)head_kv * (args.nb22 / sizeof(half));
|
||||
|
||||
device const half * maskh = args.hasMask
|
||||
? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
|
||||
: nullptr;
|
||||
|
||||
const int k_cb_mask = (1 << args.bits) - 1;
|
||||
const float k_cb_lane = K_cb[tiisg & k_cb_mask];
|
||||
|
||||
const int tid_kq = (int)tiisg % nthreads_KQ;
|
||||
|
||||
// D=256: Q_reg holds 16 float2 per thread per query slot (D/(2*nthreads_KQ)).
|
||||
float2 Q_reg[2][16];
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
|
||||
for (int i = 0; i < 16; i++) {
|
||||
const int elem = tid_kq * 16 + i; // float2 index within [0, 127]
|
||||
Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
|
||||
}
|
||||
for (int i = 0; i < 16; i++) {
|
||||
Q_reg[j][i].x *= args.scale;
|
||||
Q_reg[j][i].y *= args.scale;
|
||||
}
|
||||
}
|
||||
|
||||
// D=256: VKQ holds 4 passes × 4 float2 = 16 float2 per query slot.
|
||||
float2 VKQ[2][16];
|
||||
for (int j = 0; j < 2; j++)
|
||||
for (int i = 0; i < 16; i++)
|
||||
VKQ[j][i] = float2(0.0f, 0.0f);
|
||||
|
||||
float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
|
||||
float KQ_sum[2] = { 0.0f, 0.0f };
|
||||
|
||||
// D=256: KQ_tg sized nwarps*V_cols_per_iter*D = 4*4*256 = 4096 floats (16 KiB).
|
||||
threadgroup float KQ_tg[4096];
|
||||
threadgroup float KQ_max_tg[2][32];
|
||||
threadgroup float KQ_sum_tg[2][32];
|
||||
|
||||
for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
|
||||
|
||||
float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
|
||||
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
|
||||
const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
|
||||
const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
|
||||
const int cell_rel = k_VKQ_0 + i_KQ;
|
||||
const bool in_range = (cell_rel < args.nCells);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
|
||||
const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
|
||||
// D=256: 16 k-iterations × 2 elements each = 32 D-positions per thread.
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < 16; k++) {
|
||||
const int start_elem = tid_kq * 32 + k * 2; // float index [0..254]
|
||||
float k_dec[2];
|
||||
if (args.bits == 3) {
|
||||
const int bit_pos0 = start_elem * 3;
|
||||
const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
|
||||
const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
|
||||
int idx0 = (int)((w0 >> sh0) & 7);
|
||||
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)idx0) * rms_scale;
|
||||
const int bit_pos1 = (start_elem + 1) * 3;
|
||||
const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
|
||||
const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
|
||||
int idx1 = (int)((w1 >> sh1) & 7);
|
||||
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)idx1) * rms_scale;
|
||||
} else {
|
||||
const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
|
||||
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
|
||||
const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
|
||||
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
|
||||
}
|
||||
sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
|
||||
}
|
||||
sum += simd_shuffle_xor(sum, 4);
|
||||
sum += simd_shuffle_xor(sum, 2);
|
||||
sum += simd_shuffle_xor(sum, 1);
|
||||
|
||||
if (args.logit_softcap != 0.0f) {
|
||||
sum = args.logit_softcap * tanh(sum);
|
||||
}
|
||||
|
||||
if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
|
||||
sum += float(maskh[(long)j * args.ne31 + i_KQ]);
|
||||
}
|
||||
|
||||
if (!in_range) sum = -FLT_MAX/2.0f;
|
||||
|
||||
KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
|
||||
|
||||
if (tid_kq == (uint)i_KQ_0) {
|
||||
KQ_tg[j * nthreads + tid] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
KQ_max_new[j] = simd_max(KQ_max_new[j]);
|
||||
|
||||
const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
|
||||
KQ_max[j] = KQ_max_new[j];
|
||||
|
||||
const float kq_val = KQ_tg[j * nthreads + tid];
|
||||
const float kq_exp = exp(kq_val - KQ_max[j]);
|
||||
KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
|
||||
KQ_tg[j * nthreads + tid] = kq_exp;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
VKQ[j][i].x *= KQ_max_scale;
|
||||
VKQ[j][i].y *= KQ_max_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// D=256: 4 passes × 64 V-elements each = full 256-element coverage.
|
||||
// pass 0 → V[ 0.. 63] → VKQ[ 0.. 3]
|
||||
// pass 1 → V[ 64..127] → VKQ[ 4.. 7]
|
||||
// pass 2 → V[128..191] → VKQ[ 8..11]
|
||||
// pass 3 → V[192..255] → VKQ[12..15]
|
||||
for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
|
||||
const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
|
||||
const int cell_rel = k_VKQ_0 + k;
|
||||
|
||||
float KQ_k[2];
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
KQ_k[j] = KQ_tg[j * nthreads + k];
|
||||
}
|
||||
|
||||
device const half * V_cell = (cell_rel < args.nCells)
|
||||
? V + (long)cell_rel * (args.nb21 / sizeof(half))
|
||||
: nullptr;
|
||||
|
||||
const int v_tid = (int)tiisg % nthreads_V;
|
||||
for (int pass = 0; pass < 4; pass++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int elem = pass * 64 + v_tid * 8 + i;
|
||||
float v_val = (V_cell && elem < D) ? float(V_cell[elem]) : 0.0f;
|
||||
const int vkq_idx = pass * 4 + i / 2;
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (i % 2 == 0) VKQ[j][vkq_idx].x += v_val * KQ_k[j];
|
||||
else VKQ[j][vkq_idx].y += v_val * KQ_k[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end KV loop
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (sgitg == 0) {
|
||||
KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
|
||||
KQ_sum_tg[j][tiisg] = 0.0f;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (tiisg == 0) {
|
||||
KQ_max_tg[j][sgitg] = KQ_max[j];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
|
||||
|
||||
float kqmax_new = KQ_max_tg[j][tiisg];
|
||||
kqmax_new = simd_max(kqmax_new);
|
||||
const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
|
||||
KQ_max[j] = kqmax_new;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
VKQ[j][i].x *= kqmax_scale;
|
||||
VKQ[j][i].y *= kqmax_scale;
|
||||
}
|
||||
|
||||
// D=256: VKQ_tg layout per (sgitg,v-group) region of D/2 = 128 float2:
|
||||
// VKQ[ 0.. 3] → VKQ_tg[v_tid*4 + 0..3] (pass 0, float2 slots 0..31)
|
||||
// VKQ[ 4.. 7] → VKQ_tg[32 + v_tid*4 + 0..3] (pass 1, float2 slots 32..63)
|
||||
// VKQ[ 8..11] → VKQ_tg[64 + v_tid*4 + 0..3] (pass 2, float2 slots 64..95)
|
||||
// VKQ[12..15] → VKQ_tg[96 + v_tid*4 + 0..3] (pass 3, float2 slots 96..127)
|
||||
const int v_tid = (int)tiisg % nthreads_V;
|
||||
threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
|
||||
+ (long)sgitg * (V_cols_per_iter * D/2)
|
||||
+ (long)((int)tiisg / nthreads_V) * (D/2);
|
||||
VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
|
||||
VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
|
||||
VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
|
||||
VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
|
||||
VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
|
||||
VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
|
||||
VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
|
||||
VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
|
||||
VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
|
||||
VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
|
||||
VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
|
||||
VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
|
||||
VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
|
||||
VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
|
||||
VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
|
||||
VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
|
||||
|
||||
KQ_sum[j] *= kqmax_scale;
|
||||
KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
if (tiisg == 0) {
|
||||
KQ_sum_tg[j][sgitg] = KQ_sum[j];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// D=256: each thread writes 2 output positions (tid, tid+128). Compute
|
||||
// KQ_sum with ALL lanes participating in the simd_sum (required for
|
||||
// convergence), then two writes per thread.
|
||||
KQ_sum[j] = KQ_sum_tg[j][tiisg];
|
||||
KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
|
||||
const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
|
||||
for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
|
||||
const int out_elem = out_offset + tid;
|
||||
float dst_val = 0.0f;
|
||||
for (int w = 0; w < nwarps; w++) {
|
||||
for (int v = 0; v < V_cols_per_iter; v++) {
|
||||
dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
|
||||
}
|
||||
}
|
||||
dst_val /= KQ_sum[j];
|
||||
dst[out_idx * D + out_elem] = dst_val;
|
||||
}
|
||||
|
||||
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// kernel_tq_fattn_vec_packed
|
||||
// TQ fused flash-attention: K packed i8, V packed i8.
|
||||
|
|
@ -11451,3 +11738,278 @@ kernel void kernel_tq_fattn_vec_packed(
|
|||
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// kernel_tq_fattn_vec_packed_d256
|
||||
// TQ fused flash-attention at head dim 256: K packed i8, V packed i8.
|
||||
// Mirrors kernel_tq_fattn_vec_packed with 4 V-passes and 2 outputs per thread.
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
kernel void kernel_tq_fattn_vec_packed_d256(
|
||||
constant ggml_metal_kargs_tq_fattn_vec & args,
|
||||
device const char * Q_data [[buffer(1)]],
|
||||
device const uint8_t * K_packed [[buffer(2)]],
|
||||
device const uint8_t * V_packed [[buffer(3)]],
|
||||
device const half * mask_data [[buffer(4)]],
|
||||
device const float * K_scales [[buffer(5)]],
|
||||
device const float * K_cb [[buffer(6)]],
|
||||
device const float * V_scales [[buffer(7)]],
|
||||
device const float * V_cb [[buffer(8)]],
|
||||
device float * dst [[buffer(9)]],
|
||||
uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
uint tiisg [[thread_index_in_simdgroup]],
|
||||
uint sgitg [[simdgroup_index_in_threadgroup]])
|
||||
{
|
||||
constexpr int D = 256;
|
||||
constexpr int nthreads = 128;
|
||||
constexpr int nthreads_KQ = 8;
|
||||
constexpr int nthreads_V = 8;
|
||||
constexpr int V_cols_per_iter = 4;
|
||||
constexpr int nwarps = 4;
|
||||
|
||||
const int ic0 = (int)tgpig.x * args.ncols;
|
||||
const int blk_z = (int)tgpig.z;
|
||||
const int sequence = blk_z / args.nHeadsQ;
|
||||
const int head = blk_z % args.nHeadsQ;
|
||||
const int gqa_ratio = args.nHeadsQ / args.nKVHeads;
|
||||
const int head_kv = head / gqa_ratio;
|
||||
|
||||
const int tid = (int)sgitg * 32 + (int)tiisg;
|
||||
|
||||
device const float * Q = (device const float *)Q_data
|
||||
+ (long)sequence * (args.nb03 / sizeof(float))
|
||||
+ (long)head * (args.nb02 / sizeof(float))
|
||||
+ (long)ic0 * (args.nb01 / sizeof(float));
|
||||
|
||||
device const uint8_t * K_p = K_packed
|
||||
+ (long)args.firstCell * args.nKVHeads * args.packedBytes
|
||||
+ (long)head_kv * args.packedBytes;
|
||||
device const float * K_sc = K_scales
|
||||
+ (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
|
||||
device const uint8_t * V_p = V_packed
|
||||
+ (long)args.firstCell * args.nKVHeads * args.v_packedBytes
|
||||
+ (long)head_kv * args.v_packedBytes;
|
||||
device const float * V_sc = V_scales
|
||||
+ (long)args.firstCell * args.nKVHeads + head_kv;
|
||||
|
||||
device const half * maskh = args.hasMask
|
||||
? (mask_data + (long)ic0 * (args.nb31 / sizeof(half)))
|
||||
: nullptr;
|
||||
|
||||
const int k_cb_mask = (1 << args.bits) - 1;
|
||||
const float k_cb_lane = K_cb[tiisg & k_cb_mask];
|
||||
const int v_cb_mask = (1 << args.v_bits) - 1;
|
||||
const float v_cb_lane = V_cb[tiisg & v_cb_mask];
|
||||
|
||||
const int tid_kq = (int)tiisg % nthreads_KQ;
|
||||
|
||||
float2 Q_reg[2][16];
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
device const float2 * Q_j = (device const float2 *)(Q + (long)j * (args.nb01 / sizeof(float)));
|
||||
for (int i = 0; i < 16; i++) {
|
||||
const int elem = tid_kq * 16 + i;
|
||||
Q_reg[j][i] = (elem < D/2) ? Q_j[elem] : float2(0.0f, 0.0f);
|
||||
}
|
||||
for (int i = 0; i < 16; i++) {
|
||||
Q_reg[j][i].x *= args.scale;
|
||||
Q_reg[j][i].y *= args.scale;
|
||||
}
|
||||
}
|
||||
|
||||
float2 VKQ[2][16];
|
||||
for (int j = 0; j < 2; j++)
|
||||
for (int i = 0; i < 16; i++)
|
||||
VKQ[j][i] = float2(0.0f, 0.0f);
|
||||
|
||||
float KQ_max[2] = { -FLT_MAX/2.0f, -FLT_MAX/2.0f };
|
||||
float KQ_sum[2] = { 0.0f, 0.0f };
|
||||
|
||||
threadgroup float KQ_tg[4096];
|
||||
threadgroup float KQ_max_tg[2][32];
|
||||
threadgroup float KQ_sum_tg[2][32];
|
||||
|
||||
for (int k_VKQ_0 = 0; k_VKQ_0 < args.nCells; k_VKQ_0 += nthreads) {
|
||||
|
||||
float KQ_max_new[2] = { KQ_max[0], KQ_max[1] };
|
||||
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; i_KQ_0++) {
|
||||
const int kq_grp_start = ((int)tiisg & ~(nthreads_KQ - 1));
|
||||
const int i_KQ = (int)sgitg * 32 + kq_grp_start + i_KQ_0;
|
||||
const int cell_rel = k_VKQ_0 + i_KQ;
|
||||
const bool in_range = (cell_rel < args.nCells);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
device const uint8_t * packed_row = K_p + (long)cell_rel * args.nKVHeads * args.packedBytes;
|
||||
const float rms_scale = in_range ? K_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int k = 0; k < 16; k++) {
|
||||
const int start_elem = tid_kq * 32 + k * 2;
|
||||
float k_dec[2];
|
||||
if (args.bits == 3) {
|
||||
const int bit_pos0 = start_elem * 3;
|
||||
const int byte0 = bit_pos0 >> 3, sh0 = bit_pos0 & 7;
|
||||
const uint w0 = (uint)packed_row[byte0] | ((uint)packed_row[byte0+1] << 8);
|
||||
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((w0 >> sh0) & 7)) * rms_scale;
|
||||
const int bit_pos1 = (start_elem + 1) * 3;
|
||||
const int byte1 = bit_pos1 >> 3, sh1 = bit_pos1 & 7;
|
||||
const uint w1 = (uint)packed_row[byte1] | ((uint)packed_row[byte1+1] << 8);
|
||||
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((w1 >> sh1) & 7)) * rms_scale;
|
||||
} else {
|
||||
const int byte0 = start_elem >> 2, sh0 = (start_elem & 3) * 2;
|
||||
k_dec[0] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte0] >> sh0) & 3)) * rms_scale;
|
||||
const int byte1 = (start_elem + 1) >> 2, sh1 = ((start_elem + 1) & 3) * 2;
|
||||
k_dec[1] = simd_shuffle(k_cb_lane, (ushort)((packed_row[byte1] >> sh1) & 3)) * rms_scale;
|
||||
}
|
||||
sum += Q_reg[j][k].x * k_dec[0] + Q_reg[j][k].y * k_dec[1];
|
||||
}
|
||||
sum += simd_shuffle_xor(sum, 4);
|
||||
sum += simd_shuffle_xor(sum, 2);
|
||||
sum += simd_shuffle_xor(sum, 1);
|
||||
|
||||
if (args.logit_softcap != 0.0f) {
|
||||
sum = args.logit_softcap * tanh(sum);
|
||||
}
|
||||
if (maskh && (args.ncols == 1 || ic0 + j < args.nTokensQ)) {
|
||||
sum += float(maskh[(long)j * args.ne31 + i_KQ]);
|
||||
}
|
||||
if (!in_range) sum = -FLT_MAX/2.0f;
|
||||
|
||||
KQ_max_new[j] = max(KQ_max_new[j], sum + 0.6931f);
|
||||
|
||||
if (tid_kq == (uint)i_KQ_0) {
|
||||
KQ_tg[j * nthreads + tid] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
KQ_max_new[j] = simd_max(KQ_max_new[j]);
|
||||
|
||||
const float KQ_max_scale = exp(KQ_max[j] - KQ_max_new[j]);
|
||||
KQ_max[j] = KQ_max_new[j];
|
||||
|
||||
const float kq_val = KQ_tg[j * nthreads + tid];
|
||||
const float kq_exp = exp(kq_val - KQ_max[j]);
|
||||
KQ_sum[j] = KQ_sum[j] * KQ_max_scale + kq_exp;
|
||||
KQ_tg[j * nthreads + tid] = kq_exp;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
VKQ[j][i].x *= KQ_max_scale;
|
||||
VKQ[j][i].y *= KQ_max_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// D=256: 4 passes cover V[0..63], V[64..127], V[128..191], V[192..255].
|
||||
for (int k0 = 0; k0 < 32; k0 += V_cols_per_iter) {
|
||||
const int k = (int)sgitg * 32 + k0 + (int)tiisg / nthreads_V;
|
||||
const int cell_rel = k_VKQ_0 + k;
|
||||
const bool in_range_v = (cell_rel < args.nCells);
|
||||
|
||||
float KQ_k[2];
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
KQ_k[j] = KQ_tg[j * nthreads + k];
|
||||
}
|
||||
|
||||
device const uint8_t * v_row = in_range_v
|
||||
? V_p + (long)cell_rel * args.nKVHeads * args.v_packedBytes
|
||||
: nullptr;
|
||||
const float v_rms = in_range_v ? V_sc[cell_rel * args.nKVHeads] : 0.0f;
|
||||
|
||||
const int v_tid = (int)tiisg % nthreads_V;
|
||||
for (int pass = 0; pass < 4; pass++) {
|
||||
float v_dec[8];
|
||||
const int start_elem = pass * 64 + v_tid * 8;
|
||||
if (v_row && start_elem < D) {
|
||||
tq_decode_8_shfl(v_row, v_cb_lane, v_rms, start_elem, args.v_bits, v_dec);
|
||||
} else {
|
||||
tq_decode_8_shfl(K_p, v_cb_lane, 0.0f, 0, args.v_bits, v_dec);
|
||||
}
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int vkq_idx = pass * 4 + i / 2;
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (i % 2 == 0) VKQ[j][vkq_idx].x += v_dec[i] * KQ_k[j];
|
||||
else VKQ[j][vkq_idx].y += v_dec[i] * KQ_k[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end KV loop
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (sgitg == 0) {
|
||||
KQ_max_tg[j][tiisg] = -FLT_MAX/2.0f;
|
||||
KQ_sum_tg[j][tiisg] = 0.0f;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (tiisg == 0) {
|
||||
KQ_max_tg[j][sgitg] = KQ_max[j];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int j = 0; j < args.ncols; j++) {
|
||||
if (args.ncols > 1 && ic0 + j >= args.nTokensQ) break;
|
||||
|
||||
float kqmax_new = KQ_max_tg[j][tiisg];
|
||||
kqmax_new = simd_max(kqmax_new);
|
||||
const float kqmax_scale = exp(KQ_max[j] - kqmax_new);
|
||||
KQ_max[j] = kqmax_new;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
VKQ[j][i].x *= kqmax_scale;
|
||||
VKQ[j][i].y *= kqmax_scale;
|
||||
}
|
||||
|
||||
const int v_tid = (int)tiisg % nthreads_V;
|
||||
threadgroup float2 * VKQ_tg = (threadgroup float2 *)KQ_tg
|
||||
+ (long)sgitg * (V_cols_per_iter * D/2)
|
||||
+ (long)((int)tiisg / nthreads_V) * (D/2);
|
||||
VKQ_tg[v_tid * 4 + 0] = VKQ[j][0];
|
||||
VKQ_tg[v_tid * 4 + 1] = VKQ[j][1];
|
||||
VKQ_tg[v_tid * 4 + 2] = VKQ[j][2];
|
||||
VKQ_tg[v_tid * 4 + 3] = VKQ[j][3];
|
||||
VKQ_tg[32 + v_tid * 4 + 0] = VKQ[j][4];
|
||||
VKQ_tg[32 + v_tid * 4 + 1] = VKQ[j][5];
|
||||
VKQ_tg[32 + v_tid * 4 + 2] = VKQ[j][6];
|
||||
VKQ_tg[32 + v_tid * 4 + 3] = VKQ[j][7];
|
||||
VKQ_tg[64 + v_tid * 4 + 0] = VKQ[j][8];
|
||||
VKQ_tg[64 + v_tid * 4 + 1] = VKQ[j][9];
|
||||
VKQ_tg[64 + v_tid * 4 + 2] = VKQ[j][10];
|
||||
VKQ_tg[64 + v_tid * 4 + 3] = VKQ[j][11];
|
||||
VKQ_tg[96 + v_tid * 4 + 0] = VKQ[j][12];
|
||||
VKQ_tg[96 + v_tid * 4 + 1] = VKQ[j][13];
|
||||
VKQ_tg[96 + v_tid * 4 + 2] = VKQ[j][14];
|
||||
VKQ_tg[96 + v_tid * 4 + 3] = VKQ[j][15];
|
||||
|
||||
KQ_sum[j] *= kqmax_scale;
|
||||
KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
if (tiisg == 0) {
|
||||
KQ_sum_tg[j][sgitg] = KQ_sum[j];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
KQ_sum[j] = KQ_sum_tg[j][tiisg];
|
||||
KQ_sum[j] = simd_sum(KQ_sum[j]);
|
||||
|
||||
const long out_idx = ((long)sequence * args.nTokensQ + ic0 + j) * args.nHeadsQ + head;
|
||||
for (int out_offset = 0; out_offset < D; out_offset += nthreads) {
|
||||
const int out_elem = out_offset + tid;
|
||||
float dst_val = 0.0f;
|
||||
for (int w = 0; w < nwarps; w++) {
|
||||
for (int v = 0; v < V_cols_per_iter; v++) {
|
||||
dst_val += ((threadgroup float *)KQ_tg)[w * V_cols_per_iter * D + v * D + out_elem];
|
||||
}
|
||||
}
|
||||
dst_val /= KQ_sum[j];
|
||||
dst[out_idx * D + out_elem] = dst_val;
|
||||
}
|
||||
|
||||
if (j < args.ncols - 1) threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -342,7 +342,16 @@ func (m *ggmlTQCompressedK) DequantK(ctx ml.Context, layer int, encodeResult ml.
|
|||
// inline-decode path is slower than DequantKV + stock FA on all measured
|
||||
// hardware — DequantKV is always preferred when available.
|
||||
func (m *ggmlTQCompressedK) fusedKernelSupports() bool {
|
||||
if m.headDim != 128 {
|
||||
// D=128 on all backends; D=256 only on Metal (kernel_tq_fattn_vec_*{,_d256}).
|
||||
// CUDA still has only the D=128 kernel, so gemma3 (D=256) stays off the
|
||||
// fused path on CUDA.
|
||||
switch m.headDim {
|
||||
case 128:
|
||||
case 256:
|
||||
if !m.preferFusedAttention {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
return false
|
||||
}
|
||||
if m.bits != 2 && m.bits != 3 {
|
||||
|
|
|
|||
Loading…
Reference in a new issue