mlx: fix vision capability + min version (#15106)

This commit is contained in:
Patrick Devine 2026-03-27 17:09:28 -07:00 committed by GitHub
parent 3824e380a8
commit 9e7cb9697e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 208 additions and 22 deletions

View file

@ -301,7 +301,7 @@ Weigh anchor!
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
Requires: "0.14.0", Requires: "0.19.0",
}, false, &b); err != nil { }, false, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -310,10 +310,17 @@ Weigh anchor!
architecture test architecture test
parameters 7B parameters 7B
quantization FP16 quantization FP16
requires 0.14.0 requires 0.19.0
` `
if diff := cmp.Diff(expect, b.String()); diff != "" { trimLinePadding := func(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
lines[i] = strings.TrimRight(line, " \t\r")
}
return strings.Join(lines, "\n")
}
if diff := cmp.Diff(trimLinePadding(expect), trimLinePadding(b.String())); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff) t.Errorf("unexpected output (-want +got):\n%s", diff)
} }
}) })
@ -1912,7 +1919,7 @@ func TestShowInfoImageGen(t *testing.T) {
QuantizationLevel: "Q8", QuantizationLevel: "Q8",
}, },
Capabilities: []model.Capability{model.CapabilityImage}, Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0", Requires: "0.19.0",
}, false, &b) }, false, &b)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -1922,7 +1929,7 @@ func TestShowInfoImageGen(t *testing.T) {
" architecture ZImagePipeline \n" + " architecture ZImagePipeline \n" +
" parameters 10.3B \n" + " parameters 10.3B \n" +
" quantization Q8 \n" + " quantization Q8 \n" +
" requires 0.14.0 \n" + " requires 0.19.0 \n" +
"\n" + "\n" +
" Capabilities\n" + " Capabilities\n" +
" image \n" + " image \n" +

View file

@ -1225,9 +1225,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount)) modelDetails.ParameterSize = format.HumanNumber(uint64(paramCount))
} }
} }
// Get torch_dtype directly from config.json for quantization level // Older manifests may not have file_type populated for safetensors models.
if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" { if modelDetails.QuantizationLevel == "" {
modelDetails.QuantizationLevel = dtype if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
} }
} }

View file

