mirror of
https://github.com/ollama/ollama
synced 2026-04-23 08:45:14 +00:00
ml/backend/ggml: optimize the Metal TurboQuant dequant kernel
Two independent improvements to kernel_tq_dequant.
Dispatch: the kernel uses only [[thread_index_in_simdgroup]] (tiisg,
0..31 per SIMDgroup), has no sgitg stride, no threadgroup barriers, and
no atomics. It was nonetheless dispatched with 128-thread threadgroups
(four SIMDgroups x 32), so all four SIMDgroups ran the outer loop
identically and wrote the same f16 bytes four times. Drop non-outlier
dispatches to 32-thread threadgroups. The outlier kernel still
dispatches at 128 - it uses s_mask atomics and a popcount reduction
that legitimately need the full threadgroup.
Inner loop: replace the 1-element-per-iteration scalar path with a
4-elements-per-iteration vectorised path that issues a single half4
store per iteration. For bits=2 the 4 elements fit in one byte; for
bits=3 they fit in a 16-bit window (shift0 in {0,4}). The per-cell
scale is pre-multiplied into the codebook lane at kernel entry so the
decode path drops one fmul per element. The scalar fallback is
preserved for head dims that aren't a multiple of 128.
Decode throughput on llama3.2:3b tq3 at 32k context improves ~9% on
Apple Silicon (~42 -> ~46 tok/s); the K-only DequantK path used by
tq2k/tq3k benefits from the same kernel.
This commit is contained in:
parent
fda849a774
commit
c56d85ee5e
|
|
@ -0,0 +1,135 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Verrilli <msv@pobox.com>
|
||||
Date: Tue, 21 Apr 2026 21:02:22 +0000
|
||||
Subject: [PATCH] ml/backend/ggml: optimize the Metal TurboQuant dequant kernel
|
||||
|
||||
Two independent improvements to kernel_tq_dequant.
|
||||
|
||||
Dispatch: the kernel uses only [[thread_index_in_simdgroup]] (tiisg,
|
||||
0..31 per SIMDgroup), has no sgitg stride, no threadgroup barriers, and
|
||||
no atomics. It was nonetheless dispatched with 128-thread threadgroups
|
||||
(four SIMDgroups x 32), so all four SIMDgroups ran the outer loop
|
||||
identically and wrote the same f16 bytes four times. Drop non-outlier
|
||||
dispatches to 32-thread threadgroups. The outlier kernel still
|
||||
dispatches at 128 - it uses s_mask atomics and a popcount reduction
|
||||
that legitimately need the full threadgroup.
|
||||
|
||||
Inner loop: replace the 1-element-per-iteration scalar path with a
|
||||
4-elements-per-iteration vectorised path that issues a single half4
|
||||
store per iteration. For bits=2 the 4 elements fit in one byte; for
|
||||
bits=3 they fit in a 16-bit window (shift0 in {0,4}). The per-cell
|
||||
scale is pre-multiplied into the codebook lane at kernel entry so the
|
||||
decode path drops one fmul per element. The scalar fallback is
|
||||
preserved for head dims that aren't a multiple of 128.
|
||||
|
||||
Decode throughput on llama3.2:3b tq3 at 32k context improves ~9% on
|
||||
Apple Silicon (~42 -> ~46 tok/s); the K-only DequantK path used by
|
||||
tq2k/tq3k benefits from the same kernel.
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-ops.cpp | 15 ++++++++---
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 37 +++++++++++++++++++++++---
|
||||
2 files changed, 45 insertions(+), 7 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
index b5ab1c14e..ea3580b0b 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
|
||||
@@ -4197,7 +4197,12 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
- const int block_size = std::min(128, headDim);
|
||||
+ // Outlier kernel uses a 128-thread TG: threadgroup barriers + atomics on
|
||||
+ // s_mask require all threads. Non-outlier kernel uses a single simdgroup
|
||||
+ // (32 threads): it only reads tiisg and has no barriers, so a larger TG
|
||||
+ // just replicates work across idle simdgroups.
|
||||
+ const int outlier_block_size = 128;
|
||||
+ const int nonoutlier_block_size = 32;
|
||||
|
||||
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
|
||||
const int regular_count = headDim - outlierCount;
|
||||
@@ -4234,7 +4239,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
- ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
|
||||
+ ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, outlier_block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -4259,7 +4264,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); // codebook
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
||||
|
||||
- ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
|
||||
+ ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, nonoutlier_block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -4283,7 +4288,9 @@ int ggml_metal_op_tq_dequant_kv(ggml_metal_op_t ctx, int idx) {
|
||||
const int k_codebook_len = (int)op->src[2]->ne[0];
|
||||
const int v_codebook_len = (int)op->src[5]->ne[0];
|
||||
|
||||
- const int block_size = std::min(128, headDim);
|
||||
+ // kernel_tq_dequant is single-simdgroup (uses only tiisg, no barriers,
|
||||
+ // no atomics) — 32-thread TGs eliminate 4× redundant work vs 128-thread.
|
||||
+ const int block_size = 32;
|
||||
const size_t plane_size = (size_t)headDim * numKVHeads * nCells;
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index 2718e8bb1..b61f755aa 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -10246,8 +10246,40 @@ kernel void kernel_tq_dequant(
|
||||
|
||||
const int cb_mask = (1 << args.bits) - 1;
|
||||
// Load one codebook entry per lane; period ≤ 8 divides 32, so simd_shuffle is exact.
|
||||
- const float cb_lane = codebook[tiisg & cb_mask];
|
||||
+ // Pre-multiply by the per-cell scale so the decode path drops 1 fmul per element.
|
||||
+ const float scaled_cb_lane = codebook[tiisg & cb_mask] * scale;
|
||||
+
|
||||
+ // Fast path: when headDim is a multiple of 128 (=32 lanes × 4 elements per thread),
|
||||
+ // each thread decodes 4 consecutive D-positions per iter and writes a single half4.
|
||||
+ // bits=2: 4 elems = 8 bits, always byte-aligned (shift0=0 since elem_base mod 4 == 0).
|
||||
+ // bits=3: 4 elems = 12 bits, shift0 ∈ {0,4}, always fits in a 16-bit window.
|
||||
+ // A 16-bit window (2 packed bytes) suffices for both. The scalar fallback
|
||||
+ // covers non-multiple-of-128 head dims.
|
||||
+ if ((args.headDim & 127) == 0) {
|
||||
+ const int iters = args.headDim >> 7;
|
||||
+ for (int iter = 0; iter < iters; iter++) {
|
||||
+ const int elem_base = iter * 128 + (int)tiisg * 4;
|
||||
+ const int bit_offset = elem_base * args.bits;
|
||||
+ const int byte_base = bit_offset >> 3;
|
||||
+ const int shift0 = bit_offset & 7;
|
||||
+
|
||||
+ uint w = (uint)cell_packed[byte_base];
|
||||
+ if (args.bits == 3) {
|
||||
+ w |= ((uint)cell_packed[byte_base + 1] << 8);
|
||||
+ }
|
||||
+
|
||||
+ half4 v4;
|
||||
+ v4[0] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> shift0 ) & cb_mask)));
|
||||
+ v4[1] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + args.bits)) & cb_mask)));
|
||||
+ v4[2] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 2 * args.bits)) & cb_mask)));
|
||||
+ v4[3] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 3 * args.bits)) & cb_mask)));
|
||||
+
|
||||
+ *((device half4 *)(cell_out + elem_base)) = v4;
|
||||
+ }
|
||||
+ return;
|
||||
+ }
|
||||
|
||||
+ // Scalar fallback for head dims that aren't a multiple of 128.
|
||||
for (uint elem = tiisg; elem < (uint)args.headDim; elem += 32) {
|
||||
const int bit_offset = (int)elem * args.bits;
|
||||
const int byte_idx = bit_offset >> 3;
|
||||
@@ -10256,8 +10288,7 @@ kernel void kernel_tq_dequant(
|
||||
if (shift + args.bits > 8) {
|
||||
idx |= ((int)(cell_packed[byte_idx + 1] << (8 - shift))) & cb_mask;
|
||||
}
|
||||
- const float val = simd_shuffle(cb_lane, (ushort)idx) * scale;
|
||||
- cell_out[elem] = half(val);
|
||||
+ cell_out[elem] = half(simd_shuffle(scaled_cb_lane, (ushort)idx));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -13146,8 +13146,40 @@ kernel void kernel_tq_dequant(
|
|||
|
||||
const int cb_mask = (1 << args.bits) - 1;
|
||||
// Load one codebook entry per lane; period ≤ 8 divides 32, so simd_shuffle is exact.
|
||||
const float cb_lane = codebook[tiisg & cb_mask];
|
||||
// Pre-multiply by the per-cell scale so the decode path drops 1 fmul per element.
|
||||
const float scaled_cb_lane = codebook[tiisg & cb_mask] * scale;
|
||||
|
||||
// Fast path: when headDim is a multiple of 128 (=32 lanes × 4 elements per thread),
|
||||
// each thread decodes 4 consecutive D-positions per iter and writes a single half4.
|
||||
// bits=2: 4 elems = 8 bits, always byte-aligned (shift0=0 since elem_base mod 4 == 0).
|
||||
// bits=3: 4 elems = 12 bits, shift0 ∈ {0,4}, always fits in a 16-bit window.
|
||||
// A 16-bit window (2 packed bytes) suffices for both. The scalar fallback
|
||||
// covers non-multiple-of-128 head dims.
|
||||
if ((args.headDim & 127) == 0) {
|
||||
const int iters = args.headDim >> 7;
|
||||
for (int iter = 0; iter < iters; iter++) {
|
||||
const int elem_base = iter * 128 + (int)tiisg * 4;
|
||||
const int bit_offset = elem_base * args.bits;
|
||||
const int byte_base = bit_offset >> 3;
|
||||
const int shift0 = bit_offset & 7;
|
||||
|
||||
uint w = (uint)cell_packed[byte_base];
|
||||
if (args.bits == 3) {
|
||||
w |= ((uint)cell_packed[byte_base + 1] << 8);
|
||||
}
|
||||
|
||||
half4 v4;
|
||||
v4[0] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> shift0 ) & cb_mask)));
|
||||
v4[1] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + args.bits)) & cb_mask)));
|
||||
v4[2] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 2 * args.bits)) & cb_mask)));
|
||||
v4[3] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 3 * args.bits)) & cb_mask)));
|
||||
|
||||
*((device half4 *)(cell_out + elem_base)) = v4;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Scalar fallback for head dims that aren't a multiple of 128.
|
||||
for (uint elem = tiisg; elem < (uint)args.headDim; elem += 32) {
|
||||
const int bit_offset = (int)elem * args.bits;
|
||||
const int byte_idx = bit_offset >> 3;
|
||||
|
|
@ -13156,8 +13188,7 @@ kernel void kernel_tq_dequant(
|
|||
if (shift + args.bits > 8) {
|
||||
idx |= ((int)(cell_packed[byte_idx + 1] << (8 - shift))) & cb_mask;
|
||||
}
|
||||
const float val = simd_shuffle(cb_lane, (ushort)idx) * scale;
|
||||
cell_out[elem] = half(val);
|
||||
cell_out[elem] = half(simd_shuffle(scaled_cb_lane, (ushort)idx));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4197,7 +4197,12 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
const int block_size = std::min(128, headDim);
|
||||
// Outlier kernel uses a 128-thread TG: threadgroup barriers + atomics on
|
||||
// s_mask require all threads. Non-outlier kernel uses a single simdgroup
|
||||
// (32 threads): it only reads tiisg and has no barriers, so a larger TG
|
||||
// just replicates work across idle simdgroups.
|
||||
const int outlier_block_size = 128;
|
||||
const int nonoutlier_block_size = 32;
|
||||
|
||||
if (outlierCount > 0 && outlierBits > 0 && outlierCount < headDim) {
|
||||
const int regular_count = headDim - outlierCount;
|
||||
|
|
@ -4234,7 +4239,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, outlier_block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
@ -4259,7 +4264,7 @@ int ggml_metal_op_tq_dequant(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); // codebook
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, block_size, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nCells, numKVHeads, 1, nonoutlier_block_size, 1, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
@ -4283,7 +4288,9 @@ int ggml_metal_op_tq_dequant_kv(ggml_metal_op_t ctx, int idx) {
|
|||
const int k_codebook_len = (int)op->src[2]->ne[0];
|
||||
const int v_codebook_len = (int)op->src[5]->ne[0];
|
||||
|
||||
const int block_size = std::min(128, headDim);
|
||||
// kernel_tq_dequant is single-simdgroup (uses only tiisg, no barriers,
|
||||
// no atomics) — 32-thread TGs eliminate 4× redundant work vs 128-thread.
|
||||
const int block_size = 32;
|
||||
const size_t plane_size = (size_t)headDim * numKVHeads * nCells;
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
|
|
|||
|
|
@ -10246,8 +10246,40 @@ kernel void kernel_tq_dequant(
|
|||
|
||||
const int cb_mask = (1 << args.bits) - 1;
|
||||
// Load one codebook entry per lane; period ≤ 8 divides 32, so simd_shuffle is exact.
|
||||
const float cb_lane = codebook[tiisg & cb_mask];
|
||||
// Pre-multiply by the per-cell scale so the decode path drops 1 fmul per element.
|
||||
const float scaled_cb_lane = codebook[tiisg & cb_mask] * scale;
|
||||
|
||||
// Fast path: when headDim is a multiple of 128 (=32 lanes × 4 elements per thread),
|
||||
// each thread decodes 4 consecutive D-positions per iter and writes a single half4.
|
||||
// bits=2: 4 elems = 8 bits, always byte-aligned (shift0=0 since elem_base mod 4 == 0).
|
||||
// bits=3: 4 elems = 12 bits, shift0 ∈ {0,4}, always fits in a 16-bit window.
|
||||
// A 16-bit window (2 packed bytes) suffices for both. The scalar fallback
|
||||
// covers non-multiple-of-128 head dims.
|
||||
if ((args.headDim & 127) == 0) {
|
||||
const int iters = args.headDim >> 7;
|
||||
for (int iter = 0; iter < iters; iter++) {
|
||||
const int elem_base = iter * 128 + (int)tiisg * 4;
|
||||
const int bit_offset = elem_base * args.bits;
|
||||
const int byte_base = bit_offset >> 3;
|
||||
const int shift0 = bit_offset & 7;
|
||||
|
||||
uint w = (uint)cell_packed[byte_base];
|
||||
if (args.bits == 3) {
|
||||
w |= ((uint)cell_packed[byte_base + 1] << 8);
|
||||
}
|
||||
|
||||
half4 v4;
|
||||
v4[0] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> shift0 ) & cb_mask)));
|
||||
v4[1] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + args.bits)) & cb_mask)));
|
||||
v4[2] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 2 * args.bits)) & cb_mask)));
|
||||
v4[3] = half(simd_shuffle(scaled_cb_lane, (ushort)((w >> (shift0 + 3 * args.bits)) & cb_mask)));
|
||||
|
||||
*((device half4 *)(cell_out + elem_base)) = v4;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Scalar fallback for head dims that aren't a multiple of 128.
|
||||
for (uint elem = tiisg; elem < (uint)args.headDim; elem += 32) {
|
||||
const int bit_offset = (int)elem * args.bits;
|
||||
const int byte_idx = bit_offset >> 3;
|
||||
|
|
@ -10256,8 +10288,7 @@ kernel void kernel_tq_dequant(
|
|||
if (shift + args.bits > 8) {
|
||||
idx |= ((int)(cell_packed[byte_idx + 1] << (8 - shift))) & cb_mask;
|
||||
}
|
||||
const float val = simd_shuffle(cb_lane, (ushort)idx) * scale;
|
||||
cell_out[elem] = half(val);
|
||||
cell_out[elem] = half(simd_shuffle(scaled_cb_lane, (ushort)idx));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue