diff --git a/aichat.go b/aichat.go index 9ac415c..c699545 100644 --- a/aichat.go +++ b/aichat.go @@ -125,11 +125,13 @@ func main() { var verbose = false var listPrompts = false var nonStreaming = false + var split = false getopt.FlagLong(&temperature, "temperature", 't', "temperature") getopt.FlagLong(&maxTokens, "max-tokens", 'm', "max tokens") getopt.FlagLong(&verbose, "verbose", 'v', "verbose output") getopt.FlagLong(&listPrompts, "list-prompts", 'l', "list prompts") getopt.FlagLong(&nonStreaming, "non-streaming", 0, "non streaming mode") + getopt.FlagLong(&split, "split", 0, "split input") getopt.Parse() if listPrompts { @@ -171,40 +173,67 @@ func main() { log.Fatalf("prompt %q not found", args[0]) } // read all from Stdin - scanner := bufio.NewScanner(os.Stdin) - input := "" - for scanner.Scan() { - input += scanner.Text() + "\n" - } - messages := prompt.CreateMessages(input) - if verbose { - log.Printf("messages: %+v", messages) - } - cnt, err := CountTokens(mapSlice(messages, func(m gogpt.ChatCompletionMessage) string { return m.Content })) - if err != nil { - log.Fatal(err) - } - if cnt > 4096 { - log.Fatalf("total tokens %d exceeds 4096", cnt) - } - request := gogpt.ChatCompletionRequest{ - Model: gogpt.GPT3Dot5Turbo, - Messages: messages, - Temperature: firstNonZeroFloat32(prompt.Temperature, aiChat.options.temperature), - MaxTokens: firstNonZeroInt(prompt.MaxTokens, aiChat.options.maxTokens), - } - if aiChat.options.nonStreaming { - err = nonStreamCompletion(aiChat.client, request, os.Stdout) + input := scanAll(bufio.NewScanner(os.Stdin)) + + var messagesSlice [][]gogpt.ChatCompletionMessage + + if split { + messagesSlice, err = prompt.CreateMessagesWithSplit(input, 0) // TODO pass maxTokens if it is specified via command line flag + if err != nil { + log.Fatal(err) + } + if verbose { + log.Printf("messages was split to %d parts", len(messagesSlice)) + + } } else { - err = streamCompletion(aiChat.client, request, os.Stdout) + messages := prompt.CreateMessages(input) + if verbose { + log.Printf("messages: %+v", messagesSlice) + } + messagesSlice = [][]gogpt.ChatCompletionMessage{messages} } - if err != nil { - log.Fatal(err) + + for _, messages := range messagesSlice { + + maxTokens := firstNonZeroInt(prompt.MaxTokens, aiChat.options.maxTokens) + + request := gogpt.ChatCompletionRequest{ + Model: gogpt.GPT3Dot5Turbo, + Messages: messages, + Temperature: firstNonZeroFloat32(prompt.Temperature, aiChat.options.temperature), + MaxTokens: maxTokens, + } + + cnt, err := CountTokens(mapSlice(messages, func(m gogpt.ChatCompletionMessage) string { return m.Content })) + if err != nil { + log.Fatal(err) + } + if cnt > 4096 { + log.Fatalf("total tokens %d exceeds 4096", cnt) + } + + if aiChat.options.nonStreaming { + err = nonStreamCompletion(aiChat.client, request, os.Stdout) + } else { + err = streamCompletion(aiChat.client, request, os.Stdout) + } + if err != nil { + log.Fatal(err) + } } } } +func scanAll(scanner *bufio.Scanner) string { + input := "" + for scanner.Scan() { + input += scanner.Text() + "\n" + } + return input +} + // mapSlice maps a slice of type T to a slice of type M using the function f. func mapSlice[T any, M any](a []T, f func(T) M) []M { r := make([]M, len(a)) diff --git a/prompt.go b/prompt.go index 4e9d4a2..a1bdf87 100644 --- a/prompt.go +++ b/prompt.go @@ -6,6 +6,7 @@ import ( "path/filepath" "strings" + tokenizer "github.com/samber/go-gpt-3-encoder" gogpt "github.com/sashabaranov/go-gpt3" "gopkg.in/yaml.v3" ) @@ -37,6 +38,76 @@ func (p *Prompt) CreateMessages(input string) []gogpt.ChatCompletionMessage { return messages } +// CountTokens counts the number of tokens in the prompt +func (p *Prompt) CountTokens() (int, error) { + count := 0 + encoder, err := tokenizer.NewEncoder() + if err != nil { + return 0, err + } + for _, message := range p.Messages { + // Encode string with GPT tokenizer + encoded, err := encoder.Encode(message.Content) + if err != nil { + return 0, err + } + count += len(encoded) + } + return count, nil +} + +// AllowedInputTokens returns the number of tokens allowed for the input +func (p *Prompt) AllowedInputTokens(maxTokensOverride int) (int, error) { + maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens) + promptTokens, err := p.CountTokens() + if err != nil { + return 0, err + } + return maxTokens - promptTokens, nil +} + +func splitStringWithTokensLimit(s string, tokensLimit int) ([]string, error) { + encoder, err := tokenizer.NewEncoder() + if err != nil { + return nil, err + } + encoded, err := encoder.Encode(s) + if err != nil { + return nil, err + } + var parts []string + for { + if len(encoded) == 0 { + break + } + if len(encoded) <= tokensLimit { + parts = append(parts, encoder.Decode(encoded)) + break + } + parts = append(parts, encoder.Decode(encoded[:tokensLimit])) + encoded = encoded[tokensLimit:] + } + return parts, nil +} + +func (p *Prompt) CreateMessagesWithSplit(input string, maxTokensOverride int) ([][]gogpt.ChatCompletionMessage, error) { + maxTokens := firstNonZeroInt(maxTokensOverride, p.MaxTokens) + promptTokens, err := p.CountTokens() + if err != nil { + return nil, err + } + allowedInputTokens := maxTokens - promptTokens + inputParts, err := splitStringWithTokensLimit(input, allowedInputTokens) + if err != nil { + return nil, err + } + messages := [][]gogpt.ChatCompletionMessage{} + for _, inputPart := range inputParts { + messages = append(messages, p.CreateMessages(inputPart)) + } + return messages, nil +} + func NewPromptFromFile(filename string) (*Prompt, error) { prompt := &Prompt{} if err := ReadYamlFromFile(filename, prompt); err != nil { diff --git a/prompt_test.go b/prompt_test.go index 0bab13b..aeb82e6 100644 --- a/prompt_test.go +++ b/prompt_test.go @@ -14,3 +14,20 @@ func TestLoadPrompts(t *testing.T) { t.Errorf("expected description, got empty string") } } + +func TestSplitStringWithTokensLimit(t *testing.T) { + str := "Hello, world!" + tokens, err := splitStringWithTokensLimit(str, 2) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if len(tokens) != 2 { + t.Errorf("expected 2 tokens, got %d", len(tokens)) + } + if tokens[0] != "Hello," { + t.Errorf("expected 'Hello,', got %q", tokens[0]) + } + if tokens[1] != " world!" { + t.Errorf("expected 'world!', got %q", tokens[1]) + } +}