@ -26,6 +26,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
@ -547,6 +548,38 @@ func TestRoutes(t *testing.T) {
} }
} }
func TestGetModelInfo_SafetensorsUsesStoredFileType(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cfgData, err := json.Marshal(model.ConfigV2{
ModelFormat: "safetensors",
FileType: "mxfp8",
Capabilities: []string{"completion"},
})
if err != nil {
t.Fatalf("failed to marshal config: %v", err)
}
configLayer, err := manifest.NewLayer(bytes.NewReader(cfgData), "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatalf("failed to create config layer: %v", err)
}
name := model.ParseName("show-safetensors")
if err := manifest.WriteManifest(name, configLayer, nil); err != nil {
t.Fatalf("failed to write manifest: %v", err)
}
resp, err := GetModelInfo(api.ShowRequest{Model: name.String()})
if err != nil {
t.Fatalf("GetModelInfo() error = %v", err)
}
if resp.Details.QuantizationLevel != "mxfp8" {
t.Fatalf("QuantizationLevel = %q, want %q", resp.Details.QuantizationLevel, "mxfp8")
}
}
func casingShuffle(s string) string { func casingShuffle(s string) string {
rr := []rune(s) rr := []rune(s)
for i := range rr { for i := range rr {

View file

@ -26,7 +26,7 @@ import (
) )
// MinOllamaVersion is the minimum Ollama version required for safetensors models. // MinOllamaVersion is the minimum Ollama version required for safetensors models.
const MinOllamaVersion = "0.14.0" const MinOllamaVersion = "0.19.0"
// ModelfileConfig holds configuration extracted from a Modelfile. // ModelfileConfig holds configuration extracted from a Modelfile.
type ModelfileConfig struct { type ModelfileConfig struct {
@ -132,12 +132,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
if isSafetensors { if isSafetensors {
modelType = "safetensors model" modelType = "safetensors model"
spinnerKey = "create" spinnerKey = "create"
capabilities = []string{"completion"} capabilities = inferSafetensorsCapabilities(opts.ModelDir)
// Check if model supports thinking based on architecture
if supportsThinking(opts.ModelDir) {
capabilities = append(capabilities, "thinking")
}
// Set parser and renderer name based on architecture // Set parser and renderer name based on architecture
parserName = getParserName(opts.ModelDir) parserName = getParserName(opts.ModelDir)
@ -188,6 +183,21 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
return nil return nil
} }
func inferSafetensorsCapabilities(modelDir string) []string {
capabilities := []string{"completion"}
// Qwen3.5 multimodal checkpoints use ConditionalGeneration architectures.
if supportsVision(modelDir) {
capabilities = append(capabilities, "vision")
}
if supportsThinking(modelDir) {
capabilities = append(capabilities, "thinking")
}
return capabilities
}
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers. // newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator { func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) { return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
@ -338,6 +348,7 @@ func newManifestWriter(opts CreateOptions, capabilities []string, parserName, re
// Create config blob with version requirement // Create config blob with version requirement
configData := model.ConfigV2{ configData := model.ConfigV2{
ModelFormat: "safetensors", ModelFormat: "safetensors",
FileType: strings.ToLower(strings.TrimSpace(opts.Quantize)),
Capabilities: caps, Capabilities: caps,
Requires: MinOllamaVersion, Requires: MinOllamaVersion,
Parser: resolveParserName(opts.Modelfile, parserName), Parser: resolveParserName(opts.Modelfile, parserName),
@ -485,6 +496,34 @@ func supportsThinking(modelDir string) bool {
return false return false
} }
// supportsVision checks if the model supports image input based on its architecture.
// Qwen3.5 multimodal checkpoints are published as ConditionalGeneration architectures.
func supportsVision(modelDir string) bool {
configPath := filepath.Join(modelDir, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return false
}
var cfg struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return false
}
for _, arch := range cfg.Architectures {
archLower := strings.ToLower(arch)
if strings.Contains(archLower, "qwen3") && strings.Contains(archLower, "conditionalgeneration") {
return true
}
}
typeLower := strings.ToLower(cfg.ModelType)
return strings.Contains(typeLower, "qwen3") && strings.Contains(typeLower, "conditionalgeneration")
}
// getParserName returns the parser name for a model based on its architecture. // getParserName returns the parser name for a model based on its architecture.
// This reads the config.json from the model directory and determines the appropriate parser. // This reads the config.json from the model directory and determines the appropriate parser.
func getParserName(modelDir string) string { func getParserName(modelDir string) string {

View file

@ -3,11 +3,15 @@ package client
import ( import (
"encoding/json" "encoding/json"
"os" "os"
"path/filepath"
"slices"
"strings" "strings"
"testing" "testing"
"github.com/ollama/ollama/manifest" "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
) )
func TestModelfileConfig(t *testing.T) { func TestModelfileConfig(t *testing.T) {
@ -120,8 +124,8 @@ func TestMinOllamaVersion(t *testing.T) {
if MinOllamaVersion == "" { if MinOllamaVersion == "" {
t.Error("MinOllamaVersion should not be empty") t.Error("MinOllamaVersion should not be empty")
} }
if MinOllamaVersion != "0.14.0" { if MinOllamaVersion != "0.19.0" {
t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.14.0") t.Errorf("MinOllamaVersion = %q, want %q", MinOllamaVersion, "0.19.0")
} }
} }
@ -289,6 +293,52 @@ func TestCreateOptions_Defaults(t *testing.T) {
} }
} }
func TestInferSafetensorsCapabilities(t *testing.T) {
tests := []struct {
name string
configJSON string
want []string
}{
{
name: "qwen3.5 text model",
configJSON: `{
"architectures": ["Qwen3_5ForCausalLM"],
"model_type": "qwen3"
}`,
want: []string{"completion", "thinking"},
},
{
name: "qwen3.5 multimodal model",
configJSON: `{
"architectures": ["Qwen3_5ForConditionalGeneration"],
"model_type": "qwen3"
}`,
want: []string{"completion", "vision", "thinking"},
},
{
name: "non-qwen conditional generation model",
configJSON: `{
"architectures": ["SomeOtherForConditionalGeneration"],
"model_type": "other"
}`,
want: []string{"completion"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte(tt.configJSON), 0o644); err != nil {
t.Fatal(err)
}
if got := inferSafetensorsCapabilities(dir); !slices.Equal(got, tt.want) {
t.Fatalf("inferSafetensorsCapabilities() = %#v, want %#v", got, tt.want)
}
})
}
}
func TestQuantizeSupported(t *testing.T) { func TestQuantizeSupported(t *testing.T) {
// This just verifies the function exists and returns a boolean // This just verifies the function exists and returns a boolean
// The actual value depends on build tags (mlx vs non-mlx) // The actual value depends on build tags (mlx vs non-mlx)
@ -339,3 +389,43 @@ func TestCreateModelfileLayersIncludesParameters(t *testing.T) {
t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7)) t.Fatalf("temperature = %v, want %v", got["temperature"], float64(0.7))
} }
} }
func TestNewManifestWriter_PopulatesFileTypeFromQuantize(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
opts := CreateOptions{
ModelName: "test-quantized",
ModelDir: t.TempDir(),
Quantize: "MXFP8",
}
writer := newManifestWriter(opts, []string{"completion"}, "qwen3", "qwen3")
if err := writer(opts.ModelName, create.LayerInfo{}, nil); err != nil {
t.Fatalf("newManifestWriter() error = %v", err)
}
name := model.ParseName(opts.ModelName)
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
t.Fatalf("ParseNamedManifest() error = %v", err)
}
configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("BlobsPath() error = %v", err)
}
data, err := os.ReadFile(configPath)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
var cfg model.ConfigV2
if err := json.Unmarshal(data, &cfg); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if cfg.FileType != "mxfp8" {
t.Fatalf("FileType = %q, want %q", cfg.FileType, "mxfp8")
}
}

