This commit is contained in:
ParthSareen 2026-01-07 01:58:37 -08:00
parent 44179b7e53
commit 6b7456ca1f
4 changed files with 348 additions and 41 deletions

View file

@ -1104,3 +1104,108 @@ func PromptYesNo(question string) (bool, error) {
}
}
}
// CloudModelOption represents a suggested cloud model for the selection prompt.
type CloudModelOption struct {
Name string
Description string
}
// PromptModelChoice displays a model selection prompt with multiple options.
// Returns the selected model name, or empty string if user declined or cancelled.
func PromptModelChoice(question string, models []CloudModelOption) (string, error) {
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
return "", err
}
defer term.Restore(fd, oldState)
// Build options: models + "No thanks, continue"
optionCount := len(models) + 1
selected := 0
// Total lines: question + models + "no thanks" + hint = optionCount + 2
totalLines := optionCount + 2
// Hide cursor
fmt.Fprint(os.Stderr, "\033[?25l")
defer fmt.Fprint(os.Stderr, "\033[?25h")
firstRender := true
render := func() {
if !firstRender {
fmt.Fprintf(os.Stderr, "\033[%dA\r", totalLines-1)
}
firstRender = false
// \r\n needed in raw mode for proper line breaks
fmt.Fprintf(os.Stderr, "\033[K\033[36m%s\033[0m\r\n", question)
for i, model := range models {
fmt.Fprintf(os.Stderr, "\033[K")
if i == selected {
fmt.Fprintf(os.Stderr, " \033[1;32m> %s\033[0m \033[90m%s\033[0m\r\n", model.Name, model.Description)
} else {
fmt.Fprintf(os.Stderr, " \033[90m%s %s\033[0m\r\n", model.Name, model.Description)
}
}
fmt.Fprintf(os.Stderr, "\033[K")
if selected == len(models) {
fmt.Fprintf(os.Stderr, " \033[1;32m> No thanks, continue\033[0m\r\n")
} else {
fmt.Fprintf(os.Stderr, " \033[90mNo thanks, continue\033[0m\r\n")
}
fmt.Fprintf(os.Stderr, "\033[K\033[90m(↑/↓ to navigate, Enter to confirm)\033[0m")
}
render()
buf := make([]byte, 3)
for {
n, err := os.Stdin.Read(buf)
if err != nil {
return "", err
}
if n == 1 {
switch buf[0] {
case 'j', 'J':
if selected < optionCount-1 {
selected++
}
render()
case 'k', 'K':
if selected > 0 {
selected--
}
render()
case '\r', '\n':
fmt.Fprintf(os.Stderr, "\n")
if selected < len(models) {
return models[selected].Name, nil
}
return "", nil
case 3: // Ctrl+C
fmt.Fprintf(os.Stderr, "\n")
return "", nil
}
} else if n == 3 && buf[0] == 27 && buf[1] == 91 {
switch buf[2] {
case 'A': // Up
if selected > 0 {
selected--
}
render()
case 'B': // Down
if selected < optionCount-1 {
selected++
}
render()
}
}
}
}

25
x/agent/prompt_test.go Normal file
View file

@ -0,0 +1,25 @@
package agent
import (
"testing"
)
func TestCloudModelOptionStruct(t *testing.T) {
// Test that the struct is defined correctly
models := []CloudModelOption{
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
}
if len(models) != 2 {
t.Errorf("expected 2 models, got %d", len(models))
}
if models[0].Name != "glm-4.7:cloud" {
t.Errorf("expected glm-4.7:cloud, got %s", models[0].Name)
}
if models[1].Description != "Qwen3 Coder 480B" {
t.Errorf("expected 'Qwen3 Coder 480B', got %s", models[1].Description)
}
}

41
x/cmd/cloudmodel_test.go Normal file
View file

@ -0,0 +1,41 @@
package cmd
import (
"errors"
"testing"
)
func TestCloudModelSwitchRequest(t *testing.T) {
// Test the error type
req := &CloudModelSwitchRequest{Model: "glm-4.7:cloud"}
// Test Error() method
errMsg := req.Error()
expected := "switch to model: glm-4.7:cloud"
if errMsg != expected {
t.Errorf("expected %q, got %q", expected, errMsg)
}
// Test errors.As
var err error = req
var switchReq *CloudModelSwitchRequest
if !errors.As(err, &switchReq) {
t.Error("errors.As should return true for CloudModelSwitchRequest")
}
if switchReq.Model != "glm-4.7:cloud" {
t.Errorf("expected model glm-4.7:cloud, got %s", switchReq.Model)
}
}
func TestSuggestedCloudModels(t *testing.T) {
// Verify the suggested models are defined
if len(suggestedCloudModels) == 0 {
t.Error("suggestedCloudModels should not be empty")
}
// Check first model
if suggestedCloudModels[0].Name != "glm-4.7:cloud" {
t.Errorf("expected first model to be glm-4.7:cloud, got %s", suggestedCloudModels[0].Name)
}
}

