Skip to content

Commit

Permalink
feat: support model preference by type: text vs StT
Browse files Browse the repository at this point in the history
  • Loading branch information
daulet committed Sep 4, 2024
1 parent 63de608 commit 36605f9
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 14 deletions.
12 changes: 7 additions & 5 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ func parseConfig(ctx context.Context) (bool, error) {
fmt.Println(model)
}
fmt.Println()
if cfg.Model != nil {
fmt.Printf("Currently selected model: %s\n", *cfg.Model)
for modelType, model := range cfg.Model {
fmt.Printf("Currently selected model for %s: %s\n", modelType, model)
}
return true, nil
}
Expand All @@ -214,8 +214,11 @@ func parseConfig(ctx context.Context) (bool, error) {
dirtyCfg := false
flag.Visit(func(f *flag.Flag) {
switch f.Name {
// TODO changing provider should reset model selection
case "model":
cfg.Model = setModel
// TODO there is no way to unset model
modelType := config.ModelType(*setModel)
cfg.Model[modelType] = *setModel
dirtyCfg = true
case "connectors":
cfg.Connectors = strings.Split(*setConnectors, ",")
Expand Down Expand Up @@ -398,12 +401,11 @@ func main() {
}

err = cmd(ctx)
color.Set(color.FgYellow)
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
color.Yellow("error: %v\n", err)
os.Exit(1)
}
}
12 changes: 8 additions & 4 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ const configPath = ".cmd/config.json"
const (
ProviderGroq = "groq"
ProviderCohere = "cohere"

ModelTypeChat = "chat"
ModelTypeSpeechToText = "stt"
)

type Config struct {
Provider string `json:"provider,omitempty"`

Record bool `json:"record,omitempty"`
Model *string `json:"model,omitempty"`
Connectors []string `json:"connectors,omitempty"`
Record bool `json:"record,omitempty"`
Model map[string]string `json:"model,omitempty"`
Connectors []string `json:"connectors,omitempty"`

// Sampling parameters
Temperature *float64 `json:"temperature,omitempty"`
Expand All @@ -39,6 +42,7 @@ func ReadConfig() (*Config, error) {
Provider: ProviderGroq,
// record by default
Record: true,
Model: make(map[string]string),
}
data, err := os.ReadFile(path)
if err != nil {
Expand Down Expand Up @@ -74,6 +78,6 @@ func ConfigPath() (string, error) {
return filepath.Join(homeDir, configPath), nil
}

func ref(v string) *string {
func Ref(v string) *string {
return &v
}
18 changes: 18 additions & 0 deletions config/known.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package config

import "strings"

func ModelType(model string) string {
switch {
case strings.Contains(model, "command"):
return ModelTypeChat
case strings.Contains(model, "gemma"):
return ModelTypeChat
case strings.Contains(model, "llama"):
return ModelTypeChat
case strings.Contains(model, "whisper"):
return ModelTypeSpeechToText
default:
panic("unknown model: " + model)
}
}
6 changes: 5 additions & 1 deletion provider/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ func (p *cohereProvider) Stream(ctx context.Context, cfg *config.Config, msgs []
log.Fatalf("unknown role: %s", msg.Role)
}
}
var model *string
if cfg.Model[config.ModelTypeChat] != "" {
model = config.Ref(cfg.Model[config.ModelTypeChat])
}
req := &co.ChatStreamRequest{
ChatHistory: messages[:len(messages)-1],
Message: messages[len(messages)-1].Message,

Model: cfg.Model,
Model: model,
Temperature: cfg.Temperature,
P: cfg.TopP,
K: cfg.TopK,
Expand Down
8 changes: 4 additions & 4 deletions provider/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ func (p *openAIProvider) Stream(ctx context.Context, cfg *config.Config, msgs []
}

model := DEFAULT_CHAT_MODEL
if cfg.Model != nil {
model = *cfg.Model
if cfg.Model[config.ModelTypeChat] != "" {
model = cfg.Model[config.ModelTypeChat]
}
stream, err := p.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
Model: model,
Expand All @@ -64,8 +64,8 @@ func (p *openAIProvider) Stream(ctx context.Context, cfg *config.Config, msgs []

func (p *openAIProvider) Transcribe(ctx context.Context, cfg *config.Config, audio *AudioFile) ([]*AudioSegment, error) {
model := DEFAULT_AUDIO_MODEL
if cfg.Model != nil {
model = *cfg.Model
if cfg.Model[config.ModelTypeSpeechToText] != "" {
model = cfg.Model[config.ModelTypeSpeechToText]
}
res, err := p.client.CreateTranscription(ctx, openai.AudioRequest{
Model: model,
Expand Down

0 comments on commit 36605f9

Please sign in to comment.