From 36605f90116459ca326c083c6998c4442f4c9711 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Tue, 3 Sep 2024 18:22:37 -0700 Subject: [PATCH] feat: support model preference by type: text vs StT --- cmd/main.go | 12 +++++++----- config/config.go | 12 ++++++++---- config/known.go | 18 ++++++++++++++++++ provider/cohere.go | 6 +++++- provider/openai.go | 8 ++++---- 5 files changed, 42 insertions(+), 14 deletions(-) create mode 100644 config/known.go diff --git a/cmd/main.go b/cmd/main.go index 7837d16..d111f51 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -191,8 +191,8 @@ func parseConfig(ctx context.Context) (bool, error) { fmt.Println(model) } fmt.Println() - if cfg.Model != nil { - fmt.Printf("Currently selected model: %s\n", *cfg.Model) + for modelType, model := range cfg.Model { + fmt.Printf("Currently selected model for %s: %s\n", modelType, model) } return true, nil } @@ -214,8 +214,11 @@ func parseConfig(ctx context.Context) (bool, error) { dirtyCfg := false flag.Visit(func(f *flag.Flag) { switch f.Name { + // TODO changing provider should reset model selection case "model": - cfg.Model = setModel + // TODO there is no way to unset model + modelType := config.ModelType(*setModel) + cfg.Model[modelType] = *setModel dirtyCfg = true case "connectors": cfg.Connectors = strings.Split(*setConnectors, ",") @@ -398,12 +401,11 @@ func main() { } err = cmd(ctx) - color.Set(color.FgYellow) if exitErr, ok := err.(*exec.ExitError); ok { os.Exit(exitErr.ExitCode()) } if err != nil { - fmt.Fprintf(os.Stderr, "error: %v\n", err) + color.Yellow("error: %v\n", err) os.Exit(1) } } diff --git a/config/config.go b/config/config.go index d39e887..21c1cef 100644 --- a/config/config.go +++ b/config/config.go @@ -12,14 +12,17 @@ const configPath = ".cmd/config.json" const ( ProviderGroq = "groq" ProviderCohere = "cohere" + + ModelTypeChat = "chat" + ModelTypeSpeechToText = "stt" ) type Config struct { Provider string `json:"provider,omitempty"` - Record bool `json:"record,omitempty"` - Model *string `json:"model,omitempty"` - Connectors []string `json:"connectors,omitempty"` + Record bool `json:"record,omitempty"` + Model map[string]string `json:"model,omitempty"` + Connectors []string `json:"connectors,omitempty"` // Sampling parameters Temperature *float64 `json:"temperature,omitempty"` @@ -39,6 +42,7 @@ func ReadConfig() (*Config, error) { Provider: ProviderGroq, // record by default Record: true, + Model: make(map[string]string), } data, err := os.ReadFile(path) if err != nil { @@ -74,6 +78,6 @@ func ConfigPath() (string, error) { return filepath.Join(homeDir, configPath), nil } -func ref(v string) *string { +func Ref(v string) *string { return &v } diff --git a/config/known.go b/config/known.go new file mode 100644 index 0000000..3923f35 --- /dev/null +++ b/config/known.go @@ -0,0 +1,18 @@ +package config + +import "strings" + +func ModelType(model string) string { + switch { + case strings.Contains(model, "command"): + return ModelTypeChat + case strings.Contains(model, "gemma"): + return ModelTypeChat + case strings.Contains(model, "llama"): + return ModelTypeChat + case strings.Contains(model, "whisper"): + return ModelTypeSpeechToText + default: + panic("unknown model: " + model) + } +} diff --git a/provider/cohere.go b/provider/cohere.go index d997c30..704ec1b 100644 --- a/provider/cohere.go +++ b/provider/cohere.go @@ -41,11 +41,15 @@ func (p *cohereProvider) Stream(ctx context.Context, cfg *config.Config, msgs [] log.Fatalf("unknown role: %s", msg.Role) } } + var model *string + if cfg.Model[config.ModelTypeChat] != "" { + model = config.Ref(cfg.Model[config.ModelTypeChat]) + } req := &co.ChatStreamRequest{ ChatHistory: messages[:len(messages)-1], Message: messages[len(messages)-1].Message, - Model: cfg.Model, + Model: model, Temperature: cfg.Temperature, P: cfg.TopP, K: cfg.TopK, diff --git a/provider/openai.go b/provider/openai.go index d0e752c..ce2b092 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -49,8 +49,8 @@ func (p *openAIProvider) Stream(ctx context.Context, cfg *config.Config, msgs [] } model := DEFAULT_CHAT_MODEL - if cfg.Model != nil { - model = *cfg.Model + if cfg.Model[config.ModelTypeChat] != "" { + model = cfg.Model[config.ModelTypeChat] } stream, err := p.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{ Model: model, @@ -64,8 +64,8 @@ func (p *openAIProvider) Stream(ctx context.Context, cfg *config.Config, msgs [] func (p *openAIProvider) Transcribe(ctx context.Context, cfg *config.Config, audio *AudioFile) ([]*AudioSegment, error) { model := DEFAULT_AUDIO_MODEL - if cfg.Model != nil { - model = *cfg.Model + if cfg.Model[config.ModelTypeSpeechToText] != "" { + model = cfg.Model[config.ModelTypeSpeechToText] } res, err := p.client.CreateTranscription(ctx, openai.AudioRequest{ Model: model,