This commit is contained in:
Michael Verrilli 2026-04-23 09:56:50 +08:00 committed by GitHub
commit 3fe1bd8a29
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 759 additions and 136 deletions

View file

@ -288,6 +288,7 @@ func (kv KV) OllamaEngineRequired() bool {
"mllama",
"nemotron_h", "nemotron_h_moe",
"nomic-bert",
"nomic-bert-moe",
"olmo3",
"qwen25vl",
"qwen3", "qwen3moe",

View file

@ -152,12 +152,15 @@ func New(c fs.Config) (model.Model, error) {
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
AddSpacePrefix: c.Bool("tokenizer.ggml.add_space_prefix", false),
}
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
t = tokenizer.NewWordPiece(vocab, true)
case "t5":
t = tokenizer.NewSentencePiece(vocab)
default:
return nil, model.ErrUnsupportedTokenizer
}

View file

@ -0,0 +1,53 @@
package bert
import (
"testing"
"github.com/ollama/ollama/tokenizer"
)
// TestBertNewT5Tokenizer verifies that a bert model configured with
// tokenizer.ggml.model="t5" (as bge-m3 is) loads a SentencePiece tokenizer
// rather than returning ErrUnsupportedTokenizer.
func TestBertNewT5Tokenizer(t *testing.T) {
vocab := &tokenizer.Vocabulary{
Values: []string{"▁hello", "▁world", "▁test", "<s>", "</s>", "h", "e", "l", "o", "w", "r", "d"},
Scores: []float32{-1, -1, -1, 0, 0, -5, -5, -5, -5, -5, -5, -5},
Types: []int32{1, 1, 1, 4, 4, 1, 1, 1, 1, 1, 1, 1},
BOS: []int32{3},
EOS: []int32{4},
AddBOS: true,
AddEOS: true,
AddSpacePrefix: true,
}
spm := tokenizer.NewSentencePiece(vocab)
t.Run("encodes_without_error", func(t *testing.T) {
ids, err := spm.Encode("hello world", true)
if err != nil {
t.Fatalf("Encode: %v", err)
}
if len(ids) == 0 {
t.Error("got empty token list")
}
// With add_space_prefix=true and BOS/EOS: [<s>, ▁hello, ▁world, </s>]
t.Logf("ids: %v", ids)
})
t.Run("add_space_prefix_prepends_whitespace_token", func(t *testing.T) {
// "hello" with add_space_prefix=true should produce ▁hello token (id=0)
ids, err := spm.Encode("hello", false)
if err != nil {
t.Fatal(err)
}
if len(ids) != 1 || ids[0] != 0 {
t.Errorf("got %v, want [0] (▁hello)", ids)
}
})
t.Run("is_sentence_piece_not_wordpiece", func(t *testing.T) {
// Verify it satisfies the Tokenizer interface and is SentencePiece
var _ tokenizer.Tokenizer = spm
})
}

View file

@ -196,30 +196,40 @@ func New(c fs.Config) (model.Model, error) {
}
}
vocab := &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
AddSpacePrefix: c.Bool("tokenizer.ggml.add_space_prefix", false),
}
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
t = tokenizer.NewWordPiece(vocab, false)
case "t5":
t = tokenizer.NewSentencePiece(vocab)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
Tokenizer: tokenizer.NewWordPiece(
&tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
),
Layers: layers,
Tokenizer: t,
Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,

View file

