fix: preserve strip_accents preprocessing in BERT tokenizer conversion

This commit is contained in:
BillionClaw 2026-04-17 03:36:19 +08:00
parent b9cb535407
commit bc4fa6290d
7 changed files with 43 additions and 9 deletions

View file

@ -119,6 +119,7 @@ func (p *bertModel) KV(t *Tokenizer) KV {
kv["tokenizer.ggml.model"] = "bert"
kv["tokenizer.ggml.token_type_count"] = uint32(2)
kv["tokenizer.ggml.strip_accents"] = t.StripAccents
// convert to phantom space tokens
for i, e := range t.Tokens {

View file

@ -146,6 +146,7 @@ func (p *nomicbertModel) KV(t *Tokenizer) KV {
kv["tokenizer.ggml.model"] = "bert"
kv["tokenizer.ggml.token_type_count"] = uint32(2)
kv["tokenizer.ggml.strip_accents"] = t.StripAccents
// convert to phantom space tokens
for i, e := range t.Tokens {

View file

@ -29,8 +29,9 @@ type Tokenizer struct {
SpecialVocabulary []*SpecialVocabulary
Merges []string
Pre string
Template string
Pre string
Template string
StripAccents bool
}
func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) {
@ -141,6 +142,14 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
}
}
// Read strip_accents for BERT-style tokenizers
if bts, ok := p["strip_accents"]; ok {
if err := json.Unmarshal(bts, &t.StripAccents); err != nil {
// Ignore errors - default is false
slog.Debug("tokenizer", "strip_accents parse error", err)
}
}
for _, st := range specialTokenTypes {
sv := SpecialVocabulary{Type: st}
if bts, ok := p[fmt.Sprintf("add_%s_token", st)]; ok {

View file

@ -157,7 +157,7 @@ func New(c fs.Config) (model.Model, error) {
var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
t = tokenizer.NewWordPiece(vocab, true)
t = tokenizer.NewWordPiece(vocab, true, c.Bool("tokenizer.ggml.strip_accents", false))
default:
return nil, model.ErrUnsupportedTokenizer
}

View file

@ -218,6 +218,7 @@ func New(c fs.Config) (model.Model, error) {
},
},
false,
c.Bool("tokenizer.ggml.strip_accents", false),
),
Layers: layers,
Options: Options{

View file

@ -10,8 +10,23 @@ import (
)
type WordPiece struct {
vocab *Vocabulary
lowercase bool
vocab *Vocabulary
lowercase bool
stripAccents bool
}
// stripAccents removes combining diacritical marks (U+0300U+036F) from the string.
// This applies NFD normalization and filters out the combining marks.
func stripAccents(s string) string {
var sb strings.Builder
for _, r := range s {
// Skip combining diacritical marks (U+0300U+036F)
if r >= 0x0300 && r <= 0x036F {
continue
}
sb.WriteRune(r)
}
return sb.String()
}
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
@ -100,6 +115,11 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
// Apply accent stripping if enabled (BERT-style preprocessing)
if wpm.stripAccents {
s = stripAccents(s)
}
// TODO: use [UNK] from config
unk := wpm.vocab.Encode("[UNK]")
for word := range wpm.words(s) {
@ -163,9 +183,10 @@ func (wpm WordPiece) Vocabulary() *Vocabulary {
var _ Tokenizer = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
func NewWordPiece(vocab *Vocabulary, lowercase, stripAccents bool) WordPiece {
return WordPiece{
vocab: vocab,
lowercase: lowercase,
vocab: vocab,
lowercase: lowercase,
stripAccents: stripAccents,
}
}

View file

@ -16,7 +16,8 @@ func TestWordPiece(t *testing.T) {
BOS: []int32{1},
EOS: []int32{2},
},
true, // lowercase
true, // lowercase
false, // stripAccents
)
ids, err := wpm.Encode("Hello world!", true)