mirror of
https://github.com/ollama/ollama
synced 2026-04-23 08:45:14 +00:00
gemma4: add Gemma 4 GGML model support
Add full Gemma 4 model family support (E2B, E4B, 26B MoE, 31B Dense) for the GGML backend including text, vision, converter, parser, and renderer. Text model features: - Sliding window + full attention with per-layer patterns - KV sharing across layers with donor map - Per-layer embeddings (PLE) with learned projections - MoE routing with RMSNorm + learned scale - Proportional RoPE with freq_factors for global attention - Final logit softcapping Vision model features: - SigLIP vision encoder with 2D RoPE - ClippableLinear with input/output clamping via packed v.clamp_data - Adaptive average pooling with nMerge kernel - Multi-modal projection with unweighted RMSNorm Converter: - Safetensors to GGUF with vision tensor renaming - Fused MoE gate_up_proj splitting - Vision patch embedding reshape (HF to Conv2D layout) - Packed clamp data tensor for ClippableLinear bounds - Proportional RoPE freq_factors generation Also includes: - BackendGet() on ml.Tensor for reading weight tensor data - Q6_K CUDA get_rows kernel support - MoE-aware ffn_down quantization layer counting - Gemma4 parser with tool calling and thinking support - Gemma4 renderer with structured tool format - Architecture-based auto-detection of renderer/parser/stop tokens - Integration test gemma4 model list additions
This commit is contained in:
parent
f6b69f3f28
commit
ea3c6a3cbe
|
|
@ -290,6 +290,8 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
|
|||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||
case "Gemma3nForConditionalGeneration":
|
||||
conv = &gemma3nModel{}
|
||||
case "Gemma4ForCausalLM", "Gemma4ForConditionalGeneration":
|
||||
conv = &gemma4Model{Architecture: p.Architectures[0]}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
|
|
|
|||
514
convert/convert_gemma4.go
Normal file
514
convert/convert_gemma4.go
Normal file
|
|
@ -0,0 +1,514 @@
|
|||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
)
|
||||
|
||||
type gemma4Model struct {
|
||||
gemmaModel
|
||||
Architecture string
|
||||
TextModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
GlobalHeadDim uint32 `json:"global_head_dim"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
SlidingWindowPattern *int32 `json:"_sliding_window_pattern"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
FinalLogitSoftcapping float32 `json:"final_logit_softcapping"`
|
||||
EnableMoeBlock bool `json:"enable_moe_block"`
|
||||
NumExperts *uint32 `json:"num_experts"`
|
||||
TopKExperts *uint32 `json:"top_k_experts"`
|
||||
ExpertIntermediateSize *uint32 `json:"expert_intermediate_size"`
|
||||
HiddenSizePerLayerInput *uint32 `json:"hidden_size_per_layer_input"`
|
||||
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
|
||||
AttentionKEqV bool `json:"attention_k_eq_v"`
|
||||
NumGlobalKeyValueHeads *uint32 `json:"num_global_key_value_heads"`
|
||||
QueryPreAttnScalar *uint32 `json:"query_pre_attn_scalar"`
|
||||
UseDoubleWideMLP bool `json:"use_double_wide_mlp"`
|
||||
RopeParameters map[string]*struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
PartialRotaryFactor *float32 `json:"partial_rotary_factor"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
|
||||
VisionModel struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
PoolingKernelSize uint32 `json:"pooling_kernel_size"`
|
||||
LayerNormEps float32 `json:"layer_norm_eps"`
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (p *gemma4Model) KV(t *Tokenizer) KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma4"
|
||||
kv["tokenizer.ggml.model"] = "llama"
|
||||
kv["tokenizer.ggml.pre"] = "gemma4"
|
||||
|
||||
tc := p.TextModel
|
||||
|
||||
kv["gemma4.block_count"] = tc.NumHiddenLayers
|
||||
kv["gemma4.embedding_length"] = tc.HiddenSize
|
||||
|
||||
// Per-layer FFN width: when use_double_wide_mlp is set, KV-shared layers get 2x FFN width.
|
||||
if tc.UseDoubleWideMLP && tc.NumKVSharedLayers > 0 {
|
||||
firstShared := int(tc.NumHiddenLayers) - int(tc.NumKVSharedLayers)
|
||||
ffnWidths := make([]int32, tc.NumHiddenLayers)
|
||||
for i := range ffnWidths {
|
||||
if i >= firstShared {
|
||||
ffnWidths[i] = int32(tc.IntermediateSize * 2)
|
||||
} else {
|
||||
ffnWidths[i] = int32(tc.IntermediateSize)
|
||||
}
|
||||
}
|
||||
kv["gemma4.feed_forward_length"] = ffnWidths
|
||||
} else {
|
||||
kv["gemma4.feed_forward_length"] = tc.IntermediateSize
|
||||
}
|
||||
kv["gemma4.context_length"] = tc.MaxPositionEmbeddings
|
||||
kv["gemma4.attention.head_count"] = tc.NumAttentionHeads
|
||||
// Per-layer KV head count array: SWA layers use NumKeyValueHeads, global layers use NumGlobalKeyValueHeads
|
||||
if tc.NumGlobalKeyValueHeads != nil && *tc.NumGlobalKeyValueHeads != tc.NumKeyValueHeads && len(tc.LayerTypes) > 0 {
|
||||
kvHeads := make([]int32, len(tc.LayerTypes))
|
||||
for i, lt := range tc.LayerTypes {
|
||||
if lt == "sliding_attention" {
|
||||
kvHeads[i] = int32(tc.NumKeyValueHeads)
|
||||
} else {
|
||||
kvHeads[i] = int32(*tc.NumGlobalKeyValueHeads)
|
||||
}
|
||||
}
|
||||
kv["gemma4.attention.head_count_kv"] = kvHeads
|
||||
} else {
|
||||
kv["gemma4.attention.head_count_kv"] = tc.NumKeyValueHeads
|
||||
}
|
||||
// key_length = global head dim, key_length_swa = local (SWA) head dim
|
||||
kv["gemma4.attention.key_length"] = tc.GlobalHeadDim
|
||||
kv["gemma4.attention.value_length"] = tc.GlobalHeadDim
|
||||
kv["gemma4.attention.key_length_swa"] = tc.HeadDim
|
||||
kv["gemma4.attention.value_length_swa"] = tc.HeadDim
|
||||
kv["gemma4.attention.layer_norm_rms_epsilon"] = tc.RMSNormEps
|
||||
kv["gemma4.attention.sliding_window"] = tc.SlidingWindow
|
||||
|
||||
// Sliding window pattern from layer_types
|
||||
if len(tc.LayerTypes) > 0 {
|
||||
kv["gemma4.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for _, lt := range tc.LayerTypes {
|
||||
if !yield(lt == "sliding_attention") {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
kv["gemma4.attention.shared_kv_layers"] = tc.NumKVSharedLayers
|
||||
|
||||
// RoPE: dimension_count is the full global head dim (freq_factors handle partial rotation)
|
||||
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil {
|
||||
kv["gemma4.rope.freq_base"] = rp.RopeTheta
|
||||
kv["gemma4.rope.dimension_count"] = tc.GlobalHeadDim
|
||||
}
|
||||
if rp, ok := tc.RopeParameters["sliding_attention"]; ok && rp != nil {
|
||||
kv["gemma4.rope.freq_base_swa"] = rp.RopeTheta
|
||||
kv["gemma4.rope.dimension_count_swa"] = tc.HeadDim
|
||||
}
|
||||
|
||||
if tc.FinalLogitSoftcapping > 0 {
|
||||
kv["gemma4.final_logit_softcapping"] = tc.FinalLogitSoftcapping
|
||||
}
|
||||
|
||||
// MoE
|
||||
if tc.EnableMoeBlock && tc.NumExperts != nil {
|
||||
kv["gemma4.expert_count"] = *tc.NumExperts
|
||||
if tc.TopKExperts != nil {
|
||||
kv["gemma4.expert_used_count"] = *tc.TopKExperts
|
||||
}
|
||||
if tc.ExpertIntermediateSize != nil {
|
||||
kv["gemma4.expert_feed_forward_length"] = *tc.ExpertIntermediateSize
|
||||
}
|
||||
}
|
||||
|
||||
// PLE — always emit, even when 0
|
||||
pleSize := uint32(0)
|
||||
if tc.HiddenSizePerLayerInput != nil {
|
||||
pleSize = *tc.HiddenSizePerLayerInput
|
||||
}
|
||||
kv["gemma4.embedding_length_per_layer_input"] = pleSize
|
||||
|
||||
// Vision model KV metadata
|
||||
vc := p.VisionModel
|
||||
if vc.NumHiddenLayers > 0 {
|
||||
kv["gemma4.vision.block_count"] = vc.NumHiddenLayers
|
||||
kv["gemma4.vision.embedding_length"] = vc.HiddenSize
|
||||
kv["gemma4.vision.attention.head_count"] = vc.NumAttentionHeads
|
||||
kv["gemma4.vision.feed_forward_length"] = vc.IntermediateSize
|
||||
kv["gemma4.vision.patch_size"] = vc.PatchSize
|
||||
numCh := vc.NumChannels
|
||||
if numCh == 0 {
|
||||
numCh = 3
|
||||
}
|
||||
kv["gemma4.vision.num_channels"] = numCh
|
||||
nMerge := vc.PoolingKernelSize
|
||||
if nMerge == 0 {
|
||||
nMerge = 3
|
||||
}
|
||||
kv["gemma4.vision.projector.scale_factor"] = nMerge
|
||||
eps := vc.LayerNormEps
|
||||
if eps == 0 {
|
||||
eps = 1e-6
|
||||
}
|
||||
kv["gemma4.vision.attention.layer_norm_epsilon"] = eps
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemma4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
// First pass: collect vision clamp scalar values into a packed tensor.
|
||||
// Layout: per vision layer (0..N-1), 7 linears (q,k,v,out,gate,up,down) × 4 values (inMin,inMax,outMin,outMax).
|
||||
// Then 4 values for the projector (mm.input_projection).
|
||||
clampSuffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
|
||||
clampMap := make(map[string]float32)
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
for _, sfx := range clampSuffixes {
|
||||
if strings.HasSuffix(name, sfx) && (strings.Contains(name, "vision_tower") || strings.Contains(name, "embed_vision")) {
|
||||
var buf bytes.Buffer
|
||||
t.WriteTo(&buf)
|
||||
data := buf.Bytes()
|
||||
if len(data) >= 4 {
|
||||
clampMap[name] = math.Float32frombits(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16 | uint32(data[3])<<24)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var out []*ggml.Tensor
|
||||
for _, t := range ts {
|
||||
name := t.Name()
|
||||
|
||||
// Skip audio tensors (vision is now handled)
|
||||
if strings.Contains(name, "audio_tower") || strings.Contains(name, "embed_audio") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip embedding_post_projection_norm — used as weightless RMS norm in inference
|
||||
if strings.Contains(name, "embedding_post_projection_norm") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip clippable linear clamp scalars — packed into v.clamp_data below
|
||||
if strings.Contains(name, "input_min") || strings.Contains(name, "input_max") ||
|
||||
strings.Contains(name, "output_min") || strings.Contains(name, "output_max") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Vision tensor renaming: match published mmproj GGUF names
|
||||
if strings.HasPrefix(name, "v.blk.") {
|
||||
name = strings.Replace(name, ".attn_norm.", ".ln1.", 1)
|
||||
name = strings.Replace(name, ".ffn_norm.", ".ln2.", 1)
|
||||
name = strings.Replace(name, ".attn_output.", ".attn_out.", 1)
|
||||
name = strings.Replace(name, ".post_attention_norm.", ".attn_post_norm.", 1)
|
||||
name = strings.Replace(name, ".post_ffw_norm.", ".ffn_post_norm.", 1)
|
||||
name = strings.Replace(name, ".layer_output_scale.", ".out_scale.", 1)
|
||||
}
|
||||
|
||||
// Audio tensor post-processing: block-level norm rename and per_dim_scale softplus.
|
||||
if strings.HasPrefix(name, "a.blk.") {
|
||||
// Conformer block final norm: a.blk.N.norm.weight → a.blk.N.layer_pre_norm.weight
|
||||
if dotIdx := strings.Index(name[6:], "."); dotIdx >= 0 {
|
||||
rest := name[6+dotIdx+1:]
|
||||
if strings.HasPrefix(rest, "norm.") {
|
||||
name = name[:6+dotIdx+1] + "layer_pre_norm." + rest[5:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// per_dim_scale / per_dim_k_scale: apply softplus to weight data and add .weight suffix.
|
||||
if strings.HasPrefix(name, "a.blk.") &&
|
||||
(strings.HasSuffix(name, "per_dim_scale") || strings.HasSuffix(name, "per_dim_k_scale")) {
|
||||
name = name + ".weight"
|
||||
t.SetRepacker(softplusRepacker)
|
||||
}
|
||||
|
||||
// Depthwise conv1d: squeeze middle dimension [C, 1, K] → [C, K].
|
||||
if strings.HasPrefix(name, "a.blk.") && strings.Contains(name, "conv_dw") && strings.HasSuffix(name, ".weight") {
|
||||
t.SetRepacker(squeezeMiddleDim)
|
||||
}
|
||||
|
||||
shape := t.Shape()
|
||||
|
||||
// Convert scalar tensors (input_min/max, output_min/max) to 1D
|
||||
if len(shape) == 0 {
|
||||
shape = []uint64{1}
|
||||
}
|
||||
|
||||
// Fused MoE gate_up_proj: split [experts, 2*intermediate, hidden] into separate gate and up.
|
||||
// No transpose needed — the split shape [experts, intermediate, hidden] already matches
|
||||
// the GGUF layout after the framework's dimension reversal (ne[0]=hidden matches input).
|
||||
if strings.Contains(name, "moe.gate_up_proj") && len(shape) == 3 {
|
||||
halfDim := int(shape[1]) / 2
|
||||
newShape := slices.Clone(shape)
|
||||
newShape[1] = newShape[1] / 2
|
||||
for i, ggufName := range []string{"ffn_gate_exps.weight", "ffn_up_exps.weight"} {
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(p.sliceExperts(tensor.S(i*halfDim, (i+1)*halfDim)))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.ReplaceAll(name, "moe.gate_up_proj", ggufName),
|
||||
Kind: tt.Kind(),
|
||||
Shape: slices.Clone(newShape),
|
||||
WriterTo: tt,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// MoE expert weights: no transpose needed. Safetensors stores [experts, out, in]
|
||||
// which the framework reverses to GGUF ne=[in, out, experts], matching ggml_mul_mat_id.
|
||||
// (transposeExperts was incorrectly swapping dims — removed)
|
||||
|
||||
// Vision patch embedding: reshape from [n_embd, ksize_sq_c] to [n_embd, 3, patch_size, patch_size]
|
||||
// Must be stored as F16 (not BF16) because the Conv2D im2col kernel requires F16/F32.
|
||||
var kindOverride *uint32
|
||||
if strings.Contains(name, "v.patch_embd.weight") && len(shape) == 2 {
|
||||
nEmbd := shape[0]
|
||||
patchSize := uint64(p.VisionModel.PatchSize)
|
||||
if patchSize == 0 {
|
||||
patchSize = 16
|
||||
}
|
||||
numCh := uint64(p.VisionModel.NumChannels)
|
||||
if numCh == 0 {
|
||||
numCh = 3
|
||||
}
|
||||
t.SetRepacker(p.reshapePatchEmbed)
|
||||
shape = []uint64{nEmbd, numCh, patchSize, patchSize}
|
||||
f16Kind := uint32(1) // tensorKindFP16
|
||||
kindOverride = &f16Kind
|
||||
}
|
||||
|
||||
// Vision position embedding: keep 3D [2, maxPos, nEmbd] — matching published mmproj format.
|
||||
// The framework reverses shape to GGUF ne=[nEmbd, maxPos, 2]. No data repacking needed.
|
||||
|
||||
kind := t.Kind()
|
||||
if kindOverride != nil {
|
||||
kind = *kindOverride
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: kind,
|
||||
Shape: shape,
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
// Generate a single global rope_freqs.weight for proportional RoPE on global attention layers.
|
||||
// This matches the published GGUF format: one global tensor shared by all layers.
|
||||
// Global layers use partial_rotary_factor (0.25) — only rotate that fraction of dims.
|
||||
// Dimensions beyond the rotated portion get freq_factor=1e30 (effectively no rotation).
|
||||
tc := p.TextModel
|
||||
if tc.GlobalHeadDim > 0 {
|
||||
globalFreqsSize := tc.GlobalHeadDim / 2 // freq_factors are per dimension pair
|
||||
|
||||
// Compute number of rotated pairs for global layers
|
||||
partialRotaryFactor := float32(0.25) // default
|
||||
if rp, ok := tc.RopeParameters["full_attention"]; ok && rp != nil && rp.PartialRotaryFactor != nil {
|
||||
partialRotaryFactor = *rp.PartialRotaryFactor
|
||||
}
|
||||
nRotFull := int(float32(tc.GlobalHeadDim) * partialRotaryFactor / 2)
|
||||
|
||||
freqs := make(ropeFactor, globalFreqsSize)
|
||||
for j := range freqs {
|
||||
if j < nRotFull {
|
||||
freqs[j] = 1.0
|
||||
} else {
|
||||
freqs[j] = 1e30 // effectively disable rotation
|
||||
}
|
||||
}
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: "rope_freqs.weight",
|
||||
Kind: 0, // F32
|
||||
Shape: []uint64{uint64(len(freqs))},
|
||||
WriterTo: freqs,
|
||||
})
|
||||
}
|
||||
|
||||
// Emit packed vision clamp data as a single F32 tensor.
|
||||
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
|
||||
// then 4 floats for the projector. Total = (numLayers*7 + 1) * 4 floats.
|
||||
if len(clampMap) > 0 {
|
||||
numLayers := int(p.VisionModel.NumHiddenLayers)
|
||||
linearNames := []string{"attn_q", "attn_k", "attn_v", "attn_out", "ffn_gate", "ffn_up", "ffn_down"}
|
||||
suffixes := []string{".input_min", ".input_max", ".output_min", ".output_max"}
|
||||
|
||||
totalFloats := (numLayers*len(linearNames) + 1) * 4 // +1 for projector
|
||||
clampData := make([]float32, totalFloats)
|
||||
|
||||
for layer := range numLayers {
|
||||
for li, ln := range linearNames {
|
||||
for si, sfx := range suffixes {
|
||||
sfxMap := map[string]string{"attn_q": "q_proj", "attn_k": "k_proj", "attn_v": "v_proj", "attn_out": "o_proj", "ffn_gate": "gate_proj", "ffn_up": "up_proj", "ffn_down": "down_proj"}
|
||||
for origName, val := range clampMap {
|
||||
if strings.Contains(origName, fmt.Sprintf("layers.%d.", layer)) && strings.HasSuffix(origName, sfx) && strings.Contains(origName, sfxMap[ln]) {
|
||||
idx := (layer*len(linearNames)+li)*4 + si
|
||||
clampData[idx] = val
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Projector clamp values
|
||||
projIdx := numLayers * len(linearNames) * 4
|
||||
for si, sfx := range suffixes {
|
||||
for origName, val := range clampMap {
|
||||
if strings.Contains(origName, "input_projection") && strings.HasSuffix(origName, sfx) {
|
||||
clampData[projIdx+si] = val
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, clampData)
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: "v.clamp_data",
|
||||
Kind: 0, // F32
|
||||
Shape: []uint64{uint64(totalFloats)},
|
||||
WriterTo: &buf,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// reshapePatchEmbed reshapes the vision patch embedding from HF layout [n_embd, ksize*ksize*channels]
|
||||
// to GGUF layout [n_embd, channels, patch_size, patch_size].
|
||||
func (*gemma4Model) reshapePatchEmbed(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
if len(shape) != 2 {
|
||||
return data, nil
|
||||
}
|
||||
nEmbd := int(shape[0])
|
||||
ksqC := int(shape[1])
|
||||
nChannels := 3
|
||||
patchSize := int(math.Sqrt(float64(ksqC / nChannels)))
|
||||
|
||||
// HF layout: [n_embd, patch_size * patch_size * channels] (row-major)
|
||||
// Need: [n_embd, channels, patch_size, patch_size]
|
||||
result := make([]float32, len(data))
|
||||
for e := range nEmbd {
|
||||
for c := range nChannels {
|
||||
for h := range patchSize {
|
||||
for w := range patchSize {
|
||||
srcIdx := e*ksqC + h*patchSize*nChannels + w*nChannels + c
|
||||
dstIdx := e*nChannels*patchSize*patchSize + c*patchSize*patchSize + h*patchSize + w
|
||||
result[dstIdx] = data[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
shape[0] = uint64(nEmbd)
|
||||
shape[1] = uint64(nChannels * patchSize * patchSize)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sliceExperts returns a repacker that slices dim 1 of a 3D expert tensor.
|
||||
// Used for splitting fused gate_up_proj into separate gate and up tensors.
|
||||
func (*gemma4Model) sliceExperts(dim1Slice tensor.Slice) Repacker {
|
||||
return func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i, d := range shape {
|
||||
dims[i] = int(d)
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
t, err := t.Slice(nil, dim1Slice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *gemma4Model) Replacements() []string {
|
||||
return []string{
|
||||
// Vision ClippableLinear wraps nn.Linear — strip .linear. from weight path
|
||||
".linear.weight", ".weight",
|
||||
|
||||
// Vision encoder
|
||||
"model.vision_tower.encoder.layers", "v.blk",
|
||||
"model.vision_tower.patch_embedder.input_proj", "v.patch_embd",
|
||||
"model.vision_tower.patch_embedder.position_embedding_table", "v.position_embd.weight",
|
||||
|
||||
// Multimodal projector
|
||||
"model.embed_vision.embedding_projection", "mm.input_projection",
|
||||
|
||||
// Text model
|
||||
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
|
||||
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
|
||||
// Shared attention replacements (work for both text and vision tensors)
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
|
||||
// Post norms
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm_2", "pre_ffw_norm_2",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"post_feedforward_layernorm_1", "post_ffw_norm_1",
|
||||
"post_feedforward_layernorm_2", "post_ffw_norm_2",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
|
||||
// PLE
|
||||
"per_layer_input_gate", "inp_gate",
|
||||
"per_layer_projection", "proj",
|
||||
"post_per_layer_input_norm", "post_norm",
|
||||
|
||||
// MoE
|
||||
"router.proj", "ffn_gate_inp",
|
||||
"router.scale", "ffn_gate_inp.scale",
|
||||
"moe.gate_proj", "ffn_gate_exps.weight",
|
||||
"moe.up_proj", "ffn_up_exps.weight",
|
||||
"moe.down_proj", "ffn_down_exps.weight",
|
||||
"moe.per_expert_scale", "ffn_down_exps.scale",
|
||||
|
||||
// Layer scalar
|
||||
"layer_scalar", "layer_output_scale.weight",
|
||||
}
|
||||
}
|
||||
|
|
@ -205,8 +205,8 @@ func TestConvertInvalidDatatype(t *testing.T) {
|
|||
generateSafetensorTestData(t, tempDir, td)
|
||||
|
||||
err = ConvertModel(os.DirFS(tempDir), f)
|
||||
if err == nil || err.Error() != "unsupported safetensors model" {
|
||||
t.Errorf("expected error but didn't get one")
|
||||
if err == nil || !strings.Contains(err.Error(), "unknown data type") {
|
||||
t.Errorf("expected 'unknown data type' error but got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ func (t tensorBase) Kind() uint32 {
|
|||
strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.position_embd.weight" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" ||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
|
|
@ -53,9 +52,10 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
|||
|
||||
for _, key := range keys {
|
||||
if value := headers[key]; value.Type != "" {
|
||||
// bitsandbytes quantized models are unsupported
|
||||
// Scalar tensors (e.g. clipped linear min/max) are 0-dim in safetensors.
|
||||
// Promote them to 1-dim so they can be stored in GGUF.
|
||||
if len(value.Shape) == 0 {
|
||||
return nil, errors.New("unsupported safetensors model")
|
||||
value.Shape = []uint64{1}
|
||||
}
|
||||
ggufName := replacer.Replace(key)
|
||||
if _, ok := names[ggufName]; ok {
|
||||
|
|
|
|||
|
|
@ -281,6 +281,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
|||
"deepseekocr",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"gemma4",
|
||||
"gptoss", "gpt-oss",
|
||||
"llama4",
|
||||
"mistral3",
|
||||
|
|
|
|||
121
llama/patches/0035-CUDA-get_rows-q6_k-support.patch
Normal file
121
llama/patches/0035-CUDA-get_rows-q6_k-support.patch
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Fri, 20 Mar 2026 18:50:38 -0700
|
||||
Subject: [PATCH] CUDA get_rows q6_k support
|
||||
|
||||
---
|
||||
ggml/src/ggml-cuda/getrows.cu | 80 ++++++++++++++++++++++++++++++++-
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 1 +
|
||||
2 files changed, 80 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu
|
||||
index 2fab33243..dc5c4f57a 100644
|
||||
--- a/ggml/src/ggml-cuda/getrows.cu
|
||||
+++ b/ggml/src/ggml-cuda/getrows.cu
|
||||
@@ -155,6 +155,81 @@ static void get_rows_cuda_float(
|
||||
s10, s11, s12/*, s13*/);
|
||||
}
|
||||
|
||||
+// Specialized GET_ROWS kernel for Q6_K — the k_get_rows template doesn't work for K-quants
|
||||
+// because they lack the simple dequantize_kernel_t (float2) interface.
|
||||
+// Based on dequantize_block_q6_K from convert.cu with row-selection logic added.
|
||||
+template<typename dst_t>
|
||||
+static __global__ void k_get_rows_q6_K(
|
||||
+ const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
+ const int64_t ne00,
|
||||
+ const int64_t ne11, const int64_t ne12,
|
||||
+ const size_t s1, const size_t s2, const size_t s3,
|
||||
+ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
+ const size_t s10, const size_t s11, const size_t s12) {
|
||||
+
|
||||
+ const int64_t i10 = blockIdx.x; // row index into src1
|
||||
+ const int64_t z = blockIdx.z;
|
||||
+ const int64_t i11 = z / ne12;
|
||||
+ const int64_t i12 = z % ne12;
|
||||
+
|
||||
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
+
|
||||
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
+ const char * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
+
|
||||
+ const int64_t nb = ne00 / QK_K; // number of Q6_K blocks per row
|
||||
+
|
||||
+ // blockIdx.y iterates over Q6_K blocks within the row
|
||||
+ for (int64_t iblk = blockIdx.y; iblk < nb; iblk += gridDim.y) {
|
||||
+ const block_q6_K * x = (const block_q6_K *)src0_row + iblk;
|
||||
+
|
||||
+ // Same dequantization as dequantize_block_q6_K (assumes 64 threads)
|
||||
+ const int64_t tid = threadIdx.x;
|
||||
+ const int64_t ip = tid / 32; // 0 or 1
|
||||
+ const int64_t il = tid - 32*ip; // 0..31
|
||||
+ const int64_t is = 8*ip + il/16;
|
||||
+
|
||||
+ const int64_t y_offset = iblk * QK_K + 128*ip + il;
|
||||
+
|
||||
+ const float d = x->d;
|
||||
+ const uint8_t * ql = x->ql + 64*ip + il;
|
||||
+ const uint8_t qh = x->qh[32*ip + il];
|
||||
+ const int8_t * sc = x->scales + is;
|
||||
+
|
||||
+ if (y_offset + 0 < ne00) dst_row[y_offset + 0] = ggml_cuda_cast<dst_t>(d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32));
|
||||
+ if (y_offset + 32 < ne00) dst_row[y_offset + 32] = ggml_cuda_cast<dst_t>(d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32));
|
||||
+ if (y_offset + 64 < ne00) dst_row[y_offset + 64] = ggml_cuda_cast<dst_t>(d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32));
|
||||
+ if (y_offset + 96 < ne00) dst_row[y_offset + 96] = ggml_cuda_cast<dst_t>(d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32));
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+template<typename dst_t>
|
||||
+static void get_rows_cuda_q6_K(
|
||||
+ const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
|
||||
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
+ const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
+ cudaStream_t stream) {
|
||||
+ const int64_t nb_blocks = ne00 / QK_K;
|
||||
+ const dim3 block_dims(64, 1, 1);
|
||||
+ const dim3 block_nums(ne10, MIN(nb_blocks, (int64_t)UINT16_MAX), MIN(ne11*ne12, (int64_t)UINT16_MAX));
|
||||
+
|
||||
+ const size_t s1 = nb1 / sizeof(dst_t);
|
||||
+ const size_t s2 = nb2 / sizeof(dst_t);
|
||||
+ const size_t s3 = nb3 / sizeof(dst_t);
|
||||
+
|
||||
+ const size_t s10 = nb10 / sizeof(int32_t);
|
||||
+ const size_t s11 = nb11 / sizeof(int32_t);
|
||||
+ const size_t s12 = nb12 / sizeof(int32_t);
|
||||
+
|
||||
+ k_get_rows_q6_K<<<block_nums, block_dims, 0, stream>>>(
|
||||
+ src0_d, src1_d, dst_d,
|
||||
+ ne00, ne11, ne12,
|
||||
+ s1, s2, s3,
|
||||
+ nb01, nb02, nb03,
|
||||
+ s10, s11, s12);
|
||||
+}
|
||||
+
|
||||
template <typename dst_t>
|
||||
static void ggml_cuda_get_rows_switch_src0_type(
|
||||
const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
|
||||
@@ -199,8 +274,11 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
||||
get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
+ case GGML_TYPE_Q6_K:
|
||||
+ get_rows_cuda_q6_K(src0_d, src1_d, dst_d,
|
||||
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
+ break;
|
||||
default:
|
||||
- // TODO: k-quants
|
||||
GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
|
||||
break;
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 5c9dfd032..b8ed3709b 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -4693,6 +4693,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
+ case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -137,6 +137,7 @@ type Tensor interface {
|
|||
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
BackendGet() []float32
|
||||
|
||||
FromBytes([]byte)
|
||||
FromFloats([]float32)
|
||||
|
|
|
|||
|
|
@ -1069,6 +1069,21 @@ func (t *Tensor) Floats() (data []float32) {
|
|||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) BackendGet() []float32 {
|
||||
n := int(C.ggml_nelements(t.t))
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if t.sync != nil {
|
||||
t.sync()
|
||||
}
|
||||
|
||||
data := make([]float32, n)
|
||||
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
|
||||
return data
|
||||
}
|
||||
|
||||
func tensorSet[S ~[]E, E byte | float32 | int32](t *Tensor, s S) {
|
||||
if len(s) == 0 {
|
||||
return
|
||||
|
|
|
|||
80
ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
vendored
80
ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
vendored
|
|
@ -155,6 +155,81 @@ static void get_rows_cuda_float(
|
|||
s10, s11, s12/*, s13*/);
|
||||
}
|
||||
|
||||
// Specialized GET_ROWS kernel for Q6_K — the k_get_rows template doesn't work for K-quants
|
||||
// because they lack the simple dequantize_kernel_t (float2) interface.
|
||||
// Based on dequantize_block_q6_K from convert.cu with row-selection logic added.
|
||||
template<typename dst_t>
|
||||
static __global__ void k_get_rows_q6_K(
|
||||
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00,
|
||||
const int64_t ne11, const int64_t ne12,
|
||||
const size_t s1, const size_t s2, const size_t s3,
|
||||
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12) {
|
||||
|
||||
const int64_t i10 = blockIdx.x; // row index into src1
|
||||
const int64_t z = blockIdx.z;
|
||||
const int64_t i11 = z / ne12;
|
||||
const int64_t i12 = z % ne12;
|
||||
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const char * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
|
||||
const int64_t nb = ne00 / QK_K; // number of Q6_K blocks per row
|
||||
|
||||
// blockIdx.y iterates over Q6_K blocks within the row
|
||||
for (int64_t iblk = blockIdx.y; iblk < nb; iblk += gridDim.y) {
|
||||
const block_q6_K * x = (const block_q6_K *)src0_row + iblk;
|
||||
|
||||
// Same dequantization as dequantize_block_q6_K (assumes 64 threads)
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t ip = tid / 32; // 0 or 1
|
||||
const int64_t il = tid - 32*ip; // 0..31
|
||||
const int64_t is = 8*ip + il/16;
|
||||
|
||||
const int64_t y_offset = iblk * QK_K + 128*ip + il;
|
||||
|
||||
const float d = x->d;
|
||||
const uint8_t * ql = x->ql + 64*ip + il;
|
||||
const uint8_t qh = x->qh[32*ip + il];
|
||||
const int8_t * sc = x->scales + is;
|
||||
|
||||
if (y_offset + 0 < ne00) dst_row[y_offset + 0] = ggml_cuda_cast<dst_t>(d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32));
|
||||
if (y_offset + 32 < ne00) dst_row[y_offset + 32] = ggml_cuda_cast<dst_t>(d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32));
|
||||
if (y_offset + 64 < ne00) dst_row[y_offset + 64] = ggml_cuda_cast<dst_t>(d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32));
|
||||
if (y_offset + 96 < ne00) dst_row[y_offset + 96] = ggml_cuda_cast<dst_t>(d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void get_rows_cuda_q6_K(
|
||||
const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
|
||||
const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
const int64_t nb_blocks = ne00 / QK_K;
|
||||
const dim3 block_dims(64, 1, 1);
|
||||
const dim3 block_nums(ne10, MIN(nb_blocks, (int64_t)UINT16_MAX), MIN(ne11*ne12, (int64_t)UINT16_MAX));
|
||||
|
||||
const size_t s1 = nb1 / sizeof(dst_t);
|
||||
const size_t s2 = nb2 / sizeof(dst_t);
|
||||
const size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
const size_t s10 = nb10 / sizeof(int32_t);
|
||||
const size_t s11 = nb11 / sizeof(int32_t);
|
||||
const size_t s12 = nb12 / sizeof(int32_t);
|
||||
|
||||
k_get_rows_q6_K<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne11, ne12,
|
||||
s1, s2, s3,
|
||||
nb01, nb02, nb03,
|
||||
s10, s11, s12);
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void ggml_cuda_get_rows_switch_src0_type(
|
||||
const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
|
||||
|
|
@ -199,8 +274,11 @@ static void ggml_cuda_get_rows_switch_src0_type(
|
|||
get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
get_rows_cuda_q6_K(src0_d, src1_d, dst_d,
|
||||
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
default:
|
||||
// TODO: k-quants
|
||||
GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4693,6 +4693,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
188
model/models/gemma4/model.go
Normal file
188
model/models/gemma4/model.go
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
package gemma4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
|
||||
imageTokenID int32
|
||||
imageEndTokenID int32
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type MultiModalProjector struct {
|
||||
Projection *ClippableLinear `gguf:"input_projection"`
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||
visionOutputs = p.Projection.Forward(ctx, visionOutputs)
|
||||
// Post-projection RMSNorm without learned weight
|
||||
visionOutputs = visionOutputs.RMSNorm(ctx, nil, eps)
|
||||
return visionOutputs
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
|
||||
// Gemma 4 uses BPE with SentencePiece-style ▁ space markers (not GPT-2 byte-level encoding).
|
||||
// The tokenizer.json has merges and a Replace normalizer (space → ▁), with no pre-tokenizer.
|
||||
t := tokenizer.NewBytePairEncodingWithOptions(&vocabulary, []string{},
|
||||
tokenizer.WithSentencePieceNormalizer())
|
||||
|
||||
// Look up special token IDs for vision
|
||||
imageTokenID := int32(-1)
|
||||
imageEndTokenID := int32(-1)
|
||||
for i, tok := range vocabulary.Values {
|
||||
switch tok {
|
||||
case "<|image>":
|
||||
imageTokenID = int32(i)
|
||||
case "<image|>":
|
||||
imageEndTokenID = int32(i)
|
||||
}
|
||||
}
|
||||
|
||||
m := Model{
|
||||
Tokenizer: t,
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
MultiModalProjector: &MultiModalProjector{},
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
imageTokenID: imageTokenID,
|
||||
imageEndTokenID: imageEndTokenID,
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
// Initialize clamp values from model tensors (lazy, once, after model is fully loaded)
|
||||
m.VisionModel.InitClamp(m.MultiModalProjector)
|
||||
|
||||
t0 := time.Now()
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("vision: decode", "elapsed", time.Since(t0), "bounds", img.Bounds())
|
||||
|
||||
t1 := time.Now()
|
||||
f32s, imgW, imgH, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("vision: preprocess", "elapsed", time.Since(t1), "size", [2]int{imgW, imgH})
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s, imgW, imgH, m.ImageProcessor.numChannels)
|
||||
slog.Info("vision: pixelValues", "shape", pixelValues.Shape(), "dim0", pixelValues.Dim(0), "dim1", pixelValues.Dim(1), "dim2", pixelValues.Dim(2))
|
||||
|
||||
numPatchesX := imgW / m.ImageProcessor.patchSize
|
||||
numPatchesY := imgH / m.ImageProcessor.patchSize
|
||||
slog.Info("vision: patches", "patchesX", numPatchesX, "patchesY", numPatchesY, "total", numPatchesX*numPatchesY, "patchSize", m.ImageProcessor.patchSize)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, numPatchesX, numPatchesY)
|
||||
visionOutputs = visionPoolAndProject(ctx, visionOutputs, numPatchesX, numPatchesY, m.VisionModel.VisionModelOptions, m.MultiModalProjector)
|
||||
slog.Info("vision: encoded", "elapsed", time.Since(t0), "shape", visionOutputs.Shape())
|
||||
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
numImageTokens := inputMultimodal.Dim(1)
|
||||
|
||||
// <|image>
|
||||
if m.imageTokenID >= 0 {
|
||||
result = append(result, &input.Input{Token: m.imageTokenID, SameBatch: numImageTokens + 2})
|
||||
}
|
||||
|
||||
// Image embedding placeholder tokens
|
||||
result = append(result,
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash},
|
||||
)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, numImageTokens-1)...)
|
||||
|
||||
// <image|>
|
||||
if m.imageEndTokenID >= 0 {
|
||||
result = append(result, &input.Input{Token: m.imageEndTokenID})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
|
||||
hiddenState = m.TextModel.Output.Forward(ctx, hiddenState)
|
||||
|
||||
if m.TextModel.TextOptions.finalLogitSoftcap > 0.0 {
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextModel.TextOptions.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
hiddenState = hiddenState.Scale(ctx, float64(m.TextModel.TextOptions.finalLogitSoftcap))
|
||||
}
|
||||
|
||||
return hiddenState, nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase, ropeDims := m.TextModel.ropeForLayer(layer)
|
||||
return nn.RoPE(ctx, key, shift, ropeDims, ropeBase, 1.0, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma4", New)
|
||||
}
|
||||
454
model/models/gemma4/model_text.go
Normal file
454
model/models/gemma4/model_text.go
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTypeSWA = iota
|
||||
cacheTypeCausal
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenSize int
|
||||
numHeads, numKVHeads int
|
||||
numGlobalKVHeads int
|
||||
headDim, globalHeadDim int
|
||||
hiddenLayers int
|
||||
hiddenSizePerLayerInput int
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeLocalBase float32
|
||||
partialRotaryDims int // RoPE dims for full-attention (global) layers
|
||||
|
||||
slidingWindowPattern []bool
|
||||
// kvDonorMap maps shared layer index -> donor layer index.
|
||||
// Donor is the last non-shared layer of the same type (sliding/full).
|
||||
kvDonorMap map[int]int
|
||||
|
||||
finalLogitSoftcap float32
|
||||
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
}
|
||||
|
||||
func (o *TextOptions) isLocal(layer int) bool {
|
||||
if layer < len(o.slidingWindowPattern) {
|
||||
return o.slidingWindowPattern[layer]
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (o *TextOptions) ropeForLayer(layer int) (base float32, dims int) {
|
||||
if o.isLocal(layer) {
|
||||
return o.ropeLocalBase, o.headDim
|
||||
}
|
||||
return o.ropeBase, o.partialRotaryDims
|
||||
}
|
||||
|
||||
func (o *TextOptions) kvHeadsForLayer(layer int) int {
|
||||
if o.isLocal(layer) {
|
||||
return o.numKVHeads
|
||||
}
|
||||
if o.numGlobalKVHeads > 0 {
|
||||
return o.numGlobalKVHeads
|
||||
}
|
||||
return o.numKVHeads
|
||||
}
|
||||
|
||||
func (o *TextOptions) headDimForLayer(layer int) int {
|
||||
if o.isLocal(layer) {
|
||||
return o.headDim
|
||||
}
|
||||
return o.globalHeadDim
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
*PerLayerProjector
|
||||
Layers []TextLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
TextOptions
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
|
||||
// Head dimensions: key_length is global head dim, key_length_swa is local (SWA) head dim.
|
||||
globalHeadDim := int(c.Uint("attention.key_length", 512))
|
||||
headDim := int(c.Uint("attention.key_length_swa", 256))
|
||||
|
||||
// RoPE dimensions for global (full attention) layers with proportional RoPE.
|
||||
// The freq_factors tensor handles partial rotation (1.0 for rotated pairs,
|
||||
// 1e30 for non-rotated), so ropeDims equals the full global head dim.
|
||||
partialRotaryDims := int(c.Uint("rope.dimension_count", 0))
|
||||
if partialRotaryDims == 0 {
|
||||
partialFactor := c.Float("rope.partial_rotary_factor", 1.0)
|
||||
partialRotaryDims = int(float32(globalHeadDim) * partialFactor)
|
||||
}
|
||||
|
||||
ropeBase := c.Float("rope.freq_base", 1000000.0)
|
||||
ropeLocalBase := c.Float("rope.freq_base_swa", 0)
|
||||
if ropeLocalBase == 0 {
|
||||
ropeLocalBase = c.Float("rope.local.freq_base", 10000.0)
|
||||
}
|
||||
|
||||
numGlobalKVHeads := int(c.Uint("attention.global_head_count_kv", 0))
|
||||
slidingPattern := c.Bools("attention.sliding_window_pattern")
|
||||
|
||||
// KV heads: try per-layer array first (MoE models), then fall back to scalar
|
||||
numKVHeads := 0
|
||||
kvHeadsArray := c.Ints("attention.head_count_kv")
|
||||
if len(kvHeadsArray) > 0 {
|
||||
numKVHeads = int(kvHeadsArray[0])
|
||||
if numGlobalKVHeads == 0 && len(slidingPattern) > 0 {
|
||||
for i, isLocal := range slidingPattern {
|
||||
if !isLocal && i < len(kvHeadsArray) {
|
||||
numGlobalKVHeads = int(kvHeadsArray[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if numKVHeads == 0 {
|
||||
numKVHeads = int(c.Uint("attention.head_count_kv", 0))
|
||||
}
|
||||
|
||||
// Compute KV sharing donor map (same logic as MLX)
|
||||
sharedLayers := int(c.Uint("attention.shared_kv_layers", 0))
|
||||
kvDonorMap := make(map[int]int)
|
||||
if sharedLayers > 0 && len(slidingPattern) > 0 {
|
||||
firstShared := numLayers - sharedLayers
|
||||
for i := firstShared; i < numLayers; i++ {
|
||||
isLocal := slidingPattern[i]
|
||||
// Find last non-shared layer of same type
|
||||
for j := firstShared - 1; j >= 0; j-- {
|
||||
if slidingPattern[j] == isLocal {
|
||||
kvDonorMap[i] = j
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &TextModel{
|
||||
Layers: make([]TextLayer, numLayers),
|
||||
TextOptions: TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: numKVHeads,
|
||||
numGlobalKVHeads: numGlobalKVHeads,
|
||||
headDim: headDim,
|
||||
globalHeadDim: globalHeadDim,
|
||||
hiddenLayers: numLayers,
|
||||
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input", 0)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeBase: ropeBase,
|
||||
ropeLocalBase: ropeLocalBase,
|
||||
partialRotaryDims: partialRotaryDims,
|
||||
slidingWindowPattern: slidingPattern,
|
||||
kvDonorMap: kvDonorMap,
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||
numExperts: int(c.Uint("expert_count", 0)),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count", 0)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||
|
||||
// Inject vision embeddings into the hidden state
|
||||
var except []int
|
||||
for _, image := range batch.Multimodal {
|
||||
visionOutputs := image.Multimodal[0].Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
except = append(except, image.Index+i)
|
||||
}
|
||||
}
|
||||
|
||||
// PLE
|
||||
var perLayerInputs ml.Tensor
|
||||
if m.PerLayerProjector != nil {
|
||||
perLayerInputs = m.PerLayerProjector.Forward(ctx, batch, hiddenState, &m.TextOptions)
|
||||
}
|
||||
|
||||
for i := range len(m.Layers) {
|
||||
layer := m.Layers[i]
|
||||
if cache != nil {
|
||||
cache.SetLayer(i)
|
||||
cacheType := cacheTypeSWA
|
||||
if !m.isLocal(i) {
|
||||
cacheType = cacheTypeCausal
|
||||
}
|
||||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
}
|
||||
}
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
var perLayerInput ml.Tensor
|
||||
if perLayerInputs != nil {
|
||||
perLayerInput = perLayerInputs.View(ctx, i*perLayerInputs.Stride(1), perLayerInputs.Dim(0), perLayerInputs.Stride(2), perLayerInputs.Dim(2))
|
||||
}
|
||||
|
||||
// KV sharing: layers >= firstShared reuse K/V from donor layers
|
||||
isShared := false
|
||||
if donorLayer, ok := m.kvDonorMap[i]; ok {
|
||||
// Set cache layer to donor so Get() reads donor's K/V
|
||||
cache.SetLayer(donorLayer)
|
||||
isShared = true
|
||||
}
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, perLayerInput, lastLayerOutputs, cache, isShared, &m.TextOptions)
|
||||
}
|
||||
|
||||
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
}
|
||||
|
||||
// PerLayerProjector implements PLE.
|
||||
type PerLayerProjector struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"per_layer_token_embd"`
|
||||
Projector *nn.Linear `gguf:"per_layer_model_proj"`
|
||||
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
|
||||
}
|
||||
|
||||
func (p *PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
inputsPerLayer = inputsPerLayer.Scale(ctx, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
|
||||
// Reshape to [pleDim, numLayers, numTokens] — matching projection shape
|
||||
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
|
||||
perLayerProjection := p.Projector.Forward(ctx, inputs)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
||||
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
|
||||
|
||||
if inputsPerLayer != nil {
|
||||
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
|
||||
}
|
||||
|
||||
return perLayerProjection
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` // proportional RoPE freq_factors
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
hd := opts.headDimForLayer(layer)
|
||||
kvHeads := opts.kvHeadsForLayer(layer)
|
||||
ropeBase, ropeDims := opts.ropeForLayer(layer)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, hd, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
|
||||
var k, v ml.Tensor
|
||||
if !sharedKV {
|
||||
k = sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, hd, kvHeads, batchSize)
|
||||
|
||||
if sa.Value != nil {
|
||||
v = sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, hd, kvHeads, batchSize)
|
||||
} else {
|
||||
// K=V: use raw K projection (before K norm) as V
|
||||
v = k
|
||||
}
|
||||
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
v = v.RMSNorm(ctx, nil, opts.eps) // V norm: unweighted RMSNorm
|
||||
}
|
||||
|
||||
// RoPE with proportional freq_factors on global layers
|
||||
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if sa.RopeFactors != nil && !opts.isLocal(layer) {
|
||||
ropeOpts = append(ropeOpts, rope.WithFactors(sa.RopeFactors))
|
||||
}
|
||||
q = nn.RoPE(ctx, q, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
||||
if k != nil {
|
||||
k = nn.RoPE(ctx, k, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, q, k, v, 1.0, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, hd*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
// TextRouter implements the Gemma 4 MoE router.
|
||||
// It does: RMSNorm(no weight) → scale(1/sqrt(hidden)) → multiply by scale param → linear → softmax → topk
|
||||
type TextRouter struct {
|
||||
Proj *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
|
||||
}
|
||||
|
||||
func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) (routingWeights, selectedExperts ml.Tensor) {
|
||||
// RMSNorm without learned weight
|
||||
x := hiddenState.RMSNorm(ctx, nil, opts.eps)
|
||||
// Scale by 1/sqrt(hidden_size)
|
||||
x = x.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
||||
// Multiply by learned scale parameter
|
||||
x = x.Mul(ctx, r.Scale)
|
||||
// Project to expert logits
|
||||
expertScores := r.Proj.Forward(ctx, x)
|
||||
// Softmax over experts
|
||||
routingWeights = expertScores.Softmax(ctx)
|
||||
// TopK expert selection
|
||||
selectedExperts = routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
return routingWeights, selectedExperts
|
||||
}
|
||||
|
||||
// TextMoEBlock implements the Gemma 4 sparse MoE.
|
||||
type TextMoEBlock struct {
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Select routing weights for chosen experts and renormalize
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
|
||||
|
||||
// Expert computation using LinearBatch (MulmatID selecting experts by index)
|
||||
gateOut := moe.Gate.Forward(ctx, hiddenState, selectedExperts)
|
||||
upOut := moe.Up.Forward(ctx, hiddenState, selectedExperts)
|
||||
hiddenState = gateOut.GELU(ctx, upOut)
|
||||
experts := moe.Down.Forward(ctx, hiddenState, selectedExperts)
|
||||
|
||||
// Apply routing weights
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum across experts
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextStates
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *TextSelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm,alt:attn_post_norm"`
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm,alt:ffn_pre_norm"`
|
||||
MLP *TextMLP
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm,alt:ffn_post_norm"`
|
||||
|
||||
// MoE (present only for models with enable_moe_block=true)
|
||||
Router *TextRouter
|
||||
MoE *TextMoEBlock
|
||||
MoENorm *nn.RMSNorm `gguf:"pre_ffw_norm_2,alt:ffn_pre_norm_2"`
|
||||
PostMoENorm *nn.RMSNorm `gguf:"post_ffw_norm_2,alt:ffn_post_norm_2"`
|
||||
PostMLPNorm1 *nn.RMSNorm `gguf:"post_ffw_norm_1,alt:ffn_post_norm_1"` // used instead of PostMLPNorm when MoE is present
|
||||
|
||||
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
|
||||
PerLayerProjection *nn.Linear `gguf:"proj"`
|
||||
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
|
||||
LayerScalar ml.Tensor `gguf:"layer_scalar,alt:layer_output_scale.weight"`
|
||||
}
|
||||
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, perLayerInput, outputs ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positions, cache, sharedKV, opts)
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
if perLayerInput != nil {
|
||||
perLayerInput = perLayerInput.Rows(ctx, outputs)
|
||||
}
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// MLP (+ optional MoE in parallel)
|
||||
if l.Router != nil && l.MoE != nil && l.MoE.Gate != nil && l.MoE.Gate.Weight != nil {
|
||||
// MoE layers: run MLP and MoE in parallel, sum results
|
||||
mlpState := l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
mlpState = l.MLP.Forward(ctx, mlpState)
|
||||
mlpState = l.PostMLPNorm1.Forward(ctx, mlpState, opts.eps)
|
||||
|
||||
routingWeights, selectedExperts := l.Router.Forward(ctx, hiddenState, opts)
|
||||
moeState := l.MoENorm.Forward(ctx, hiddenState, opts.eps)
|
||||
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, opts)
|
||||
moeState = l.PostMoENorm.Forward(ctx, moeState, opts.eps)
|
||||
|
||||
// Combine MLP + MoE, apply outer post-FFN norm, then add residual
|
||||
combined := mlpState.Add(ctx, moeState)
|
||||
combined = l.PostMLPNorm.Forward(ctx, combined, opts.eps)
|
||||
hiddenState = combined.Add(ctx, residual)
|
||||
} else {
|
||||
// Dense layers: MLP only
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState)
|
||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
// PLE injection (after MLP residual)
|
||||
if perLayerInput != nil && l.PerLayerInputGate != nil {
|
||||
pleState := l.PerLayerInputGate.Forward(ctx, hiddenState)
|
||||
pleState = pleState.GELU(ctx, perLayerInput)
|
||||
pleState = l.PerLayerProjection.Forward(ctx, pleState)
|
||||
pleState = l.PostPerLayerNorm.Forward(ctx, pleState, opts.eps)
|
||||
hiddenState = hiddenState.Add(ctx, pleState)
|
||||
}
|
||||
|
||||
// Layer scalar applied at end of layer (full-attention layers only)
|
||||
if l.LayerScalar != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, l.LayerScalar)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
304
model/models/gemma4/model_vision.go
Normal file
304
model/models/gemma4/model_vision.go
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
package gemma4
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
const batchSize = 1
|
||||
|
||||
// ClippableLinear is a linear layer with optional input/output clamping.
|
||||
// Required by Gemma4 vision encoder for numerical stability with F16 weights.
|
||||
// Clamp values are populated by VisionModel.InitClamp from the packed v.clamp_data tensor.
|
||||
type ClippableLinear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
|
||||
inMin, inMax, outMin, outMax float32
|
||||
}
|
||||
|
||||
func (l *ClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
|
||||
if l.inMax != 0 {
|
||||
x = x.Clamp(ctx, l.inMin, l.inMax)
|
||||
}
|
||||
out := l.Weight.Mulmat(ctx, x)
|
||||
if l.outMax != 0 {
|
||||
out = out.Clamp(ctx, l.outMin, l.outMax)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// InitClamp distributes packed clamp values from v.clamp_data to ClippableLinear structs.
|
||||
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
|
||||
// then 4 floats for the projector.
|
||||
func (m *VisionModel) InitClamp(proj *MultiModalProjector) {
|
||||
if m.clampInitDone || m.ClampData == nil {
|
||||
return
|
||||
}
|
||||
m.clampInitDone = true
|
||||
|
||||
// Read all clamp values from packed F32 tensor
|
||||
data := m.ClampData.BackendGet()
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Distribute to layer linears: 7 per layer × 4 values each
|
||||
linears := func(l *VisionEncoderLayer) []*ClippableLinear {
|
||||
return []*ClippableLinear{
|
||||
l.SelfAttention.Query, l.SelfAttention.Key, l.SelfAttention.Value,
|
||||
l.SelfAttention.Output, l.MLP.Gate, l.MLP.Up, l.MLP.Down,
|
||||
}
|
||||
}
|
||||
for i := range m.Layers {
|
||||
for li, cl := range linears(&m.Layers[i]) {
|
||||
idx := (i*7 + li) * 4
|
||||
if idx+3 < len(data) {
|
||||
cl.inMin = data[idx]
|
||||
cl.inMax = data[idx+1]
|
||||
cl.outMin = data[idx+2]
|
||||
cl.outMax = data[idx+3]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Projector clamp values (last 4 floats)
|
||||
if proj != nil {
|
||||
projIdx := len(m.Layers) * 7 * 4
|
||||
if projIdx+3 < len(data) {
|
||||
proj.Projection.inMin = data[projIdx]
|
||||
proj.Projection.inMax = data[projIdx+1]
|
||||
proj.Projection.outMin = data[projIdx+2]
|
||||
proj.Projection.outMax = data[projIdx+3]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *ClippableLinear `gguf:"attn_q"`
|
||||
Key *ClippableLinear `gguf:"attn_k"`
|
||||
Value *ClippableLinear `gguf:"attn_v"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Output *ClippableLinear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
numPatches := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
|
||||
// Q/K norms (Gemma-style: x * (1 + weight) / rms(x))
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
// V norm (RMSNorm without learned weights)
|
||||
value = value.RMSNorm(ctx, nil, opts.eps)
|
||||
|
||||
// 2D RoPE: split head dim in half, apply NeoX RoPE with x positions to first half,
|
||||
// y positions to second half, then concatenate.
|
||||
halfDim := headDim / 2
|
||||
ropeOpts := rope.WithTypeNeoX()
|
||||
|
||||
qFirst := query.View(ctx, 0, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
|
||||
qFirst = nn.RoPE(ctx, qFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
kFirst := key.View(ctx, 0, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
|
||||
kFirst = nn.RoPE(ctx, kFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
halfOffset := halfDim * query.Stride(0)
|
||||
qSecond := query.View(ctx, halfOffset, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
|
||||
qSecond = nn.RoPE(ctx, qSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
halfOffsetK := halfDim * key.Stride(0)
|
||||
kSecond := key.View(ctx, halfOffsetK, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
|
||||
kSecond = nn.RoPE(ctx, kSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
query = qFirst.Concat(ctx, qSecond, 0)
|
||||
key = kFirst.Concat(ctx, kSecond, 0)
|
||||
|
||||
// Use flash attention for numerical stability (handles large attention scores
|
||||
// from unclamped RMSNorm weights, e.g. 26B has addOne weights up to 19.5)
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Gate *ClippableLinear `gguf:"ffn_gate"`
|
||||
Up *ClippableLinear `gguf:"ffn_up"`
|
||||
Down *ClippableLinear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
gate := mlp.Gate.Forward(ctx, hiddenState)
|
||||
up := mlp.Up.Forward(ctx, hiddenState)
|
||||
hiddenState = gate.QuickGELU(ctx, up)
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"attn_post_norm"`
|
||||
|
||||
FFNNorm *nn.RMSNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
PostFFNNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
|
||||
|
||||
LayerOutputScale ml.Tensor `gguf:"out_scale.weight"`
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// Pre-attention norm -> self attention -> post-attention norm
|
||||
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, posX, posY, attnMask, opts)
|
||||
hiddenState = e.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// Residual connection
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// Pre-FFN norm -> FFN -> post-FFN norm
|
||||
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState)
|
||||
hiddenState = e.PostFFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// Residual connection
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
|
||||
// Per-layer output scale
|
||||
if e.LayerOutputScale != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, e.LayerOutputScale)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
patchSize int
|
||||
nMerge int
|
||||
eps float32
|
||||
ropeTheta float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
PositionEmbedding ml.Tensor `gguf:"position_embd.weight"`
|
||||
ClampData ml.Tensor `gguf:"clamp_data"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
clampInitDone bool
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, numPatchesX, numPatchesY int) ml.Tensor {
|
||||
numPatches := numPatchesX * numPatchesY
|
||||
|
||||
// Patch embedding via Conv2D
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Conv2D with F16 weights produces F16 output via im2col; cast to F32 for encoder precision
|
||||
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
|
||||
|
||||
// 2D positional embeddings from 3D tensor [nEmbd, maxPos, 2]
|
||||
posSize := m.PositionEmbedding.Dim(1)
|
||||
nb1 := m.PositionEmbedding.Stride(1)
|
||||
tblX := m.PositionEmbedding.View(ctx, 0, m.hiddenSize, nb1, posSize)
|
||||
tblY := m.PositionEmbedding.View(ctx, posSize*nb1, m.hiddenSize, nb1, posSize)
|
||||
|
||||
// Position indices for patches
|
||||
posXData := make([]int32, numPatches)
|
||||
posYData := make([]int32, numPatches)
|
||||
for i := range numPatches {
|
||||
posXData[i] = int32(i % numPatchesX)
|
||||
posYData[i] = int32(i / numPatchesX)
|
||||
}
|
||||
|
||||
posXEmb := ctx.Input().FromInts(posXData, numPatches)
|
||||
posYEmb := ctx.Input().FromInts(posYData, numPatches)
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, tblX.Rows(ctx, posXEmb))
|
||||
hiddenState = hiddenState.Add(ctx, tblY.Rows(ctx, posYEmb))
|
||||
|
||||
// No attention mask — all positions are real patches
|
||||
var attnMask ml.Tensor
|
||||
|
||||
// RoPE positions
|
||||
posXRope := ctx.Input().FromInts(posXData, numPatches)
|
||||
posYRope := ctx.Input().FromInts(posYData, numPatches)
|
||||
|
||||
// Vision transformer layers
|
||||
for i := range m.Layers {
|
||||
hiddenState = m.Layers[i].Forward(ctx, hiddenState, posXRope, posYRope, attnMask, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
||||
nMerge: int(c.Uint("vision.projector.scale_factor", 3)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||
ropeTheta: 100.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func visionTokenCount(imageWidth, imageHeight, patchSize, nMerge int) int {
|
||||
patchesX := imageWidth / patchSize
|
||||
patchesY := imageHeight / patchSize
|
||||
mergedX := patchesX / nMerge
|
||||
mergedY := patchesY / nMerge
|
||||
return mergedX * mergedY
|
||||
}
|
||||
|
||||
func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, numPatchesY int, opts *VisionModelOptions, proj *MultiModalProjector) ml.Tensor {
|
||||
hiddenSize := opts.hiddenSize
|
||||
|
||||
// Reshape from [hiddenSize, numPatches] to spatial layout for pooling
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatchesX, numPatchesY, hiddenSize)
|
||||
|
||||
// AvgPool2D with kernel=stride=nMerge
|
||||
hiddenState = hiddenState.AvgPool2D(ctx, opts.nMerge, opts.nMerge, 0)
|
||||
|
||||
// Reshape back to [hiddenSize, numMergedPatches]
|
||||
mergedX := numPatchesX / opts.nMerge
|
||||
mergedY := numPatchesY / opts.nMerge
|
||||
hiddenState = hiddenState.Reshape(ctx, mergedX*mergedY, hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Ensure F32 for the projection Mulmat. The Metal mul_mm kernel for F16×F32
|
||||
// casts F32 activations to F16 in shared memory, so values must stay within
|
||||
// F16 range (≤65504). The sqrt(hiddenSize) scaling from the HF reference is
|
||||
// omitted because it's normalized out by the unweighted RMSNorm that follows
|
||||
// the projection: RMSNorm(√d·x) = x/rms(x) = RMSNorm(x).
|
||||
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
|
||||
|
||||
// Project to text embedding dimension
|
||||
hiddenState = proj.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
103
model/models/gemma4/process_image.go
Normal file
103
model/models/gemma4/process_image.go
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
package gemma4
|
||||
|
||||
import (
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"golang.org/x/image/draw"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
patchSize int
|
||||
numChannels int
|
||||
nMerge int
|
||||
minPixels int
|
||||
maxPixels int
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 16))
|
||||
nMerge := int(c.Uint("vision.projector.scale_factor", 3))
|
||||
numChannels := int(c.Uint("vision.num_channels", 3))
|
||||
|
||||
// Token limits from reference: min=40, max=280 output tokens after pooling.
|
||||
// Convert to pixel counts: tokens * nMerge^2 * patchSize^2
|
||||
minTokens := 40
|
||||
maxTokens := 280
|
||||
patchArea := patchSize * patchSize * nMerge * nMerge
|
||||
minPixels := minTokens * patchArea
|
||||
maxPixels := maxTokens * patchArea
|
||||
|
||||
return ImageProcessor{
|
||||
patchSize: patchSize,
|
||||
numChannels: numChannels,
|
||||
nMerge: nMerge,
|
||||
minPixels: minPixels,
|
||||
maxPixels: maxPixels,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessImage resizes an image preserving aspect ratio, aligning dimensions
|
||||
// to (patchSize * nMerge) boundaries, and normalizes pixels to [-1, 1].
|
||||
// Returns the float32 pixel data and the actual output dimensions.
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, error) {
|
||||
// Compute target size preserving aspect ratio
|
||||
alignSize := p.patchSize * p.nMerge
|
||||
targetW, targetH := p.smartResize(img.Bounds().Dx(), img.Bounds().Dy(), alignSize)
|
||||
|
||||
// Resize directly without alpha compositing, matching MLX reference.
|
||||
dst := image.NewRGBA(image.Rect(0, 0, targetW, targetH))
|
||||
draw.BiLinear.Scale(dst, dst.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
// Normalize to [-1, 1] using mean=0.5, std=0.5: (pixel/255 - 0.5) / 0.5 = 2*pixel/255 - 1
|
||||
data := p.pack(dst)
|
||||
return data, targetW, targetH, nil
|
||||
}
|
||||
|
||||
// smartResize computes target dimensions that preserve aspect ratio and
|
||||
// align to alignSize boundaries. It scales the image to fill the maximum
|
||||
// patch budget (maxPixels), matching the MLX reference.
|
||||
func (p *ImageProcessor) smartResize(origW, origH, alignSize int) (int, int) {
|
||||
totalPx := origW * origH
|
||||
|
||||
var targetW, targetH int
|
||||
if p.maxPixels > 0 && totalPx > 0 {
|
||||
factor := math.Sqrt(float64(p.maxPixels) / float64(totalPx))
|
||||
targetH = max(alignSize, int(math.Floor(factor*float64(origH)/float64(alignSize)))*alignSize)
|
||||
targetW = max(alignSize, int(math.Floor(factor*float64(origW)/float64(alignSize)))*alignSize)
|
||||
} else {
|
||||
targetH = max(alignSize, (origH/alignSize)*alignSize)
|
||||
targetW = max(alignSize, (origW/alignSize)*alignSize)
|
||||
}
|
||||
|
||||
return targetW, targetH
|
||||
}
|
||||
|
||||
// pack extracts RGB values from an image and normalizes to [-1, 1].
|
||||
// Returns channel-first layout: [R..., G..., B...].
|
||||
func (p *ImageProcessor) pack(img image.Image) []float32 {
|
||||
bounds := img.Bounds()
|
||||
w := bounds.Dx()
|
||||
h := bounds.Dy()
|
||||
size := w * h
|
||||
|
||||
pixelVals := make([]float32, 3*size)
|
||||
rOff, gOff, bOff := 0, size, 2*size
|
||||
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
c := img.At(x, y)
|
||||
r, g, b, _ := c.RGBA()
|
||||
idx := (y-bounds.Min.Y)*w + (x - bounds.Min.X)
|
||||
|
||||
// Normalize [0, 255] -> [-1, 1]: 2 * (val/255) - 1
|
||||
pixelVals[rOff+idx] = float32(r>>8)/255.0*2.0 - 1.0
|
||||
pixelVals[gOff+idx] = float32(g>>8)/255.0*2.0 - 1.0
|
||||
pixelVals[bOff+idx] = float32(b>>8)/255.0*2.0 - 1.0
|
||||
}
|
||||
}
|
||||
|
||||
return pixelVals
|
||||
}
|
||||
102
model/models/gemma4/tokenizer_compare_test.go
Normal file
102
model/models/gemma4/tokenizer_compare_test.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package gemma4
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// TestTokenizerMatchesHF compares our tokenizer output against HuggingFace reference tokens.
|
||||
func TestTokenizerMatchesHF(t *testing.T) {
|
||||
modelPath := os.Getenv("GEMMA4_MODEL_PATH")
|
||||
if modelPath == "" {
|
||||
t.Skip("set GEMMA4_MODEL_PATH to a gemma4 GGUF file")
|
||||
}
|
||||
|
||||
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load model: %v", err)
|
||||
}
|
||||
defer m.Backend().Close()
|
||||
|
||||
tok := m.(tokenizer.Tokenizer)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []int32
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
input: "Hello, world!",
|
||||
expected: []int32{9259, 236764, 1902, 236888},
|
||||
},
|
||||
{
|
||||
name: "special_tokens",
|
||||
input: "<|turn>user\nWhat is 2+2?<turn|>\n<|turn>model\n",
|
||||
expected: []int32{105, 2364, 107, 3689, 563, 236743, 236778, 236862, 236778, 236881, 106, 107, 105, 4368, 107},
|
||||
},
|
||||
{
|
||||
name: "tool_declaration",
|
||||
input: "<|tool>declaration:bash{description:<|\"|>Run a command<|\"|>}<tool|>",
|
||||
expected: []int32{46, 163688, 236787, 42422, 236782, 7777, 236787, 52, 7306, 496, 4991, 52, 236783, 47},
|
||||
},
|
||||
{
|
||||
name: "tool_call",
|
||||
input: "<|tool_call>call:bash{command:<|\"|>ls -la<|\"|>}<tool_call|>",
|
||||
expected: []int32{48, 6639, 236787, 42422, 236782, 7674, 236787, 52, 5629, 753, 2149, 52, 236783, 49},
|
||||
},
|
||||
{
|
||||
name: "thinking",
|
||||
input: "<|channel>thought\nLet me think about this...<channel|>The answer is 42.",
|
||||
expected: []int32{100, 45518, 107, 6481, 786, 1751, 1003, 672, 1390, 101, 818, 3890, 563, 236743, 236812, 236778, 236761},
|
||||
},
|
||||
{
|
||||
name: "code",
|
||||
input: "func main() { fmt.Println(\"hello\") }",
|
||||
expected: []int32{6823, 1689, 825, 642, 22766, 236761, 29006, 885, 23391, 1373, 682},
|
||||
},
|
||||
{
|
||||
name: "numbers",
|
||||
input: "The answer is 42, not 43.5 or -1",
|
||||
expected: []int32{818, 3890, 563, 236743, 236812, 236778, 236764, 711, 236743, 236812, 236800, 236761, 236810, 653, 753, 236770},
|
||||
},
|
||||
{
|
||||
name: "mixed_chat_with_tools",
|
||||
input: "<|turn>system\nYou are a helpful assistant.\n<|tool>declaration:get_weather{description:<|\"|>Get weather<|\"|>,parameters:{properties:{city:{type:<|\"|>STRING<|\"|>}},type:<|\"|>OBJECT<|\"|>}}<tool|><turn|>\n<|turn>user\nWhat's the weather in Paris?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: []int32{105, 9731, 107, 3048, 659, 496, 11045, 16326, 236761, 107, 46, 163688, 236787, 828, 236779, 19323, 236782, 7777, 236787, 52, 3407, 7606, 52, 236764, 19031, 29616, 15921, 29616, 13319, 29616, 2084, 236787, 52, 35410, 52, 5237, 2084, 236787, 52, 60688, 52, 1807, 47, 106, 107, 105, 2364, 107, 3689, 236789, 236751, 506, 7606, 528, 9079, 236881, 106, 107, 105, 4368, 107, 100, 45518, 107, 101},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokens, err := tok.Encode(tt.input, false) // no BOS
|
||||
if err != nil {
|
||||
t.Fatalf("encode error: %v", err)
|
||||
}
|
||||
|
||||
if len(tokens) != len(tt.expected) {
|
||||
t.Errorf("token count mismatch: got %d, want %d", len(tokens), len(tt.expected))
|
||||
t.Logf("got: %v", tokens)
|
||||
t.Logf("want: %v", tt.expected)
|
||||
return
|
||||
}
|
||||
|
||||
mismatches := 0
|
||||
for i := range tokens {
|
||||
if tokens[i] != tt.expected[i] {
|
||||
mismatches++
|
||||
if mismatches <= 5 {
|
||||
t.Errorf("mismatch at [%d]: got %d, want %d", i, tokens[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if mismatches > 5 {
|
||||
t.Errorf("... and %d more mismatches", mismatches-5)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/gemma4"
|
||||
_ "github.com/ollama/ollama/model/models/glm4moelite"
|
||||
_ "github.com/ollama/ollama/model/models/glmocr"
|
||||
_ "github.com/ollama/ollama/model/models/gptoss"
|
||||
|
|
|
|||
353
model/renderers/gemma4.go
Normal file
353
model/renderers/gemma4.go
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Gemma4Renderer renders prompts using Gemma 4's chat format with
|
||||
// <|turn>/<turn|> markers, <|"|> string delimiters, and <|tool>/
|
||||
// <|tool_call>/<|tool_response> tags for function calling.
|
||||
type Gemma4Renderer struct {
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
const (
|
||||
g4Q = `<|"|>` // Gemma 4 string delimiter
|
||||
)
|
||||
|
||||
func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
imageOffset := 0
|
||||
|
||||
// BOS token — Gemma 4 models have add_bos_token=false in their tokenizer
|
||||
// config, so the tokenizer does not auto-prepend BOS. We must emit it
|
||||
// explicitly in the rendered prompt, matching the HF chat template.
|
||||
sb.WriteString("<bos>")
|
||||
// Extract system message if present.
|
||||
var systemMessage string
|
||||
var loopMessages []api.Message
|
||||
if len(messages) > 0 && (messages[0].Role == "system" || messages[0].Role == "developer") {
|
||||
systemMessage = messages[0].Content
|
||||
loopMessages = messages[1:]
|
||||
} else {
|
||||
loopMessages = messages
|
||||
}
|
||||
|
||||
// Emit system turn if there's a system message, tools, or thinking.
|
||||
hasThink := thinkValue != nil && thinkValue.Bool()
|
||||
if systemMessage != "" || len(tools) > 0 || hasThink {
|
||||
sb.WriteString("<|turn>system\n")
|
||||
if hasThink {
|
||||
sb.WriteString("<|think|>")
|
||||
}
|
||||
if systemMessage != "" {
|
||||
sb.WriteString(systemMessage)
|
||||
}
|
||||
for _, tool := range tools {
|
||||
sb.WriteString(r.renderToolDeclaration(tool))
|
||||
}
|
||||
sb.WriteString("<turn|>\n")
|
||||
}
|
||||
|
||||
// inModelTurn tracks whether we're inside an open <|turn>model block.
|
||||
// Tool responses are appended inline (no separate turn), and the model
|
||||
// turn is only closed when we see a non-tool message or reach the end.
|
||||
inModelTurn := false
|
||||
|
||||
for i, message := range loopMessages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
if inModelTurn {
|
||||
// Check if the preceding content was a tool response (no <turn|>
|
||||
// between tool response and next user turn per HF reference).
|
||||
prevIsToolResponse := i > 0 && loopMessages[i-1].Role == "tool"
|
||||
if !prevIsToolResponse {
|
||||
sb.WriteString("<turn|>\n")
|
||||
}
|
||||
inModelTurn = false
|
||||
}
|
||||
sb.WriteString("<|turn>user\n")
|
||||
r.renderContent(&sb, message, &imageOffset)
|
||||
sb.WriteString("<turn|>\n")
|
||||
|
||||
case "assistant":
|
||||
if inModelTurn {
|
||||
sb.WriteString("<turn|>\n")
|
||||
}
|
||||
sb.WriteString("<|turn>model\n")
|
||||
inModelTurn = true
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
for _, tc := range message.ToolCalls {
|
||||
sb.WriteString(r.formatToolCall(tc))
|
||||
}
|
||||
|
||||
case "tool":
|
||||
// Tool responses are rendered inline in the preceding model turn,
|
||||
// matching the reference format from HuggingFace's chat template.
|
||||
// Format: <|tool_response>response:NAME{key:value,...}<tool_response|>
|
||||
toolName := r.findToolName(loopMessages, i)
|
||||
sb.WriteString("<|tool_response>response:" + toolName + "{")
|
||||
r.renderToolResponseContent(&sb, message.Content)
|
||||
sb.WriteString("}<tool_response|>")
|
||||
// Keep the model turn open — it will be closed when we see the
|
||||
// next non-tool message or the assistant adds content after the response.
|
||||
|
||||
default:
|
||||
if inModelTurn {
|
||||
sb.WriteString("<turn|>\n")
|
||||
inModelTurn = false
|
||||
}
|
||||
sb.WriteString("<|turn>" + message.Role + "\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<turn|>\n")
|
||||
}
|
||||
}
|
||||
|
||||
// If the last message is not an open assistant turn, add the generation prompt.
|
||||
if !inModelTurn {
|
||||
sb.WriteString("<|turn>model\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// renderContent writes a message's content, interleaving [img-N] tags for images.
|
||||
func (r *Gemma4Renderer) renderContent(sb *strings.Builder, msg api.Message, imageOffset *int) {
|
||||
if len(msg.Images) > 0 && r.useImgTags {
|
||||
for range msg.Images {
|
||||
sb.WriteString(fmt.Sprintf("[img-%d]", *imageOffset))
|
||||
*imageOffset++
|
||||
}
|
||||
}
|
||||
sb.WriteString(msg.Content)
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) renderToolDeclaration(tool api.Tool) string {
|
||||
var sb strings.Builder
|
||||
fn := tool.Function
|
||||
|
||||
sb.WriteString("<|tool>declaration:" + fn.Name + "{")
|
||||
sb.WriteString("description:" + g4Q + fn.Description + g4Q)
|
||||
|
||||
if fn.Parameters.Properties != nil || fn.Parameters.Type != "" {
|
||||
sb.WriteString(",parameters:{")
|
||||
|
||||
needsComma := false
|
||||
|
||||
if fn.Parameters.Properties != nil && fn.Parameters.Properties.Len() > 0 {
|
||||
sb.WriteString("properties:{")
|
||||
r.writeProperties(&sb, fn.Parameters.Properties)
|
||||
sb.WriteString("}")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if len(fn.Parameters.Required) > 0 {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("required:[")
|
||||
for i, req := range fn.Parameters.Required {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(g4Q + req + g4Q)
|
||||
}
|
||||
sb.WriteString("]")
|
||||
needsComma = true
|
||||
}
|
||||
|
||||
if fn.Parameters.Type != "" {
|
||||
if needsComma {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("type:" + g4Q + strings.ToUpper(fn.Parameters.Type) + g4Q)
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
}
|
||||
|
||||
sb.WriteString("}<tool|>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) writeProperties(sb *strings.Builder, props *api.ToolPropertiesMap) {
|
||||
keys := make([]string, 0, props.Len())
|
||||
for k := range props.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, name := range keys {
|
||||
prop, _ := props.Get(name)
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
|
||||
sb.WriteString(name + ":{")
|
||||
if prop.Description != "" {
|
||||
sb.WriteString("description:" + g4Q + prop.Description + g4Q)
|
||||
}
|
||||
if len(prop.Enum) > 0 {
|
||||
if prop.Description != "" {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("enum:[")
|
||||
for j, e := range prop.Enum {
|
||||
if j > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(g4Q + fmt.Sprintf("%v", e) + g4Q)
|
||||
}
|
||||
sb.WriteString("]")
|
||||
}
|
||||
if len(prop.Type) > 0 {
|
||||
if prop.Description != "" || len(prop.Enum) > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString("type:" + g4Q + strings.ToUpper(prop.Type[0]) + g4Q)
|
||||
}
|
||||
sb.WriteString("}")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatToolCall(tc api.ToolCall) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<|tool_call>call:" + tc.Function.Name + "{")
|
||||
|
||||
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||
for k := range tc.Function.Arguments.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
value, _ := tc.Function.Arguments.Get(key)
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(value))
|
||||
}
|
||||
|
||||
sb.WriteString("}<tool_call|>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatArgValue(value any) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return g4Q + v + g4Q
|
||||
case bool:
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
case float64:
|
||||
if v == float64(int64(v)) {
|
||||
return fmt.Sprintf("%d", int64(v))
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
case int, int64, int32:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case map[string]any:
|
||||
return r.formatMapValue(v)
|
||||
case []any:
|
||||
return r.formatArrayValue(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatMapValue(m map[string]any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("{")
|
||||
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(m[key]))
|
||||
}
|
||||
|
||||
sb.WriteString("}")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *Gemma4Renderer) formatArrayValue(arr []any) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[")
|
||||
for i, item := range arr {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
sb.WriteString(r.formatArgValue(item))
|
||||
}
|
||||
sb.WriteString("]")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// renderToolResponseContent renders tool response content in Gemma 4 format.
|
||||
// If the content is valid JSON, it renders each field as key:value pairs with
|
||||
// proper type formatting (strings get <|"|> delimiters, numbers/bools are bare).
|
||||
// If not valid JSON, wraps the entire content as a single "value" string.
|
||||
func (r *Gemma4Renderer) renderToolResponseContent(sb *strings.Builder, content string) {
|
||||
// Try to parse as JSON object.
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(content), &obj); err == nil {
|
||||
keys := make([]string, 0, len(obj))
|
||||
for k := range obj {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
first = false
|
||||
sb.WriteString(key + ":" + r.formatArgValue(obj[key]))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Not JSON — wrap as a single string value.
|
||||
sb.WriteString("value:" + g4Q + content + g4Q)
|
||||
}
|
||||
|
||||
// findToolName walks backwards from tool message index to find the matching tool call name.
|
||||
func (r *Gemma4Renderer) findToolName(messages []api.Message, toolIdx int) string {
|
||||
for j := toolIdx - 1; j >= 0; j-- {
|
||||
if messages[j].Role == "assistant" && len(messages[j].ToolCalls) > 0 {
|
||||
toolOffset := 0
|
||||
for k := j + 1; k < toolIdx; k++ {
|
||||
if messages[k].Role == "tool" {
|
||||
toolOffset++
|
||||
}
|
||||
}
|
||||
if toolOffset < len(messages[j].ToolCalls) {
|
||||
return messages[j].ToolCalls[toolOffset].Function.Name
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
@ -509,6 +509,24 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
|||
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
|
||||
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
|
||||
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
|
||||
|
||||
// Auto-detect renderer, parser, and stop tokens from GGUF architecture.
|
||||
// TODO: abstract this into a registry/lookup table when multiple models
|
||||
// need architecture-based renderer/parser/stop defaults.
|
||||
if config.Renderer == "" || config.Parser == "" {
|
||||
arch := layer.GGML.KV().Architecture()
|
||||
switch arch {
|
||||
case "gemma4":
|
||||
config.Renderer = cmp.Or(config.Renderer, "gemma4")
|
||||
config.Parser = cmp.Or(config.Parser, "gemma4")
|
||||
if _, ok := r.Parameters["stop"]; !ok {
|
||||
if r.Parameters == nil {
|
||||
r.Parameters = make(map[string]any)
|
||||
}
|
||||
r.Parameters["stop"] = []string{"<turn|>"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
layers = append(layers, layer.Layer)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -153,7 +153,16 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
|||
// MLA tensors need higher precision to avoid quality degradation
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
} else if strings.Contains(name, "ffn_down") {
|
||||
iLayer := qs.iFfnDown
|
||||
// For MoE models, ffn_down.weight (dense) and ffn_down_exps.weight (expert) both
|
||||
// exist per layer and should get the same useMoreBits treatment. Dense sorts before
|
||||
// expert alphabetically, so dense increments the counter and expert uses counter-1.
|
||||
var iLayer int
|
||||
if strings.Contains(name, "_exps") {
|
||||
iLayer = max(0, qs.iFfnDown-1)
|
||||
} else {
|
||||
iLayer = qs.iFfnDown
|
||||
qs.iFfnDown++
|
||||
}
|
||||
n_layer := qs.nFfnDown
|
||||
if ftype == fsggml.FileTypeQ4_K_M {
|
||||
if useMoreBits(iLayer, n_layer) {
|
||||
|
|
@ -162,7 +171,6 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
|||
} else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
|
||||
newType = fsggml.TensorTypeQ5_K
|
||||
}
|
||||
qs.iFfnDown++
|
||||
} else if strings.Contains(name, "attn_output.weight") {
|
||||
if nExperts == 8 {
|
||||
if ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
|
||||
|
|
|
|||
Loading…
Reference in a new issue