@ -1,9 +1,9 @@
package tokenizer
import (
"container/heap"
"fmt"
"log/slog"
"math"
"strconv"
"strings"
@ -24,7 +24,8 @@ func (spm SentencePiece) Vocabulary() *Vocabulary {
}
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
end := min(5, len(vocab.Values))
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:end], "scores", vocab.Scores[:end], "types", vocab.Types[:end])
counter := map[int]int{}
var maxTokenLen int
@ -53,6 +54,9 @@ func (spm SentencePiece) Is(id int32, special Special) bool {
}
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
if spm.vocab.AddSpacePrefix {
s = " " + s
}
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
@ -89,96 +93,7 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
if id := spm.vocab.Encode(text); id >= 0 {
ids = append(ids, id)
continue
}
q := &queue{}
heap.Init(q)
runes := []rune(text)
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
size: len(left) + len(right),
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
heap.Push(q, pair)
}
}
for q.Len() > 0 {
pair := heap.Pop(q).(*candidate)
left, right := merges[pair.a], merges[pair.b]
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
heap.Push(q, pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
heap.Push(q, pair)
}
}
for _, merge := range merges {
if token := string(merge.runes); token != "" {
id := spm.vocab.Encode(token)
if id >= 0 {
ids = append(ids, id)
continue
}
// Fallback to byte tokenization
var result []int32
for _, b := range []byte(token) {
byteToken := fmt.Sprintf("<0x%02X>", b)
unknownID := spm.vocab.Encode(byteToken)
if unknownID >= 0 {
result = append(result, unknownID)
} else {
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
}
}
ids = append(ids, result...)
}
}
ids = append(ids, spm.tokenizeViterbi(text)...)
}
if addSpecial {
@ -189,33 +104,82 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
type candidate struct {
a, b int
score float32
size int
}
// tokenizeViterbi segments text into vocabulary tokens using the Viterbi algorithm,
// finding the globally optimal (highest log-probability) segmentation.
func (spm SentencePiece) tokenizeViterbi(text string) []int32 {
runes := []rune(text)
n := len(runes)
if n == 0 {
return nil
}
type queue []*candidate
// dp[i] = best cumulative score for segmenting runes[0:i]
dp := make([]float32, n+1)
for i := range dp {
dp[i] = float32(math.Inf(-1))
}
dp[0] = 0
func (q queue) Len() int { return len(q) }
// back[i] = rune-length of the token ending at position i in the best path
back := make([]int, n+1)
func (q queue) Less(i, j int) bool {
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
}
for i := 0; i < n; i++ {
if math.IsInf(float64(dp[i]), -1) {
continue
}
for l := 1; i+l <= n && l <= spm.maxTokenLen; l++ {
piece := string(runes[i : i+l])
id := spm.vocab.Encode(piece)
if id < 0 {
continue
}
score := dp[i] + spm.vocab.Scores[id]
if score > dp[i+l] {
dp[i+l] = score
back[i+l] = l
}
}
}
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
// Traceback from position n to 0
var ids []int32
pos := n
for pos > 0 {
l := back[pos]
if l == 0 {
// Position unreachable via vocab — fall back to byte tokens for this char
ch := string(runes[pos-1 : pos])
var result []int32
for _, b := range []byte(ch) {
byteToken := fmt.Sprintf("<0x%02X>", b)
uid := spm.vocab.Encode(byteToken)
if uid >= 0 {
result = append(result, uid)
} else {
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
}
}
// Prepend in reverse order since we're tracing backwards
for i := len(result) - 1; i >= 0; i-- {
ids = append(ids, result[i])
}
pos--
continue
}
piece := string(runes[pos-l : pos])
id := spm.vocab.Encode(piece)
if id >= 0 {
ids = append(ids, id)
}
pos -= l
}
func (q *queue) Push(x interface{}) {
item := x.(*candidate)
*q = append(*q, item)
}
// Reverse: traceback produces tokens in reverse order
for i, j := 0, len(ids)-1; i < j; i, j = i+1, j-1 {
ids[i], ids[j] = ids[j], ids[i]
}
func (q *queue) Pop() interface{} {
old := *q
n := len(old)
item := old[n-1]
*q = old[0 : n-1]
return item
return ids
}
func (spm SentencePiece) Decode(ids []int32) (string, error) {

View file

@ -0,0 +1,569 @@
package tokenizer
import (
"slices"
"testing"
)
// makeUnigramVocab builds a SentencePiece from (piece, score) pairs.
// All tokens are TOKEN_TYPE_NORMAL unless specified via the byteTokens map.
func makeUnigramVocab(pieces []struct {
value string
score float32
typ int32
}) SentencePiece {
var v Vocabulary
for _, p := range pieces {
v.Values = append(v.Values, p.value)
v.Scores = append(v.Scores, p.score)
if p.typ == 0 {
v.Types = append(v.Types, TOKEN_TYPE_NORMAL)
} else {
v.Types = append(v.Types, p.typ)
}
}
return NewSentencePiece(&v)
}
// TestSentencePieceViterbiGlobalOptimum verifies that Viterbi finds the globally
// optimal segmentation even when a greedy left-to-right approach would not.
//
// Vocab scores: "a"=-1, "b"=-1, "c"=-1, "ab"=-0.5, "bc"=-0.3, "abc"=-3.0
//
// For text "abc":
// - Greedy (picks best-scoring prefix at each position) → [ab, c] = -1.5
// - Viterbi (global DP) → [a, bc] = -1.3 ← optimal
func TestSentencePieceViterbiGlobalOptimum(t *testing.T) {
spm := makeUnigramVocab([]struct {
value string
score float32
typ int32
}{
{"a", -1.0, 0},
{"b", -1.0, 0},
{"c", -1.0, 0},
{"ab", -0.5, 0}, // greedy would pick this over "a" alone
{"bc", -0.3, 0}, // but "bc" makes the global optimum [a, bc]
{"abc", -3.0, 0}, // whole word is least optimal
})
ids, err := spm.Encode("abc", false)
if err != nil {
t.Fatalf("Encode: %v", err)
}
// id of "a"=0, "bc"=4 — Viterbi must choose [a, bc] = -1.0 + -0.3 = -1.3
// over [ab, c] = -0.5 + -1.0 = -1.5
want := []int32{0, 4}
if !slices.Equal(ids, want) {
pieces := make([]string, len(ids))
for i, id := range ids {
pieces[i] = spm.vocab.Values[id]
}
t.Errorf("got %v (%v), want %v ([a, bc] — global optimum)", ids, pieces, want)
}
}
// TestSentencePieceViterbiPrefersSingleToken verifies that when one long token
// has a better score than any segmentation into pieces, Viterbi picks it.
func TestSentencePieceViterbiPrefersSingleToken(t *testing.T) {
spm := makeUnigramVocab([]struct {
value string
score float32
typ int32
}{
{"x", -5.0, 0},
{"y", -5.0, 0},
{"z", -5.0, 0},
{"xy", -3.0, 0},
{"yz", -3.0, 0},
{"xyz", -1.0, 0}, // single token is best
})
ids, err := spm.Encode("xyz", false)
if err != nil {
t.Fatalf("Encode: %v", err)
}
// [xyz] = -1.0 > [xy, z] = -8.0 > [x, yz] = -8.0 > [x, y, z] = -15.0
want := []int32{5} // id of "xyz"
if !slices.Equal(ids, want) {
pieces := make([]string, len(ids))
for i, id := range ids {
pieces[i] = spm.vocab.Values[id]
}
t.Errorf("got %v (%v), want %v ([xyz])", ids, pieces, want)
}
}
// TestSentencePieceViterbiThreeWaySplit verifies a case where neither a greedy
// left-to-right nor right-to-left approach finds the optimum.
//
// Vocab: "a"=-1, "b"=-2, "c"=-2, "d"=-1, "ab"=-0.8, "bc"=-0.5, "cd"=-5.0
//
// "abcd" optimal is [a, bc, d] = -1.0 + -0.5 + -1.0 = -2.5
// Greedy (picks "ab" first at pos 0, since -0.8 > -1.0) → [ab, c, d] = -0.8 + -2.0 + -1.0 = -3.8
func TestSentencePieceViterbiThreeWaySplit(t *testing.T) {
spm := makeUnigramVocab([]struct {
value string
score float32
typ int32
}{
{"a", -1.0, 0}, // 0
{"b", -2.0, 0}, // 1
{"c", -2.0, 0}, // 2
{"d", -1.0, 0}, // 3
{"ab", -0.8, 0}, // 4: greedy picks this first (better than "a")
{"bc", -0.5, 0}, // 5: Viterbi finds [a, bc, d] via bc
{"cd", -5.0, 0}, // 6: poor score, never chosen
})
ids, err := spm.Encode("abcd", false)
if err != nil {
t.Fatalf("Encode: %v", err)
}
// Scores:
// [a, bc, d] = -1.0 + -0.5 + -1.0 = -2.5 ← optimal
// [ab, c, d] = -0.8 + -2.0 + -1.0 = -3.8
// [a, b, cd] = -1.0 + -2.0 + -5.0 = -8.0
// [a, b, c, d] = -7.0
want := []int32{0, 5, 3} // a=0, bc=5, d=3
if !slices.Equal(ids, want) {
pieces := make([]string, len(ids))
for i, id := range ids {
pieces[i] = spm.vocab.Values[id]
}
t.Errorf("got %v (%v), want %v ([a, bc, d])", ids, pieces, want)
}
}
// TestSentencePieceAddSpacePrefix verifies that setting AddSpacePrefix=true
// prepends a ▁ to the first word, matching SentencePiece add_dummy_prefix semantics.
func TestSentencePieceAddSpacePrefix(t *testing.T) {
pieces := []struct {
value string
score float32
typ int32
}{
{"hello", -1.0, 0}, // 0: no leading ▁
{"▁hello", -1.0, 0}, // 1: with leading ▁
{"▁world", -1.0, 0}, // 2
{"world", -1.0, 0}, // 3: no leading ▁
{"h", -10.0, 0}, // 4: fallback chars
{"e", -10.0, 0}, // 5
{"l", -10.0, 0}, // 6
{"o", -10.0, 0}, // 7
{"w", -10.0, 0}, // 8
{"r", -10.0, 0}, // 9
{"d", -10.0, 0}, // 10
{"▁", -100.0, 0}, // 11
}
withoutPrefix := makeUnigramVocab(pieces)
withoutPrefix.vocab.AddSpacePrefix = false
withPrefix := makeUnigramVocab(pieces)
withPrefix.vocab.AddSpacePrefix = true
t.Run("without_add_space_prefix", func(t *testing.T) {
// "hello world" → replace space with ▁ → "hello▁world"
// tokenizes as: "hello" + "▁world"
ids, err := withoutPrefix.Encode("hello world", false)
if err != nil {
t.Fatal(err)
}
want := []int32{0, 2} // "hello", "▁world"
if !slices.Equal(ids, want) {
pieces := make([]string, len(ids))
for i, id := range ids {
pieces[i] = withoutPrefix.vocab.Values[id]
}
t.Errorf("got %v (%v), want %v ([hello, ▁world])", ids, pieces, want)
}
})
t.Run("with_add_space_prefix", func(t *testing.T) {
// add_space_prefix prepends " " → " hello world" → "▁hello▁world"
// tokenizes as: "▁hello" + "▁world"
ids, err := withPrefix.Encode("hello world", false)
if err != nil {
t.Fatal(err)
}
want := []int32{1, 2} // "▁hello", "▁world"
if !slices.Equal(ids, want) {
pieces := make([]string, len(ids))
for i, id := range ids {
pieces[i] = withPrefix.vocab.Values[id]
}
t.Errorf("got %v (%v), want %v ([▁hello, ▁world])", ids, pieces, want)
}
})
t.Run("first_word_gets_prefix", func(t *testing.T) {
// Single word with add_space_prefix: "hello" → " hello" → "▁hello"
ids, err := withPrefix.Encode("hello", false)
if err != nil {
t.Fatal(err)
}
want := []int32{1} // "▁hello"
if !slices.Equal(ids, want) {
pieces := make([]string, len(ids))
for i, id := range ids {
pieces[i] = withPrefix.vocab.Values[id]
}
t.Errorf("got %v (%v), want %v ([▁hello])", ids, pieces, want)
}
})
t.Run("empty_string", func(t *testing.T) {
ids, err := withPrefix.Encode("", false)
if err != nil {
t.Fatal(err)
}
// " " → "▁" → should produce the ▁ token or nothing if only space
// After replacing " " with "▁", we get "▁" which is a token
// (or if add_space_prefix=true with empty, " " → "▁", encoded as id=11)
if len(ids) > 1 {
t.Errorf("empty string with add_space_prefix got %d tokens, want ≤1", len(ids))
}
})
}
// TestSentencePieceByteTokenFallback verifies that characters with no vocabulary
// entry fall back to byte-level tokens.
func TestSentencePieceByteTokenFallback(t *testing.T) {
spm := makeUnigramVocab([]struct {
value string
score float32
typ int32
}{
{"hello", -1.0, TOKEN_TYPE_NORMAL}, // 0
{"<0xC3>", -1.0, TOKEN_TYPE_BYTE}, // 1: first byte of é in UTF-8
{"<0xA9>", -1.0, TOKEN_TYPE_BYTE}, // 2: second byte of é
{"<0x21>", -1.0, TOKEN_TYPE_BYTE}, // 3: '!'
{"world", -1.0, TOKEN_TYPE_NORMAL}, // 4: padding to reach min trace size
})
t.Run("utf8_byte_fallback", func(t *testing.T) {
// "é" = 0xC3 0xA9 in UTF-8; no "é" token → fall back to bytes
ids, err := spm.Encode("é", false)
if err != nil {
t.Fatal(err)
}
want := []int32{1, 2} // <0xC3>, <0xA9>
if !slices.Equal(ids, want) {
t.Errorf("got %v, want %v (<0xC3>, <0xA9>)", ids, want)
}
})
t.Run("mixed_known_and_byte_fallback", func(t *testing.T) {
// "hello!" — "hello" is in vocab, "!" falls back to <0x21>
ids, err := spm.Encode("hello!", false)
if err != nil {
t.Fatal(err)
}
want := []int32{0, 3} // hello, <0x21>
if !slices.Equal(ids, want) {
t.Errorf("got %v, want %v (hello, <0x21>)", ids, want)
}
})
t.Run("multi_byte_utf8_byte_fallback", func(t *testing.T) {
// "é!" — both use byte fallback for é, then <0x21> for !
ids, err := spm.Encode("é!", false)
if err != nil {
t.Fatal(err)
}
want := []int32{1, 2, 3} // <0xC3>, <0xA9>, <0x21>
if !slices.Equal(ids, want) {
t.Errorf("got %v, want %v (<0xC3>, <0xA9>, <0x21>)", ids, want)
}
})
}
// TestSentencePieceEdgeCases covers boundary conditions.
func TestSentencePieceEdgeCases(t *testing.T) {
spm := makeUnigramVocab([]struct {
value string
score float32
typ int32
}{
{"a", -1.0, 0},
{"b", -1.0, 0},
{"ab", -0.5, 0},
{"▁a", -1.0, 0},
{"▁b", -1.0, 0},
{"<0x61>", -5.0, TOKEN_TYPE_BYTE}, // byte for 'a'
})
t.Run("empty_string", func(t *testing.T) {
ids, err := spm.Encode("", false)
if err != nil {
t.Fatal(err)
}
if len(ids) != 0 {
t.Errorf("got %v, want empty", ids)
}
})
t.Run("single_char", func(t *testing.T) {
ids, err := spm.Encode("a", false)
if err != nil {
t.Fatal(err)
}
if len(ids) == 0 {
t.Error("got empty, want non-empty")
}
})
t.Run("single_char_in_vocab", func(t *testing.T) {
ids, err := spm.Encode("a", false)
if err != nil {
t.Fatal(err)
}
want := []int32{0} // "a"
if !slices.Equal(ids, want) {
t.Errorf("got %v, want %v", ids, want)
}
})
t.Run("exact_vocab_match_whole_string", func(t *testing.T) {
ids, err := spm.Encode("ab", false)
if err != nil {
t.Fatal(err)
}
// "ab" is in vocab as single token, score -0.5 > "a"(-1.0) + "b"(-1.0) = -2.0
want := []int32{2}
if !slices.Equal(ids, want) {
t.Errorf("got %v, want %v ([ab])", ids, want)
}
})
t.Run("space_becomes_spm_whitespace", func(t *testing.T) {
// " a" → "▁a"
ids, err := spm.Encode(" a", false)
if err != nil {
t.Fatal(err)
}
want := []int32{3} // "▁a"
if !slices.Equal(ids, want) {
t.Errorf("got %v, want %v ([▁a])", ids, want)
}
})
t.Run("space_word_space_word", func(t *testing.T) {
// " a b" → "▁a▁b"
ids, err := spm.Encode(" a b", false)
if err != nil {
t.Fatal(err)
}
want := []int32{3, 4} // "▁a", "▁b"
if !slices.Equal(ids, want) {
t.Errorf("got %v, want %v ([▁a, ▁b])", ids, want)
}
})
}
// TestSentencePieceSpecialTokenHandling verifies that control/special tokens
// are split out before Viterbi tokenization and passed through as-is.
func TestSentencePieceSpecialTokenHandling(t *testing.T) {
v := &Vocabulary{
Values: []string{
"hello", // 0: NORMAL
"▁world", // 1: NORMAL
"<s>", // 2: CONTROL (BOS)
"</s>", // 3: CONTROL (EOS)
"h", // 4: NORMAL fallback
"e", // 5: NORMAL fallback
"l", // 6: NORMAL fallback
"o", // 7: NORMAL fallback
},
Scores: []float32{-1.0, -1.0, 0.0, 0.0, -5.0, -5.0, -5.0, -5.0},
Types: []int32{
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_CONTROL,
TOKEN_TYPE_CONTROL,
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_NORMAL,
TOKEN_TYPE_NORMAL,
},
BOS: []int32{2},
EOS: []int32{3},
AddBOS: true,
AddEOS: true,
}
spm := NewSentencePiece(v)
t.Run("addSpecial_wraps_with_bos_eos", func(t *testing.T) {
ids, err := spm.Encode("hello", true)
if err != nil {
t.Fatal(err)
}
// [<s>, hello, </s>]
want := []int32{2, 0, 3}
if !slices.Equal(ids, want) {
t.Errorf("got %v, want %v", ids, want)
}
})
t.Run("no_addSpecial", func(t *testing.T) {
ids, err := spm.Encode("hello", false)
if err != nil {
t.Fatal(err)
}
want := []int32{0}
if !slices.Equal(ids, want) {
t.Errorf("got %v, want %v", ids, want)
}
})
t.Run("control_token_in_text_is_passthrough", func(t *testing.T) {
// "<s>" in the input text should be treated as a special token
ids, err := spm.Encode("<s>hello</s>", false)
if err != nil {
t.Fatal(err)
}
// ["<s>", "hello", "</s>"]
want := []int32{2, 0, 3}
if !slices.Equal(ids, want) {
t.Errorf("got %v, want %v", ids, want)
}
})
}
// TestSentencePieceRoundtrip verifies encode→decode round-trip for the Unigram model.
// Note: with add_space_prefix=true, the leading ▁ is part of the encoded token stream
// and decodes back to a leading space — round-trip input must include the leading space.
func TestSentencePieceRoundtrip(t *testing.T) {
spm := makeUnigramVocab([]struct {
value string
score float32
typ int32
}{
{"▁hello", -1.0, 0},
{"▁world", -1.0, 0},
{"▁", -5.0, 0},
{"h", -10.0, 0},
{"e", -10.0, 0},
{"l", -10.0, 0},
{"o", -10.0, 0},
{"w", -10.0, 0},
{"r", -10.0, 0},
{"d", -10.0, 0},
{"▁t", -2.0, 0},
{"es", -2.0, 0},
{"t", -10.0, 0},
})
spm.vocab.AddSpacePrefix = false
// These strings already have leading ▁-equivalents embedded via space→▁ replacement.
// encode(" hello world") → "▁hello▁world" → [▁hello, ▁world] → decode → " hello world"
cases := []string{
" hello",
" hello world",
" test",
}
for _, want := range cases {
ids, err := spm.Encode(want, false)
if err != nil {
t.Fatalf("Encode(%q): %v", want, err)
}
got, err := spm.Decode(ids)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if got != want {
t.Errorf("roundtrip(%q): got %q", want, got)
}
}
}
// TestSentencePieceMaxTokenLen verifies that Viterbi doesn't look beyond maxTokenLen.
func TestSentencePieceMaxTokenLen(t *testing.T) {
// Only "abc" and individual chars — no token longer than 3
spm := makeUnigramVocab([]struct {
value string
score float32
typ int32
}{
{"a", -1.0, 0},
{"b", -1.0, 0},
{"c", -1.0, 0},
{"d", -1.0, 0},
{"abc", -0.1, 0},
{"abcd", -0.05, 0}, // longer than maxTokenLen if we limit it — but maxTokenLen is computed from vocab
})
// "abcd" should be tokenized as [abcd] since it's in vocab and best overall
ids, err := spm.Encode("abcd", false)
if err != nil {
t.Fatal(err)
}
// [abcd] = -0.05 is best
want := []int32{5} // "abcd"
if !slices.Equal(ids, want) {
pieces := make([]string, len(ids))
for i, id := range ids {
pieces[i] = spm.vocab.Values[id]
}
t.Errorf("got %v (%v), want [abcd]", ids, pieces)
}
}
// TestSentencePieceUnicodeRunes verifies correct rune-based handling for
// multi-byte Unicode characters.
func TestSentencePieceUnicodeRunes(t *testing.T) {
spm := makeUnigramVocab([]struct {
value string
score float32
typ int32
}{
{"▁こんにちは", -1.0, 0}, // Japanese greeting as one token
{"▁世界", -1.0, 0}, // "world" in Chinese/Japanese
{"こ", -5.0, 0},
{"ん", -5.0, 0},
{"に", -5.0, 0},
{"ち", -5.0, 0},
{"は", -5.0, 0},
{"世", -5.0, 0},
{"界", -5.0, 0},
})
spm.vocab.AddSpacePrefix = true
t.Run("japanese_single_token", func(t *testing.T) {
ids, err := spm.Encode("こんにちは", false)
if err != nil {
t.Fatal(err)
}
// AddSpacePrefix=true → " こんにちは" → "▁こんにちは" → id 0
want := []int32{0}
if !slices.Equal(ids, want) {
pieces := make([]string, len(ids))
for i, id := range ids {
pieces[i] = spm.vocab.Values[id]
}
t.Errorf("got %v (%v), want [▁こんにちは]", ids, pieces)
}
})
t.Run("roundtrip_japanese", func(t *testing.T) {
// With add_space_prefix=true, encode prepends " " so the decoded output has a
// leading space. Roundtrip from " こんにちは 世界" → encode → decode → same string.
want := " こんにちは 世界"
ids, err := spm.Encode("こんにちは 世界", false)
if err != nil {
t.Fatal(err)
}
got, err := spm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if got != want {
t.Errorf("roundtrip: got %q, want %q", got, want)
}
})
}

View file

@ -21,6 +21,7 @@ type Vocabulary struct {
BOS, EOS []int32
AddBOS, AddEOS bool
AddSpacePrefix bool
specialOnce sync.Once
special []string

View file

@ -139,6 +139,9 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
if len(pieces) > 0 {
ids = append(ids, pieces...)
} else {
if unk < 0 {
return nil, fmt.Errorf("token %q not in vocabulary and no [UNK] fallback available", word)
}
ids = append(ids, unk)
}
}

View file

@ -38,6 +38,25 @@ func TestWordPiece(t *testing.T) {
}
}
func TestWordPieceEncodeReturnsErrorWhenUnkMissing(t *testing.T) {
// Regression test for issue #15174: WordPiece used to silently emit
// -1 when a word could not be tokenized and [UNK] was absent from the
// vocab. That -1 then crashed the embedding forward pass with a
// GGML_ASSERT in ggml_get_rows.
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[CLS]", "[SEP]", "▁hello"},
BOS: []int32{0},
EOS: []int32{1},
},
true,
)
if _, err := wpm.Encode("hello world!", false); err == nil {
t.Error("expected error when word is OOV and [UNK] is missing, got nil")
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece