Skip to content

Commit

Permalink
fix: no panics
Browse files Browse the repository at this point in the history
  • Loading branch information
daulet committed Sep 8, 2024
1 parent 88a9091 commit 03eb9ec
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
32 changes: 21 additions & 11 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,12 @@ func parseConfig(ctx context.Context, flagDefs []*flags.Option, flagVals *flagVa
if flagVals.SetModel != nil {
dirtyCfg = true
model := *flagVals.SetModel
modelType, err := config.ModelType(model)
if err != nil {
return false, err
}
// TODO there is no way to unset model
cfg.Model[config.ModelType(model)] = model
cfg.Model[modelType] = model
}

if flagVals.SetConnectors != nil {
Expand Down Expand Up @@ -373,24 +377,24 @@ func cmd(ctx context.Context, usrMsg string, flagVals *flagValues) error {
return err
}

func main() {
func run() error {
flagVals := &flagValues{}
parser := flags.NewParser(nil, flags.Default)
g, err := parser.AddGroup("Application Options", "", flagVals)
if err != nil {
panic(err)
return err
}
unparsed, err := parser.Parse()
if err != nil {
panic(err)
return err
}
usrMsg := strings.Join(unparsed, " ")
flagDefs := g.Options()

// read config first so we can use the right provider
cfg, err := config.ReadConfig()
if err != nil {
panic(err)
return err
}

switch cfg.Provider {
Expand All @@ -399,20 +403,19 @@ func main() {
case config.ProviderCohere:
prov, err = provider.NewCohereProvider()
default:
log.Fatalf("unknown provider: %s", cfg.Provider)
return fmt.Errorf("unknown provider in config: %s", cfg.Provider)
}
if err != nil {
color.Yellow("error: %v\n", err)
os.Exit(1)
return err
}

ctx := context.Background()
done, err := parseConfig(ctx, flagDefs, flagVals)
if err != nil {
panic(err)
return err
}
if done {
return
return nil
}

if cfg.Record {
Expand All @@ -424,7 +427,14 @@ func main() {
prov = closer
}

err = cmd(ctx, usrMsg, flagVals)
return cmd(ctx, usrMsg, flagVals)
}

func main() {
err := run()
if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp {
os.Exit(1)
}
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
Expand Down
17 changes: 10 additions & 7 deletions config/known.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
package config

import "strings"
import (
"fmt"
"strings"
)

func ModelType(model string) string {
func ModelType(model string) (string, error) {
switch {
case strings.Contains(model, "command"):
return ModelTypeChat
return ModelTypeChat, nil
case strings.Contains(model, "gemma"):
return ModelTypeChat
return ModelTypeChat, nil
case strings.Contains(model, "llama"):
return ModelTypeChat
return ModelTypeChat, nil
case strings.Contains(model, "whisper"):
return ModelTypeSpeechToText
return ModelTypeSpeechToText, nil
default:
panic("unknown model: " + model)
return "", fmt.Errorf("unknown model: %s", model)
}
}

0 comments on commit 03eb9ec

Please sign in to comment.