diff --git a/README.md b/README.md index 90c9bd70..fe568833 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,9 @@ Encode text and decode tokens: fmt.Println("Vocab size:", tk.VocabSize()) // Vocab size: 30522 fmt.Println(tk.Encode("brown fox jumps over the lazy dog", false)) -// [2829 4419 14523 2058 1996 13971 3899] +// [2829 4419 14523 2058 1996 13971 3899] [brown fox jumps over the lazy dog] fmt.Println(tk.Encode("brown fox jumps over the lazy dog", true)) -// [101 2829 4419 14523 2058 1996 13971 3899 102] +// [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]] fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true)) // brown fox jumps over the lazy dog ``` diff --git a/example/main.go b/example/main.go index cb73e5bb..0436e51d 100644 --- a/example/main.go +++ b/example/main.go @@ -16,9 +16,9 @@ func main() { fmt.Println("Vocab size:", tk.VocabSize()) // Vocab size: 30522 fmt.Println(tk.Encode("brown fox jumps over the lazy dog", false)) - // [2829 4419 14523 2058 1996 13971 3899] + // [2829 4419 14523 2058 1996 13971 3899] [brown fox jumps over the lazy dog] fmt.Println(tk.Encode("brown fox jumps over the lazy dog", true)) - // [101 2829 4419 14523 2058 1996 13971 3899 102] + // [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]] fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true)) // brown fox jumps over the lazy dog } diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 29d8e05c..df361ff5 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -3,6 +3,13 @@ use std::path::PathBuf; use std::ptr; use tokenizers::tokenizer::Tokenizer; +#[repr(C)] +pub struct Buffer { + ids: *mut u32, + tokens: *mut *mut libc::c_char, + len: usize, +} + #[no_mangle] pub extern "C" fn from_bytes(bytes: *const u8, len: u32) -> *mut Tokenizer { let bytes_slice = unsafe { std::slice::from_raw_parts(bytes, len as usize) }; @@ -44,15 +51,7 @@ pub extern "C" fn from_file(config: *const libc::c_char) -> *mut libc::c_void { } #[no_mangle] -pub extern "C" fn free_tokenizer(ptr: *mut ::libc::c_void) { - if ptr.is_null() { - return; - } - ptr.cast::(); -} - -#[no_mangle] -pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, len: *mut u32, add_special_tokens: bool) -> *mut u32 { +pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, add_special_tokens: bool) -> Buffer { let tokenizer: &Tokenizer; unsafe { tokenizer = ptr.cast::().as_ref().expect("failed to cast tokenizer"); @@ -61,14 +60,23 @@ pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, l let message = message_cstr.to_str().unwrap(); let encoding = tokenizer.encode(message, add_special_tokens).expect("failed to encode input"); - let mut vec = encoding.get_ids().to_vec(); - vec.shrink_to_fit(); - unsafe { - *len = vec.len() as u32; - } - let vec_ptr = vec.as_mut_ptr(); - std::mem::forget(vec); - vec_ptr + let mut vec_ids = encoding.get_ids().to_vec(); + let mut vec_tokens = encoding.get_tokens() + .to_vec().into_iter() + .map(|s| std::ffi::CString::new(s).unwrap().into_raw()) + .collect::>(); + + vec_ids.shrink_to_fit(); + vec_tokens.shrink_to_fit(); + + let ids = vec_ids.as_mut_ptr(); + let tokens = vec_tokens.as_mut_ptr(); + let len = vec_ids.len(); + + std::mem::forget(vec_ids); + std::mem::forget(vec_tokens); + + Buffer { ids, tokens, len } } #[no_mangle] @@ -92,3 +100,35 @@ pub extern "C" fn vocab_size(ptr: *mut libc::c_void) -> u32 { } tokenizer.get_vocab_size(true) as u32 } + +#[no_mangle] +pub extern "C" fn free_tokenizer(ptr: *mut ::libc::c_void) { + if ptr.is_null() { + return; + } + ptr.cast::(); +} + +#[no_mangle] +pub extern "C" fn free_buffer(buf: Buffer) { + if buf.ids.is_null() { + return; + } + unsafe { + Vec::from_raw_parts(buf.ids, buf.len, buf.len); + let strings = Vec::from_raw_parts(buf.tokens, buf.len, buf.len); + for s in strings { + drop(std::ffi::CString::from_raw(s)); + } + } +} + +#[no_mangle] +pub extern "C" fn free_string(ptr: *mut libc::c_char) { + if ptr.is_null() { + return; + } + unsafe { + drop(std::ffi::CString::from_raw(ptr)); + } +} \ No newline at end of file diff --git a/release/main.go b/release/main.go index cb73e5bb..0436e51d 100644 --- a/release/main.go +++ b/release/main.go @@ -16,9 +16,9 @@ func main() { fmt.Println("Vocab size:", tk.VocabSize()) // Vocab size: 30522 fmt.Println(tk.Encode("brown fox jumps over the lazy dog", false)) - // [2829 4419 14523 2058 1996 13971 3899] + // [2829 4419 14523 2058 1996 13971 3899] [brown fox jumps over the lazy dog] fmt.Println(tk.Encode("brown fox jumps over the lazy dog", true)) - // [101 2829 4419 14523 2058 1996 13971 3899 102] + // [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]] fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true)) // brown fox jumps over the lazy dog } diff --git a/tokenizer.go b/tokenizer.go index 7e3a6a65..b0c39ae7 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -54,22 +54,27 @@ func (t *Tokenizer) Close() error { return nil } -func (t *Tokenizer) Encode(str string, addSpecialTokens bool) []uint32 { +func (t *Tokenizer) Encode(str string, addSpecialTokens bool) ([]uint32, []string) { cStr := C.CString(str) defer C.free(unsafe.Pointer(cStr)) - var len C.uint - res := C.encode(t.tokenizer, cStr, &len, C.bool(addSpecialTokens)) - if len > 0 { - // can't dealloc nil - defer C.free(unsafe.Pointer(res)) + res := C.encode(t.tokenizer, cStr, C.bool(addSpecialTokens)) + len := int(res.len) + if len == 0 { + return nil, nil } - slice := unsafe.Slice(res, len) + defer C.free_buffer(res) + ids := unsafe.Slice(res.ids, len) tokenIDs := make([]uint32, len) - for i, v := range slice { + for i, v := range ids { tokenIDs[i] = uint32(v) } - return tokenIDs + + tokens := make([]string, len) + for i, s := range (*[1 << 30]*C.char)(unsafe.Pointer(res.tokens))[:len:len] { + tokens[i] = C.GoString(s) + } + return tokenIDs, tokens } func (t *Tokenizer) Decode(tokenIDs []uint32, skipSpecialTokens bool) string { @@ -78,7 +83,7 @@ func (t *Tokenizer) Decode(tokenIDs []uint32, skipSpecialTokens bool) string { } len := C.uint(len(tokenIDs)) res := C.decode(t.tokenizer, (*C.uint)(unsafe.Pointer(&tokenIDs[0])), len, C.bool(skipSpecialTokens)) - defer C.free(unsafe.Pointer(res)) + defer C.free_string(res) return C.GoString(res) } diff --git a/tokenizer_test.go b/tokenizer_test.go index d9254a56..72331c5e 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -30,26 +30,30 @@ func TestEmbeddingConfig(t *testing.T) { name string str string addSpecial bool - want []uint32 + wantIDs []uint32 + wantTokens []string }{ { name: "without special tokens", str: "brown fox jumps over the lazy dog", addSpecial: false, - want: []uint32{0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4}, + wantIDs: []uint32{0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4}, + wantTokens: []string{"brown", "fox", "jumps", "over", "the", "lazy", "dog"}, }, { name: "with special tokens", str: "brown fox jumps over the lazy dog", addSpecial: true, - want: []uint32{0x65, 0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4, 0x66}, + wantIDs: []uint32{0x65, 0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4, 0x66}, + wantTokens: []string{"[CLS]", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "[SEP]"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tk.Encode(tt.str, tt.addSpecial) - got := tk.Encode(tt.str, tt.addSpecial) - assert.Equal(t, tt.want, got) + gotIDs, gotTokens := tk.Encode(tt.str, tt.addSpecial) + assert.Equal(t, tt.wantIDs, gotIDs) + assert.Equal(t, tt.wantTokens, gotTokens) }) } } @@ -62,37 +66,39 @@ func TestEncode(t *testing.T) { name string str string addSpecial bool - want []uint32 + wantIDs []uint32 + wantTokens []string }{ { name: "without special tokens", str: "brown fox jumps over the lazy dog", addSpecial: false, - want: []uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, + wantIDs: []uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, + wantTokens: []string{"brown", "fox", "jumps", "over", "the", "lazy", "dog"}, }, { name: "with special tokens", str: "brown fox jumps over the lazy dog", addSpecial: true, - want: []uint32{101, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102}, + wantIDs: []uint32{101, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102}, + wantTokens: []string{"[CLS]", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "[SEP]"}, }, { name: "empty string", str: "", addSpecial: false, - want: []uint32{}, }, { name: "empty string with special tokens", str: "", addSpecial: false, - want: []uint32{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tk.Encode(tt.str, tt.addSpecial) - assert.Equal(t, tt.want, got) + gotIDs, gotTokens := tk.Encode(tt.str, tt.addSpecial) + assert.Equal(t, tt.wantIDs, gotIDs) + assert.Equal(t, tt.wantTokens, gotTokens) }) } } @@ -104,7 +110,8 @@ func TestEncodeWithTruncation(t *testing.T) { addSpecial bool maxLen int dir tokenizers.TruncationDirection - want []uint32 + wantIDs []uint32 + wantTokens []string }{ { name: "without special tokens, left truncation", @@ -112,7 +119,8 @@ func TestEncodeWithTruncation(t *testing.T) { addSpecial: false, maxLen: 5, dir: tokenizers.TruncationDirectionLeft, - want: []uint32{0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4}, + wantIDs: []uint32{0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4}, + wantTokens: []string{"jumps", "over", "the", "lazy", "dog"}, }, { name: "without special tokens, right truncation", @@ -120,7 +128,8 @@ func TestEncodeWithTruncation(t *testing.T) { addSpecial: false, maxLen: 5, dir: tokenizers.TruncationDirectionRight, - want: []uint32{0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89}, + wantIDs: []uint32{0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89}, + wantTokens: []string{"brown", "fox", "jumps", "over", "the"}, }, { name: "with special tokens, left truncation", @@ -128,7 +137,8 @@ func TestEncodeWithTruncation(t *testing.T) { addSpecial: true, maxLen: 5, dir: tokenizers.TruncationDirectionLeft, - want: []uint32{0x65, 0x3a89, 0x35fc3, 0x57b4, 0x66}, + wantIDs: []uint32{0x65, 0x3a89, 0x35fc3, 0x57b4, 0x66}, + wantTokens: []string{"[CLS]", "the", "lazy", "dog", "[SEP]"}, }, { name: "with special tokens, right truncation", @@ -136,7 +146,8 @@ func TestEncodeWithTruncation(t *testing.T) { addSpecial: true, maxLen: 5, dir: tokenizers.TruncationDirectionRight, - want: []uint32{0x65, 0xca3f, 0x2f304, 0x5185b, 0x66}, + wantIDs: []uint32{0x65, 0xca3f, 0x2f304, 0x5185b, 0x66}, + wantTokens: []string{"[CLS]", "brown", "fox", "jumps", "[SEP]"}, }, } for _, tt := range tests { @@ -146,8 +157,9 @@ func TestEncodeWithTruncation(t *testing.T) { defer tk.Close() tk.Encode(tt.str, tt.addSpecial) - got := tk.Encode(tt.str, tt.addSpecial) - assert.Equal(t, tt.want, got) + gotIDs, gotTokens := tk.Encode(tt.str, tt.addSpecial) + assert.Equal(t, tt.wantIDs, gotIDs) + assert.Equal(t, tt.wantTokens, gotTokens) }) } } @@ -215,7 +227,7 @@ func BenchmarkEncodeNTimes(b *testing.B) { expected := []uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899} b.ResetTimer() for i := 0; i < b.N; i++ { - tokens := tk.Encode("brown fox jumps over the lazy dog", false) + tokens, _ := tk.Encode("brown fox jumps over the lazy dog", false) assert.Equal(b, expected, tokens) } } @@ -230,7 +242,7 @@ func BenchmarkEncodeNChars(b *testing.B) { } str := string(input) b.ResetTimer() - tokens := tk.Encode(str, false) + tokens, _ := tk.Encode(str, false) assert.Greater(b, len(tokens), 0) } diff --git a/tokenizers.h b/tokenizers.h index 909c92eb..665f6c9b 100644 --- a/tokenizers.h +++ b/tokenizers.h @@ -1,16 +1,26 @@ #include #include +struct Buffer { + uint32_t *ids; + char *tokens; + uint32_t len; +}; + void *from_bytes(const uint8_t *config, uint32_t len); void *from_bytes_with_truncation(const uint8_t *config, uint32_t len, uint32_t max_len, uint8_t direction); void *from_file(const char *config); -void free_tokenizer(void *ptr); - -uint32_t *encode(void *ptr, const char *message, uint32_t *len, bool add_special_tokens); +struct Buffer encode(void *ptr, const char *message, bool add_special_tokens); char *decode(void *ptr, const uint32_t *ids, uint32_t len, bool skip_special_tokens); uint32_t vocab_size(void *ptr); + +void free_tokenizer(void *ptr); + +void free_buffer(struct Buffer buffer); + +void free_string(char *string);