Skip to content

Commit

Permalink
fix: refactor out cohere into separate provider
Browse files Browse the repository at this point in the history
  • Loading branch information
daulet committed Aug 25, 2024
1 parent 31f0578 commit 20523ca
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 45 deletions.
70 changes: 25 additions & 45 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@ import (
"github.com/daulet/llm-cli/cohere"
"github.com/daulet/llm-cli/config"
"github.com/daulet/llm-cli/parser"

co "github.com/cohere-ai/cohere-go/v2"
cocli "github.com/cohere-ai/cohere-go/v2/client"
)

const apiKeyEnvVar = "COHERE_API_KEY"

var (
client *cocli.Client
cfg *config.Config
provider cohere.Provider
cfg *config.Config

// config flags
showConfig = flag.Bool("config", false, "Show current config.")
Expand All @@ -47,11 +44,11 @@ func multiTurn(
ctx context.Context,
out io.WriteCloser,
in io.Reader,
turnFn func(context.Context, io.WriteCloser, []*co.ChatMessage) (string, error),
turnFn func(context.Context, io.WriteCloser, []*cohere.Message) (string, error),
) error {
var (
r = bufio.NewScanner(in)
msgs []*co.ChatMessage
msgs []*cohere.Message
)
for {
select {
Expand All @@ -67,9 +64,9 @@ func multiTurn(
userMsg := r.Text()

msgs = append(msgs,
&co.ChatMessage{
Role: co.ChatMessageRoleUser,
Message: userMsg,
&cohere.Message{
Role: cohere.User,
Content: userMsg,
})

botMsg, err := turnFn(ctx, out, msgs)
Expand All @@ -78,9 +75,9 @@ func multiTurn(
}

msgs = append(msgs,
&co.ChatMessage{
Role: co.ChatMessageRoleChatbot,
Message: botMsg,
&cohere.Message{
Role: cohere.Assistant,
Content: botMsg,
},
)
}
Expand All @@ -89,32 +86,17 @@ func multiTurn(
func generate(
ctx context.Context,
out io.WriteCloser,
msgs []*co.ChatMessage,
msgs []*cohere.Message,
) (string, error) {
buf := bytes.NewBuffer(nil)
req := &co.ChatStreamRequest{
ChatHistory: msgs[:len(msgs)-1],
Message: msgs[len(msgs)-1].Message,

Model: cfg.Model,
Temperature: cfg.Temperature,
P: cfg.TopP,
K: cfg.TopK,
FrequencyPenalty: cfg.FrequencyPenalty,
PresencePenalty: cfg.PresencePenalty,
}
for _, connector := range cfg.Connectors {
req.Connectors = append(req.Connectors, &co.ChatConnector{Id: connector})
}
stream, err := client.ChatStream(ctx, req)
reader, err := provider.Stream(ctx, cfg, msgs)
if err != nil {
return "", err
}
_, err = io.Copy(parser.MultiWriter(out, buf), cohere.ReadFrom(stream))
buf := bytes.NewBuffer(nil)
_, err = io.Copy(parser.MultiWriter(out, buf), reader)
if err != nil {
return "", err
}
stream.Close()
out.Write([]byte("\n"))
return buf.String(), nil
}
Expand Down Expand Up @@ -183,29 +165,27 @@ func parseConfig(ctx context.Context) (bool, error) {
}

if *listModels {
resp, err := client.Models.List(ctx, &co.ModelsListRequest{
Endpoint: (*co.CompatibleEndpoint)(co.String(string(co.CompatibleEndpointChat))),
})
modelNames, err := provider.ListModels(ctx)
if err != nil {
return false, err
}
fmt.Println("Available models:")
for _, model := range resp.Models {
fmt.Println(*model.Name)
for _, model := range modelNames {
fmt.Println(model)
}
fmt.Println()
fmt.Printf("Currently selected model: %s\n", *cfg.Model)
return true, nil
}

if *listConnectors {
resp, err := client.Connectors.List(ctx, &co.ConnectorsListRequest{})
connectorIDs, err := provider.ListConnectors(ctx)
if err != nil {
return false, err
}
fmt.Println("Available connectors:")
for _, connector := range resp.Connectors {
fmt.Println(connector.Id)
for _, connectorID := range connectorIDs {
fmt.Println(connectorID)
}
fmt.Println()
fmt.Printf("Currently selected connectors: %s\n", cfg.Connectors)
Expand Down Expand Up @@ -257,7 +237,7 @@ func parseConfig(ctx context.Context) (bool, error) {
}

func cmd(ctx context.Context) error {
turnFn := func(ctx context.Context, out io.WriteCloser, msgs []*co.ChatMessage) (string, error) {
turnFn := func(ctx context.Context, out io.WriteCloser, msgs []*cohere.Message) (string, error) {
var blocks []*parser.CodeBlock
done := make(chan struct{})

Expand Down Expand Up @@ -322,10 +302,10 @@ func cmd(ctx context.Context) error {
case *chat:
err = multiTurn(ctx, os.Stdout, os.Stdin, turnFn)
default:
_, err = turnFn(ctx, os.Stdout, []*co.ChatMessage{
_, err = turnFn(ctx, os.Stdout, []*cohere.Message{
{
Role: co.ChatMessageRoleUser,
Message: usrMsg,
Role: cohere.User,
Content: usrMsg,
},
})
}
Expand All @@ -335,7 +315,7 @@ func cmd(ctx context.Context) error {
func main() {
flag.Parse()

client = cocli.NewClient(cocli.WithToken(os.Getenv(apiKeyEnvVar)))
provider = cohere.NewCohereProvider(os.Getenv(apiKeyEnvVar))

ctx := context.Background()
done, err := parseConfig(ctx)
Expand Down
106 changes: 106 additions & 0 deletions cohere/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package cohere

import (
"context"
"io"
"log"

"github.com/daulet/llm-cli/config"

co "github.com/cohere-ai/cohere-go/v2"
cocli "github.com/cohere-ai/cohere-go/v2/client"
)

type Role string

const (
User Role = "user"
Assistant Role = "assistant"
)

type Message struct {
Role Role
Content string
}

type Provider interface {
Stream(ctx context.Context, cfg *config.Config, msgs []*Message) (io.Reader, error)
ListModels(ctx context.Context) ([]string, error)
ListConnectors(ctx context.Context) ([]string, error)
}

var _ Provider = (*CohereProvider)(nil)

func NewCohereProvider(apiKey string) *CohereProvider {
return &CohereProvider{client: cocli.NewClient(cocli.WithToken(apiKey))}
}

type CohereProvider struct {
client *cocli.Client
}

func (p *CohereProvider) Stream(ctx context.Context, cfg *config.Config, msgs []*Message) (io.Reader, error) {
var messages []*co.ChatMessage
for _, msg := range msgs {
switch msg.Role {
case User:
messages = append(messages, &co.ChatMessage{
Role: co.ChatMessageRoleUser,
Message: msg.Content,
})
case Assistant:
messages = append(messages, &co.ChatMessage{
Role: co.ChatMessageRoleChatbot,
Message: msg.Content,
})
default:
log.Fatalf("unknown role: %s", msg.Role)
}
}
req := &co.ChatStreamRequest{
ChatHistory: messages[:len(messages)-1],
Message: messages[len(messages)-1].Message,

Model: cfg.Model,
Temperature: cfg.Temperature,
P: cfg.TopP,
K: cfg.TopK,
FrequencyPenalty: cfg.FrequencyPenalty,
PresencePenalty: cfg.PresencePenalty,
}
for _, connector := range cfg.Connectors {
req.Connectors = append(req.Connectors, &co.ChatConnector{Id: connector})
}
stream, err := p.client.ChatStream(ctx, req)
if err != nil {
return nil, err
}
// TODO stream.Close()
return ReadFrom(stream), nil
}

func (p *CohereProvider) ListModels(ctx context.Context) ([]string, error) {
resp, err := p.client.Models.List(ctx, &co.ModelsListRequest{
Endpoint: (*co.CompatibleEndpoint)(co.String(string(co.CompatibleEndpointChat))),
})
if err != nil {
return nil, err
}
var modelNames []string
for _, model := range resp.Models {
modelNames = append(modelNames, *model.Name)
}
return modelNames, nil
}

func (p *CohereProvider) ListConnectors(ctx context.Context) ([]string, error) {
resp, err := p.client.Connectors.List(ctx, &co.ConnectorsListRequest{})
if err != nil {
return nil, err
}
var connectorNames []string
for _, connector := range resp.Connectors {
connectorNames = append(connectorNames, connector.Id)
}
return connectorNames, nil
}
1 change: 1 addition & 0 deletions parser/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
)

// multiWriter exists because std doesn't implement Closer
type multiWriter struct {
writers []io.Writer
}
Expand Down

0 comments on commit 20523ca

Please sign in to comment.