From 49bd0555197f84510bd03bf0cba50c7174987e9d Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Fri, 27 Nov 2020 16:45:13 -0500 Subject: [PATCH] Node - Update bindings with train_from_files --- bindings/node/native/src/tokenizer.rs | 4 +-- bindings/node/native/src/trainers.rs | 39 +++++++++++++++------------ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/bindings/node/native/src/tokenizer.rs b/bindings/node/native/src/tokenizer.rs index 0d5264958..8dc640c11 100644 --- a/bindings/node/native/src/tokenizer.rs +++ b/bindings/node/native/src/tokenizer.rs @@ -749,7 +749,7 @@ declare_types! { // train(files: string[], trainer?: Trainer) let files = cx.extract::>(0)?; - let trainer = if let Some(val) = cx.argument_opt(1) { + let mut trainer = if let Some(val) = cx.argument_opt(1) { let js_trainer = val.downcast::().or_throw(&mut cx)?; let guard = cx.lock(); @@ -768,7 +768,7 @@ declare_types! { this.borrow_mut(&guard) .tokenizer.write().unwrap() - .train(&trainer, files) + .train_from_files(&mut trainer, files) .map_err(|e| Error(format!("{}", e)))?; Ok(cx.undefined().upcast()) diff --git a/bindings/node/native/src/trainers.rs b/bindings/node/native/src/trainers.rs index a58f77c11..4141268d8 100644 --- a/bindings/node/native/src/trainers.rs +++ b/bindings/node/native/src/trainers.rs @@ -4,8 +4,7 @@ use crate::extraction::*; use crate::models::Model; use crate::tokenizer::AddedToken; use neon::prelude::*; -use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use tk::models::{ bpe::BpeTrainer, unigram::UnigramTrainer, wordlevel::WordLevelTrainer, @@ -15,13 +14,13 @@ use tk::models::{ /// Trainer #[derive(Clone)] pub struct Trainer { - pub trainer: Option>, + pub trainer: Option>>, } impl From for Trainer { fn from(trainer: TrainerWrapper) -> Self { Self { - trainer: Some(Arc::new(trainer)), + trainer: Some(Arc::new(RwLock::new(trainer))), } } } @@ -33,20 +32,19 @@ impl tk::Trainer for Trainer { self.trainer .as_ref() .expect("Uninitialized Trainer") + .read() + .unwrap() .should_show_progress() } - fn train( - &self, - words: HashMap, - model: &mut Self::Model, - ) -> tk::Result> { + fn train(&self, model: &mut Self::Model) -> tk::Result> { let special_tokens = self .trainer .as_ref() .ok_or("Uninitialized Trainer")? + .read() + .unwrap() .train( - words, &mut model .model .as_ref() @@ -58,11 +56,18 @@ impl tk::Trainer for Trainer { Ok(special_tokens) } - fn process_tokens(&self, words: &mut HashMap, tokens: Vec) { + fn feed(&mut self, iterator: I, process: F) -> tk::Result<()> + where + I: Iterator + Send, + S: AsRef + Send, + F: Fn(&str) -> tk::Result> + Sync, + { self.trainer .as_ref() - .expect("Uninitialized Trainer") - .process_tokens(words, tokens) + .ok_or("Uninitialized Trainer")? + .write() + .unwrap() + .feed(iterator, process) } } @@ -162,7 +167,7 @@ fn bpe_trainer(mut cx: FunctionContext) -> JsResult { let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into())); + js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(RwLock::new(trainer.into()))); Ok(js_trainer) } @@ -254,7 +259,7 @@ fn wordpiece_trainer(mut cx: FunctionContext) -> JsResult { let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into())); + js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(RwLock::new(trainer.into()))); Ok(js_trainer) } @@ -327,7 +332,7 @@ fn wordlevel_trainer(mut cx: FunctionContext) -> JsResult { let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into())); + js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(RwLock::new(trainer.into()))); Ok(js_trainer) } @@ -424,7 +429,7 @@ fn unigram_trainer(mut cx: FunctionContext) -> JsResult { let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into())); + js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(RwLock::new(trainer.into()))); Ok(js_trainer) }