diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 1834d243..29d8e05c 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -10,6 +10,23 @@ pub extern "C" fn from_bytes(bytes: *const u8, len: u32) -> *mut Tokenizer { return Box::into_raw(Box::new(tokenizer)); } +#[no_mangle] +pub extern "C" fn from_bytes_with_truncation(bytes: *const u8, len: u32, max_len: usize, dir: u8) -> *mut Tokenizer { + let bytes_slice = unsafe { std::slice::from_raw_parts(bytes, len as usize) }; + let tokenizer: Tokenizer = Tokenizer::from_bytes(bytes_slice) + .expect("failed to create tokenizer") + .with_truncation(Some(tokenizers::tokenizer::TruncationParams{ + max_length: max_len, + direction: match dir { + 0 => tokenizers::tokenizer::TruncationDirection::Left, + 1 => tokenizers::tokenizer::TruncationDirection::Right, + _ => panic!("invalid truncation direction"), + }, + ..Default::default() + })).to_owned().into(); + return Box::into_raw(Box::new(tokenizer)); +} + #[no_mangle] pub extern "C" fn from_file(config: *const libc::c_char) -> *mut libc::c_void { let config_cstr = unsafe { CStr::from_ptr(config) }; diff --git a/tokenizer.go b/tokenizer.go index 1bc6c25a..58d055d4 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -19,6 +19,13 @@ type Tokenizer struct { tokenizer unsafe.Pointer } +type TruncationDirection int + +const ( + TruncationDirectionLeft TruncationDirection = iota + TruncationDirectionRight +) + var _ io.Closer = (*Tokenizer)(nil) func FromBytes(data []byte) (*Tokenizer, error) { @@ -26,6 +33,11 @@ func FromBytes(data []byte) (*Tokenizer, error) { return &Tokenizer{tokenizer: tokenizer}, nil } +func FromBytesWithTruncation(data []byte, maxLen uint32, dir TruncationDirection) (*Tokenizer, error) { + tokenizer := C.from_bytes_with_truncation((*C.uchar)(unsafe.Pointer(&data[0])), C.uint(len(data)), C.uint(maxLen), C.uchar(dir)) + return &Tokenizer{tokenizer: tokenizer}, nil +} + func FromFile(path string) (*Tokenizer, error) { cPath := C.CString(path) defer C.free(unsafe.Pointer(cPath)) diff --git a/tokenizer_test.go b/tokenizer_test.go index f48ec22c..233ab5d9 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -85,6 +85,61 @@ func TestEncode(t *testing.T) { } } +func TestEncodeWithTruncation(t *testing.T) { + tests := []struct { + name string + str string + addSpecial bool + maxLen int + dir tokenizers.TruncationDirection + want []uint32 + }{ + { + name: "without special tokens, left truncation", + str: "brown fox jumps over the lazy dog", + addSpecial: false, + maxLen: 5, + dir: tokenizers.TruncationDirectionLeft, + want: []uint32{0x5185b, 0x3c54, 0x3a89, 0x35fc3, 0x57b4}, + }, + { + name: "without special tokens, right truncation", + str: "brown fox jumps over the lazy dog", + addSpecial: false, + maxLen: 5, + dir: tokenizers.TruncationDirectionRight, + want: []uint32{0xca3f, 0x2f304, 0x5185b, 0x3c54, 0x3a89}, + }, + { + name: "with special tokens, left truncation", + str: "brown fox jumps over the lazy dog", + addSpecial: true, + maxLen: 5, + dir: tokenizers.TruncationDirectionLeft, + want: []uint32{0x65, 0x3a89, 0x35fc3, 0x57b4, 0x66}, + }, + { + name: "with special tokens, right truncation", + str: "brown fox jumps over the lazy dog", + addSpecial: true, + maxLen: 5, + dir: tokenizers.TruncationDirectionRight, + want: []uint32{0x65, 0xca3f, 0x2f304, 0x5185b, 0x66}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tk, err := tokenizers.FromBytesWithTruncation(embeddedBytes, uint32(tt.maxLen), tt.dir) + require.NoError(t, err) + defer tk.Close() + + tk.Encode(tt.str, tt.addSpecial) + got := tk.Encode(tt.str, tt.addSpecial) + assert.Equal(t, tt.want, got) + }) + } +} + func TestDecode(t *testing.T) { tk, err := tokenizers.FromFile("./test/data/bert-base-uncased.json") require.NoError(t, err) diff --git a/tokenizers.h b/tokenizers.h index 24eac849..909c92eb 100644 --- a/tokenizers.h +++ b/tokenizers.h @@ -3,6 +3,8 @@ void *from_bytes(const uint8_t *config, uint32_t len); +void *from_bytes_with_truncation(const uint8_t *config, uint32_t len, uint32_t max_len, uint8_t direction); + void *from_file(const char *config); void free_tokenizer(void *ptr);