Skip to content

Commit

Permalink
Fix moses punct (#140)
Browse files Browse the repository at this point in the history
* Fix moses punctuation

* max_number actually not there - but it's in the generation_config!
  • Loading branch information
johnml1135 authored Nov 11, 2024
1 parent 730ea67 commit 0fb9518
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
12 changes: 8 additions & 4 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def __init__(
self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True)
if isinstance(self._tokenizer, (NllbTokenizer, NllbTokenizerFast)):
self._mpn = MosesPunctNormalizer()
self._mpn.substitutions = [
(str(re.compile(r)), sub)
self._mpn.substitutions = [ # type: ignore
(re.compile(r), sub)
for r, sub in self._mpn.substitutions
if isinstance(r, str) and isinstance(sub, str)
]
Expand Down Expand Up @@ -236,8 +236,12 @@ def _forward(self, model_inputs, **generate_kwargs):

input_tokens = model_inputs["input_tokens"]
del model_inputs["input_tokens"]
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
config = self.model.generation_config
else:
config = self.model.config
generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length)
generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length)
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
output = self.model.generate(
**model_inputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,8 @@ def __init__(
self._add_unk_src_tokens = add_unk_src_tokens
self._add_unk_tgt_tokens = add_unk_tgt_tokens
self._mpn = MosesPunctNormalizer()
self._mpn.substitutions = [
(str(re.compile(r)), sub)
for r, sub in self._mpn.substitutions
if isinstance(r, str) and isinstance(sub, str)
self._mpn.substitutions = [ # type: ignore
(re.compile(r), sub) for r, sub in self._mpn.substitutions if isinstance(r, str) and isinstance(sub, str)
]
self._stats = TrainStats()

Expand Down

0 comments on commit 0fb9518

Please sign in to comment.