Skip to content

Commit

Permalink
feat: better flags
Browse files Browse the repository at this point in the history
  • Loading branch information
daulet committed Sep 8, 2024
1 parent faf98a4 commit 88a9091
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 66 deletions.
146 changes: 83 additions & 63 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"log"
Expand All @@ -18,6 +17,7 @@ import (
"github.com/daulet/cmd/provider"

"github.com/fatih/color"
"github.com/jessevdk/go-flags"
)

const (
Expand All @@ -27,25 +27,28 @@ const (
var (
prov provider.Provider
cfg *config.Config

// config flags
showConfig = flag.Bool("config", false, "Show current config.")
listModels = flag.Bool("list-models", false, "List available models.")
setModel = flag.String("model", "", "Set model to use.")
listConnectors = flag.Bool("list-connectors", false, "List available connectors.")
setConnectors = flag.String("connectors", "", "Set comma delimited list of connectors to use.")
setTemp = flag.Float64("temperature", 0.0, "Set temperature value.")
setTopP = flag.Float64("top-p", 0.0, "Set top-p value.")
setTopK = flag.Int("top-k", 0, "Set top-k value.")
setFreqPen = flag.Float64("frequency-penalty", 0.0, "Set frequency penalty value.")
setPresPen = flag.Float64("presence-penalty", 0.0, "Set presence penalty value.")

// TODO make -i option
chat = flag.Bool("chat", false, "Start chat session with LLM, other flags apply.")
execute = flag.Bool("exec", false, "Execute generated command/code, do not show LLM output.")
run = flag.Bool("run", false, "Stream LLM output and run generated command/code at the end.")
)

type flagValues struct {
Interactive bool `short:"i" long:"interactive" description:"Start chat session with LLM, other flags apply."`
Execute bool `short:"e" long:"execute" description:"Execute generated command/code, do not show LLM output."`
Run bool `short:"r" long:"run" description:"Stream LLM output and run generated command/code at the end."`

ShowConfig bool `short:"c" long:"config" description:"Show current config."`

ListModels bool `long:"list-models" description:"List available models."`
SetModel *string `long:"model" description:"Set model to use."`

ListConnectors bool `long:"list-connectors" description:"List available connectors."`
SetConnectors []string `long:"connector" description:"Set connectors to use."`

SetTemperature *float64 `short:"t" long:"temperature" description:"Set temperature value."`
SetTopP *float64 `short:"p" long:"top-p" description:"Set top-p value."`
SetTopK *int `short:"k" long:"top-k" description:"Set top-k value."`
SetFrequencyPenalty *float64 `long:"freq" description:"Set frequency penalty value."`
SetPresencePenalty *float64 `long:"pres" description:"Set presence penalty value."`
}

func multiTurn(
ctx context.Context,
out io.WriteCloser,
Expand Down Expand Up @@ -162,14 +165,14 @@ func runCmd(prog string, args ...string) error {
return cmd.Run()
}

func parseConfig(ctx context.Context) (bool, error) {
func parseConfig(ctx context.Context, flagDefs []*flags.Option, flagVals *flagValues) (bool, error) {
var err error
cfg, err = config.ReadConfig()
if err != nil {
return false, err
}

if *showConfig {
if flagVals.ShowConfig {
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return false, err
Expand All @@ -179,7 +182,7 @@ func parseConfig(ctx context.Context) (bool, error) {
return true, nil
}

if *listModels {
if flagVals.ListModels {
modelNames, err := prov.ListModels(ctx)
if err != nil {
return false, err
Expand All @@ -195,7 +198,7 @@ func parseConfig(ctx context.Context) (bool, error) {
return true, nil
}

if *listConnectors {
if flagVals.ListConnectors {
connectorIDs, err := prov.ListConnectors(ctx)
if err != nil {
return false, err
Expand All @@ -210,37 +213,44 @@ func parseConfig(ctx context.Context) (bool, error) {
}

dirtyCfg := false
flag.Visit(func(f *flag.Flag) {
switch f.Name {
// TODO changing provider should reset model selection
case "model":
// TODO there is no way to unset model
modelType := config.ModelType(*setModel)
cfg.Model[modelType] = *setModel
dirtyCfg = true
case "connectors":
cfg.Connectors = strings.Split(*setConnectors, ",")
if *setConnectors == "" {
cfg.Connectors = nil
}
dirtyCfg = true
case "temperature":
cfg.Temperature = setTemp
dirtyCfg = true
case "top-p":
cfg.TopP = setTopP
dirtyCfg = true
case "top-k":
cfg.TopK = setTopK
dirtyCfg = true
case "frequency-penalty":
cfg.FrequencyPenalty = setFreqPen
dirtyCfg = true
case "presence-penalty":
cfg.PresencePenalty = setPresPen
dirtyCfg = true
}
})
// TODO changing provider should reset model selection
if flagVals.SetModel != nil {
dirtyCfg = true
model := *flagVals.SetModel
// TODO there is no way to unset model
cfg.Model[config.ModelType(model)] = model
}

if flagVals.SetConnectors != nil {
dirtyCfg = true
cfg.Connectors = flagVals.SetConnectors
}

if flagVals.SetTemperature != nil {
dirtyCfg = true
cfg.Temperature = flagVals.SetTemperature
}

if flagVals.SetTopP != nil {
dirtyCfg = true
cfg.TopP = flagVals.SetTopP
}

if flagVals.SetTopK != nil {
dirtyCfg = true
cfg.TopK = flagVals.SetTopK
}

if flagVals.SetFrequencyPenalty != nil {
dirtyCfg = true
cfg.FrequencyPenalty = flagVals.SetFrequencyPenalty
}

if flagVals.SetPresencePenalty != nil {
dirtyCfg = true
cfg.PresencePenalty = flagVals.SetPresencePenalty
}

if dirtyCfg {
if err := config.WriteConfig(cfg); err != nil {
return false, err
Expand All @@ -256,13 +266,13 @@ func parseConfig(ctx context.Context) (bool, error) {
return false, nil
}

func cmd(ctx context.Context) error {
func cmd(ctx context.Context, usrMsg string, flagVals *flagValues) error {
turnFn := func(ctx context.Context, out io.WriteCloser, msgs []*provider.Message) (string, error) {
var blocks []*parser.CodeBlock
done := make(chan struct{})

switch {
case *execute:
case flagVals.Execute:
codeW, blockCh := parser.NewCode()
go func() {
defer close(done)
Expand All @@ -272,7 +282,7 @@ func cmd(ctx context.Context) error {
}()
// no output to the user, we just execute the code
out = codeW
case *run:
case flagVals.Run:
codeW, blockCh := parser.NewCode()
go func() {
defer close(done)
Expand Down Expand Up @@ -318,16 +328,15 @@ func cmd(ctx context.Context) error {
}
pipeContent = string(pipeBytes)
}
if *chat {
if flagVals.Interactive {
in, err = os.Open("/dev/tty")
if err != nil {
return fmt.Errorf("failed to open /dev/tty: %w", err)
}
}
if flag.NArg() == 0 && pipeContent == "" {
if usrMsg == "" && pipeContent == "" {
return fmt.Errorf("what's your command?")
}
usrMsg := strings.Join(flag.Args(), " ")
if pipeContent != "" {
usrMsg = fmt.Sprintf(CONTEXT_TEMPLATE, pipeContent, usrMsg)
}
Expand All @@ -351,7 +360,7 @@ func cmd(ctx context.Context) error {
}

switch {
case *chat:
case flagVals.Interactive:
err = multiTurn(ctx, os.Stdout, in, usrMsg, turnFn)
default:
_, err = turnFn(ctx, os.Stdout, []*provider.Message{
Expand All @@ -365,7 +374,18 @@ func cmd(ctx context.Context) error {
}

func main() {
flag.Parse()
flagVals := &flagValues{}
parser := flags.NewParser(nil, flags.Default)
g, err := parser.AddGroup("Application Options", "", flagVals)
if err != nil {
panic(err)
}
unparsed, err := parser.Parse()
if err != nil {
panic(err)
}
usrMsg := strings.Join(unparsed, " ")
flagDefs := g.Options()

// read config first so we can use the right provider
cfg, err := config.ReadConfig()
Expand All @@ -387,7 +407,7 @@ func main() {
}

ctx := context.Background()
done, err := parseConfig(ctx)
done, err := parseConfig(ctx, flagDefs, flagVals)
if err != nil {
panic(err)
}
Expand All @@ -404,7 +424,7 @@ func main() {
prov = closer
}

err = cmd(ctx)
err = cmd(ctx, usrMsg, flagVals)
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ go 1.23.0
require (
github.com/cohere-ai/cohere-go/v2 v2.7.0
github.com/fatih/color v1.17.0
github.com/jessevdk/go-flags v1.6.1
github.com/sashabaranov/go-openai v1.29.0
)

require (
github.com/google/uuid v1.4.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/sys v0.21.0 // indirect
)
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4=
github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI=
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jessevdk/go-flags v1.6.1 h1:Cvu5U8UGrLay1rZfv/zP7iLpSHGUZ/Ou68T0iX1bBK4=
github.com/jessevdk/go-flags v1.6.1/go.mod h1:Mk8T1hIAWpOiJiHa9rJASDK2UGWji0EuPGBnNLMooyc=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
Expand All @@ -19,7 +21,7 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

0 comments on commit 88a9091

Please sign in to comment.