Skip to content

Commit

Permalink
Add CountTokens function and test
Browse files Browse the repository at this point in the history
This commit adds a CountTokens function that counts the number of tokens in a
given slice of messages. The function uses the go-gpt-3-encoder package to
encode the messages and count the tokens. A test for the CountTokens function
is also added to the aichat_test.go file.
  • Loading branch information
tkawachi committed Mar 12, 2023
1 parent 690f23c commit 1dd017b
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 0 deletions.
35 changes: 35 additions & 0 deletions aichat.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"

"github.com/pborman/getopt/v2"
tokenizer "github.com/samber/go-gpt-3-encoder"
gogpt "github.com/sashabaranov/go-gpt3"
)

Expand Down Expand Up @@ -179,6 +180,13 @@ func main() {
if verbose {
log.Printf("messages: %+v", messages)
}
cnt, err := CountTokens(mapSlice(messages, func(m gogpt.ChatCompletionMessage) string { return m.Content }))
if err != nil {
log.Fatal(err)
}
if cnt > 4096 {
log.Fatalf("total tokens %d exceeds 4096", cnt)
}
request := gogpt.ChatCompletionRequest{
Model: gogpt.GPT3Dot5Turbo,
Messages: messages,
Expand All @@ -196,3 +204,30 @@ func main() {
}

}

// mapSlice maps a slice of type T to a slice of type M using the function f.
func mapSlice[T any, M any](a []T, f func(T) M) []M {
r := make([]M, len(a))
for i, v := range a {
r[i] = f(v)
}
return r
}

// CountTokens returns the number of tokens in the messages.
func CountTokens(messages []string) (int, error) {
count := 0
encoder, err := tokenizer.NewEncoder()
if err != nil {
return 0, err
}
for _, message := range messages {
// Encode string with GPT tokenizer
encoded, err := encoder.Encode(message)
if err != nil {
return 0, err
}
count += len(encoded)
}
return count, nil
}
19 changes: 19 additions & 0 deletions aichat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package main

import (
"testing"
)

func TestCountTokens(t *testing.T) {
messages := []string{
"Hello, world!",
"How are you?",
}
count, err := CountTokens(messages)
if err != nil {
t.Errorf("CountTokens() returned an error: %v", err)
}
if count != 8 {
t.Errorf("CountTokens() returned %d, expected 8", count)
}
}
7 changes: 7 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ go 1.20

require (
github.com/pborman/getopt/v2 v2.1.0
github.com/samber/go-gpt-3-encoder v0.3.1
github.com/sashabaranov/go-gpt3 v1.4.0
gopkg.in/yaml.v3 v3.0.1
)

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/samber/lo v1.37.0 // indirect
golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 // indirect
)
11 changes: 11 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/pborman/getopt/v2 v2.1.0 h1:eNfR+r+dWLdWmV8g5OlpyrTYHkhVNxHBdN2cCrJmOEA=
github.com/pborman/getopt/v2 v2.1.0/go.mod h1:4NtW75ny4eBw9fO1bhtNdYTlZKYX5/tBLtsOpwKIKd0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/samber/go-gpt-3-encoder v0.3.1 h1:YWb9GsGYUgSX/wPtsEHjyNGRQXsQ9vDCg9SU2x9uMeU=
github.com/samber/go-gpt-3-encoder v0.3.1/go.mod h1:27nvdvk9ZtALyNtgs9JsPCMYja0Eleow/XzgjqwRtLU=
github.com/samber/lo v1.37.0 h1:XjVcB8g6tgUp8rsPsJ2CvhClfImrpL04YpQHXeHPhRw=
github.com/samber/lo v1.37.0/go.mod h1:9vaz2O4o8oOnK23pd2TrXufcbdbJIa3b6cstBWKpopA=
github.com/sashabaranov/go-gpt3 v1.4.0 h1:UqHYdXgJNtNvTtbzDnnQgkQ9TgTnHtCXx966uFTYXvU=
github.com/sashabaranov/go-gpt3 v1.4.0/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 h1:LGJsf5LRplCck6jUCH3dBL2dmycNruWNF5xugkSlfXw=
golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down

0 comments on commit 1dd017b

Please sign in to comment.