From 300da692caae73f405abc7e16499d55d20ee7e64 Mon Sep 17 00:00:00 2001 From: reugn Date: Wed, 4 Dec 2024 11:59:17 +0200 Subject: [PATCH] feat!: add support for model configuration and persistent history storage --- README.md | 85 +++++++++--- cmd/gemini/.gitignore | 1 + cmd/gemini/main.go | 37 ++++-- gemini/chat_session.go | 50 ++++++- gemini/generative_model_builder.go | 134 +++++++++++++++++++ gemini/serializable_content.go | 52 ++++++++ gemini/system_instruction.go | 13 ++ internal/chat/chat.go | 146 +++++++++++++++++++++ internal/chat/chat_options.go | 9 ++ internal/cli/chat.go | 120 ----------------- internal/cli/command.go | 173 ------------------------- internal/cli/prompt.go | 22 ++-- internal/cli/select.go | 51 -------- internal/cli/spinner.go | 46 ++++--- internal/config/application_data.go | 38 ++++++ internal/config/configuration.go | 61 +++++++++ internal/handler/gemini_query.go | 59 +++++++++ internal/handler/handler.go | 8 ++ internal/handler/history_command.go | 171 ++++++++++++++++++++++++ internal/handler/input_mode_command.go | 80 ++++++++++++ internal/handler/model_command.go | 110 ++++++++++++++++ internal/handler/prompt_command.go | 75 +++++++++++ internal/handler/quit_command.go | 18 +++ internal/handler/response.go | 46 +++++++ internal/handler/system_command.go | 63 +++++++++ 25 files changed, 1257 insertions(+), 411 deletions(-) create mode 100644 gemini/generative_model_builder.go create mode 100644 gemini/serializable_content.go create mode 100644 gemini/system_instruction.go create mode 100644 internal/chat/chat.go create mode 100644 internal/chat/chat_options.go delete mode 100644 internal/cli/chat.go delete mode 100644 internal/cli/command.go delete mode 100644 internal/cli/select.go create mode 100644 internal/config/application_data.go create mode 100644 internal/config/configuration.go create mode 100644 internal/handler/gemini_query.go create mode 100644 internal/handler/handler.go create mode 100644 internal/handler/history_command.go create mode 100644 internal/handler/input_mode_command.go create mode 100644 internal/handler/model_command.go create mode 100644 internal/handler/prompt_command.go create mode 100644 internal/handler/quit_command.go create mode 100644 internal/handler/response.go create mode 100644 internal/handler/system_command.go diff --git a/README.md b/README.md index 47a1a76..08e1cf2 100644 --- a/README.md +++ b/README.md @@ -8,16 +8,16 @@ A command-line interface (CLI) for [Google Gemini](https://deepmind.google/techn Google Gemini is a family of multimodal artificial intelligence (AI) large language models that have capabilities in language, audio, code and video understanding. -The current version only supports multi-turn conversations (chat), using the `gemini-pro` model. +This application offers a command-line interface for interacting with various generative models through +multi-turn chat. Model selection is controlled via [system command](#system-commands) inputs. ## Installation -Choose a binary from the [releases](https://github.com/reugn/gemini-cli/releases). +Choose a binary from [releases](https://github.com/reugn/gemini-cli/releases). ### Build from Source Download and [install Go](https://golang.org/doc/install). Install the application: - ```sh go install github.com/reugn/gemini-cli/cmd/gemini@latest ``` @@ -25,24 +25,73 @@ go install github.com/reugn/gemini-cli/cmd/gemini@latest See the [go install](https://go.dev/ref/mod#go-install) instructions for more information about the command. ## Usage +> [!NOTE] +> For information on the available regions for the Gemini API and Google AI Studio, +> see [here](https://ai.google.dev/available_regions#available_regions). ### API key To use `gemini-cli`, you'll need an API key set in the `GEMINI_API_KEY` environment variable. If you don't already have one, create a key in [Google AI Studio](https://makersuite.google.com/app/apikey). -> [!NOTE] -> For information on the available regions for the Gemini API and Google AI Studio, see [here](https://ai.google.dev/available_regions#available_regions). +To set the environment variable in the terminal: +```console +export GEMINI_API_KEY= +``` ### System commands The system chat message must begin with an exclamation mark and is used for internal operations. A short list of supported system commands: -| Command | Description | -|---------|------------------------------------------------------| -| !q | Quit the application | -| !p | Delete the history used as chat context by the model | -| !i | Toggle input mode (single-line <-> multi-line) | -| !m | Select generative model | +| Command | Description | +|---------|----------------------------------------------------| +| !q | Quit the application | +| !p | Select the system prompt for the chat 1 | +| !i | Toggle input mode (single-line <-> multi-line) | +| !m | Select a model operation 2 | +| !h | Select a history operation 3 | + +1 System instruction (also known as "system prompt") is a more forceful prompt to the model. +The model will adhere the instructions more strongly than if they appeared in a normal prompt. +The system instructions must be specified by the user in the [configuration file](#configuration-file). + +2 Model operations: +* Select a generative model from the list of available models +* Show the selected model information + +3 History operations: +* Clear the chat history +* Store the chat history to the configuration file +* Load a chat history record from the configuration file +* Delete all history records from the configuration file + +### Configuration file +The application uses a configuration file to store generative model settings and chat history. This file is optional. +If it doesn't exist, the application will attempt to create it using default values. You can use the +[config flag](#cli-help) to specify the location of the configuration file. + +An example of basic configuration: +```json +{ + "SystemPrompts": { + "Software Engineer": "You are an experienced software engineer.", + "Technical Writer": "Act as a tech writer. I will provide you with the basic steps of an app functionality, and you will come up with an engaging article on how to do those steps." + }, + "SafetySettings": [ + { + "Category": 7, + "Threshold": 1 + }, + { + "Category": 10, + "Threshold": 1 + } + ], + "History": { + } +} +``` +Upon user request, the `History` map will be populated with records. Note that the chat history is stored in plain +text format. See [history operations](#system-commands) for details. ### CLI help ```console @@ -53,13 +102,13 @@ Usage: [flags] Flags: - -f, --format render markdown-formatted response (default true) - -h, --help help for this command - -m, --model string generative model name (default "gemini-pro") - --multiline read input as a multi-line string - -s, --style string markdown format style (ascii, dark, light, pink, notty, dracula) (default "auto") - -t, --term string multi-line input terminator (default "$") - -v, --version version for this command + -c, --config string path to configuration file in JSON format (default "gemini_cli_config.json") + -h, --help help for this command + -m, --model string generative model name (default "gemini-pro") + --multiline read input as a multi-line string + -s, --style string markdown format style (ascii, dark, light, pink, notty, dracula) (default "auto") + -t, --term string multi-line input terminator (default "$") + -v, --version version for this command ``` ## License diff --git a/cmd/gemini/.gitignore b/cmd/gemini/.gitignore index 18edafa..8995900 100644 --- a/cmd/gemini/.gitignore +++ b/cmd/gemini/.gitignore @@ -1 +1,2 @@ gemini +*.json diff --git a/cmd/gemini/main.go b/cmd/gemini/main.go index 2c97370..40060a2 100644 --- a/cmd/gemini/main.go +++ b/cmd/gemini/main.go @@ -6,13 +6,15 @@ import ( "os/user" "github.com/reugn/gemini-cli/gemini" - "github.com/reugn/gemini-cli/internal/cli" + "github.com/reugn/gemini-cli/internal/chat" + "github.com/reugn/gemini-cli/internal/config" "github.com/spf13/cobra" ) const ( - version = "0.3.1" - apiKeyEnv = "GEMINI_API_KEY" //nolint:gosec + version = "0.3.1" + apiKeyEnv = "GEMINI_API_KEY" //nolint:gosec + defaultConfigPath = "gemini_cli_config.json" ) func run() int { @@ -21,26 +23,39 @@ func run() int { Version: version, } - var opts cli.ChatOpts - rootCmd.Flags().StringVarP(&opts.Model, "model", "m", gemini.DefaultModel, "generative model name") - rootCmd.Flags().BoolVarP(&opts.Format, "format", "f", true, "render markdown-formatted response") + var opts chat.Opts + var configPath string + rootCmd.Flags().StringVarP(&opts.GenerativeModel, "model", "m", gemini.DefaultModel, + "generative model name") rootCmd.Flags().StringVarP(&opts.Style, "style", "s", "auto", "markdown format style (ascii, dark, light, pink, notty, dracula)") - rootCmd.Flags().BoolVar(&opts.Multiline, "multiline", false, "read input as a multi-line string") - rootCmd.Flags().StringVarP(&opts.Terminator, "term", "t", "$", "multi-line input terminator") + rootCmd.Flags().BoolVar(&opts.Multiline, "multiline", false, + "read input as a multi-line string") + rootCmd.Flags().StringVarP(&opts.LineTerminator, "term", "t", "$", + "multi-line input terminator") + rootCmd.Flags().StringVarP(&configPath, "config", "c", defaultConfigPath, + "path to configuration file in JSON format") rootCmd.RunE = func(_ *cobra.Command, _ []string) error { + configuration, err := config.NewConfiguration(configPath) + if err != nil { + return err + } + + modelBuilder := gemini.NewGenerativeModelBuilder(). + WithName(opts.GenerativeModel). + WithSafetySettings(configuration.Data.SafetySettings) apiKey := os.Getenv(apiKeyEnv) - chatSession, err := gemini.NewChatSession(context.Background(), opts.Model, apiKey) + chatSession, err := gemini.NewChatSession(context.Background(), modelBuilder, apiKey) if err != nil { return err } - chat, err := cli.NewChat(getCurrentUser(), chatSession, &opts) + chatHandler, err := chat.New(getCurrentUser(), chatSession, configuration, os.Stdout, &opts) if err != nil { return err } - chat.StartChat() + chatHandler.Start() return chatSession.Close() } diff --git a/gemini/chat_session.go b/gemini/chat_session.go index 0c5aea6..62fb331 100644 --- a/gemini/chat_session.go +++ b/gemini/chat_session.go @@ -2,6 +2,8 @@ package gemini import ( "context" + "encoding/json" + "fmt" "sync" "github.com/google/generative-ai-go/genai" @@ -15,6 +17,7 @@ type ChatSession struct { ctx context.Context client *genai.Client + model *genai.GenerativeModel session *genai.ChatSession loadModels sync.Once @@ -22,16 +25,20 @@ type ChatSession struct { } // NewChatSession returns a new [ChatSession]. -func NewChatSession(ctx context.Context, model, apiKey string) (*ChatSession, error) { +func NewChatSession( + ctx context.Context, modelBuilder *GenerativeModelBuilder, apiKey string, +) (*ChatSession, error) { client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) if err != nil { return nil, err } + generativeModel := modelBuilder.build(client) return &ChatSession{ ctx: ctx, client: client, - session: client.GenerativeModel(model).StartChat(), + model: generativeModel, + session: generativeModel.StartChat(), }, nil } @@ -45,14 +52,33 @@ func (c *ChatSession) SendMessageStream(input string) *genai.GenerateContentResp return c.session.SendMessageStream(c.ctx, genai.Text(input)) } -// SetGenerativeModel sets the name of the generative model for the chat. -// It preserves the history from the previous chat session. -func (c *ChatSession) SetGenerativeModel(model string) { +// SetModel sets a new generative model configured with the builder and starts +// a new chat session. It preserves the history of the previous chat session. +func (c *ChatSession) SetModel(modelBuilder *GenerativeModelBuilder) { history := c.session.History - c.session = c.client.GenerativeModel(model).StartChat() + c.model = modelBuilder.build(c.client) + c.session = c.model.StartChat() c.session.History = history } +// CopyModelBuilder returns a copy builder for the chat generative model. +func (c *ChatSession) CopyModelBuilder() *GenerativeModelBuilder { + return newCopyGenerativeModelBuilder(c.model) +} + +// ModelInfo returns information about the chat generative model in JSON format. +func (c *ChatSession) ModelInfo() (string, error) { + modelInfo, err := c.model.Info(c.ctx) + if err != nil { + return "", err + } + encoded, err := json.MarshalIndent(modelInfo, "", " ") + if err != nil { + return "", fmt.Errorf("error encoding model info: %w", err) + } + return string(encoded), nil +} + // ListModels returns a list of the supported generative model names. func (c *ChatSession) ListModels() []string { c.loadModels.Do(func() { @@ -69,7 +95,17 @@ func (c *ChatSession) ListModels() []string { return c.models } -// ClearHistory clears chat history. +// GetHistory returns the chat session history. +func (c *ChatSession) GetHistory() []*genai.Content { + return c.session.History +} + +// SetHistory sets the chat session history. +func (c *ChatSession) SetHistory(content []*genai.Content) { + c.session.History = content +} + +// ClearHistory clears the chat session history. func (c *ChatSession) ClearHistory() { c.session.History = make([]*genai.Content, 0) } diff --git a/gemini/generative_model_builder.go b/gemini/generative_model_builder.go new file mode 100644 index 0000000..724a404 --- /dev/null +++ b/gemini/generative_model_builder.go @@ -0,0 +1,134 @@ +package gemini + +import ( + "github.com/google/generative-ai-go/genai" +) + +type boxed[T any] struct { + value T +} + +// GenerativeModelBuilder implements the builder pattern for [genai.GenerativeModel]. +type GenerativeModelBuilder struct { + copy *genai.GenerativeModel + + name *boxed[string] + generationConfig *boxed[genai.GenerationConfig] + safetySettings *boxed[[]*genai.SafetySetting] + tools *boxed[[]*genai.Tool] + toolConfig *boxed[*genai.ToolConfig] + systemInstruction *boxed[*genai.Content] + cachedContentName *boxed[string] +} + +// NewGenerativeModelBuilder returns a new [GenerativeModelBuilder] with empty default values. +func NewGenerativeModelBuilder() *GenerativeModelBuilder { + return &GenerativeModelBuilder{} +} + +// newCopyGenerativeModelBuilder creates a new [GenerativeModelBuilder], +// taking the default values from an existing [genai.GenerativeModel] object. +func newCopyGenerativeModelBuilder(copy *genai.GenerativeModel) *GenerativeModelBuilder { + return &GenerativeModelBuilder{copy: copy} +} + +// WithName sets the model name. +func (b *GenerativeModelBuilder) WithName( + modelName string, +) *GenerativeModelBuilder { + b.name = &boxed[string]{modelName} + return b +} + +// WithGenerationConfig sets the generation config. +func (b *GenerativeModelBuilder) WithGenerationConfig( + generationConfig genai.GenerationConfig, +) *GenerativeModelBuilder { + b.generationConfig = &boxed[genai.GenerationConfig]{generationConfig} + return b +} + +// WithSafetySettings sets the safety settings. +func (b *GenerativeModelBuilder) WithSafetySettings( + safetySettings []*genai.SafetySetting, +) *GenerativeModelBuilder { + b.safetySettings = &boxed[[]*genai.SafetySetting]{safetySettings} + return b +} + +// WithTools sets the tools. +func (b *GenerativeModelBuilder) WithTools( + tools []*genai.Tool, +) *GenerativeModelBuilder { + b.tools = &boxed[[]*genai.Tool]{tools} + return b +} + +// WithToolConfig sets the tool config. +func (b *GenerativeModelBuilder) WithToolConfig( + toolConfig *genai.ToolConfig, +) *GenerativeModelBuilder { + b.toolConfig = &boxed[*genai.ToolConfig]{toolConfig} + return b +} + +// WithSystemInstruction sets the system instruction. +func (b *GenerativeModelBuilder) WithSystemInstruction( + systemInstruction *genai.Content, +) *GenerativeModelBuilder { + b.systemInstruction = &boxed[*genai.Content]{systemInstruction} + return b +} + +// WithCachedContentName sets the name of the [genai.CachedContent] to use. +func (b *GenerativeModelBuilder) WithCachedContentName( + cachedContentName string, +) *GenerativeModelBuilder { + b.cachedContentName = &boxed[string]{cachedContentName} + return b +} + +// build builds and returns a new [genai.GenerativeModel] using the given [genai.Client]. +// It will panic if the copy and the model name are not set. +func (b *GenerativeModelBuilder) build(client *genai.Client) *genai.GenerativeModel { + if b.copy == nil && b.name == nil { + panic("model name is required") + } + + model := b.copy + if b.name != nil { + model = client.GenerativeModel(b.name.value) + if b.copy != nil { + model.GenerationConfig = b.copy.GenerationConfig + model.SafetySettings = b.copy.SafetySettings + model.Tools = b.copy.Tools + model.ToolConfig = b.copy.ToolConfig + model.SystemInstruction = b.copy.SystemInstruction + model.CachedContentName = b.copy.CachedContentName + } + } + b.configure(model) + return model +} + +// configure configures the given generative model using the builder values. +func (b *GenerativeModelBuilder) configure(model *genai.GenerativeModel) { + if b.generationConfig != nil { + model.GenerationConfig = b.generationConfig.value + } + if b.safetySettings != nil { + model.SafetySettings = b.safetySettings.value + } + if b.tools != nil { + model.Tools = b.tools.value + } + if b.toolConfig != nil { + model.ToolConfig = b.toolConfig.value + } + if b.systemInstruction != nil { + model.SystemInstruction = b.systemInstruction.value + } + if b.cachedContentName != nil { + model.CachedContentName = b.cachedContentName.value + } +} diff --git a/gemini/serializable_content.go b/gemini/serializable_content.go new file mode 100644 index 0000000..1b267c8 --- /dev/null +++ b/gemini/serializable_content.go @@ -0,0 +1,52 @@ +package gemini + +import ( + "fmt" + + "github.com/google/generative-ai-go/genai" +) + +// SerializableContent is the data type containing multipart text message content. +// It is a serializable equivalent of [genai.Content], where message content parts +// are represented as strings. +type SerializableContent struct { + // Ordered parts that constitute a single message. + Parts []string + // The producer of the content. Must be either 'user' or 'model'. + Role string +} + +// NewSerializableContent instantiates and returns a new SerializableContent from +// the given [genai.Content]. +// It will panic if the content type is not supported. +func NewSerializableContent(c *genai.Content) *SerializableContent { + parts := make([]string, len(c.Parts)) + for i, part := range c.Parts { + parts[i] = partToString(part) + } + return &SerializableContent{ + Parts: parts, + Role: c.Role, + } +} + +// ToContent converts the SerializableContent into a [genai.Content]. +func (c *SerializableContent) ToContent() *genai.Content { + parts := make([]genai.Part, len(c.Parts)) + for i, part := range c.Parts { + parts[i] = genai.Text(part) + } + return &genai.Content{ + Parts: parts, + Role: c.Role, + } +} + +func partToString(part genai.Part) string { + switch p := part.(type) { + case genai.Text: + return string(p) + default: + panic(fmt.Errorf("unsupported part type: %T", part)) + } +} diff --git a/gemini/system_instruction.go b/gemini/system_instruction.go new file mode 100644 index 0000000..347470b --- /dev/null +++ b/gemini/system_instruction.go @@ -0,0 +1,13 @@ +package gemini + +import "github.com/google/generative-ai-go/genai" + +// SystemInstruction represents a serializable system prompt, a more forceful +// instruction to the language model. The model will prioritize adhering to +// system instructions over regular prompts. +type SystemInstruction string + +// ToContent converts the SystemInstruction to [genai.Content]. +func (si SystemInstruction) ToContent() *genai.Content { + return genai.NewUserContent(genai.Text(si)) +} diff --git a/internal/chat/chat.go b/internal/chat/chat.go new file mode 100644 index 0000000..d7f55eb --- /dev/null +++ b/internal/chat/chat.go @@ -0,0 +1,146 @@ +package chat + +import ( + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/chzyer/readline" + "github.com/reugn/gemini-cli/gemini" + "github.com/reugn/gemini-cli/internal/cli" + "github.com/reugn/gemini-cli/internal/config" + "github.com/reugn/gemini-cli/internal/handler" +) + +// Chat handles the interactive exchange of messages between user and model. +type Chat struct { + reader *readline.Instance + writer io.Writer + terminalPrompt *cli.Prompt + opts *Opts + + geminiHandler handler.MessageHandler + systemHandler handler.MessageHandler +} + +// New returns a new Chat. +func New( + user string, session *gemini.ChatSession, + configuration *config.Configuration, writer io.Writer, opts *Opts, +) (*Chat, error) { + reader, err := readline.NewEx(&readline.Config{}) + if err != nil { + return nil, err + } + + terminalPrompt := cli.NewPrompt(user) + reader.SetPrompt(terminalPrompt.User) + if opts.Multiline { + // disable history for multiline input mode + reader.HistoryDisable() + } + + spinner := cli.NewSpinner(writer, time.Second, 5) + geminiHandler, err := handler.NewGeminiQuery(session, spinner, opts.Style) + if err != nil { + return nil, err + } + + systemHandler := handler.NewSystemCommand(session, configuration, reader, + &opts.Multiline, opts.GenerativeModel) + + return &Chat{ + terminalPrompt: terminalPrompt, + reader: reader, + writer: writer, + opts: opts, + geminiHandler: geminiHandler, + systemHandler: systemHandler, + }, nil +} + +// Start starts the main chat loop between user and model. +func (c *Chat) Start() { + for { + message, ok := c.read() + if !ok { + continue + } + + // process the message + messageHandler, terminalPrompt := c.getHandlerPrompt(message) + response, quit := messageHandler.Handle(message) + _ = response.Print(c.writer, terminalPrompt) + + if quit { + break + } + } +} + +func (c *Chat) read() (string, bool) { + if c.opts.Multiline { + return c.readMultiLine() + } + return c.readLine() +} + +func (c *Chat) readLine() (string, bool) { + input, err := c.reader.Readline() + if err != nil { + return c.handleReadError(len(input), err) + } + return validateInput(input) +} + +func (c *Chat) readMultiLine() (string, bool) { + var builder strings.Builder + term := c.opts.LineTerminator + for { + input, err := c.reader.Readline() + if err != nil { + c.reader.SetPrompt(c.terminalPrompt.User) + return c.handleReadError(builder.Len()+len(input), err) + } + + if strings.HasSuffix(input, term) || + strings.HasPrefix(input, handler.SystemCmdPrefix) { + builder.WriteString(strings.TrimSuffix(input, term)) + break + } + + if builder.Len() == 0 { + c.reader.SetPrompt(c.terminalPrompt.UserNext) + } + + builder.WriteString(input + "\n") + } + c.reader.SetPrompt(c.terminalPrompt.User) + return validateInput(builder.String()) +} + +func (c *Chat) handleReadError(inputLen int, err error) (string, bool) { + if errors.Is(err, readline.ErrInterrupt) { + if inputLen == 0 { + return handler.SystemCmdPrefix + handler.SystemCmdQuit, true + } + } else { + handler.PrintError(c.writer, c.terminalPrompt.Cli, err) + } + return "", false +} + +func (c *Chat) getHandlerPrompt(message string) (handler.MessageHandler, string) { + if strings.HasPrefix(message, handler.SystemCmdPrefix) { + return c.systemHandler, c.terminalPrompt.Cli + } + _, _ = fmt.Fprint(c.writer, c.terminalPrompt.Gemini) + return c.geminiHandler, "" +} + +func validateInput(input string) (string, bool) { + input = strings.TrimSpace(input) + return input, input != "" +} diff --git a/internal/chat/chat_options.go b/internal/chat/chat_options.go new file mode 100644 index 0000000..51d1f66 --- /dev/null +++ b/internal/chat/chat_options.go @@ -0,0 +1,9 @@ +package chat + +// Opts represents the Chat configuration options. +type Opts struct { + GenerativeModel string + Style string + Multiline bool + LineTerminator string +} diff --git a/internal/cli/chat.go b/internal/cli/chat.go deleted file mode 100644 index e9d6f2f..0000000 --- a/internal/cli/chat.go +++ /dev/null @@ -1,120 +0,0 @@ -package cli - -import ( - "errors" - "fmt" - "strings" - - "github.com/chzyer/readline" - "github.com/reugn/gemini-cli/gemini" -) - -// ChatOpts represents Chat configuration options. -type ChatOpts struct { - Model string - Format bool - Style string - Multiline bool - Terminator string -} - -// Chat controls the chat flow. -type Chat struct { - model *gemini.ChatSession - prompt *prompt - reader *readline.Instance - opts *ChatOpts -} - -// NewChat returns a new Chat. -func NewChat(user string, model *gemini.ChatSession, opts *ChatOpts) (*Chat, error) { - reader, err := readline.NewEx(&readline.Config{}) - if err != nil { - return nil, err - } - prompt := newPrompt(user) - reader.SetPrompt(prompt.user) - if opts.Multiline { - reader.HistoryDisable() - } - return &Chat{ - model: model, - prompt: prompt, - reader: reader, - opts: opts, - }, nil -} - -// StartChat starts the chat loop. -func (c *Chat) StartChat() { - for { - message, ok := c.read() - if !ok { - continue - } - command := c.parseCommand(message) - if quit := command.run(message); quit { - break - } - } -} - -func (c *Chat) read() (string, bool) { - if c.opts.Multiline { - return c.readMultiLine() - } - return c.readLine() -} - -func (c *Chat) readLine() (string, bool) { - input, err := c.reader.Readline() - if err != nil { - return c.handleReadError(len(input), err) - } - return validateInput(input) -} - -func (c *Chat) readMultiLine() (string, bool) { - var builder strings.Builder - term := c.opts.Terminator - for { - input, err := c.reader.Readline() - if err != nil { - c.reader.SetPrompt(c.prompt.user) - return c.handleReadError(builder.Len()+len(input), err) - } - if strings.HasSuffix(input, term) { - builder.WriteString(strings.TrimSuffix(input, term)) - break - } - if builder.Len() == 0 { - c.reader.SetPrompt(c.prompt.userNext) - } - builder.WriteString(input + "\n") - } - c.reader.SetPrompt(c.prompt.user) - return validateInput(builder.String()) -} - -func (c *Chat) parseCommand(message string) command { - if strings.HasPrefix(message, systemCmdPrefix) { - return newSystemCommand(c) - } - return newGeminiCommand(c) -} - -func (c *Chat) handleReadError(inputLen int, err error) (string, bool) { - if errors.Is(err, readline.ErrInterrupt) { - if inputLen == 0 { - return systemCmdQuit, true - } - return "", false - } - fmt.Printf("%s%s\n", c.prompt.cli, err) - return "", false -} - -func validateInput(input string) (string, bool) { - input = strings.TrimSpace(input) - return input, input != "" -} diff --git a/internal/cli/command.go b/internal/cli/command.go deleted file mode 100644 index 7215714..0000000 --- a/internal/cli/command.go +++ /dev/null @@ -1,173 +0,0 @@ -package cli - -import ( - "bufio" - "errors" - "fmt" - "os" - "strings" - "time" - - "github.com/charmbracelet/glamour" - "github.com/reugn/gemini-cli/internal/cli/color" - "google.golang.org/api/iterator" -) - -const ( - systemCmdPrefix = "!" - systemCmdQuit = "!q" - systemCmdPurgeHistory = "!p" - systemCmdSelectInputMode = "!i" - systemCmdSelectModel = "!m" -) - -type command interface { - run(message string) bool -} - -type systemCommand struct { - chat *Chat -} - -var _ command = (*systemCommand)(nil) - -func newSystemCommand(chat *Chat) command { - return &systemCommand{ - chat: chat, - } -} - -func (c *systemCommand) run(message string) bool { - switch message { - case systemCmdQuit: - c.print("Exiting gemini-cli...") - return true - case systemCmdPurgeHistory: - c.chat.model.ClearHistory() - c.print("Cleared the chat history.") - case systemCmdSelectInputMode: - multiline, err := selectInputMode(c.chat.opts.Multiline) - if err != nil { - c.error(err) - break - } - if multiline == c.chat.opts.Multiline { - c.printSelectedCurrent() - break - } - c.chat.opts.Multiline = multiline - if c.chat.opts.Multiline { - c.print("Switched to multi-line input mode.") - // disable history for multi-line messages since it is - // unusable for future requests - c.chat.reader.HistoryDisable() - } else { - c.print("Switched to single-line input mode.") - c.chat.reader.HistoryEnable() - } - case systemCmdSelectModel: - model, err := selectModel(c.chat.opts.Model, c.chat.model.ListModels()) - if err != nil { - c.error(err) - break - } - if model == c.chat.opts.Model { - c.printSelectedCurrent() - break - } - c.chat.opts.Model = model - c.chat.model.SetGenerativeModel(model) - c.print(fmt.Sprintf("Selected '%s' generative model.", model)) - default: - c.print("Unknown system command.") - } - return false -} - -func (c *systemCommand) print(message string) { - fmt.Printf("%s%s\n", c.chat.prompt.cli, message) -} - -func (c *systemCommand) printSelectedCurrent() { - fmt.Printf("%sThe selection is unchanged.\n", c.chat.prompt.cli) -} - -func (c *systemCommand) error(err error) { - fmt.Printf(color.Red("%s%s\n"), c.chat.prompt.cli, err) -} - -type geminiCommand struct { - chat *Chat - spinner *spinner - writer *bufio.Writer -} - -var _ command = (*geminiCommand)(nil) - -func newGeminiCommand(chat *Chat) command { - writer := bufio.NewWriter(os.Stdout) - return &geminiCommand{ - chat: chat, - spinner: newSpinner(5, time.Second, writer), - writer: writer, - } -} - -func (c *geminiCommand) run(message string) bool { - c.printFlush(c.chat.prompt.gemini) - c.spinner.start() - if c.chat.opts.Format { - // requires the entire response to be formatted - c.runBlocking(message) - } else { - c.runStreaming(message) - } - return false -} - -func (c *geminiCommand) runBlocking(message string) { - response, err := c.chat.model.SendMessage(message) - c.spinner.stop() - if err != nil { - fmt.Println(color.Red(err.Error())) - } else { - var buf strings.Builder - for _, candidate := range response.Candidates { - for _, part := range candidate.Content.Parts { - fmt.Fprintf(&buf, "%s", part) - } - } - output, err := glamour.Render(buf.String(), c.chat.opts.Style) - if err != nil { - fmt.Printf(color.Red("Failed to format: %s\n"), err) - fmt.Println(buf.String()) - return - } - fmt.Print(output) - } -} - -func (c *geminiCommand) runStreaming(message string) { - responseIterator := c.chat.model.SendMessageStream(message) - c.spinner.stop() - for { - response, err := responseIterator.Next() - if err != nil { - if !errors.Is(err, iterator.Done) { - fmt.Print(color.Red(err.Error())) - } - break - } - for _, candidate := range response.Candidates { - for _, part := range candidate.Content.Parts { - c.printFlush(fmt.Sprintf("%s", part)) - } - } - } - fmt.Print("\n") -} - -func (c *geminiCommand) printFlush(message string) { - _, _ = c.writer.WriteString(message) - _ = c.writer.Flush() -} diff --git a/internal/cli/prompt.go b/internal/cli/prompt.go index 5217c69..c5104fb 100644 --- a/internal/cli/prompt.go +++ b/internal/cli/prompt.go @@ -13,11 +13,11 @@ const ( cliUser = "cli" ) -type prompt struct { - user string - userNext string - gemini string - cli string +type Prompt struct { + User string + UserNext string + Gemini string + Cli string } type promptColor struct { @@ -41,14 +41,14 @@ func newPromptColor() *promptColor { } } -func newPrompt(currentUser string) *prompt { +func NewPrompt(currentUser string) *Prompt { maxLength := maxLength(currentUser, geminiUser, cliUser) pc := newPromptColor() - return &prompt{ - user: pc.user(buildPrompt(currentUser, maxLength)), - userNext: pc.user(buildPrompt(strings.Repeat(" ", len(currentUser)), maxLength)), - gemini: pc.gemini(buildPrompt(geminiUser, maxLength)), - cli: pc.cli(buildPrompt(cliUser, maxLength)), + return &Prompt{ + User: pc.user(buildPrompt(currentUser, maxLength)), + UserNext: pc.user(buildPrompt(strings.Repeat(" ", len(currentUser)), maxLength)), + Gemini: pc.gemini(buildPrompt(geminiUser, maxLength)), + Cli: pc.cli(buildPrompt(cliUser, maxLength)), } } diff --git a/internal/cli/select.go b/internal/cli/select.go deleted file mode 100644 index d82576c..0000000 --- a/internal/cli/select.go +++ /dev/null @@ -1,51 +0,0 @@ -package cli - -import ( - "slices" - - "github.com/manifoldco/promptui" -) - -var ( - inputMode = []string{"single-line", "multi-line"} -) - -// selectModel returns the selected generative model name. -func selectModel(current string, models []string) (string, error) { - prompt := promptui.Select{ - Label: "Select generative model", - HideSelected: true, - Items: models, - CursorPos: slices.Index(models, current), - } - - _, result, err := prompt.Run() - if err != nil { - return "", err - } - - return result, nil -} - -// selectInputMode returns true if multiline input is selected; -// otherwise, it returns false. -func selectInputMode(multiline bool) (bool, error) { - var cursorPos int - if multiline { - cursorPos = 1 - } - - prompt := promptui.Select{ - Label: "Select input mode", - HideSelected: true, - Items: inputMode, - CursorPos: cursorPos, - } - - _, result, err := prompt.Run() - if err != nil { - return false, err - } - - return result == inputMode[1], nil -} diff --git a/internal/cli/spinner.go b/internal/cli/spinner.go index da72766..fc92562 100644 --- a/internal/cli/spinner.go +++ b/internal/cli/spinner.go @@ -3,6 +3,7 @@ package cli import ( "bufio" "fmt" + "io" "time" ) @@ -12,44 +13,49 @@ const ( progressRune = '.' ) -type spinner struct { - length int - interval time.Duration +// Spinner is a visual indicator of progress displayed in the terminal as a +// scrolling dot animation. +type Spinner struct { writer *bufio.Writer + interval time.Duration signal chan struct{} + + maxLength int + length int } -func newSpinner(length int, interval time.Duration, writer *bufio.Writer) *spinner { - return &spinner{ - length: length, - interval: interval, - writer: writer, - signal: make(chan struct{}), +// NewSpinner returns a new Spinner. +func NewSpinner(w io.Writer, interval time.Duration, length int) *Spinner { + return &Spinner{ + writer: bufio.NewWriter(w), + interval: interval, + signal: make(chan struct{}), + maxLength: length, } } //nolint:errcheck -func (s *spinner) start() { +func (s *Spinner) Start() { go func() { ticker := time.NewTicker(s.interval) defer ticker.Stop() - var n int + s.length = 0 for { select { case <-s.signal: - if n > 0 { - s.clear(n) + if s.length > 0 { + s.Clear() } s.signal <- struct{}{} return case <-ticker.C: - if n < s.length { + if s.length < s.maxLength { s.writer.WriteRune(progressRune) s.writer.Flush() - n++ + s.length++ } else { - s.clear(n) - n = 0 + s.Clear() + s.length = 0 } } } @@ -57,13 +63,13 @@ func (s *spinner) start() { } //nolint:errcheck -func (s *spinner) clear(n int) { - s.writer.WriteString(fmt.Sprintf(moveCursorBackward, n)) +func (s *Spinner) Clear() { + s.writer.WriteString(fmt.Sprintf(moveCursorBackward, s.length)) s.writer.WriteString(clearLineFromCursor) s.writer.Flush() } -func (s *spinner) stop() { +func (s *Spinner) Stop() { s.signal <- struct{}{} <-s.signal } diff --git a/internal/config/application_data.go b/internal/config/application_data.go new file mode 100644 index 0000000..821b273 --- /dev/null +++ b/internal/config/application_data.go @@ -0,0 +1,38 @@ +package config + +import ( + "github.com/google/generative-ai-go/genai" + "github.com/reugn/gemini-cli/gemini" +) + +// ApplicationData encapsulates application state and configuration. +// Note that the chat history is stored in plain text format. +type ApplicationData struct { + SystemPrompts map[string]gemini.SystemInstruction + SafetySettings []*genai.SafetySetting + History map[string][]*gemini.SerializableContent +} + +// newDefaultApplicationData returns a new ApplicationData with default values. +func newDefaultApplicationData() *ApplicationData { + defaultSafetySettings := []*genai.SafetySetting{ + {Category: genai.HarmCategoryHarassment, Threshold: genai.HarmBlockLowAndAbove}, + {Category: genai.HarmCategoryHateSpeech, Threshold: genai.HarmBlockLowAndAbove}, + {Category: genai.HarmCategorySexuallyExplicit, Threshold: genai.HarmBlockLowAndAbove}, + {Category: genai.HarmCategoryDangerousContent, Threshold: genai.HarmBlockLowAndAbove}, + } + return &ApplicationData{ + SystemPrompts: make(map[string]gemini.SystemInstruction), + SafetySettings: defaultSafetySettings, + History: make(map[string][]*gemini.SerializableContent), + } +} + +// AddHistoryRecord adds a history record to the application data. +func (d *ApplicationData) AddHistoryRecord(label string, content []*genai.Content) { + serializableContent := make([]*gemini.SerializableContent, len(content)) + for i, c := range content { + serializableContent[i] = gemini.NewSerializableContent(c) + } + d.History[label] = serializableContent +} diff --git a/internal/config/configuration.go b/internal/config/configuration.go new file mode 100644 index 0000000..528e198 --- /dev/null +++ b/internal/config/configuration.go @@ -0,0 +1,61 @@ +package config + +import ( + "encoding/json" + "fmt" + "os" +) + +// Configuration contains the details of the application configuration. +type Configuration struct { + // filePath is the path to the configuration file. This file contains the + // application data in JSON format. + filePath string + // Data is the application data. This data is loaded from the configuration + // file and is used to configure the application. + Data *ApplicationData +} + +// NewConfiguration returns a new Configuration from a JSON file. +func NewConfiguration(filePath string) (*Configuration, error) { + configuration := &Configuration{ + filePath: filePath, + Data: newDefaultApplicationData(), + } + + file, err := os.Open(filePath) + if err != nil { + if os.IsNotExist(err) { + _ = configuration.Flush() // ignore error if file write failed + return configuration, nil + } + return nil, fmt.Errorf("error opening file: %w", err) + } + defer file.Close() + + decoder := json.NewDecoder(file) + err = decoder.Decode(configuration.Data) + if err != nil { + return nil, fmt.Errorf("error decoding JSON: %w", err) + } + + return configuration, nil +} + +// Flush serializes and writes the configuration to the file. +func (c *Configuration) Flush() error { + file, err := os.Create(c.filePath) + if err != nil { + return fmt.Errorf("error opening file: %w", err) + } + defer file.Close() + + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + err = encoder.Encode(c.Data) + if err != nil { + return fmt.Errorf("error encoding JSON: %w", err) + } + + return file.Sync() +} diff --git a/internal/handler/gemini_query.go b/internal/handler/gemini_query.go new file mode 100644 index 0000000..8e37c48 --- /dev/null +++ b/internal/handler/gemini_query.go @@ -0,0 +1,59 @@ +package handler + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/glamour" + "github.com/reugn/gemini-cli/gemini" + "github.com/reugn/gemini-cli/internal/cli" +) + +// GeminiQuery processes queries to gemini models. +// It implements the MessageHandler interface. +type GeminiQuery struct { + session *gemini.ChatSession + spinner *cli.Spinner + renderer *glamour.TermRenderer +} + +var _ MessageHandler = (*GeminiQuery)(nil) + +// NewGeminiQuery returns a new GeminiQuery message handler. +func NewGeminiQuery(session *gemini.ChatSession, spinner *cli.Spinner, + style string) (*GeminiQuery, error) { + renderer, err := glamour.NewTermRenderer(glamour.WithStylePath(style)) + if err != nil { + return nil, fmt.Errorf("failed to instantiate terminal renderer: %w", err) + } + + return &GeminiQuery{ + session: session, + spinner: spinner, + renderer: renderer, + }, nil +} + +// Handle processes the chat message. +func (h *GeminiQuery) Handle(message string) (Response, bool) { + h.spinner.Start() + response, err := h.session.SendMessage(message) + h.spinner.Stop() + if err != nil { + return newErrorResponse(err), false + } + + var b strings.Builder + for _, candidate := range response.Candidates { + for _, part := range candidate.Content.Parts { + _, _ = fmt.Fprintf(&b, "%s", part) + } + } + + rendered, err := h.renderer.Render(b.String()) + if err != nil { + return newErrorResponse(fmt.Errorf("failed to format response: %w", err)), false + } + + return dataResponse(rendered), false +} diff --git a/internal/handler/handler.go b/internal/handler/handler.go new file mode 100644 index 0000000..8822071 --- /dev/null +++ b/internal/handler/handler.go @@ -0,0 +1,8 @@ +package handler + +// MessageHandler handles chat messages from the user. +type MessageHandler interface { + // Handle processes the message and returns a response, along with a flag + // indicating whether the application should terminate. + Handle(message string) (Response, bool) +} diff --git a/internal/handler/history_command.go b/internal/handler/history_command.go new file mode 100644 index 0000000..13ad898 --- /dev/null +++ b/internal/handler/history_command.go @@ -0,0 +1,171 @@ +package handler + +import ( + "fmt" + "slices" + "time" + + "github.com/google/generative-ai-go/genai" + "github.com/manifoldco/promptui" + "github.com/reugn/gemini-cli/gemini" + "github.com/reugn/gemini-cli/internal/config" +) + +var historyOptions = []string{ + "Clear chat history", + "Store chat history", + "Load chat history", + "Delete stored history records", +} + +// HistoryCommand processes the chat history system commands. +// It implements the MessageHandler interface. +type HistoryCommand struct { + session *gemini.ChatSession + configuration *config.Configuration +} + +var _ MessageHandler = (*HistoryCommand)(nil) + +// NewHistoryCommand returns a new HistoryCommand. +func NewHistoryCommand(session *gemini.ChatSession, + configuration *config.Configuration) *HistoryCommand { + return &HistoryCommand{ + session: session, + configuration: configuration, + } +} + +// Handle processes the history system command. +func (h *HistoryCommand) Handle(_ string) (Response, bool) { + option, err := h.selectHistoryOption() + if err != nil { + return newErrorResponse(err), false + } + var response Response + switch option { + case historyOptions[0]: + response = h.handleClear() + case historyOptions[1]: + response = h.handleStore() + case historyOptions[2]: + response = h.handleLoad() + case historyOptions[3]: + response = h.handleDelete() + default: + response = newErrorResponse(fmt.Errorf("unsupported option: %s", option)) + } + return response, false +} + +// handleClear handles the chat history clear request. +func (h *HistoryCommand) handleClear() Response { + h.session.ClearHistory() + return dataResponse("Cleared the chat history.") +} + +// handleStore handles the chat history store request. +func (h *HistoryCommand) handleStore() Response { + historyLabel, err := h.promptHistoryLabel() + if err != nil { + return newErrorResponse(err) + } + + timeLabel := time.Now().In(time.Local).Format(time.DateTime) + recordLabel := fmt.Sprintf("%s - %s", timeLabel, historyLabel) + h.configuration.Data.AddHistoryRecord( + recordLabel, + h.session.GetHistory(), + ) + + if err := h.configuration.Flush(); err != nil { + return newErrorResponse(err) + } + + return dataResponse(fmt.Sprintf("%q has been saved to the file.", recordLabel)) +} + +// handleLoad handles the chat history load request. +func (h *HistoryCommand) handleLoad() Response { + label, history, err := h.loadHistory() + if err != nil { + return newErrorResponse(err) + } + + h.session.SetHistory(history) + return dataResponse(fmt.Sprintf("%q has been loaded to the chat history.", label)) +} + +// handleDelete handles deletion of the stored history records. +func (h *HistoryCommand) handleDelete() Response { + h.configuration.Data.History = make(map[string][]*gemini.SerializableContent) + if err := h.configuration.Flush(); err != nil { + return newErrorResponse(err) + } + return dataResponse("History records have been removed from the file.") +} + +// loadHistory returns history data to be set. +func (h *HistoryCommand) loadHistory() (string, []*genai.Content, error) { + promptNames := make([]string, len(h.configuration.Data.History)+1) + promptNames[0] = empty + i := 1 + for p := range h.configuration.Data.History { + promptNames[i] = p + i++ + } + prompt := promptui.Select{ + Label: "Select conversation history to load", + HideSelected: true, + Items: promptNames, + CursorPos: slices.Index(promptNames, empty), + } + + _, result, err := prompt.Run() + if err != nil { + return result, nil, err + } + + if result == empty { + return result, nil, nil + } + + serializedContent := h.configuration.Data.History[result] + content := make([]*genai.Content, len(serializedContent)) + for i, c := range serializedContent { + content[i] = c.ToContent() + } + + return result, content, nil +} + +// promptHistoryLabel returns a label for the history record. +func (h *HistoryCommand) promptHistoryLabel() (string, error) { + prompt := promptui.Prompt{ + Label: "Enter a label for the history record", + HideEntered: true, + } + + label, err := prompt.Run() + if err != nil { + return "", err + } + + return label, nil +} + +// selectHistoryOption returns the selected history action name. +func (h *HistoryCommand) selectHistoryOption() (string, error) { + prompt := promptui.Select{ + Label: "Select history option", + HideSelected: true, + Items: historyOptions, + } + + _, result, err := prompt.Run() + if err != nil { + return "", err + } + + return result, nil +} diff --git a/internal/handler/input_mode_command.go b/internal/handler/input_mode_command.go new file mode 100644 index 0000000..979b625 --- /dev/null +++ b/internal/handler/input_mode_command.go @@ -0,0 +1,80 @@ +package handler + +import ( + "fmt" + + "github.com/chzyer/readline" + "github.com/manifoldco/promptui" +) + +var inputModeOptions = []string{ + "Single-line", + "Multi-line", +} + +// InputModeCommand processes the chat input mode system command. +// It implements the MessageHandler interface. +type InputModeCommand struct { + reader *readline.Instance + multiline *bool +} + +var _ MessageHandler = (*InputModeCommand)(nil) + +// NewInputModeCommand returns a new InputModeCommand. +func NewInputModeCommand(reader *readline.Instance, multiline *bool) *InputModeCommand { + return &InputModeCommand{ + reader: reader, + multiline: multiline, + } +} + +// Handle processes the chat input mode system command. +func (h *InputModeCommand) Handle(_ string) (Response, bool) { + multiline, err := h.selectInputMode() + if err != nil { + return newErrorResponse(err), false + } + + if *h.multiline == multiline { + // the same input mode is selected + return dataResponse(unchangedMessage), false + } + + *h.multiline = multiline + if *h.multiline { + // disable history for multi-line messages since it is + // unusable for future requests + h.reader.HistoryDisable() + } else { + h.reader.HistoryEnable() + } + + mode := inputModeOptions[modeIndex(*h.multiline)] + return dataResponse(fmt.Sprintf("Switched to %q input mode.", mode)), false +} + +// selectInputMode returns true if multiline input is selected; +// otherwise, it returns false. +func (h *InputModeCommand) selectInputMode() (bool, error) { + prompt := promptui.Select{ + Label: "Select input mode", + HideSelected: true, + Items: inputModeOptions, + CursorPos: modeIndex(*h.multiline), + } + + _, result, err := prompt.Run() + if err != nil { + return false, err + } + + return result == inputModeOptions[1], nil +} + +func modeIndex(b bool) int { + if b { + return 1 + } + return 0 +} diff --git a/internal/handler/model_command.go b/internal/handler/model_command.go new file mode 100644 index 0000000..0cd918a --- /dev/null +++ b/internal/handler/model_command.go @@ -0,0 +1,110 @@ +package handler + +import ( + "fmt" + "slices" + + "github.com/manifoldco/promptui" + "github.com/reugn/gemini-cli/gemini" +) + +var modelOptions = []string{ + "Select generative model", + "Chat model info", +} + +// ModelCommand processes the chat model system commands. +// It implements the MessageHandler interface. +type ModelCommand struct { + session *gemini.ChatSession + currentModel string +} + +var _ MessageHandler = (*ModelCommand)(nil) + +// NewModelCommand returns a new ModelCommand. +func NewModelCommand(session *gemini.ChatSession, modelName string) *ModelCommand { + return &ModelCommand{ + session: session, + currentModel: modelName, + } +} + +// Handle processes the chat model system command. +func (h *ModelCommand) Handle(_ string) (Response, bool) { + option, err := h.selectModelOption() + if err != nil { + return newErrorResponse(err), false + } + + var response Response + switch option { + case modelOptions[0]: + response = h.handleSelectModel() + case modelOptions[1]: + response = h.handleModelInfo() + default: + response = newErrorResponse(fmt.Errorf("unsupported option: %s", option)) + } + return response, false +} + +// handleSelectModel handles the generative model selection. +func (h *ModelCommand) handleSelectModel() Response { + model, err := h.selectModel(h.session.ListModels()) + if err != nil { + return newErrorResponse(err) + } + + if h.currentModel == model { + return dataResponse(unchangedMessage) + } + + modelBuilder := h.session.CopyModelBuilder().WithName(model) + h.session.SetModel(modelBuilder) + h.currentModel = model + + return dataResponse(fmt.Sprintf("Selected %q generative model.", model)) +} + +// handleSelectModel handles the current generative model info request. +func (h *ModelCommand) handleModelInfo() Response { + modelInfo, err := h.session.ModelInfo() + if err != nil { + return newErrorResponse(err) + } + return dataResponse(modelInfo) +} + +// selectModelOption returns the selected action name. +func (h *ModelCommand) selectModelOption() (string, error) { + prompt := promptui.Select{ + Label: "Select model option", + HideSelected: true, + Items: modelOptions, + } + + _, result, err := prompt.Run() + if err != nil { + return "", err + } + + return result, nil +} + +// selectModel returns the selected generative model name. +func (h *ModelCommand) selectModel(models []string) (string, error) { + prompt := promptui.Select{ + Label: "Select generative session", + HideSelected: true, + Items: models, + CursorPos: slices.Index(models, h.currentModel), + } + + _, result, err := prompt.Run() + if err != nil { + return "", err + } + + return result, nil +} diff --git a/internal/handler/prompt_command.go b/internal/handler/prompt_command.go new file mode 100644 index 0000000..15b7ea0 --- /dev/null +++ b/internal/handler/prompt_command.go @@ -0,0 +1,75 @@ +package handler + +import ( + "fmt" + "slices" + + "github.com/google/generative-ai-go/genai" + "github.com/manifoldco/promptui" + "github.com/reugn/gemini-cli/gemini" + "github.com/reugn/gemini-cli/internal/config" +) + +// SystemPromptCommand processes the chat prompt system command. +// It implements the MessageHandler interface. +type SystemPromptCommand struct { + session *gemini.ChatSession + applicationData *config.ApplicationData + + currentPrompt string +} + +var _ MessageHandler = (*SystemPromptCommand)(nil) + +// NewSystemPromptCommand returns a new SystemPromptCommand. +func NewSystemPromptCommand(session *gemini.ChatSession, + applicationData *config.ApplicationData) *SystemPromptCommand { + return &SystemPromptCommand{ + session: session, + applicationData: applicationData, + } +} + +// Handle processes the chat prompt system command. +func (h *SystemPromptCommand) Handle(_ string) (Response, bool) { + label, systemPrompt, err := h.selectSystemPrompt() + if err != nil { + return newErrorResponse(err), false + } + + modelBuilder := h.session.CopyModelBuilder(). + WithSystemInstruction(systemPrompt) + h.session.SetModel(modelBuilder) + + return dataResponse(fmt.Sprintf("Selected %q system instruction.", label)), false +} + +// selectSystemPrompt returns a system instruction to be set. +func (h *SystemPromptCommand) selectSystemPrompt() (string, *genai.Content, error) { + promptNames := make([]string, len(h.applicationData.SystemPrompts)+1) + promptNames[0] = empty + i := 1 + for p := range h.applicationData.SystemPrompts { + promptNames[i] = p + i++ + } + prompt := promptui.Select{ + Label: "Select system instruction", + HideSelected: true, + Items: promptNames, + CursorPos: slices.Index(promptNames, h.currentPrompt), + } + + _, result, err := prompt.Run() + if err != nil { + return result, nil, err + } + + h.currentPrompt = result + if result == empty { + return result, nil, nil + } + + systemInstruction := h.applicationData.SystemPrompts[result] + return result, systemInstruction.ToContent(), nil +} diff --git a/internal/handler/quit_command.go b/internal/handler/quit_command.go new file mode 100644 index 0000000..af2f408 --- /dev/null +++ b/internal/handler/quit_command.go @@ -0,0 +1,18 @@ +package handler + +// QuitCommand processes the chat quit system command. +// It implements the MessageHandler interface. +type QuitCommand struct { +} + +var _ MessageHandler = (*QuitCommand)(nil) + +// NewQuitCommand returns a new QuitCommand. +func NewQuitCommand() *QuitCommand { + return &QuitCommand{} +} + +// Handle processes the chat quit command. +func (h *QuitCommand) Handle(_ string) (Response, bool) { + return dataResponse("Exiting gemini-cli..."), true +} diff --git a/internal/handler/response.go b/internal/handler/response.go new file mode 100644 index 0000000..f748b2d --- /dev/null +++ b/internal/handler/response.go @@ -0,0 +1,46 @@ +package handler + +import ( + "fmt" + "io" + + "github.com/reugn/gemini-cli/internal/cli/color" +) + +const ( + empty = "Empty" + unchangedMessage = "The selection is unchanged." +) + +// Response represents a response from a chat message handler. +type Response interface { + Print(w io.Writer, prompt string) error +} + +type dataResponse string + +var _ Response = (*dataResponse)(nil) + +func (r dataResponse) Print(w io.Writer, prompt string) error { + _, err := fmt.Fprintf(w, "%s%s\n", prompt, r) + return err +} + +type errorResponse struct { + error +} + +func newErrorResponse(err error) errorResponse { + return errorResponse{error: err} +} + +var _ Response = (*errorResponse)(nil) + +func (r errorResponse) Print(w io.Writer, prompt string) error { + _, err := fmt.Fprintf(w, "%s%s\n", prompt, color.Red(r.Error())) + return err +} + +func PrintError(w io.Writer, prompt string, err error) { + _ = newErrorResponse(err).Print(w, prompt) +} diff --git a/internal/handler/system_command.go b/internal/handler/system_command.go new file mode 100644 index 0000000..f05c0cd --- /dev/null +++ b/internal/handler/system_command.go @@ -0,0 +1,63 @@ +package handler + +import ( + "fmt" + "strings" + + "github.com/chzyer/readline" + "github.com/reugn/gemini-cli/gemini" + "github.com/reugn/gemini-cli/internal/config" +) + +const ( + SystemCmdPrefix = "!" + SystemCmdQuit = "q" + systemCmdSelectPrompt = "p" + systemCmdSelectInputMode = "i" + systemCmdModel = "m" + systemCmdHistory = "h" +) + +// SystemCommand processes chat system commands; implements the MessageHandler interface. +// It aggregates the processing by delegating it to one of the underlying handlers. +type SystemCommand struct { + handlers map[string]MessageHandler +} + +var _ MessageHandler = (*SystemCommand)(nil) + +// NewSystemCommand returns a new SystemCommand. +func NewSystemCommand(session *gemini.ChatSession, configuration *config.Configuration, + reader *readline.Instance, multiline *bool, modelName string) *SystemCommand { + handlers := map[string]MessageHandler{ + SystemCmdQuit: NewQuitCommand(), + systemCmdSelectPrompt: NewSystemPromptCommand(session, configuration.Data), + systemCmdSelectInputMode: NewInputModeCommand(reader, multiline), + systemCmdModel: NewModelCommand(session, modelName), + systemCmdHistory: NewHistoryCommand(session, configuration), + } + + return &SystemCommand{ + handlers: handlers, + } +} + +// Handle processes the chat system command. +func (s *SystemCommand) Handle(message string) (Response, bool) { + if !strings.HasPrefix(message, SystemCmdPrefix) { + return newErrorResponse(fmt.Errorf("system command mismatch")), false + } + + var args string + t := strings.SplitN(message, " ", 2) + if len(t) == 2 { + args = t[1] + } + + systemHandler, ok := s.handlers[message[1:]] + if !ok { + return newErrorResponse(fmt.Errorf("unknown system command")), false + } + + return systemHandler.Handle(args) +}