From 88a9091f494255a420827bfa331eeaca475ba387 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Sun, 8 Sep 2024 01:33:46 -0700 Subject: [PATCH] feat: better flags --- cmd/main.go | 146 +++++++++++++++++++++++++++++----------------------- go.mod | 3 +- go.sum | 6 ++- 3 files changed, 89 insertions(+), 66 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index b807ea6..e347584 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "encoding/json" - "flag" "fmt" "io" "log" @@ -18,6 +17,7 @@ import ( "github.com/daulet/cmd/provider" "github.com/fatih/color" + "github.com/jessevdk/go-flags" ) const ( @@ -27,25 +27,28 @@ const ( var ( prov provider.Provider cfg *config.Config - - // config flags - showConfig = flag.Bool("config", false, "Show current config.") - listModels = flag.Bool("list-models", false, "List available models.") - setModel = flag.String("model", "", "Set model to use.") - listConnectors = flag.Bool("list-connectors", false, "List available connectors.") - setConnectors = flag.String("connectors", "", "Set comma delimited list of connectors to use.") - setTemp = flag.Float64("temperature", 0.0, "Set temperature value.") - setTopP = flag.Float64("top-p", 0.0, "Set top-p value.") - setTopK = flag.Int("top-k", 0, "Set top-k value.") - setFreqPen = flag.Float64("frequency-penalty", 0.0, "Set frequency penalty value.") - setPresPen = flag.Float64("presence-penalty", 0.0, "Set presence penalty value.") - - // TODO make -i option - chat = flag.Bool("chat", false, "Start chat session with LLM, other flags apply.") - execute = flag.Bool("exec", false, "Execute generated command/code, do not show LLM output.") - run = flag.Bool("run", false, "Stream LLM output and run generated command/code at the end.") ) +type flagValues struct { + Interactive bool `short:"i" long:"interactive" description:"Start chat session with LLM, other flags apply."` + Execute bool `short:"e" long:"execute" description:"Execute generated command/code, do not show LLM output."` + Run bool `short:"r" long:"run" description:"Stream LLM output and run generated command/code at the end."` + + ShowConfig bool `short:"c" long:"config" description:"Show current config."` + + ListModels bool `long:"list-models" description:"List available models."` + SetModel *string `long:"model" description:"Set model to use."` + + ListConnectors bool `long:"list-connectors" description:"List available connectors."` + SetConnectors []string `long:"connector" description:"Set connectors to use."` + + SetTemperature *float64 `short:"t" long:"temperature" description:"Set temperature value."` + SetTopP *float64 `short:"p" long:"top-p" description:"Set top-p value."` + SetTopK *int `short:"k" long:"top-k" description:"Set top-k value."` + SetFrequencyPenalty *float64 `long:"freq" description:"Set frequency penalty value."` + SetPresencePenalty *float64 `long:"pres" description:"Set presence penalty value."` +} + func multiTurn( ctx context.Context, out io.WriteCloser, @@ -162,14 +165,14 @@ func runCmd(prog string, args ...string) error { return cmd.Run() } -func parseConfig(ctx context.Context) (bool, error) { +func parseConfig(ctx context.Context, flagDefs []*flags.Option, flagVals *flagValues) (bool, error) { var err error cfg, err = config.ReadConfig() if err != nil { return false, err } - if *showConfig { + if flagVals.ShowConfig { data, err := json.MarshalIndent(cfg, "", " ") if err != nil { return false, err @@ -179,7 +182,7 @@ func parseConfig(ctx context.Context) (bool, error) { return true, nil } - if *listModels { + if flagVals.ListModels { modelNames, err := prov.ListModels(ctx) if err != nil { return false, err @@ -195,7 +198,7 @@ func parseConfig(ctx context.Context) (bool, error) { return true, nil } - if *listConnectors { + if flagVals.ListConnectors { connectorIDs, err := prov.ListConnectors(ctx) if err != nil { return false, err @@ -210,37 +213,44 @@ func parseConfig(ctx context.Context) (bool, error) { } dirtyCfg := false - flag.Visit(func(f *flag.Flag) { - switch f.Name { - // TODO changing provider should reset model selection - case "model": - // TODO there is no way to unset model - modelType := config.ModelType(*setModel) - cfg.Model[modelType] = *setModel - dirtyCfg = true - case "connectors": - cfg.Connectors = strings.Split(*setConnectors, ",") - if *setConnectors == "" { - cfg.Connectors = nil - } - dirtyCfg = true - case "temperature": - cfg.Temperature = setTemp - dirtyCfg = true - case "top-p": - cfg.TopP = setTopP - dirtyCfg = true - case "top-k": - cfg.TopK = setTopK - dirtyCfg = true - case "frequency-penalty": - cfg.FrequencyPenalty = setFreqPen - dirtyCfg = true - case "presence-penalty": - cfg.PresencePenalty = setPresPen - dirtyCfg = true - } - }) + // TODO changing provider should reset model selection + if flagVals.SetModel != nil { + dirtyCfg = true + model := *flagVals.SetModel + // TODO there is no way to unset model + cfg.Model[config.ModelType(model)] = model + } + + if flagVals.SetConnectors != nil { + dirtyCfg = true + cfg.Connectors = flagVals.SetConnectors + } + + if flagVals.SetTemperature != nil { + dirtyCfg = true + cfg.Temperature = flagVals.SetTemperature + } + + if flagVals.SetTopP != nil { + dirtyCfg = true + cfg.TopP = flagVals.SetTopP + } + + if flagVals.SetTopK != nil { + dirtyCfg = true + cfg.TopK = flagVals.SetTopK + } + + if flagVals.SetFrequencyPenalty != nil { + dirtyCfg = true + cfg.FrequencyPenalty = flagVals.SetFrequencyPenalty + } + + if flagVals.SetPresencePenalty != nil { + dirtyCfg = true + cfg.PresencePenalty = flagVals.SetPresencePenalty + } + if dirtyCfg { if err := config.WriteConfig(cfg); err != nil { return false, err @@ -256,13 +266,13 @@ func parseConfig(ctx context.Context) (bool, error) { return false, nil } -func cmd(ctx context.Context) error { +func cmd(ctx context.Context, usrMsg string, flagVals *flagValues) error { turnFn := func(ctx context.Context, out io.WriteCloser, msgs []*provider.Message) (string, error) { var blocks []*parser.CodeBlock done := make(chan struct{}) switch { - case *execute: + case flagVals.Execute: codeW, blockCh := parser.NewCode() go func() { defer close(done) @@ -272,7 +282,7 @@ func cmd(ctx context.Context) error { }() // no output to the user, we just execute the code out = codeW - case *run: + case flagVals.Run: codeW, blockCh := parser.NewCode() go func() { defer close(done) @@ -318,16 +328,15 @@ func cmd(ctx context.Context) error { } pipeContent = string(pipeBytes) } - if *chat { + if flagVals.Interactive { in, err = os.Open("/dev/tty") if err != nil { return fmt.Errorf("failed to open /dev/tty: %w", err) } } - if flag.NArg() == 0 && pipeContent == "" { + if usrMsg == "" && pipeContent == "" { return fmt.Errorf("what's your command?") } - usrMsg := strings.Join(flag.Args(), " ") if pipeContent != "" { usrMsg = fmt.Sprintf(CONTEXT_TEMPLATE, pipeContent, usrMsg) } @@ -351,7 +360,7 @@ func cmd(ctx context.Context) error { } switch { - case *chat: + case flagVals.Interactive: err = multiTurn(ctx, os.Stdout, in, usrMsg, turnFn) default: _, err = turnFn(ctx, os.Stdout, []*provider.Message{ @@ -365,7 +374,18 @@ func cmd(ctx context.Context) error { } func main() { - flag.Parse() + flagVals := &flagValues{} + parser := flags.NewParser(nil, flags.Default) + g, err := parser.AddGroup("Application Options", "", flagVals) + if err != nil { + panic(err) + } + unparsed, err := parser.Parse() + if err != nil { + panic(err) + } + usrMsg := strings.Join(unparsed, " ") + flagDefs := g.Options() // read config first so we can use the right provider cfg, err := config.ReadConfig() @@ -387,7 +407,7 @@ func main() { } ctx := context.Background() - done, err := parseConfig(ctx) + done, err := parseConfig(ctx, flagDefs, flagVals) if err != nil { panic(err) } @@ -404,7 +424,7 @@ func main() { prov = closer } - err = cmd(ctx) + err = cmd(ctx, usrMsg, flagVals) if exitErr, ok := err.(*exec.ExitError); ok { os.Exit(exitErr.ExitCode()) } diff --git a/go.mod b/go.mod index 03e4087..d1bac52 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.0 require ( github.com/cohere-ai/cohere-go/v2 v2.7.0 github.com/fatih/color v1.17.0 + github.com/jessevdk/go-flags v1.6.1 github.com/sashabaranov/go-openai v1.29.0 ) @@ -12,5 +13,5 @@ require ( github.com/google/uuid v1.4.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - golang.org/x/sys v0.18.0 // indirect + golang.org/x/sys v0.21.0 // indirect ) diff --git a/go.sum b/go.sum index 4ecf21b..f0ff657 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jessevdk/go-flags v1.6.1 h1:Cvu5U8UGrLay1rZfv/zP7iLpSHGUZ/Ou68T0iX1bBK4= +github.com/jessevdk/go-flags v1.6.1/go.mod h1:Mk8T1hIAWpOiJiHa9rJASDK2UGWji0EuPGBnNLMooyc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -19,7 +21,7 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=