Skip to content

Commit

Permalink
Revert initial implementation.
Browse files Browse the repository at this point in the history
Respond to reviewer comments.
  • Loading branch information
johnml1135 committed Nov 21, 2023
1 parent 9872305 commit db31d41
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 73 deletions.
4 changes: 2 additions & 2 deletions machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from ...corpora.text_corpus import TextCorpus
from ...translation.huggingface.hugging_face_nmt_engine import HuggingFaceNmtEngine
from ...translation.huggingface.hugging_face_nmt_model_trainer import HuggingFaceNmtModelTrainer
from ...translation.nmt_translation_engine import NmtTranslationEngine
from ...translation.null_trainer import NullTrainer
from ...translation.trainer import Trainer
from ...translation.translation_engine import TranslationEngine
from ..nmt_model_factory import NmtModelFactory
from ..shared_file_service import SharedFileService

Expand Down Expand Up @@ -70,7 +70,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
add_unk_trg_tokens=self._config.huggingface.tokenizer.add_unk_trg_tokens,
)

def create_engine(self, half_previous_batch_size=False) -> NmtTranslationEngine:
def create_engine(self, half_previous_batch_size=False) -> TranslationEngine:
if half_previous_batch_size:
self._config.huggingface.generate_params.batch_size = max(
self._config.huggingface.generate_params.batch_size // 2, 1
Expand Down
46 changes: 12 additions & 34 deletions machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Optional, Sequence

from ..corpora.corpora_utils import batch
from ..translation.nmt_translation_engine import NmtTranslationEngine
from ..translation.translation_engine import TranslationEngine
from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter
from ..utils.progress_status import ProgressStatus
from .nmt_model_factory import NmtModelFactory
Expand Down Expand Up @@ -81,48 +81,26 @@ def run(
inference_step_count = sum(1 for _ in src_pretranslations)
with ExitStack() as stack:
phase_progress = stack.enter_context(progress_reporter.start_next_phase())
engine = stack.enter_context(self._nmt_model_factory.create_engine())
src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations())
writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer())
current_inference_step = 0
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
batch_size = self._config["batch_size"]
translate_batch = TranslateBatch(stack, self._nmt_model_factory)
for pi_batch in batch(src_pretranslations, batch_size):
if check_canceled is not None:
check_canceled()
translate_batch.translate(pi_batch, writer)
_translate_batch(engine, pi_batch, writer)
current_inference_step += len(pi_batch)
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))


batch_divisor = 1


class TranslateBatch:
def __init__(self, stack: ExitStack, nmt_model_factory: NmtModelFactory):
self._stack = stack
self._nmt_model_factory = nmt_model_factory
self._engine: NmtTranslationEngine = self._stack.enter_context(self._nmt_model_factory.create_engine())

def translate(
self,
batch: Sequence[PretranslationInfo],
writer: PretranslationWriter,
) -> None:
while True:
source_segments = [pi["translation"] for pi in batch]
outer_batch_size = len(source_segments)
try:
for step in range(0, outer_batch_size, self._engine.get_batch_size()):
for i, result in enumerate(
self._engine.translate_batch(source_segments[step : step + self._engine.get_batch_size()])
):
batch[i + step]["translation"] = result.translation
for i in range(len(source_segments)):
writer.write(batch[i])
break
except Exception:
logger.info(f"Out of memory error, reducing batch size to {self._engine.get_batch_size() // 2}")
self._engine = self._stack.enter_context(
self._nmt_model_factory.create_engine(half_previous_batch_size=True)
)
def _translate_batch(
engine: TranslationEngine,
batch: Sequence[PretranslationInfo],
writer: PretranslationWriter,
) -> None:
source_segments = [pi["translation"] for pi in batch]
for i, result in enumerate(engine.translate_batch(source_segments)):
batch[i]["translation"] = result.translation
writer.write(batch[i])
4 changes: 2 additions & 2 deletions machine/jobs/nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from ..corpora.parallel_text_corpus import ParallelTextCorpus
from ..corpora.text_corpus import TextCorpus
from ..translation.nmt_translation_engine import NmtTranslationEngine
from ..translation.trainer import Trainer
from ..translation.translation_engine import TranslationEngine


class NmtModelFactory(ABC):
Expand All @@ -29,7 +29,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
...

@abstractmethod
def create_engine(self, half_previous_batch_size=False) -> NmtTranslationEngine:
def create_engine(self, half_previous_batch_size=False) -> TranslationEngine:
...

@abstractmethod
Expand Down
2 changes: 0 additions & 2 deletions machine/jobs/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,5 @@ staging:
max_steps: 10
huggingface:
parent_model_name: facebook/nllb-200-distilled-600M
train_params:
group_by_length: false
generate_params:
num_beams: 1
23 changes: 11 additions & 12 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ...annotations.range import Range
from ...utils.typeshed import StrPath
from ..nmt_translation_engine import NmtTranslationEngine
from ..translation_engine import TranslationEngine
from ..translation_result import TranslationResult
from ..translation_result_builder import TranslationResultBuilder
from ..translation_sources import TranslationSources
Expand All @@ -21,7 +21,7 @@
logger = logging.getLogger(__name__)


