From f3f1a0e080425dec1ec9e5b154b150a805350922 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Sun, 7 May 2023 15:04:49 -0700 Subject: [PATCH] handle empty encode/decode params --- tokenizer.go | 3 +++ tokenizer_test.go | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/tokenizer.go b/tokenizer.go index 7c399384..7e3a6a65 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -73,6 +73,9 @@ func (t *Tokenizer) Encode(str string, addSpecialTokens bool) []uint32 { } func (t *Tokenizer) Decode(tokenIDs []uint32, skipSpecialTokens bool) string { + if len(tokenIDs) == 0 { + return "" + } len := C.uint(len(tokenIDs)) res := C.decode(t.tokenizer, (*C.uint)(unsafe.Pointer(&tokenIDs[0])), len, C.bool(skipSpecialTokens)) defer C.free(unsafe.Pointer(res)) diff --git a/tokenizer_test.go b/tokenizer_test.go index 233ab5d9..d9254a56 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -76,6 +76,18 @@ func TestEncode(t *testing.T) { addSpecial: true, want: []uint32{101, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102}, }, + { + name: "empty string", + str: "", + addSpecial: false, + want: []uint32{}, + }, + { + name: "empty string with special tokens", + str: "", + addSpecial: false, + want: []uint32{}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -174,6 +186,12 @@ func TestDecode(t *testing.T) { skipSpecial: false, want: "[CLS] brown fox jumps over the lazy dog [SEP]", }, + { + name: "no tokens", + tokens: []uint32{}, + skipSpecial: false, + want: "", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {