ml/backend/ggml: optimize the Metal TurboQuant dequant kernel

Two independent improvements to kernel_tq_dequant.

Dispatch: the kernel uses only [[thread_index_in_simdgroup]] (tiisg,
0..31 per SIMDgroup), has no sgitg stride, no threadgroup barriers, and
no atomics. It was nonetheless dispatched with 128-thread threadgroups
(four SIMDgroups x 32), so all four SIMDgroups ran the outer loop
identically and wrote the same f16 bytes four times. Drop non-outlier
dispatches to 32-thread threadgroups. The outlier kernel still
dispatches at 128 - it uses s_mask atomics and a popcount reduction
that legitimately need the full threadgroup.

Inner loop: replace the 1-element-per-iteration scalar path with a
4-elements-per-iteration vectorised path that issues a single half4
store per iteration. For bits=2 the 4 elements fit in one byte; for
bits=3 they fit in a 16-bit window (shift0 in {0,4}). The per-cell
scale is pre-multiplied into the codebook lane at kernel entry so the
decode path drops one fmul per element. The scalar fallback is
preserved for head dims that aren't a multiple of 128.

Decode throughput on llama3.2:3b tq3 at 32k context improves ~9% on
Apple Silicon (~42 -> ~46 tok/s); the K-only DequantK path used by
tq2k/tq3k benefits from the same kernel.
This commit is contained in:
Michael Verrilli 2026-04-21 13:57:43 +00:00
parent fda849a774
commit c56d85ee5e
No known key found for this signature in database
GPG key ID: E4F2103B6C63B961
4 changed files with 214 additions and 10 deletions

View file

@ -0,0 +1,135 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Verrilli <msv@pobox.com>
Date: Tue, 21 Apr 2026 21:02:22 +0000
Subject: [PATCH] ml/backend/ggml: optimize the Metal TurboQuant dequant kernel
Two independent improvements to kernel_tq_dequant.
Dispatch: the kernel uses only [[thread_index_in_simdgroup]] (tiisg,
0..31 per SIMDgroup), has no sgitg stride, no threadgroup barriers, and
no atomics. It was nonetheless dispatched with 128-thread threadgroups
(four SIMDgroups x 32), so all four SIMDgroups ran the outer loop
identically and wrote the same f16 bytes four times. Drop non-outlier
dispatches to 32-thread threadgroups. The outlier kernel still
dispatches at 128 - it uses s_mask atomics and a popcount reduction
that legitimately need the full threadgroup.
Inner loop: replace the 1-element-per-iteration scalar path with a
4-elements-per-iteration vectorised path that issues a single half4
store per iteration. For bits=2 the 4 elements fit in one byte; for
bits=3 they fit in a 16-bit window (shift0 in {0,4}). The per-cell
scale is pre-multiplied into the codebook lane at kernel entry so the
decode path drops one fmul per element. The scalar fallback is
preserved for head dims that aren't a multiple of 128.
Decode throughput on llama3.2:3b tq3 at 32k context improves ~9% on
Apple Silicon (~42 -> ~46 tok/s); the K-only DequantK path used by
tq2k/tq3k benefits from the same kernel.
---
ggml/src/ggml-metal/ggml-metal-ops.cpp | 15 ++++++++---
ggml/src/ggml-metal/ggml-metal.metal | 37 +++++++++++++++++++++++---
2 files changed, 45 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index b5ab1c14e..ea3580b0b 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -4197,7 +4197,12 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
- const int block_size = std::min(128, headDim);
+ // Outlier kernel uses a 128-thread TG: threadgroup barriers + atomics on
+ // s_mask require all threads. Non-outlier kernel uses a single simdgroup
+ // (32 threads): it only reads tiisg and has no barriers, so a larger TG
+ // just replicates work across idle simdgroups.
+ const int outlier_block_size = 128;
+ const int nonoutlier_block_size = 32;
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
const int regular_count = headDim - outlierCount;
@@ -4234,7 +4239,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
- ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
+ ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, outlier_block_size, 1, 1);
return 1;
}
@@ -4259,7 +4264,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); // codebook
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
- ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
+ ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, nonoutlier_block_size, 1, 1);
return 1;
}
@@ -4283,7 +4288,9 @@ int ggml_metal_op_tq_dequant_kv(ggml_metal_op_t ctx, int idx) {
const int k_codebook_len = (int)op->src[2]->ne[0];
const int v_codebook_len = (int)op->src[5]->ne[0];
- const int block_size = std::min(128, headDim);
+ // kernel_tq_dequant is single-simdgroup (uses only tiisg, no barriers,
+ // no atomics) — 32-thread TGs eliminate 4× redundant work vs 128-thread.
+ const int block_size = 32;
const size_t plane_size = (size_t)headDim * numKVHeads * nCells;
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 2718e8bb1..b61f755aa 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -10246,8 +10246,40 @@ kernel void kernel_tq_dequant(
const int cb_mask = (1 << args.bits) - 1;
// Load one codebook entry per lane; period ≤ 8 divides 32, so simd_shuffle is exact.
- const float cb_lane = codebook[tiisg & cb_mask];
+ // Pre-multiply by the per-cell scale so the decode path drops 1 fmul per element.
+ const float scaled_cb_lane = codebook[tiisg & cb_mask] * scale;
+
+ // Fast path: when headDim is a multiple of 128 (=32 lanes × 4 elements per thread),
+ // each thread decodes 4 consecutive D-positions per iter and writes a single half4.
+ // bits=2: 4 elems = 8 bits, always byte-aligned (shift0=0 since elem_base mod 4 == 0).
+ // bits=3: 4 elems = 12 bits, shift0 ∈ {0,4}, always fits in a 16-bit window.
+ // A 16-bit window (2 packed bytes) suffices for both. The scalar fallback
+ // covers non-multiple-of-128 head dims.
+ if ((args.headDim & 127) == 0) {
+ const int iters = args.headDim >> 7;
+ for (int iter = 0; iter < iters; iter++) {
+ const int elem_base = iter * 128 + (int)tiisg * 4;
+ const int bit_offset = elem_base * args.bits;
+ const int byte_base = bit_offset >> 3;
+ const int shift0 = bit_offset & 7;
+
+ uint w = (uint)cell_packed[byte_base];
+ if (args.bits == 3) {
+ w |= ((uint)cell_packed[byte_base + 1] << 8);
+ }
+
+ half4 v4;
+ v4[0] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> shift0 ) & cb_mask)));
+ v4[1] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + args.bits)) & cb_mask)));
+ v4[2] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 2 * args.bits)) & cb_mask)));
+ v4[3] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 3 * args.bits)) & cb_mask)));
+
+ *((device half4 *)(cell_out + elem_base)) = v4;
+ }
+ return;
+ }
+ // Scalar fallback for head dims that aren't a multiple of 128.
for (uint elem = tiisg; elem < (uint)args.headDim; elem += 32) {
const int bit_offset = (int)elem * args.bits;
const int byte_idx = bit_offset >> 3;
@@ -10256,8 +10288,7 @@ kernel void kernel_tq_dequant(
if (shift + args.bits > 8) {
idx |= ((int)(cell_packed[byte_idx + 1] << (8 - shift))) & cb_mask;
}
- const float val = simd_shuffle(cb_lane, (ushort)idx) * scale;
- cell_out[elem] = half(val);
+ cell_out[elem] = half(simd_shuffle(scaled_cb_lane, (ushort)idx));
}
}

View file

@ -13146,8 +13146,40 @@ kernel void kernel_tq_dequant(
const int cb_mask = (1 << args.bits) - 1;
// Load one codebook entry per lane; period ≤ 8 divides 32, so simd_shuffle is exact.
const float cb_lane = codebook[tiisg & cb_mask];
// Pre-multiply by the per-cell scale so the decode path drops 1 fmul per element.
const float scaled_cb_lane = codebook[tiisg & cb_mask] * scale;
// Fast path: when headDim is a multiple of 128 (=32 lanes × 4 elements per thread),
// each thread decodes 4 consecutive D-positions per iter and writes a single half4.
// bits=2: 4 elems = 8 bits, always byte-aligned (shift0=0 since elem_base mod 4 == 0).
// bits=3: 4 elems = 12 bits, shift0 ∈ {0,4}, always fits in a 16-bit window.
// A 16-bit window (2 packed bytes) suffices for both. The scalar fallback
// covers non-multiple-of-128 head dims.
if ((args.headDim & 127) == 0) {
const int iters = args.headDim >> 7;
for (int iter = 0; iter < iters; iter++) {
const int elem_base = iter * 128 + (int)tiisg * 4;
const int bit_offset = elem_base * args.bits;
const int byte_base = bit_offset >> 3;
const int shift0 = bit_offset & 7;
uint w = (uint)cell_packed[byte_base];
if (args.bits == 3) {
w |= ((uint)cell_packed[byte_base + 1] << 8);
}
half4 v4;
v4[0] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> shift0 ) & cb_mask)));
v4[1] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + args.bits)) & cb_mask)));
v4[2] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 2 * args.bits)) & cb_mask)));
v4[3] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 3 * args.bits)) & cb_mask)));
*((device half4 *)(cell_out + elem_base)) = v4;
}
return;
}
// Scalar fallback for head dims that aren't a multiple of 128.
for (uint elem = tiisg; elem < (uint)args.headDim; elem += 32) {
const int bit_offset = (int)elem * args.bits;
const int byte_idx = bit_offset >> 3;
@ -13156,8 +13188,7 @@ kernel void kernel_tq_dequant(
if (shift + args.bits > 8) {
idx |= ((int)(cell_packed[byte_idx + 1] << (8 - shift))) & cb_mask;
}
const float val = simd_shuffle(cb_lane, (ushort)idx) * scale;
cell_out[elem] = half(val);
cell_out[elem] = half(simd_shuffle(scaled_cb_lane, (ushort)idx));
}
}

