Skip to content

Commit

Permalink
Rework to add to huggingface directly
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Nov 20, 2023
1 parent f77b1c8 commit 9872305
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 14 deletions.
3 changes: 3 additions & 0 deletions machine/jobs/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ default:
max_steps: 20000
data_dir: ~/machine
batch_size: 1024
oom_batch_size_backoff_multiplier: 0.5
huggingface:
parent_model_name: facebook/nllb-200-distilled-1.3B
train_params:
Expand Down Expand Up @@ -33,5 +34,7 @@ staging:
max_steps: 10
huggingface:
parent_model_name: facebook/nllb-200-distilled-600M
train_params:
group_by_length: false
generate_params:
num_beams: 1
67 changes: 54 additions & 13 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import gc
import logging
from math import exp, prod
from typing import Any, Iterable, List, Sequence, Tuple, Union, cast

Expand All @@ -17,29 +18,35 @@
from ..translation_sources import TranslationSources
from ..word_alignment_matrix import WordAlignmentMatrix

logger = logging.getLogger(__name__)


class HuggingFaceNmtEngine(NmtTranslationEngine):
def __init__(
self,
model: Union[PreTrainedModel, StrPath, str],
**pipeline_kwargs,
) -> None:
if isinstance(model, PreTrainedModel):
model.eval()
self._model = model
self._pipeline_kwargs = pipeline_kwargs
if isinstance(self._model, PreTrainedModel):
self._model.eval()
else:
model_config = AutoConfig.from_pretrained(str(model), label2id={}, id2label={}, num_labels=0)
model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(model), config=model_config))
self._tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True)
model_config = AutoConfig.from_pretrained(str(self._model), label2id={}, id2label={}, num_labels=0)
self._model = cast(
PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(self._model), config=model_config)
)
self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True)

src_lang = pipeline_kwargs.get("src_lang")
tgt_lang = pipeline_kwargs.get("tgt_lang")
src_lang = self._pipeline_kwargs.get("src_lang")
tgt_lang = self._pipeline_kwargs.get("tgt_lang")
if (
src_lang is not None
and tgt_lang is not None
and "prefix" not in pipeline_kwargs
and (model.name_or_path.startswith("t5-") or model.name_or_path.startswith("google/mt5-"))
and "prefix" not in self._pipeline_kwargs
and (self._model.name_or_path.startswith("t5-") or self._model.name_or_path.startswith("google/mt5-"))
):
pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
self._pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
else:
additional_special_tokens = self._tokenizer.additional_special_tokens
if (
Expand All @@ -56,16 +63,20 @@ def __init__(
):
raise ValueError(f"'{tgt_lang}' is not a valid language code.")

batch_size = pipeline_kwargs.get("batch_size")
batch_size = self._pipeline_kwargs.pop("batch_size")
if batch_size is not None:
self._batch_size = int(batch_size) # type: ignore[assignment]
else:
self._batch_size = 16

# If not set, default to not backing off (1.0).
self._oom_batch_size_backoff_multiplier = self._pipeline_kwargs.pop("oom_batch_size_backoff_multiplier", 1.0)

self._pipeline = _TranslationPipeline(
model=model,
model=self._model,
tokenizer=self._tokenizer,
**pipeline_kwargs,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)

def translate(self, segment: Union[str, Sequence[str]]) -> TranslationResult:
Expand All @@ -82,6 +93,36 @@ def get_batch_size(self) -> int:

def translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
while True:
if type(segments) is str:
segments = [segments]
else:
segments = [segment for segment in segments]
outer_batch_size = len(segments)
all_results: List[Sequence[TranslationResult]] = []
try:
for step in range(0, outer_batch_size, self._batch_size):
all_results.extend(self._try_translate_n_batch(n, segments[step : step + self._batch_size]))
return all_results
except Exception as e:
if self._oom_batch_size_backoff_multiplier >= 0.9999:
raise Exception(
"Likely an Out of Memory Error. Change oom_batch_size_backoff_multiplier to < 1 to gracefuly handle these type of errors."
) from e
self._batch_size = max(int(round(self._batch_size * self._oom_batch_size_backoff_multiplier)), 1)
logger.info(
f"Out of memory error caught, reducing batch size to {self._batch_size}. Remaking translation pipeline."
)
self._pipeline = _TranslationPipeline(
model=self._model,
tokenizer=self._tokenizer,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)

def _try_translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
all_results: List[List[TranslationResult]] = []
i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def train(
# Set seed before initializing model.
set_seed(self._training_args.seed)

logger.info("Initializing tokenizer.")
if isinstance(self._model, PreTrainedModel):
model = self._model
self._original_use_cache = model.config.use_cache
Expand Down Expand Up @@ -193,6 +194,7 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
logger.info(f"Added {len(missing_tokens)} tokens to the tokenizer: {missing_tokens}")
return AutoTokenizer.from_pretrained(str(tokenizer_dir), use_fast=True)

logger.info("Checking for missing tokens.")
if self._add_unk_src_tokens or self._add_unk_trg_tokens:
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
Expand Down Expand Up @@ -233,6 +235,7 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):
tokenizer.lang_token_to_id[lang_code] = lang_id
tokenizer.id_to_lang_token[lang_id] = lang_code

logger.info("Add new language codes as tokens.")
if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS):
if self._src_lang is not None:
add_lang_code_to_tokenizer(tokenizer, self._src_lang)
Expand Down Expand Up @@ -309,6 +312,7 @@ def preprocess_function(examples):
model_inputs["labels"] = labels["input_ids"]
return model_inputs

logger.info("Run tokenizer.")
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
Expand Down Expand Up @@ -339,17 +343,21 @@ def preprocess_function(examples):
],
)

logger.info("Train NMT model.")
ckpt = None
if self._training_args.resume_from_checkpoint is not None:
ckpt = self._training_args.resume_from_checkpoint
elif last_checkpoint is not None:
ckpt = last_checkpoint
train_result = self._trainer.train(resume_from_checkpoint=ckpt)
train_result = self._trainer.train(
resume_from_checkpoint=ckpt,
)

self._metrics = train_result.metrics
self._metrics["train_samples"] = len(train_dataset)

self._trainer.log_metrics("train", self._metrics)
logger.info("Model training finished.")

def save(self) -> None:
if self._trainer is None:
Expand Down

0 comments on commit 9872305

Please sign in to comment.