View file

@ -37,6 +37,22 @@ const (
charsPerToken = 4
)
// suggestedCloudModels are the models suggested to users after signing in.
// TODO(parthsareen): Dynamically recommend models based on user context instead of hardcoding
var suggestedCloudModels = []agent.CloudModelOption{
{Name: "glm-4.7:cloud", Description: "GLM 4.7 Cloud"},
{Name: "qwen3-coder:480b-cloud", Description: "Qwen3 Coder 480B"},
}
// CloudModelSwitchRequest signals that the user wants to switch to a different model.
type CloudModelSwitchRequest struct {
Model string
}
func (c *CloudModelSwitchRequest) Error() string {
return fmt.Sprintf("switch to model: %s", c.Model)
}
// isLocalModel checks if the model is running locally (not a cloud model).
// TODO: Improve local/cloud model identification - could check model metadata
func isLocalModel(modelName string) bool {
@ -119,6 +135,21 @@ func waitForOllamaSignin(ctx context.Context) error {
return nil
}
// promptCloudModelSuggestion shows cloud model suggestions after successful sign-in.
// Returns the selected model name, or empty string if user declines.
func promptCloudModelSuggestion() string {
fmt.Fprintf(os.Stderr, "\n")
fmt.Fprintf(os.Stderr, "\033[1;36mTry cloud models for free!\033[0m\n")
fmt.Fprintf(os.Stderr, "\033[90mCloud models offer powerful capabilities without local hardware requirements.\033[0m\n")
fmt.Fprintf(os.Stderr, "\n")
selectedModel, err := agent.PromptModelChoice("Try a cloud model now?", suggestedCloudModels)
if err != nil || selectedModel == "" {
return ""
}
return selectedModel
}
// RunOptions contains options for running an interactive agent session.
type RunOptions struct {
Model string
@ -144,6 +175,40 @@ type RunOptions struct {
// LastToolOutputTruncated stores the truncated version shown inline
LastToolOutputTruncated *string
// ActiveModel points to the current model name - can be updated mid-turn
// for model switching. If nil, opts.Model is used.
ActiveModel *string
}
// getActiveModel returns the current model name, checking ActiveModel pointer first.
func getActiveModel(opts *RunOptions) string {
if opts.ActiveModel != nil && *opts.ActiveModel != "" {
return *opts.ActiveModel
}
return opts.Model
}
// showModelConnection displays "Connecting to X on ollama.com" for cloud models.
func showModelConnection(ctx context.Context, modelName string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
info, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
if err != nil {
return err
}
if info.RemoteHost != "" {
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
return nil
}
// Chat runs an agent chat loop with tool support.
@ -243,7 +308,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
// Agentic loop: continue until no more tool calls
for {
req := &api.ChatRequest{
Model: opts.Model,
Model: getActiveModel(&opts),
Messages: messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
@ -267,7 +332,6 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
return nil, nil
}
// Check for 401 Unauthorized - prompt user to sign in
var authErr api.AuthorizationError
if errors.As(err, &authErr) {
p.StopAndClear()
@ -275,9 +339,13 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the chat request
suggestedModel := promptCloudModelSuggestion()
if suggestedModel != "" {
return nil, &CloudModelSwitchRequest{Model: suggestedModel}
}
fmt.Fprintf(os.Stderr, "\033[90mRetrying...\033[0m\n")
continue // Retry the loop
continue
}
}
return nil, fmt.Errorf("authentication required - run 'ollama signin' to authenticate")
@ -415,19 +483,20 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
// Execute the tool
toolResult, err := toolRegistry.Execute(call)
if err != nil {
// Check if web search needs authentication
if errors.Is(err, tools.ErrWebSearchAuthRequired) {
// Prompt user to sign in
fmt.Fprintf(os.Stderr, "\033[33m Web search requires authentication.\033[0m\n")
result, promptErr := agent.PromptYesNo("Sign in to Ollama?")
if promptErr == nil && result {
// Get signin URL and wait for auth completion
if signinErr := waitForOllamaSignin(ctx); signinErr == nil {
// Retry the web search
fmt.Fprintf(os.Stderr, "\033[90m Retrying web search...\033[0m\n")
suggestedModel := promptCloudModelSuggestion()
if suggestedModel != "" && opts.ActiveModel != nil {
*opts.ActiveModel = suggestedModel
showModelConnection(ctx, suggestedModel)
}
fmt.Fprintf(os.Stderr, "\033[90mRetrying web search...\033[0m\n")
toolResult, err = toolRegistry.Execute(call)
if err == nil {
goto toolSuccess
@ -466,7 +535,7 @@ func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
}
// Truncate output to prevent context overflow
toolResultForLLM := truncateToolOutput(toolResult, opts.Model)
toolResultForLLM := truncateToolOutput(toolResult, getActiveModel(&opts))
toolResults = append(toolResults, api.Message{
Role: "tool",
@ -625,25 +694,28 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
return out
}
// checkModelCapabilities checks if the model supports tools.
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) {
// checkModelCapabilities checks if the model supports tools and thinking.
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, supportsThinking bool, err error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return false, err
return false, false, err
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
if err != nil {
return false, err
return false, false, err
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityTools {
return true, nil
supportsTools = true
}
if cap == model.CapabilityThinking {
supportsThinking = true
}
}
return false, nil
return supportsTools, supportsThinking, nil
}
// GenerateInteractive runs an interactive agent session.
@ -663,13 +735,17 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
// Check if model supports tools
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
// Check if model supports tools and thinking
supportsTools, supportsThinking, err := checkModelCapabilities(cmd.Context(), modelName)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
supportsTools = false
supportsThinking = false
}
// Track if session is using thinking mode
usingThinking := think != nil && supportsThinking
// Create tool registry only if model supports tools
var toolRegistry *tools.Registry
if supportsTools {
@ -757,30 +833,44 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
if sb.Len() > 0 {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
LastToolOutput: &lastToolOutput,
LastToolOutputTruncated: &lastToolOutputTruncated,
}
// Reset expanded state for new tool execution
toolOutputExpanded = false
assistant, err := Chat(cmd.Context(), opts)
if err != nil {
return err
}
if assistant != nil {
messages = append(messages, *assistant)
retryChat:
for {
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
YoloMode: yoloMode,
LastToolOutput: &lastToolOutput,
LastToolOutputTruncated: &lastToolOutputTruncated,
ActiveModel: &modelName,
}
assistant, err := Chat(cmd.Context(), opts)
if err != nil {
var switchReq *CloudModelSwitchRequest
if errors.As(err, &switchReq) {
newModel := switchReq.Model
if err := switchToModel(cmd.Context(), newModel, &modelName, &supportsTools, &supportsThinking, &toolRegistry, usingThinking); err != nil {
fmt.Fprintf(os.Stderr, "\033[33m%v\033[0m\n", err)
fmt.Fprintf(os.Stderr, "\033[90mContinuing with %s...\033[0m\n", modelName)
}
continue retryChat
}
return err
}
if assistant != nil {
messages = append(messages, *assistant)
}
break retryChat
}
sb.Reset()
@ -788,6 +878,52 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
}
}
// switchToModel handles model switching with capability checks and UI updates.
func switchToModel(ctx context.Context, newModel string, modelName *string, supportsTools, supportsThinking *bool, toolRegistry **tools.Registry, usingThinking bool) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
newSupportsTools, newSupportsThinking, capErr := checkModelCapabilities(ctx, newModel)
if capErr != nil {
return fmt.Errorf("could not check model capabilities: %w", capErr)
}
// TODO(parthsareen): Handle thinking -> non-thinking model switch gracefully
if usingThinking && !newSupportsThinking {
return fmt.Errorf("%s does not support thinking mode", newModel)
}
// Show "Connecting to X on ollama.com" for cloud models
info, err := client.Show(ctx, &api.ShowRequest{Model: newModel})
if err == nil && info.RemoteHost != "" {
if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
*modelName = newModel
*supportsTools = newSupportsTools
*supportsThinking = newSupportsThinking
if *supportsTools {
if *toolRegistry == nil {
*toolRegistry = tools.DefaultRegistry()
}
if (*toolRegistry).Count() > 0 {
fmt.Fprintf(os.Stderr, "\033[90mTools available: %s\033[0m\n", strings.Join((*toolRegistry).Names(), ", "))
}
} else {
*toolRegistry = nil
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
}
return nil
}
// showToolsStatus displays the current tools and approval status.
func showToolsStatus(registry *tools.Registry, approval *agent.ApprovalManager, supportsTools bool) {
if !supportsTools || registry == nil {