Skip to content

Commit

Permalink
Node - Update bindings with train_from_files
Browse files Browse the repository at this point in the history
  • Loading branch information
n1t0 committed Nov 28, 2020
1 parent 3a8627c commit 49bd055
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
4 changes: 2 additions & 2 deletions bindings/node/native/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ declare_types! {
// train(files: string[], trainer?: Trainer)

let files = cx.extract::<Vec<String>>(0)?;
let trainer = if let Some(val) = cx.argument_opt(1) {
let mut trainer = if let Some(val) = cx.argument_opt(1) {
let js_trainer = val.downcast::<JsTrainer>().or_throw(&mut cx)?;
let guard = cx.lock();

Expand All @@ -768,7 +768,7 @@ declare_types! {

this.borrow_mut(&guard)
.tokenizer.write().unwrap()
.train(&trainer, files)
.train_from_files(&mut trainer, files)
.map_err(|e| Error(format!("{}", e)))?;

Ok(cx.undefined().upcast())
Expand Down
39 changes: 22 additions & 17 deletions bindings/node/native/src/trainers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use crate::extraction::*;
use crate::models::Model;
use crate::tokenizer::AddedToken;
use neon::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Arc, RwLock};

use tk::models::{
bpe::BpeTrainer, unigram::UnigramTrainer, wordlevel::WordLevelTrainer,
Expand All @@ -15,13 +14,13 @@ use tk::models::{
/// Trainer
#[derive(Clone)]
pub struct Trainer {
pub trainer: Option<Arc<TrainerWrapper>>,
pub trainer: Option<Arc<RwLock<TrainerWrapper>>>,
}

impl From<TrainerWrapper> for Trainer {
fn from(trainer: TrainerWrapper) -> Self {
Self {
trainer: Some(Arc::new(trainer)),
trainer: Some(Arc::new(RwLock::new(trainer))),
}
}
}
Expand All @@ -33,20 +32,19 @@ impl tk::Trainer for Trainer {
self.trainer
.as_ref()
.expect("Uninitialized Trainer")
.read()
.unwrap()
.should_show_progress()
}

fn train(
&self,
words: HashMap<String, u32>,
model: &mut Self::Model,
) -> tk::Result<Vec<tk::AddedToken>> {
fn train(&self, model: &mut Self::Model) -> tk::Result<Vec<tk::AddedToken>> {
let special_tokens = self
.trainer
.as_ref()
.ok_or("Uninitialized Trainer")?
.read()
.unwrap()
.train(
words,
&mut model
.model
.as_ref()
Expand All @@ -58,11 +56,18 @@ impl tk::Trainer for Trainer {
Ok(special_tokens)
}

fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> tk::Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> tk::Result<Vec<String>> + Sync,
{
self.trainer
.as_ref()
.expect("Uninitialized Trainer")
.process_tokens(words, tokens)
.ok_or("Uninitialized Trainer")?
.write()
.unwrap()
.feed(iterator, process)
}
}

Expand Down Expand Up @@ -162,7 +167,7 @@ fn bpe_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {

let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into()));
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(RwLock::new(trainer.into())));

Ok(js_trainer)
}
Expand Down Expand Up @@ -254,7 +259,7 @@ fn wordpiece_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {

let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into()));
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(RwLock::new(trainer.into())));

Ok(js_trainer)
}
Expand Down Expand Up @@ -327,7 +332,7 @@ fn wordlevel_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {

let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into()));
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(RwLock::new(trainer.into())));

Ok(js_trainer)
}
Expand Down Expand Up @@ -424,7 +429,7 @@ fn unigram_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {

let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into()));
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(RwLock::new(trainer.into())));

Ok(js_trainer)
}
Expand Down

0 comments on commit 49bd055

Please sign in to comment.