Skip to content

Commit

Permalink
fix: rename cohere package into common provider
Browse files Browse the repository at this point in the history
  • Loading branch information
daulet committed Aug 25, 2024
1 parent 20523ca commit f3b3a15
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
34 changes: 17 additions & 17 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import (
"os/exec"
"strings"

"github.com/daulet/llm-cli/cohere"
"github.com/daulet/llm-cli/config"
"github.com/daulet/llm-cli/parser"
"github.com/daulet/llm-cli/provider"
)

const apiKeyEnvVar = "COHERE_API_KEY"

var (
provider cohere.Provider
cfg *config.Config
prov provider.Provider
cfg *config.Config

// config flags
showConfig = flag.Bool("config", false, "Show current config.")
Expand All @@ -44,11 +44,11 @@ func multiTurn(
ctx context.Context,
out io.WriteCloser,
in io.Reader,
turnFn func(context.Context, io.WriteCloser, []*cohere.Message) (string, error),
turnFn func(context.Context, io.WriteCloser, []*provider.Message) (string, error),
) error {
var (
r = bufio.NewScanner(in)
msgs []*cohere.Message
msgs []*provider.Message
)
for {
select {
Expand All @@ -64,8 +64,8 @@ func multiTurn(
userMsg := r.Text()

msgs = append(msgs,
&cohere.Message{
Role: cohere.User,
&provider.Message{
Role: provider.User,
Content: userMsg,
})

Expand All @@ -75,8 +75,8 @@ func multiTurn(
}

msgs = append(msgs,
&cohere.Message{
Role: cohere.Assistant,
&provider.Message{
Role: provider.Assistant,
Content: botMsg,
},
)
Expand All @@ -86,9 +86,9 @@ func multiTurn(
func generate(
ctx context.Context,
out io.WriteCloser,
msgs []*cohere.Message,
msgs []*provider.Message,
) (string, error) {
reader, err := provider.Stream(ctx, cfg, msgs)
reader, err := prov.Stream(ctx, cfg, msgs)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -165,7 +165,7 @@ func parseConfig(ctx context.Context) (bool, error) {
}

if *listModels {
modelNames, err := provider.ListModels(ctx)
modelNames, err := prov.ListModels(ctx)
if err != nil {
return false, err
}
Expand All @@ -179,7 +179,7 @@ func parseConfig(ctx context.Context) (bool, error) {
}

if *listConnectors {
connectorIDs, err := provider.ListConnectors(ctx)
connectorIDs, err := prov.ListConnectors(ctx)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -237,7 +237,7 @@ func parseConfig(ctx context.Context) (bool, error) {
}

func cmd(ctx context.Context) error {
turnFn := func(ctx context.Context, out io.WriteCloser, msgs []*cohere.Message) (string, error) {
turnFn := func(ctx context.Context, out io.WriteCloser, msgs []*provider.Message) (string, error) {
var blocks []*parser.CodeBlock
done := make(chan struct{})

Expand Down Expand Up @@ -302,9 +302,9 @@ func cmd(ctx context.Context) error {
case *chat:
err = multiTurn(ctx, os.Stdout, os.Stdin, turnFn)
default:
_, err = turnFn(ctx, os.Stdout, []*cohere.Message{
_, err = turnFn(ctx, os.Stdout, []*provider.Message{
{
Role: cohere.User,
Role: provider.User,
Content: usrMsg,
},
})
Expand All @@ -315,7 +315,7 @@ func cmd(ctx context.Context) error {
func main() {
flag.Parse()

provider = cohere.NewCohereProvider(os.Getenv(apiKeyEnvVar))
prov = provider.NewCohereProvider(os.Getenv(apiKeyEnvVar))

ctx := context.Background()
done, err := parseConfig(ctx)
Expand Down
2 changes: 1 addition & 1 deletion cohere/chat.go → provider/chat.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package cohere
package provider

import (
"context"
Expand Down
2 changes: 1 addition & 1 deletion cohere/reader.go → provider/reader.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package cohere
package provider

import (
"io"
Expand Down

0 comments on commit f3b3a15

Please sign in to comment.