diff --git a/cmd/main.go b/cmd/main.go index 3239879..0576312 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -15,16 +15,13 @@ import ( "github.com/daulet/llm-cli/cohere" "github.com/daulet/llm-cli/config" "github.com/daulet/llm-cli/parser" - - co "github.com/cohere-ai/cohere-go/v2" - cocli "github.com/cohere-ai/cohere-go/v2/client" ) const apiKeyEnvVar = "COHERE_API_KEY" var ( - client *cocli.Client - cfg *config.Config + provider cohere.Provider + cfg *config.Config // config flags showConfig = flag.Bool("config", false, "Show current config.") @@ -47,11 +44,11 @@ func multiTurn( ctx context.Context, out io.WriteCloser, in io.Reader, - turnFn func(context.Context, io.WriteCloser, []*co.ChatMessage) (string, error), + turnFn func(context.Context, io.WriteCloser, []*cohere.Message) (string, error), ) error { var ( r = bufio.NewScanner(in) - msgs []*co.ChatMessage + msgs []*cohere.Message ) for { select { @@ -67,9 +64,9 @@ func multiTurn( userMsg := r.Text() msgs = append(msgs, - &co.ChatMessage{ - Role: co.ChatMessageRoleUser, - Message: userMsg, + &cohere.Message{ + Role: cohere.User, + Content: userMsg, }) botMsg, err := turnFn(ctx, out, msgs) @@ -78,9 +75,9 @@ func multiTurn( } msgs = append(msgs, - &co.ChatMessage{ - Role: co.ChatMessageRoleChatbot, - Message: botMsg, + &cohere.Message{ + Role: cohere.Assistant, + Content: botMsg, }, ) } @@ -89,32 +86,17 @@ func multiTurn( func generate( ctx context.Context, out io.WriteCloser, - msgs []*co.ChatMessage, + msgs []*cohere.Message, ) (string, error) { - buf := bytes.NewBuffer(nil) - req := &co.ChatStreamRequest{ - ChatHistory: msgs[:len(msgs)-1], - Message: msgs[len(msgs)-1].Message, - - Model: cfg.Model, - Temperature: cfg.Temperature, - P: cfg.TopP, - K: cfg.TopK, - FrequencyPenalty: cfg.FrequencyPenalty, - PresencePenalty: cfg.PresencePenalty, - } - for _, connector := range cfg.Connectors { - req.Connectors = append(req.Connectors, &co.ChatConnector{Id: connector}) - } - stream, err := client.ChatStream(ctx, req) + reader, err := provider.Stream(ctx, cfg, msgs) if err != nil { return "", err } - _, err = io.Copy(parser.MultiWriter(out, buf), cohere.ReadFrom(stream)) + buf := bytes.NewBuffer(nil) + _, err = io.Copy(parser.MultiWriter(out, buf), reader) if err != nil { return "", err } - stream.Close() out.Write([]byte("\n")) return buf.String(), nil } @@ -183,15 +165,13 @@ func parseConfig(ctx context.Context) (bool, error) { } if *listModels { - resp, err := client.Models.List(ctx, &co.ModelsListRequest{ - Endpoint: (*co.CompatibleEndpoint)(co.String(string(co.CompatibleEndpointChat))), - }) + modelNames, err := provider.ListModels(ctx) if err != nil { return false, err } fmt.Println("Available models:") - for _, model := range resp.Models { - fmt.Println(*model.Name) + for _, model := range modelNames { + fmt.Println(model) } fmt.Println() fmt.Printf("Currently selected model: %s\n", *cfg.Model) @@ -199,13 +179,13 @@ func parseConfig(ctx context.Context) (bool, error) { } if *listConnectors { - resp, err := client.Connectors.List(ctx, &co.ConnectorsListRequest{}) + connectorIDs, err := provider.ListConnectors(ctx) if err != nil { return false, err } fmt.Println("Available connectors:") - for _, connector := range resp.Connectors { - fmt.Println(connector.Id) + for _, connectorID := range connectorIDs { + fmt.Println(connectorID) } fmt.Println() fmt.Printf("Currently selected connectors: %s\n", cfg.Connectors) @@ -257,7 +237,7 @@ func parseConfig(ctx context.Context) (bool, error) { } func cmd(ctx context.Context) error { - turnFn := func(ctx context.Context, out io.WriteCloser, msgs []*co.ChatMessage) (string, error) { + turnFn := func(ctx context.Context, out io.WriteCloser, msgs []*cohere.Message) (string, error) { var blocks []*parser.CodeBlock done := make(chan struct{}) @@ -322,10 +302,10 @@ func cmd(ctx context.Context) error { case *chat: err = multiTurn(ctx, os.Stdout, os.Stdin, turnFn) default: - _, err = turnFn(ctx, os.Stdout, []*co.ChatMessage{ + _, err = turnFn(ctx, os.Stdout, []*cohere.Message{ { - Role: co.ChatMessageRoleUser, - Message: usrMsg, + Role: cohere.User, + Content: usrMsg, }, }) } @@ -335,7 +315,7 @@ func cmd(ctx context.Context) error { func main() { flag.Parse() - client = cocli.NewClient(cocli.WithToken(os.Getenv(apiKeyEnvVar))) + provider = cohere.NewCohereProvider(os.Getenv(apiKeyEnvVar)) ctx := context.Background() done, err := parseConfig(ctx) diff --git a/cohere/chat.go b/cohere/chat.go new file mode 100644 index 0000000..929282a --- /dev/null +++ b/cohere/chat.go @@ -0,0 +1,106 @@ +package cohere + +import ( + "context" + "io" + "log" + + "github.com/daulet/llm-cli/config" + + co "github.com/cohere-ai/cohere-go/v2" + cocli "github.com/cohere-ai/cohere-go/v2/client" +) + +type Role string + +const ( + User Role = "user" + Assistant Role = "assistant" +) + +type Message struct { + Role Role + Content string +} + +type Provider interface { + Stream(ctx context.Context, cfg *config.Config, msgs []*Message) (io.Reader, error) + ListModels(ctx context.Context) ([]string, error) + ListConnectors(ctx context.Context) ([]string, error) +} + +var _ Provider = (*CohereProvider)(nil) + +func NewCohereProvider(apiKey string) *CohereProvider { + return &CohereProvider{client: cocli.NewClient(cocli.WithToken(apiKey))} +} + +type CohereProvider struct { + client *cocli.Client +} + +func (p *CohereProvider) Stream(ctx context.Context, cfg *config.Config, msgs []*Message) (io.Reader, error) { + var messages []*co.ChatMessage + for _, msg := range msgs { + switch msg.Role { + case User: + messages = append(messages, &co.ChatMessage{ + Role: co.ChatMessageRoleUser, + Message: msg.Content, + }) + case Assistant: + messages = append(messages, &co.ChatMessage{ + Role: co.ChatMessageRoleChatbot, + Message: msg.Content, + }) + default: + log.Fatalf("unknown role: %s", msg.Role) + } + } + req := &co.ChatStreamRequest{ + ChatHistory: messages[:len(messages)-1], + Message: messages[len(messages)-1].Message, + + Model: cfg.Model, + Temperature: cfg.Temperature, + P: cfg.TopP, + K: cfg.TopK, + FrequencyPenalty: cfg.FrequencyPenalty, + PresencePenalty: cfg.PresencePenalty, + } + for _, connector := range cfg.Connectors { + req.Connectors = append(req.Connectors, &co.ChatConnector{Id: connector}) + } + stream, err := p.client.ChatStream(ctx, req) + if err != nil { + return nil, err + } + // TODO stream.Close() + return ReadFrom(stream), nil +} + +func (p *CohereProvider) ListModels(ctx context.Context) ([]string, error) { + resp, err := p.client.Models.List(ctx, &co.ModelsListRequest{ + Endpoint: (*co.CompatibleEndpoint)(co.String(string(co.CompatibleEndpointChat))), + }) + if err != nil { + return nil, err + } + var modelNames []string + for _, model := range resp.Models { + modelNames = append(modelNames, *model.Name) + } + return modelNames, nil +} + +func (p *CohereProvider) ListConnectors(ctx context.Context) ([]string, error) { + resp, err := p.client.Connectors.List(ctx, &co.ConnectorsListRequest{}) + if err != nil { + return nil, err + } + var connectorNames []string + for _, connector := range resp.Connectors { + connectorNames = append(connectorNames, connector.Id) + } + return connectorNames, nil +} diff --git a/parser/multi.go b/parser/multi.go index 926c4da..56ae741 100644 --- a/parser/multi.go +++ b/parser/multi.go @@ -5,6 +5,7 @@ import ( "os" ) +// multiWriter exists because std doesn't implement Closer type multiWriter struct { writers []io.Writer }