View file

@ -15,6 +15,10 @@ import (
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
func canonicalQuantType(quantType string) string {
return strings.ToLower(strings.TrimSpace(quantType))
}
// modelConfig represents the HuggingFace config.json structure // modelConfig represents the HuggingFace config.json structure
type modelConfig struct { type modelConfig struct {
Architectures []string `json:"architectures"` Architectures []string `json:"architectures"`
@ -256,7 +260,7 @@ func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
} }
if info.QuantType != "" { if info.QuantType != "" {
quantType := strings.ToUpper(info.QuantType) quantType := canonicalQuantType(info.QuantType)
shape := make([]uint64, len(info.Shape)) shape := make([]uint64, len(info.Shape))
for i, s := range info.Shape { for i, s := range info.Shape {
@ -323,8 +327,8 @@ func GetSafetensorsDtype(name model.Name) (string, error) {
if err != nil { if err != nil {
continue continue
} }
if info.QuantType != "" { if quantType := canonicalQuantType(info.QuantType); quantType != "" {
return strings.ToUpper(info.QuantType), nil return quantType, nil
} }
// Only check the first tensor blob // Only check the first tensor blob
break break

View file

@ -705,8 +705,8 @@ func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
if tensor.Name != "model.layers.0.mlp.up_proj.weight" { if tensor.Name != "model.layers.0.mlp.up_proj.weight" {
t.Errorf("Name = %v, want model.layers.0.mlp.up_proj.weight", tensor.Name) t.Errorf("Name = %v, want model.layers.0.mlp.up_proj.weight", tensor.Name)
} }
if tensor.Type != "INT4" { if tensor.Type != "int4" {
t.Errorf("Type = %v, want INT4", tensor.Type) t.Errorf("Type = %v, want int4", tensor.Type)
} }
// Shape should be unpacked: 320 * 8 = 2560 // Shape should be unpacked: 320 * 8 = 2560
if len(tensor.Shape) != 2 || tensor.Shape[0] != 2560 || tensor.Shape[1] != 2560 { if len(tensor.Shape) != 2 || tensor.Shape[0] != 2560 || tensor.Shape[1] != 2560 {
@ -1196,6 +1196,17 @@ func TestGetTensorInfoFromManifest_Packed(t *testing.T) {
if !packedNames["model.layers.0.mlp.experts.0.gate_proj.weight"] { if !packedNames["model.layers.0.mlp.experts.0.gate_proj.weight"] {
t.Error("missing packed tensor: model.layers.0.mlp.experts.0.gate_proj.weight") t.Error("missing packed tensor: model.layers.0.mlp.experts.0.gate_proj.weight")
} }
packedTypes := make(map[string]string)
for _, r := range result[1:] {
packedTypes[r.Name] = r.Type
}
if packedTypes["model.layers.0.mlp.experts.0.down_proj.weight"] != "int8" {
t.Errorf("down_proj.Type = %v, want int8", packedTypes["model.layers.0.mlp.experts.0.down_proj.weight"])
}
if packedTypes["model.layers.0.mlp.experts.0.gate_proj.weight"] != "int4" {
t.Errorf("gate_proj.Type = %v, want int4", packedTypes["model.layers.0.mlp.experts.0.gate_proj.weight"])
}
} }
func TestReadSafetensorsHeader(t *testing.T) { func TestReadSafetensorsHeader(t *testing.T) {