View file

@ -4197,7 +4197,12 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
const int block_size = std::min(128, headDim);
// Outlier kernel uses a 128-thread TG: threadgroup barriers + atomics on
// s_mask require all threads. Non-outlier kernel uses a single simdgroup
// (32 threads): it only reads tiisg and has no barriers, so a larger TG
// just replicates work across idle simdgroups.
const int outlier_block_size = 128;
const int nonoutlier_block_size = 32;
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
const int regular_count = headDim - outlierCount;
@ -4234,7 +4239,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, outlier_block_size, 1, 1);
return 1;
}
@ -4259,7 +4264,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); // codebook
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, nonoutlier_block_size, 1, 1);
return 1;
}
@ -4283,7 +4288,9 @@ int ggml_metal_op_tq_dequant_kv(ggml_metal_op_t ctx, int idx) {
const int k_codebook_len = (int)op->src[2]->ne[0];
const int v_codebook_len = (int)op->src[5]->ne[0];
const int block_size = std::min(128, headDim);
// kernel_tq_dequant is single-simdgroup (uses only tiisg, no barriers,
// no atomics) — 32-thread TGs eliminate 4× redundant work vs 128-thread.
const int block_size = 32;
const size_t plane_size = (size_t)headDim * numKVHeads * nCells;
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);

View file

@ -10246,8 +10246,40 @@ kernel void kernel_tq_dequant(
const int cb_mask = (1 << args.bits) - 1;
// Load one codebook entry per lane; period ≤ 8 divides 32, so simd_shuffle is exact.
const float cb_lane = codebook[tiisg & cb_mask];
// Pre-multiply by the per-cell scale so the decode path drops 1 fmul per element.
const float scaled_cb_lane = codebook[tiisg & cb_mask] * scale;
// Fast path: when headDim is a multiple of 128 (=32 lanes × 4 elements per thread),
// each thread decodes 4 consecutive D-positions per iter and writes a single half4.
// bits=2: 4 elems = 8 bits, always byte-aligned (shift0=0 since elem_base mod 4 == 0).
// bits=3: 4 elems = 12 bits, shift0 ∈ {0,4}, always fits in a 16-bit window.
// A 16-bit window (2 packed bytes) suffices for both. The scalar fallback
// covers non-multiple-of-128 head dims.
if ((args.headDim & 127) == 0) {
const int iters = args.headDim >> 7;
for (int iter = 0; iter < iters; iter++) {
const int elem_base = iter * 128 + (int)tiisg * 4;
const int bit_offset = elem_base * args.bits;
const int byte_base = bit_offset >> 3;
const int shift0 = bit_offset & 7;
uint w = (uint)cell_packed[byte_base];
if (args.bits == 3) {
w |= ((uint)cell_packed[byte_base + 1] << 8);
}
half4 v4;
v4[0] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> shift0 ) & cb_mask)));
v4[1] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + args.bits)) & cb_mask)));
v4[2] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 2 * args.bits)) & cb_mask)));
v4[3] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 3 * args.bits)) & cb_mask)));
*((device half4 *)(cell_out + elem_base)) = v4;
}
return;
}
// Scalar fallback for head dims that aren't a multiple of 128.
for (uint elem = tiisg; elem < (uint)args.headDim; elem += 32) {
const int bit_offset = (int)elem * args.bits;
const int byte_idx = bit_offset >> 3;
@ -10256,8 +10288,7 @@ kernel void kernel_tq_dequant(
if (shift + args.bits > 8) {
idx |= ((int)(cell_packed[byte_idx + 1] << (8 - shift))) & cb_mask;
}
const float val = simd_shuffle(cb_lane, (ushort)idx) * scale;
cell_out[elem] = half(val);
cell_out[elem] = half(simd_shuffle(scaled_cb_lane, (ushort)idx));
}
}