From f3b3a15f6dbcfbe7b3812e5ea69bcaea85affd76 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Sun, 25 Aug 2024 16:32:07 -0700 Subject: [PATCH] fix: rename cohere package into common provider --- cmd/main.go | 34 +++++++++++++++++----------------- {cohere => provider}/chat.go | 2 +- {cohere => provider}/reader.go | 2 +- 3 files changed, 19 insertions(+), 19 deletions(-) rename {cohere => provider}/chat.go (99%) rename {cohere => provider}/reader.go (97%) diff --git a/cmd/main.go b/cmd/main.go index 0576312..40e78eb 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -12,16 +12,16 @@ import ( "os/exec" "strings" - "github.com/daulet/llm-cli/cohere" "github.com/daulet/llm-cli/config" "github.com/daulet/llm-cli/parser" + "github.com/daulet/llm-cli/provider" ) const apiKeyEnvVar = "COHERE_API_KEY" var ( - provider cohere.Provider - cfg *config.Config + prov provider.Provider + cfg *config.Config // config flags showConfig = flag.Bool("config", false, "Show current config.") @@ -44,11 +44,11 @@ func multiTurn( ctx context.Context, out io.WriteCloser, in io.Reader, - turnFn func(context.Context, io.WriteCloser, []*cohere.Message) (string, error), + turnFn func(context.Context, io.WriteCloser, []*provider.Message) (string, error), ) error { var ( r = bufio.NewScanner(in) - msgs []*cohere.Message + msgs []*provider.Message ) for { select { @@ -64,8 +64,8 @@ func multiTurn( userMsg := r.Text() msgs = append(msgs, - &cohere.Message{ - Role: cohere.User, + &provider.Message{ + Role: provider.User, Content: userMsg, }) @@ -75,8 +75,8 @@ func multiTurn( } msgs = append(msgs, - &cohere.Message{ - Role: cohere.Assistant, + &provider.Message{ + Role: provider.Assistant, Content: botMsg, }, ) @@ -86,9 +86,9 @@ func multiTurn( func generate( ctx context.Context, out io.WriteCloser, - msgs []*cohere.Message, + msgs []*provider.Message, ) (string, error) { - reader, err := provider.Stream(ctx, cfg, msgs) + reader, err := prov.Stream(ctx, cfg, msgs) if err != nil { return "", err } @@ -165,7 +165,7 @@ func parseConfig(ctx context.Context) (bool, error) { } if *listModels { - modelNames, err := provider.ListModels(ctx) + modelNames, err := prov.ListModels(ctx) if err != nil { return false, err } @@ -179,7 +179,7 @@ func parseConfig(ctx context.Context) (bool, error) { } if *listConnectors { - connectorIDs, err := provider.ListConnectors(ctx) + connectorIDs, err := prov.ListConnectors(ctx) if err != nil { return false, err } @@ -237,7 +237,7 @@ func parseConfig(ctx context.Context) (bool, error) { } func cmd(ctx context.Context) error { - turnFn := func(ctx context.Context, out io.WriteCloser, msgs []*cohere.Message) (string, error) { + turnFn := func(ctx context.Context, out io.WriteCloser, msgs []*provider.Message) (string, error) { var blocks []*parser.CodeBlock done := make(chan struct{}) @@ -302,9 +302,9 @@ func cmd(ctx context.Context) error { case *chat: err = multiTurn(ctx, os.Stdout, os.Stdin, turnFn) default: - _, err = turnFn(ctx, os.Stdout, []*cohere.Message{ + _, err = turnFn(ctx, os.Stdout, []*provider.Message{ { - Role: cohere.User, + Role: provider.User, Content: usrMsg, }, }) @@ -315,7 +315,7 @@ func cmd(ctx context.Context) error { func main() { flag.Parse() - provider = cohere.NewCohereProvider(os.Getenv(apiKeyEnvVar)) + prov = provider.NewCohereProvider(os.Getenv(apiKeyEnvVar)) ctx := context.Background() done, err := parseConfig(ctx) diff --git a/cohere/chat.go b/provider/chat.go similarity index 99% rename from cohere/chat.go rename to provider/chat.go index 929282a..0e0ebb5 100644 --- a/cohere/chat.go +++ b/provider/chat.go @@ -1,4 +1,4 @@ -package cohere +package provider import ( "context" diff --git a/cohere/reader.go b/provider/reader.go similarity index 97% rename from cohere/reader.go rename to provider/reader.go index 733e793..3d84fe3 100644 --- a/cohere/reader.go +++ b/provider/reader.go @@ -1,4 +1,4 @@ -package cohere +package provider import ( "io"