diff --git a/aichat.go b/aichat.go index 0915e12..9ac415c 100644 --- a/aichat.go +++ b/aichat.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/pborman/getopt/v2" + tokenizer "github.com/samber/go-gpt-3-encoder" gogpt "github.com/sashabaranov/go-gpt3" ) @@ -179,6 +180,13 @@ func main() { if verbose { log.Printf("messages: %+v", messages) } + cnt, err := CountTokens(mapSlice(messages, func(m gogpt.ChatCompletionMessage) string { return m.Content })) + if err != nil { + log.Fatal(err) + } + if cnt > 4096 { + log.Fatalf("total tokens %d exceeds 4096", cnt) + } request := gogpt.ChatCompletionRequest{ Model: gogpt.GPT3Dot5Turbo, Messages: messages, @@ -196,3 +204,30 @@ func main() { } } + +// mapSlice maps a slice of type T to a slice of type M using the function f. +func mapSlice[T any, M any](a []T, f func(T) M) []M { + r := make([]M, len(a)) + for i, v := range a { + r[i] = f(v) + } + return r +} + +// CountTokens returns the number of tokens in the messages. +func CountTokens(messages []string) (int, error) { + count := 0 + encoder, err := tokenizer.NewEncoder() + if err != nil { + return 0, err + } + for _, message := range messages { + // Encode string with GPT tokenizer + encoded, err := encoder.Encode(message) + if err != nil { + return 0, err + } + count += len(encoded) + } + return count, nil +} diff --git a/aichat_test.go b/aichat_test.go new file mode 100644 index 0000000..faecbb7 --- /dev/null +++ b/aichat_test.go @@ -0,0 +1,19 @@ +package main + +import ( + "testing" +) + +func TestCountTokens(t *testing.T) { + messages := []string{ + "Hello, world!", + "How are you?", + } + count, err := CountTokens(messages) + if err != nil { + t.Errorf("CountTokens() returned an error: %v", err) + } + if count != 8 { + t.Errorf("CountTokens() returned %d, expected 8", count) + } +} diff --git a/go.mod b/go.mod index 391add9..433d78b 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,13 @@ go 1.20 require ( github.com/pborman/getopt/v2 v2.1.0 + github.com/samber/go-gpt-3-encoder v0.3.1 github.com/sashabaranov/go-gpt3 v1.4.0 gopkg.in/yaml.v3 v3.0.1 ) + +require ( + github.com/dlclark/regexp2 v1.8.1 // indirect + github.com/samber/lo v1.37.0 // indirect + golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 // indirect +) diff --git a/go.sum b/go.sum index a99d6d1..d9b9d82 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,18 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= +github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/pborman/getopt/v2 v2.1.0 h1:eNfR+r+dWLdWmV8g5OlpyrTYHkhVNxHBdN2cCrJmOEA= github.com/pborman/getopt/v2 v2.1.0/go.mod h1:4NtW75ny4eBw9fO1bhtNdYTlZKYX5/tBLtsOpwKIKd0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/samber/go-gpt-3-encoder v0.3.1 h1:YWb9GsGYUgSX/wPtsEHjyNGRQXsQ9vDCg9SU2x9uMeU= +github.com/samber/go-gpt-3-encoder v0.3.1/go.mod h1:27nvdvk9ZtALyNtgs9JsPCMYja0Eleow/XzgjqwRtLU= +github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw= +github.com/samber/lo v1.37.0/go.mod h1:9vaz2O4o8oOnK23pd2TrXufcbdbJIa3b6cstBWKpopA= github.com/sashabaranov/go-gpt3 v1.4.0 h1:UqHYdXgJNtNvTtbzDnnQgkQ9TgTnHtCXx966uFTYXvU= github.com/sashabaranov/go-gpt3 v1.4.0/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 h1:LGJsf5LRplCck6jUCH3dBL2dmycNruWNF5xugkSlfXw= +golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=