Skip to content

Commit

Permalink
Default to latin tokenizer
Browse files Browse the repository at this point in the history
Fix minor issues.
  • Loading branch information
johnml1135 committed May 10, 2024
1 parent 1548698 commit ff35d57
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 22 deletions.
26 changes: 18 additions & 8 deletions machine/jobs/smt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,24 @@
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any, Callable, Optional, cast

from machine.translation.thot.thot_smt_model_trainer import ThotSmtModelTrainer
from machine.translation.thot.thot_word_alignment_model_type import (
from dynaconf.base import Settings

from ..tokenization import get_tokenizer_detokenizer
from ..translation.thot.thot_smt_model import ThotSmtParameters, ThotWordAlignmentModelType
from ..translation.thot.thot_smt_model_trainer import ThotSmtModelTrainer
from ..translation.thot.thot_word_alignment_model_type import (
checkThotWordAlignmentModelType,
getThotWordAlignmentModelType,
)
from machine.translation.unigram_truecaser_trainer import UnigramTruecaserTrainer

from ..translation.thot.thot_smt_model import ThotSmtParameters, ThotWordAlignmentModelType
from ..translation.unigram_truecaser_trainer import UnigramTruecaserTrainer
from ..utils.progress_status import ProgressStatus
from .shared_file_service import SharedFileService

logger = logging.getLogger(__name__)


class SmtEngineBuildJob:
def __init__(self, config: Any, shared_file_service: SharedFileService) -> None:
def __init__(self, config: Settings, shared_file_service: SharedFileService) -> None:
self._config = config
self._shared_file_service = shared_file_service
self._model_type = cast(str, self._config.model_type).lower()
Expand All @@ -35,6 +37,8 @@ def run(
check_canceled()

self._check_config()
(tokenizer, _) = get_tokenizer_detokenizer(str(self._config.get("tokenizer", default="latin")))
logger.info(f"Tokenizer used: {type(tokenizer).__name__}")

with TemporaryDirectory() as temp_dir:

Expand All @@ -58,14 +62,20 @@ def run(
check_canceled()

with ThotSmtModelTrainer(
getThotWordAlignmentModelType(self._model_type), parallel_corpus, parameters
word_alignment_model_type=getThotWordAlignmentModelType(self._model_type),
corpus=parallel_corpus,
config=parameters,
source_tokenizer=tokenizer,
target_tokenizer=tokenizer,
) as trainer:
logger.info("Training Model")
trainer.train(progress=progress, check_canceled=check_canceled)
trainer.save()
parameters = trainer.parameters

with UnigramTruecaserTrainer(target_corpus, os.path.join(temp_dir, "truecase.txt")) as truecase_trainer:
with UnigramTruecaserTrainer(
corpus=target_corpus, model_path=os.path.join(temp_dir, "truecase.txt"), tokenizer=tokenizer
) as truecase_trainer:
logger.info("Training Truecaser")
truecase_trainer.train(progress=progress, check_canceled=check_canceled)
truecase_trainer.save()
Expand Down
17 changes: 17 additions & 0 deletions machine/tokenization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

from .detokenizer import Detokenizer
from .latin_sentence_tokenizer import LatinSentenceTokenizer
from .latin_word_detokenizer import LatinWordDetokenizer
Expand Down Expand Up @@ -34,3 +36,18 @@
"ZwspWordDetokenizer",
"ZwspWordTokenizer",
]

TOKENIZERS = ["latin", "whitespace", "zwsp"]


def get_tokenizer_detokenizer(name: str = "") -> Tuple[Tokenizer, Detokenizer]:
name_lower = name.lower()
if "latin" in name_lower or name == "":
return LatinWordTokenizer(), LatinWordDetokenizer()
if "whitespace" in name_lower:
return WhitespaceTokenizer(), WhitespaceDetokenizer()
if "zwsp" in name_lower:
return ZwspWordTokenizer(), ZwspWordDetokenizer()
raise ValueError(
f"Unknown tokenizer/detokenizer name: {name}. Available tokenizers are: {TOKENIZERS}. Defualt tokenizer is latin."
)
22 changes: 11 additions & 11 deletions machine/translation/unigram_truecaser_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Callable, Optional

from machine.tokenization.tokenizer import Tokenizer
Expand All @@ -12,9 +12,9 @@

@dataclass
class UnigramTruecaserTrainer(Trainer):
_corpus: TextCorpus
_model_path: str = ""
_new_truecaser: UnigramTruecaser = UnigramTruecaser()
corpus: TextCorpus
model_path: str = ""
new_truecaser: UnigramTruecaser = field(default_factory=UnigramTruecaser)
stats: TrainStats = TrainStats()
tokenizer: Tokenizer = WHITESPACE_TOKENIZER

Expand All @@ -25,26 +25,26 @@ def train(
) -> None:
step_count = 0
if progress is not None:
step_count = self._corpus.count(include_empty=False)
step_count = self.corpus.count(include_empty=False)
current_step = 0
for row in self._corpus.tokenize(tokenizer=self.tokenizer).filter_nonempty():
for row in self.corpus.tokenize(tokenizer=self.tokenizer).filter_nonempty():
if check_canceled is not None:
check_canceled()
self._new_truecaser.train_segment(row)
self.new_truecaser.train_segment(row)
current_step += 1
if progress is not None:
progress(ProgressStatus(current_step, step_count))
self.stats.train_corpus_size = current_step

def save(self) -> None:
if self._model_path != "":
self._new_truecaser.save(self._model_path)
if self.model_path != "":
self.new_truecaser.save(self.model_path)


class SubclassUnigramTruecaserTrainer(UnigramTruecaserTrainer):
_true_caser: UnigramTruecaser

def save(self):
self._true_caser._casing = self._new_truecaser._casing
self._true_caser._bestTokens = self._new_truecaser._bestTokens
self._true_caser._casing = self.new_truecaser._casing
self._true_caser._bestTokens = self.new_truecaser._bestTokens
self._true_caser.save()
6 changes: 3 additions & 3 deletions tests/translation/test_unigram_truecaser.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def test_compare_with_truecase_trainer() -> None:
"text1", [text_row("text1", i, " ".join(segment)) for i, segment in enumerate(training_segments)]
)
trainer = UnigramTruecaserTrainer(text)
trainer._new_truecaser = UnigramTruecaser()
trainer.new_truecaser = UnigramTruecaser()
trainer.train()

truecaser = create_truecaser()

assert trainer._new_truecaser._bestTokens == truecaser._bestTokens
assert trainer._new_truecaser._casing == truecaser._casing
assert trainer.new_truecaser._bestTokens == truecaser._bestTokens
assert trainer.new_truecaser._casing == truecaser._casing

0 comments on commit ff35d57

Please sign in to comment.