Skip to content

Commit

Permalink
support more attributes from the Encoding structure (#5)
Browse files Browse the repository at this point in the history
* support more attributes from the Encoding structure

MiniLM requires the attention mask to perform the mean pooling
operation as can be seen at
https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2

* adapt the example and the readme

* introduce EncodeWithOptions method for selecting the returned attributes

* add missing options

* fix the example and the readme

* fix benchmarks

---------

Co-authored-by: Clément Michaud <[email protected]>
  • Loading branch information
clems4ever and cmichaudav authored Nov 15, 2023
1 parent 786da40 commit 38a9a14
Show file tree
Hide file tree
Showing 5 changed files with 31,085 additions and 84 deletions.
100 changes: 78 additions & 22 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use tokenizers::tokenizer::Tokenizer;
#[repr(C)]
pub struct Buffer {
ids: *mut u32,
type_ids: *mut u32,
special_tokens_mask: *mut u32,
attention_mask: *mut u32,
tokens: *mut *mut libc::c_char,
len: usize,
}
Expand Down Expand Up @@ -50,36 +53,71 @@ pub extern "C" fn from_file(config: *const libc::c_char) -> *mut libc::c_void {
}
}

#[repr(C)]
pub struct EncodeOptions {
add_special_tokens: bool,

return_type_ids: bool,
return_tokens: bool,
return_special_tokens_mask: bool,
return_attention_mask: bool,
}

#[no_mangle]
pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, add_special_tokens: bool) -> Buffer {
pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, options: &EncodeOptions) -> Buffer {
let tokenizer: &Tokenizer;
unsafe {
tokenizer = ptr.cast::<Tokenizer>().as_ref().expect("failed to cast tokenizer");
}
let message_cstr = unsafe { CStr::from_ptr(message) };
let message = message_cstr.to_str();
if message.is_err() {
return Buffer { ids: ptr::null_mut(), tokens: ptr::null_mut(), len: 0 };
return Buffer { ids: ptr::null_mut(), tokens: ptr::null_mut(), len: 0, type_ids: ptr::null_mut(), special_tokens_mask: ptr::null_mut(), attention_mask: ptr::null_mut() };
}

let encoding = tokenizer.encode(message.unwrap(), add_special_tokens).expect("failed to encode input");
let encoding = tokenizer.encode(message.unwrap(), options.add_special_tokens).expect("failed to encode input");
let mut vec_ids = encoding.get_ids().to_vec();
let mut vec_tokens = encoding.get_tokens()
.to_vec().into_iter()
.map(|s| std::ffi::CString::new(s).unwrap().into_raw())
.collect::<Vec<_>>();

vec_ids.shrink_to_fit();
vec_tokens.shrink_to_fit();

let ids = vec_ids.as_mut_ptr();
let tokens = vec_tokens.as_mut_ptr();
let len = vec_ids.len();

std::mem::forget(vec_ids);
std::mem::forget(vec_tokens);

Buffer { ids, tokens, len }
let mut type_ids: *mut u32 = ptr::null_mut();
if options.return_type_ids {
let mut vec_type_ids = encoding.get_type_ids().to_vec();
vec_type_ids.shrink_to_fit();
type_ids = vec_type_ids.as_mut_ptr();
std::mem::forget(vec_type_ids);
}

let mut tokens: *mut *mut i8 = ptr::null_mut();
if options.return_tokens {
let mut vec_tokens = encoding.get_tokens()
.to_vec().into_iter()
.map(|s| std::ffi::CString::new(s).unwrap().into_raw())
.collect::<Vec<_>>();
vec_tokens.shrink_to_fit();
tokens = vec_tokens.as_mut_ptr();
std::mem::forget(vec_tokens);
}

let mut special_tokens_mask: *mut u32 = ptr::null_mut();
if options.return_special_tokens_mask {
let mut vec_special_tokens_mask = encoding.get_special_tokens_mask().to_vec();
vec_special_tokens_mask.shrink_to_fit();
special_tokens_mask = vec_special_tokens_mask.as_mut_ptr();
std::mem::forget(vec_special_tokens_mask);
}

let mut attention_mask: *mut u32 = ptr::null_mut();
if options.return_attention_mask {
let mut vec_attention_mask = encoding.get_attention_mask().to_vec();
vec_attention_mask.shrink_to_fit();
attention_mask = vec_attention_mask.as_mut_ptr();
std::mem::forget(vec_attention_mask);
}

Buffer { ids, type_ids, special_tokens_mask, attention_mask, tokens, len }
}

#[no_mangle]
Expand Down Expand Up @@ -118,15 +156,33 @@ pub extern "C" fn free_tokenizer(ptr: *mut ::libc::c_void) {

#[no_mangle]
pub extern "C" fn free_buffer(buf: Buffer) {
if buf.ids.is_null() {
return;
if !buf.ids.is_null() {
unsafe {
Vec::from_raw_parts(buf.ids, buf.len, buf.len);
}
}
unsafe {
Vec::from_raw_parts(buf.ids, buf.len, buf.len);
let strings = Vec::from_raw_parts(buf.tokens, buf.len, buf.len);
for s in strings {
drop(std::ffi::CString::from_raw(s));
}
if !buf.type_ids.is_null() {
unsafe {
Vec::from_raw_parts(buf.type_ids, buf.len, buf.len);
}
}
if !buf.special_tokens_mask.is_null() {
unsafe {
Vec::from_raw_parts(buf.special_tokens_mask, buf.len, buf.len);
}
}
if !buf.attention_mask.is_null() {
unsafe {
Vec::from_raw_parts(buf.attention_mask, buf.len, buf.len);
}
}
if !buf.tokens.is_null() {
unsafe {
let strings = Vec::from_raw_parts(buf.tokens, buf.len, buf.len);
for s in strings {
drop(std::ffi::CString::from_raw(s));
}
}
}
}

Expand Down
Loading

0 comments on commit 38a9a14

Please sign in to comment.