mlx: follow up for gemma4

Followup for review comments on #15244
This commit is contained in:
Daniel Hiltgen 2026-04-13 16:43:46 -07:00
parent 2cba7756c5
commit 966ab7ffa0
2 changed files with 82 additions and 168 deletions

View file

@ -132,11 +132,12 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = inferSafetensorsCapabilities(opts.ModelDir)
// Set parser and renderer name based on architecture
parserName = getParserName(opts.ModelDir)
rendererName = getRendererName(opts.ModelDir)
// Load config.json once and share it across capability / parser /
// renderer detection.
cfg := loadSafetensorsConfig(opts.ModelDir)
capabilities = inferSafetensorsCapabilities(cfg)
parserName = getParserName(cfg)
rendererName = getRendererName(cfg)
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
@ -183,19 +184,65 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
return nil
}
func inferSafetensorsCapabilities(modelDir string) []string {
// safetensorsConfig is the subset of HuggingFace config.json fields used for
// capability / parser / renderer detection at create time. Load once per
// create to avoid re-reading and re-parsing config.json for each predicate.
type safetensorsConfig struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
VisionConfig *map[string]any `json:"vision_config"`
AudioConfig *map[string]any `json:"audio_config"`
}
// loadSafetensorsConfig reads modelDir/config.json. Returns nil if the file
// is missing or unparseable — callers should treat that as "unknown" and
// fall through to defaults.
func loadSafetensorsConfig(modelDir string) *safetensorsConfig {
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
if err != nil {
return nil
}
var cfg safetensorsConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil
}
return &cfg
}
// archAndTypeTokens returns the lowercased architectures (in declaration
// order) followed by the lowercased model_type. Used by parser/renderer
// dispatch so the matching logic walks architectures and model_type in one
// loop instead of duplicating the match block.
func (cfg *safetensorsConfig) archAndTypeTokens() []string {
if cfg == nil {
return nil
}
tokens := make([]string, 0, len(cfg.Architectures)+1)
for _, arch := range cfg.Architectures {
tokens = append(tokens, strings.ToLower(arch))
}
if cfg.ModelType != "" {
tokens = append(tokens, strings.ToLower(cfg.ModelType))
}
return tokens
}
func inferSafetensorsCapabilities(cfg *safetensorsConfig) []string {
capabilities := []string{"completion"}
if cfg == nil {
return capabilities
}
// Qwen3.5 multimodal checkpoints use ConditionalGeneration architectures.
if supportsVision(modelDir) {
if cfg.VisionConfig != nil {
capabilities = append(capabilities, "vision")
}
if supportsAudio(modelDir) {
if cfg.AudioConfig != nil {
capabilities = append(capabilities, "audio")
}
if supportsThinking(modelDir) {
if supportsThinking(cfg) {
capabilities = append(capabilities, "thinking")
}
@ -453,191 +500,57 @@ func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
return layers, nil
}
// supportsThinking checks if the model supports thinking mode based on its architecture.
// This reads the config.json from the model directory and checks the architectures field.
func supportsThinking(modelDir string) bool {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
// supportsThinking reports whether the model's architecture or model_type
// matches one of the known thinking-capable families.
func supportsThinking(cfg *safetensorsConfig) bool {
if cfg == nil {
return false
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
// Check architectures that support thinking
thinkingArchitectures := []string{
"glm4moe", // GLM-4 MoE models
"deepseek", // DeepSeek models
"qwen3", // Qwen3 models
}
// Check the architecture list
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
for _, tok := range cfg.archAndTypeTokens() {
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(archLower, thinkArch) {
if strings.Contains(tok, thinkArch) {
return true
}
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
for _, thinkArch := range thinkingArchitectures {
if strings.Contains(typeLower, thinkArch) {
return true
}
}
}
return false
}
// supportsVision checks if the model has a vision encoder by looking for
// vision_config in config.json.
func supportsVision(modelDir string) bool {
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
if err != nil {
return false
}
var cfg struct {
VisionConfig *map[string]any `json:"vision_config"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
return cfg.VisionConfig != nil
}
func supportsAudio(modelDir string) bool {
data, err := os.ReadFile(filepath.Join(modelDir, "config.json"))
if err != nil {
return false
}
var cfg struct {
AudioConfig *map[string]any `json:"audio_config"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
return cfg.AudioConfig != nil
}
// getParserName returns the parser name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate parser.
func getParserName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
// Check architectures for known parsers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
func getParserName(cfg *safetensorsConfig) string {
for _, tok := range cfg.archAndTypeTokens() {
switch {
case strings.Contains(tok, "glm4"), strings.Contains(tok, "glm-4"):
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
case strings.Contains(tok, "deepseek"):
return "deepseek3"
}
if strings.Contains(archLower, "gemma4") {
case strings.Contains(tok, "gemma4"):
return "gemma4"
}
if strings.Contains(archLower, "qwen3") {
case strings.Contains(tok, "qwen3"):
return "qwen3"
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "gemma4") {
return "gemma4"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3"
}
}
return ""
}
// getRendererName returns the renderer name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate renderer.
func getRendererName(modelDir string) string {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return ""
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return ""
}
// Check architectures for known renderers
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "gemma4") {
func getRendererName(cfg *safetensorsConfig) string {
for _, tok := range cfg.archAndTypeTokens() {
switch {
case strings.Contains(tok, "gemma4"):
return "gemma4"
}
if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
case strings.Contains(tok, "glm4"), strings.Contains(tok, "glm-4"):
return "glm-4.7"
}
if strings.Contains(archLower, "deepseek") {
case strings.Contains(tok, "deepseek"):
return "deepseek3"
}
if strings.Contains(archLower, "qwen3") {
case strings.Contains(tok, "qwen3"):
return "qwen3-coder"
}
}
// Also check model_type
if cfg.ModelType != "" {
typeLower := strings.ToLower(cfg.ModelType)
if strings.Contains(typeLower, "gemma4") {
return "gemma4"
}
if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
return "glm-4.7"
}
if strings.Contains(typeLower, "deepseek") {
return "deepseek3"
}
if strings.Contains(typeLower, "qwen3") {
return "qwen3-coder"
}
}
return ""
}

View file

@ -352,7 +352,8 @@ func TestInferSafetensorsCapabilities(t *testing.T) {
t.Fatal(err)
}
if got := inferSafetensorsCapabilities(dir); !slices.Equal(got, tt.want) {
cfg := loadSafetensorsConfig(dir)
if got := inferSafetensorsCapabilities(cfg); !slices.Equal(got, tt.want) {
t.Fatalf("inferSafetensorsCapabilities() = %#v, want %#v", got, tt.want)
}
})
@ -571,7 +572,7 @@ func TestSupportsThinking(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644)
if got := supportsThinking(dir); got != tt.want {
if got := supportsThinking(loadSafetensorsConfig(dir)); got != tt.want {
t.Errorf("supportsThinking() = %v, want %v", got, tt.want)
}
})
@ -579,7 +580,7 @@ func TestSupportsThinking(t *testing.T) {
}
func TestSupportsThinking_NoConfig(t *testing.T) {
if supportsThinking(t.TempDir()) {
if supportsThinking(loadSafetensorsConfig(t.TempDir())) {
t.Error("supportsThinking should return false for missing config.json")
}
}
@ -627,7 +628,7 @@ func TestGetParserName(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644)
if got := getParserName(dir); got != tt.want {
if got := getParserName(loadSafetensorsConfig(dir)); got != tt.want {
t.Errorf("getParserName() = %q, want %q", got, tt.want)
}
})
@ -667,7 +668,7 @@ func TestGetRendererName(t *testing.T) {
dir := t.TempDir()
os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644)
if got := getRendererName(dir); got != tt.want {
if got := getRendererName(loadSafetensorsConfig(dir)); got != tt.want {
t.Errorf("getRendererName() = %q, want %q", got, tt.want)
}
})