package mlxrunner import ( "context" "errors" "log/slog" "net" "net/http" "strings" "golang.org/x/sync/errgroup" "github.com/ollama/ollama/api" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/mlxrunner/sample" "github.com/ollama/ollama/x/tokenizer" ) // Request is a short-lived struct that carries a completion request through // a channel from the HTTP handler to the runner goroutine. The ctx field // must travel with the request so that cancellation propagates across the // channel boundary. type Request struct { CompletionRequest Responses chan CompletionResponse Pipeline func(context.Context, Request) error Ctx context.Context //nolint:containedctx Tokens []int32 Sampler *sample.Sampler } type Runner struct { Model base.Model Tokenizer *tokenizer.Tokenizer Requests chan Request cache kvCache contextLength int } func (r *Runner) Load(modelName string) error { root, err := model.Open(modelName) if err != nil { return err } defer root.Close() m, err := base.New(root) if err != nil { return err } // Load all tensor blobs from manifest tensors, err := loadTensorsFromManifest(root) if err != nil { return err } // Assign weights to model (model-specific logic) loadWeights := base.Weights(m) if err := loadWeights(tensors); err != nil { return err } r.Model = m r.Tokenizer = m.Tokenizer() r.contextLength = m.MaxContextLength() mlx.EnableCompile() return nil } // loadTensorsFromManifest loads all tensor blobs from the manifest into a // flat map, deduplicating by digest and remapping safetensors key suffixes. // // Uses a two-phase approach: first loads all raw tensors, then remaps // .bias → _qbias with complete knowledge of which base names have .scale // entries. This avoids a race condition where Go map iteration order could // cause .bias to be processed before .scale within the same blob. func loadTensorsFromManifest(root *model.Root) (map[string]*mlx.Array, error) { // Phase 1: Load all tensors raw from all blobs rawTensors := make(map[string]*mlx.Array) seen := make(map[string]bool) for _, layer := range root.Manifest.GetTensorLayers("") { if seen[layer.Digest] { continue } seen[layer.Digest] = true blobPath := root.Manifest.BlobPath(layer.Digest) for name, arr := range mlx.Load(blobPath) { rawTensors[name] = arr } } // Phase 2: Identify all base names that have .scale tensors and remap them scaleBaseNames := make(map[string]bool) allTensors := make(map[string]*mlx.Array, len(rawTensors)) for name, arr := range rawTensors { if strings.HasSuffix(name, ".scale") { baseName := strings.TrimSuffix(name, ".scale") allTensors[baseName+"_scale"] = arr scaleBaseNames[baseName] = true } } // Phase 3: Process remaining tensors with complete scale knowledge for name, arr := range rawTensors { if strings.HasSuffix(name, ".scale") { continue // already handled } if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") { baseName := strings.TrimSuffix(name, ".bias") if scaleBaseNames[baseName] { allTensors[baseName+"_qbias"] = arr } else { allTensors[name] = arr } } else { allTensors[name] = arr } } slog.Info("Loaded tensors from manifest", "count", len(allTensors)) return allTensors, nil } func (r *Runner) Run(host, port string, mux http.Handler) error { g, ctx := errgroup.WithContext(context.Background()) g.Go(func() error { for { select { case <-ctx.Done(): return nil case request := <-r.Requests: if err := request.Pipeline(request.Ctx, request); err != nil { slog.Info("Request terminated", "error", err) var statusErr api.StatusError if !errors.As(err, &statusErr) { statusErr = api.StatusError{ StatusCode: http.StatusInternalServerError, ErrorMessage: err.Error(), } } select { case request.Responses <- CompletionResponse{Error: &statusErr}: case <-request.Ctx.Done(): } } close(request.Responses) } } }) g.Go(func() error { slog.Info("Starting HTTP server", "host", host, "port", port) return http.ListenAndServe(net.JoinHostPort(host, port), mux) }) return g.Wait() }