Skip to content

Commit

Permalink
Remove obsolete code from NmtModelFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
ddaspit committed Nov 22, 2023
1 parent 4faa596 commit 9b68138
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 7 deletions.
6 changes: 1 addition & 5 deletions machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
add_unk_trg_tokens=self._config.huggingface.tokenizer.add_unk_trg_tokens,
)

def create_engine(self, half_previous_batch_size=False) -> TranslationEngine:
if half_previous_batch_size:
self._config.huggingface.generate_params.batch_size = max(
self._config.huggingface.generate_params.batch_size // 2, 1
)
def create_engine(self) -> TranslationEngine:
return HuggingFaceNmtEngine(
self._model,
src_lang=self._config.src_lang,
Expand Down
2 changes: 1 addition & 1 deletion machine/jobs/nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
...

@abstractmethod
def create_engine(self, half_previous_batch_size=False) -> TranslationEngine:
def create_engine(self) -> TranslationEngine:
...

@abstractmethod
Expand Down
1 change: 0 additions & 1 deletion tests/jobs/test_nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test_run(decoy: Decoy) -> None:
env = _TestEnvironment(decoy)
env.job.run()

decoy.verify(env.engine.translate_batch(matchers.Anything()), times=1)
pretranslations = json.loads(env.target_pretranslations)
assert len(pretranslations) == 1
assert pretranslations[0]["translation"] == "Please, I have booked a room."
Expand Down

0 comments on commit 9b68138

Please sign in to comment.