diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 75969210..c26a1b4c 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -1042,7 +1042,10 @@ def _execute_train_loop( training_loss_tracker = [] + step_count = 1 + for epoch in range(num_epochs): + step_loss_log = {} model.train() total_loss = 0 tqdm_loader = tqdm(train_dataloader, disable=silence_progress_bars) @@ -1060,6 +1063,8 @@ def _execute_train_loop( optimizer.step() lr_scheduler.step() optimizer.zero_grad() + step_loss_log[step_count] = loss + step_count += 1 except torch.cuda.OutOfMemoryError: error( "", @@ -1067,14 +1072,18 @@ def _execute_train_loop( ) log.info("", {"loss": float(loss), "epoch": epoch}) - # Below is added to be propagated and stored as training_metadata - training_loss_tracker.append( - { - "epoch": epoch, - "value": float(loss), - "timestamp": datetime.isoformat(datetime.now()), - } - ) + + for step, loss_val in step_loss_log.items(): + + # Below is added to be propagated and stored as training_metadata + training_loss_tracker.append( + { + "epoch": epoch, + "step": step, + "value": float(loss_val), + "timestamp": datetime.isoformat(datetime.now()), + } + ) if eval_dataloader is not None: model.eval() @@ -1132,7 +1141,6 @@ def _execute_train_loop( eval_epoch_loss, ) - error.value_check("", len(training_loss_tracker) == num_epochs) return {"loss": training_loss_tracker} @classmethod diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index 705bc39c..63807594 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -14,8 +14,9 @@ # Standard -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import gc +import json import os import tempfile @@ -52,6 +53,8 @@ error = error_handler.get(log) +TRAINING_LOSS_LOG_FILENAME = "training_logs.jsonl" + # pylint: disable=too-many-lines,too-many-instance-attributes @module( id="f9181353-4ccf-4572-bd1e-f12bcda26792", @@ -95,6 +98,7 @@ def __init__( sep_token: Optional[str] = None, eos_token: Optional[str] = None, pad_token: Optional[str] = None, + training_metadata: Union[Dict[str, Any], None] = None, ): super().__init__() @@ -106,6 +110,9 @@ def __init__( self._sep_token = sep_token self._eos_token = eos_token self._pad_token = pad_token + self.training_metadata = ( + training_metadata if training_metadata is not None else {} + ) # pylint: disable=duplicate-code def __del__(self): @@ -359,6 +366,8 @@ def train( "dataloader_pin_memory": False, "gradient_accumulation_steps": accumulate_steps, "gradient_checkpointing": True, + "logging_strategy": "steps", + "logging_steps": 1, # logging at every step # NOTE: This is explicitly set to false since it will # negatively impact the performance "full_determinism": False, @@ -395,11 +404,16 @@ def train( # NOTE: torch distributed can hang if run on CPUs, # to avoid that, specially for unit tests, we are only # running below when GPUs are available - torch.distributed.launcher.api.elastic_launch( + training_loss_history = torch.distributed.launcher.api.elastic_launch( launch_config, cls._launch_training )(base_model, training_dataset, training_args, checkpoint_dir) + + # NOTE: We are currently only storing the loss information from + # rank 0, i.e main process. training_loss_history is dictionary containing + # rank of the process as key + training_loss_history = training_loss_history[0] else: - cls._launch_training( + training_loss_history = cls._launch_training( base_model, training_dataset, training_args, checkpoint_dir ) @@ -423,6 +437,7 @@ def train( sep_token=model.tokenizer.sep_token or None, eos_token=model.tokenizer.eos_token or None, pad_token=model.tokenizer.pad_token or None, + training_metadata={"loss": training_loss_history}, ) @classmethod @@ -485,6 +500,24 @@ def save(self, model_path): base_model_dirname=artifacts_dir, ) + training_loss_filename = TRAINING_LOSS_LOG_FILENAME + + saver.update_config({"training_logs": training_loss_filename}) + + # We are currently only saving logs containing loss in jsonl format + if "loss" in self.training_metadata: + loss_log_lines = self.training_metadata.get("loss") + error.type_check("", list, loss_log_lines=loss_log_lines) + with open( + os.path.join(model_path, training_loss_filename), + "w", + encoding="utf-8", + ) as f: + for loss_log in loss_log_lines: + loss_log = {"name": "loss", "data": loss_log} + json.dump(loss_log, f) + f.write("\n") + def run( self, text: str, @@ -596,6 +629,10 @@ def _launch_training( # save tokenizer explicitly base_model.tokenizer.save_pretrained(checkpoint_dir) + # Below will return log history but launch will automatically attach rank to it. + # if started in distributed fashion + return trainer.state.log_history + @staticmethod def infer_max_steps( num_epochs: int, diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index 59048a2f..eba74744 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -15,7 +15,7 @@ # Standard from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import Callable, List, Optional, Tuple, Type, Union +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import json import os @@ -41,11 +41,29 @@ # Local from ...data_model import GenerationTrainRecord, PromptOutputModelType from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype +from ...toolkit.trainer_utils import log_step log = alog.use_channel("HFRBAS") error = error_handler.get(log) +class LoggingTrainer(Trainer): + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training. + + Subclass and override this method to inject custom behavior. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + self.state = log_step(self.state, logs) + self.control = self.callback_handler.on_log( + self.args, self.state, self.control, logs + ) + + class PretrainedModelBase(ABC, ModuleBase): """Common abstractions and requirements for pretrained model resources""" @@ -286,7 +304,7 @@ def get_trainer( "eval_dataset": eval_dataset, } - return Trainer(self._model, training_args, **trainer_arguments) + return LoggingTrainer(self._model, training_args, **trainer_arguments) def _get_data_collator(self, **kwargs): """Function to return appropriate data collator based on resource. diff --git a/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py b/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py index be736996..598b0136 100644 --- a/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py +++ b/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py @@ -16,7 +16,7 @@ """ # Standard from collections.abc import Mapping -from typing import List, Union +from typing import Dict, List, Union # Third Party from torch.utils.data import IterableDataset @@ -35,6 +35,7 @@ # Local from ...data_model import GenerationTrainRecord, PromptOutputModelType +from ...toolkit.trainer_utils import log_step from ...toolkit.verbalizer_utils import render_verbalizer from .base import PretrainedModelBase @@ -44,6 +45,23 @@ IGNORE_ID = -100 +class LoggingTrainer(Seq2SeqTrainer): + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training. + + Subclass and override this method to inject custom behavior. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + self.state = log_step(self.state, logs) + self.control = self.callback_handler.on_log( + self.args, self.state, self.control, logs + ) + + @module( id="6759e891-287b-405b-bd8b-54a4a4d51c25", name="HF Transformers Auto Seq2Seq LM", @@ -110,7 +128,7 @@ def get_trainer( # "generation_max_length": max_target_length, } - return Seq2SeqTrainer(self._model, training_args, **trainer_arguments) + return LoggingTrainer(self._model, training_args, **trainer_arguments) def _get_data_collator(self, **kwargs): """Function to return appropriate data collator based on resource. diff --git a/caikit_nlp/toolkit/trainer_utils.py b/caikit_nlp/toolkit/trainer_utils.py new file mode 100644 index 00000000..736e9d90 --- /dev/null +++ b/caikit_nlp/toolkit/trainer_utils.py @@ -0,0 +1,57 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains toolkit functionality for huggingface Trainer""" +# Standard +from datetime import datetime + +# Third Party +import torch + +# First Party +import alog + +log = alog.use_channel("TRNR_UTILS") + + +def log_step(state, logs): + if state.epoch is not None: + logs["epoch"] = round(state.epoch, 2) + + # Get Rank + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = 0 + + if "loss" in logs: + if state.epoch is not None: + logs["epoch"] = round(state.epoch, 2) + + log.debug( + "process rank: {} loss: {} step: {}".format( + rank, float(logs["loss"]), state.global_step + ) + ) + output = { + "epoch": float(logs["epoch"]), + "step": state.global_step, + "value": float(logs["loss"]), + "timestamp": datetime.isoformat(datetime.now()), + } + state.log_history.append(output) + else: + output = {**logs, **{"step": state.global_step}} + state.log_history.append(output) + + return state diff --git a/tests/modules/text_generation/test_peft_prompt_tuning.py b/tests/modules/text_generation/test_peft_prompt_tuning.py index 1d600305..36a10024 100644 --- a/tests/modules/text_generation/test_peft_prompt_tuning.py +++ b/tests/modules/text_generation/test_peft_prompt_tuning.py @@ -69,13 +69,13 @@ def test_save_log_loss_file(causal_lm_dummy_model): """Ensure saving a model saves the log loss file""" with tempfile.TemporaryDirectory() as model_dir: causal_lm_dummy_model.save(model_dir, save_base_model=False) - assert os.path.isfile( - os.path.join( - model_dir, - caikit_nlp.modules.text_generation.peft_prompt_tuning.TRAINING_LOSS_LOG_FILENAME, - ) + file_path = os.path.join( + model_dir, + caikit_nlp.modules.text_generation.peft_prompt_tuning.TRAINING_LOSS_LOG_FILENAME, ) + assert os.path.isfile(file_path) + def test_run_model(causal_lm_dummy_model): """Ensure that we can run a model and get the right type out."""