class HuggingFaceNmtEngine(NmtTranslationEngine):
class HuggingFaceNmtEngine(TranslationEngine):
def __init__(
self,
model: Union[PreTrainedModel, StrPath, str],
Expand Down Expand Up @@ -63,11 +63,7 @@ def __init__(
):
raise ValueError(f"'{tgt_lang}' is not a valid language code.")

batch_size = self._pipeline_kwargs.pop("batch_size")
if batch_size is not None:
self._batch_size = int(batch_size) # type: ignore[assignment]
else:
self._batch_size = 16
self._batch_size = int(self._pipeline_kwargs.pop("batch_size", 1))

# If not set, default to not backing off (1.0).
self._oom_batch_size_backoff_multiplier = self._pipeline_kwargs.pop("oom_batch_size_backoff_multiplier", 1.0)
Expand All @@ -88,9 +84,6 @@ def translate_n(self, n: int, segment: Union[str, Sequence[str]]) -> Sequence[Tr
def translate_batch(self, segments: Sequence[Union[str, Sequence[str]]]) -> Sequence[TranslationResult]:
return [results[0] for results in self.translate_n_batch(1, segments)]

def get_batch_size(self) -> int:
return self._batch_size

def translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
Expand All @@ -106,13 +99,19 @@ def translate_n_batch(
all_results.extend(self._try_translate_n_batch(n, segments[step : step + self._batch_size]))
return all_results
except Exception as e:
# The out or memory error is not inherited from
if self._oom_batch_size_backoff_multiplier >= 0.9999:
raise Exception(
"Likely an Out of Memory Error. Change oom_batch_size_backoff_multiplier to < 1 to gracefuly handle these type of errors."
"Likely an Out of Memory Error. Change oom_batch_size_backoff_multiplier "
+ "to < 1 to gracefuly handle these type of errors."
) from e
if self._batch_size == 1:
# Could it be another error?
raise e
self._batch_size = max(int(round(self._batch_size * self._oom_batch_size_backoff_multiplier)), 1)
logger.info(
f"Out of memory error caught, reducing batch size to {self._batch_size}. Remaking translation pipeline."
f"Out of memory error caught with message {e.args[0]}, reducing batch size to {self._batch_size}. "
+ "Remaking translation pipeline."
)
self._pipeline = _TranslationPipeline(
model=self._model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def train(
# Set seed before initializing model.
set_seed(self._training_args.seed)

logger.info("Initializing tokenizer.")
if isinstance(self._model, PreTrainedModel):
model = self._model
self._original_use_cache = model.config.use_cache
Expand All @@ -148,6 +147,8 @@ def train(
num_labels=0,
)
model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(self._model, config=config))

logger.info("Initializing tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True)

src_lang = self._src_lang
Expand Down Expand Up @@ -194,8 +195,8 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
logger.info(f"Added {len(missing_tokens)} tokens to the tokenizer: {missing_tokens}")
return AutoTokenizer.from_pretrained(str(tokenizer_dir), use_fast=True)

logger.info("Checking for missing tokens.")
if self._add_unk_src_tokens or self._add_unk_trg_tokens:
logger.info("Checking for missing tokens")
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
f"Tokenizer can not be updated from default configuration: \
Expand Down Expand Up @@ -235,8 +236,8 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):
tokenizer.lang_token_to_id[lang_code] = lang_id
tokenizer.id_to_lang_token[lang_id] = lang_code

logger.info("Add new language codes as tokens.")
if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS):
logger.info("Add new language codes as tokens")
if self._src_lang is not None:
add_lang_code_to_tokenizer(tokenizer, self._src_lang)
if self._tgt_lang is not None:
Expand Down Expand Up @@ -312,7 +313,7 @@ def preprocess_function(examples):
model_inputs["labels"] = labels["input_ids"]
return model_inputs

logger.info("Run tokenizer.")
logger.info("Run tokenizer")
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
Expand Down Expand Up @@ -343,7 +344,7 @@ def preprocess_function(examples):
],
)

logger.info("Train NMT model.")
logger.info("Train NMT model")
ckpt = None
if self._training_args.resume_from_checkpoint is not None:
ckpt = self._training_args.resume_from_checkpoint
Expand All @@ -357,7 +358,7 @@ def preprocess_function(examples):
self._metrics["train_samples"] = len(train_dataset)

self._trainer.log_metrics("train", self._metrics)
logger.info("Model training finished.")
logger.info("Model training finished")

def save(self) -> None:
if self._trainer is None:
Expand Down
12 changes: 0 additions & 12 deletions machine/translation/nmt_translation_engine.py

This file was deleted.

5 changes: 2 additions & 3 deletions tests/jobs/test_nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from machine.corpora import DictionaryTextCorpus
from machine.jobs import NmtEngineBuildJob, NmtModelFactory, PretranslationInfo, PretranslationWriter, SharedFileService
from machine.translation import Phrase, Trainer, TrainStats, TranslationResult, TranslationSources, WordAlignmentMatrix
from machine.translation.nmt_translation_engine import NmtTranslationEngine
from machine.translation.translation_engine import TranslationEngine
from machine.utils import CanceledError, ContextManagedGenerator


Expand Down Expand Up @@ -45,9 +45,8 @@ def __init__(self, decoy: Decoy) -> None:
stats.metrics["bleu"] = 30.0
decoy.when(self.model_trainer.stats).then_return(stats)

self.engine = decoy.mock(cls=NmtTranslationEngine)
self.engine = decoy.mock(cls=TranslationEngine)
decoy.when(self.engine.__enter__()).then_return(self.engine)
decoy.when(self.engine.get_batch_size()).then_return(16)
decoy.when(self.engine.translate_batch(matchers.Anything())).then_return(
[
TranslationResult(
Expand Down

0 comments on commit db31d41

Please sign in to comment.