Skip to content

Commit

Permalink
Added --split flag
Browse files Browse the repository at this point in the history
  • Loading branch information
tkawachi committed Mar 12, 2023
1 parent 74375ec commit 3e1b510
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 27 deletions.
83 changes: 56 additions & 27 deletions aichat.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,13 @@ func main() {
var verbose = false
var listPrompts = false
var nonStreaming = false
var split = false
getopt.FlagLong(&temperature, "temperature", 't', "temperature")
getopt.FlagLong(&maxTokens, "max-tokens", 'm', "max tokens")
getopt.FlagLong(&verbose, "verbose", 'v', "verbose output")
getopt.FlagLong(&listPrompts, "list-prompts", 'l', "list prompts")
getopt.FlagLong(&nonStreaming, "non-streaming", 0, "non streaming mode")
getopt.FlagLong(&split, "split", 0, "split input")
getopt.Parse()

if listPrompts {
Expand Down Expand Up @@ -171,40 +173,67 @@ func main() {
log.Fatalf("prompt %q not found", args[0])
}
// read all from Stdin
scanner := bufio.NewScanner(os.Stdin)
input := ""
for scanner.Scan() {
input += scanner.Text() + "\n"
}
messages := prompt.CreateMessages(input)
if verbose {
log.Printf("messages: %+v", messages)
}
cnt, err := CountTokens(mapSlice(messages, func(m gogpt.ChatCompletionMessage) string { return m.Content }))
if err != nil {
log.Fatal(err)
}
if cnt > 4096 {
log.Fatalf("total tokens %d exceeds 4096", cnt)
}
request := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
Messages: messages,
Temperature: firstNonZeroFloat32(prompt.Temperature, aiChat.options.temperature),
MaxTokens: firstNonZeroInt(prompt.MaxTokens, aiChat.options.maxTokens),
}
if aiChat.options.nonStreaming {
err = nonStreamCompletion(aiChat.client, request, os.Stdout)
input := scanAll(bufio.NewScanner(os.Stdin))

var messagesSlice [][]gogpt.ChatCompletionMessage

if split {
messagesSlice, err = prompt.CreateMessagesWithSplit(input, 0) // TODO pass maxTokens if it is specified via command line flag
if err != nil {
log.Fatal(err)
}
if verbose {
log.Printf("messages was split to %d parts", len(messagesSlice))

}
} else {
err = streamCompletion(aiChat.client, request, os.Stdout)
messages := prompt.CreateMessages(input)
if verbose {
log.Printf("messages: %+v", messagesSlice)
}
messagesSlice = [][]gogpt.ChatCompletionMessage{messages}
}
if err != nil {
log.Fatal(err)

for _, messages := range messagesSlice {

maxTokens := firstNonZeroInt(prompt.MaxTokens, aiChat.options.maxTokens)

request := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
Messages: messages,
Temperature: firstNonZeroFloat32(prompt.Temperature, aiChat.options.temperature),
MaxTokens: maxTokens,
}

cnt, err := CountTokens(mapSlice(messages, func(m gogpt.ChatCompletionMessage) string { return m.Content }))
if err != nil {
log.Fatal(err)
}
if cnt > 4096 {
log.Fatalf("total tokens %d exceeds 4096", cnt)
}

if aiChat.options.nonStreaming {
err = nonStreamCompletion(aiChat.client, request, os.Stdout)
} else {
err = streamCompletion(aiChat.client, request, os.Stdout)
}
if err != nil {
log.Fatal(err)
}
}
}

}

func scanAll(scanner *bufio.Scanner) string {
input := ""
for scanner.Scan() {
input += scanner.Text() + "\n"
}
return input
}

// mapSlice maps a slice of type T to a slice of type M using the function f.
func mapSlice[T any, M any](a []T, f func(T) M) []M {
r := make([]M, len(a))
Expand Down
71 changes: 71 additions & 0 deletions prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"path/filepath"
"strings"

tokenizer "github.com/samber/go-gpt-3-encoder"
gogpt "github.com/sashabaranov/go-gpt3"
"gopkg.in/yaml.v3"
)
Expand Down Expand Up @@ -37,6 +38,76 @@ func (p *Prompt) CreateMessages(input string) []gogpt.ChatCompletionMessage {
return messages
}

// CountTokens counts the number of tokens in the prompt
func (p *Prompt) CountTokens() (int, error) {
count := 0
encoder, err := tokenizer.NewEncoder()
if err != nil {
return 0, err
}
for _, message := range p.Messages {
// Encode string with GPT tokenizer
encoded, err := encoder.Encode(message.Content)
if err != nil {
return 0, err
}
count += len(encoded)
}
return count, nil
}

// AllowedInputTokens returns the number of tokens allowed for the input
func (p *Prompt) AllowedInputTokens(maxTokensOverride int) (int, error) {
maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens)
promptTokens, err := p.CountTokens()
if err != nil {
return 0, err
}
return maxTokens - promptTokens, nil
}

func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) {
encoder, err := tokenizer.NewEncoder()
if err != nil {
return nil, err
}
encoded, err := encoder.Encode(s)
if err != nil {
return nil, err
}
var parts []string
for {
if len(encoded) == 0 {
break
}
if len(encoded) <= tokensLimit {
parts = append(parts, encoder.Decode(encoded))
break
}
parts = append(parts, encoder.Decode(encoded[:tokensLimit]))
encoded = encoded[tokensLimit:]
}
return parts, nil
}

func (p *Prompt) CreateMessagesWithSplit(input string, maxTokensOverride int) ([][]gogpt.ChatCompletionMessage, error) {
maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens)
promptTokens, err := p.CountTokens()
if err != nil {
return nil, err
}
allowedInputTokens := maxTokens - promptTokens
inputParts, err := splitStringWithTokensLimit(input, allowedInputTokens)
if err != nil {
return nil, err
}
messages := [][]gogpt.ChatCompletionMessage{}
for _, inputPart := range inputParts {
messages = append(messages, p.CreateMessages(inputPart))
}
return messages, nil
}

func NewPromptFromFile(filename string) (*Prompt, error) {
prompt := &Prompt{}
if err := ReadYamlFromFile(filename, prompt); err != nil {
Expand Down
17 changes: 17 additions & 0 deletions prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,20 @@ func TestLoadPrompts(t *testing.T) {
t.Errorf("expected description, got empty string")
}
}

func TestSplitStringWithTokensLimit(t *testing.T) {
str := "Hello, world!"
tokens, err := splitStringWithTokensLimit(str, 2)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(tokens) != 2 {
t.Errorf("expected 2 tokens, got %d", len(tokens))
}
if tokens[0] != "Hello," {
t.Errorf("expected 'Hello,', got %q", tokens[0])
}
if tokens[1] != " world!" {
t.Errorf("expected 'world!', got %q", tokens[1])
}
}

0 comments on commit 3e1b510

Please sign in to comment.