Skip to content

Commit

Permalink
Merge pull request caikit#208 from gkumbhat/add_ft_loss_logging
Browse files Browse the repository at this point in the history
Add ft loss logging
  • Loading branch information
gkumbhat authored Sep 27, 2023
2 parents fcf3df9 + b1f4d81 commit ef283e4
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 21 deletions.
26 changes: 17 additions & 9 deletions caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,10 @@ def _execute_train_loop(

training_loss_tracker = []

step_count = 1

for epoch in range(num_epochs):
step_loss_log = {}
model.train()
total_loss = 0
tqdm_loader = tqdm(train_dataloader, disable=silence_progress_bars)
Expand All @@ -1060,21 +1063,27 @@ def _execute_train_loop(
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
step_loss_log[step_count] = loss
step_count += 1
except torch.cuda.OutOfMemoryError:
error(
"<NLP07175292E>",
MemoryError("Not enough memory available for training!"),
)

log.info("<NLP46114010I>", {"loss": float(loss), "epoch": epoch})
# Below is added to be propagated and stored as training_metadata
training_loss_tracker.append(
{
"epoch": epoch,
"value": float(loss),
"timestamp": datetime.isoformat(datetime.now()),
}
)

for step, loss_val in step_loss_log.items():

# Below is added to be propagated and stored as training_metadata
training_loss_tracker.append(
{
"epoch": epoch,
"step": step,
"value": float(loss_val),
"timestamp": datetime.isoformat(datetime.now()),
}
)

if eval_dataloader is not None:
model.eval()
Expand Down Expand Up @@ -1132,7 +1141,6 @@ def _execute_train_loop(
eval_epoch_loss,
)

error.value_check("<NLP66129758E>", len(training_loss_tracker) == num_epochs)
return {"loss": training_loss_tracker}

@classmethod
Expand Down
43 changes: 40 additions & 3 deletions caikit_nlp/modules/text_generation/text_generation_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@


# Standard
from typing import Optional, Union
from typing import Any, Dict, Optional, Union
import gc
import json
import os
import tempfile

Expand Down Expand Up @@ -52,6 +53,8 @@
error = error_handler.get(log)


TRAINING_LOSS_LOG_FILENAME = "training_logs.jsonl"

# pylint: disable=too-many-lines,too-many-instance-attributes
@module(
id="f9181353-4ccf-4572-bd1e-f12bcda26792",
Expand Down Expand Up @@ -95,6 +98,7 @@ def __init__(
sep_token: Optional[str] = None,
eos_token: Optional[str] = None,
pad_token: Optional[str] = None,
training_metadata: Union[Dict[str, Any], None] = None,
):
super().__init__()

Expand All @@ -106,6 +110,9 @@ def __init__(
self._sep_token = sep_token
self._eos_token = eos_token
self._pad_token = pad_token
self.training_metadata = (
training_metadata if training_metadata is not None else {}
)

# pylint: disable=duplicate-code
def __del__(self):
Expand Down Expand Up @@ -359,6 +366,8 @@ def train(
"dataloader_pin_memory": False,
"gradient_accumulation_steps": accumulate_steps,
"gradient_checkpointing": True,
"logging_strategy": "steps",
"logging_steps": 1, # logging at every step
# NOTE: This is explicitly set to false since it will
# negatively impact the performance
"full_determinism": False,
Expand Down Expand Up @@ -395,11 +404,16 @@ def train(
# NOTE: torch distributed can hang if run on CPUs,
# to avoid that, specially for unit tests, we are only
# running below when GPUs are available
torch.distributed.launcher.api.elastic_launch(
training_loss_history = torch.distributed.launcher.api.elastic_launch(
launch_config, cls._launch_training
)(base_model, training_dataset, training_args, checkpoint_dir)

# NOTE: We are currently only storing the loss information from
# rank 0, i.e main process. training_loss_history is dictionary containing
# rank of the process as key
training_loss_history = training_loss_history[0]
else:
cls._launch_training(
training_loss_history = cls._launch_training(
base_model, training_dataset, training_args, checkpoint_dir
)

Expand All @@ -423,6 +437,7 @@ def train(
sep_token=model.tokenizer.sep_token or None,
eos_token=model.tokenizer.eos_token or None,
pad_token=model.tokenizer.pad_token or None,
training_metadata={"loss": training_loss_history},
)

@classmethod
Expand Down Expand Up @@ -485,6 +500,24 @@ def save(self, model_path):
base_model_dirname=artifacts_dir,
)

training_loss_filename = TRAINING_LOSS_LOG_FILENAME

saver.update_config({"training_logs": training_loss_filename})

# We are currently only saving logs containing loss in jsonl format
if "loss" in self.training_metadata:
loss_log_lines = self.training_metadata.get("loss")
error.type_check("<NLP60269855E>", list, loss_log_lines=loss_log_lines)
with open(
os.path.join(model_path, training_loss_filename),
"w",
encoding="utf-8",
) as f:
for loss_log in loss_log_lines:
loss_log = {"name": "loss", "data": loss_log}
json.dump(loss_log, f)
f.write("\n")

def run(
self,
text: str,
Expand Down Expand Up @@ -596,6 +629,10 @@ def _launch_training(
# save tokenizer explicitly
base_model.tokenizer.save_pretrained(checkpoint_dir)

# Below will return log history but launch will automatically attach rank to it.
# if started in distributed fashion
return trainer.state.log_history

@staticmethod
def infer_max_steps(
num_epochs: int,
Expand Down
22 changes: 20 additions & 2 deletions caikit_nlp/resources/pretrained_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Standard
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Callable, List, Optional, Tuple, Type, Union
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import json
import os

Expand All @@ -41,11 +41,29 @@
# Local
from ...data_model import GenerationTrainRecord, PromptOutputModelType
from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype
from ...toolkit.trainer_utils import log_step

log = alog.use_channel("HFRBAS")
error = error_handler.get(log)


class LoggingTrainer(Trainer):
def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
self.state = log_step(self.state, logs)
self.control = self.callback_handler.on_log(
self.args, self.state, self.control, logs
)


class PretrainedModelBase(ABC, ModuleBase):
"""Common abstractions and requirements for pretrained model resources"""

Expand Down Expand Up @@ -286,7 +304,7 @@ def get_trainer(
"eval_dataset": eval_dataset,
}

return Trainer(self._model, training_args, **trainer_arguments)
return LoggingTrainer(self._model, training_args, **trainer_arguments)

def _get_data_collator(self, **kwargs):
"""Function to return appropriate data collator based on resource.
Expand Down
22 changes: 20 additions & 2 deletions caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
# Standard
from collections.abc import Mapping
from typing import List, Union
from typing import Dict, List, Union

# Third Party
from torch.utils.data import IterableDataset
Expand All @@ -35,6 +35,7 @@

# Local
from ...data_model import GenerationTrainRecord, PromptOutputModelType
from ...toolkit.trainer_utils import log_step
from ...toolkit.verbalizer_utils import render_verbalizer
from .base import PretrainedModelBase

Expand All @@ -44,6 +45,23 @@
IGNORE_ID = -100


class LoggingTrainer(Seq2SeqTrainer):
def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
self.state = log_step(self.state, logs)
self.control = self.callback_handler.on_log(
self.args, self.state, self.control, logs
)


@module(
id="6759e891-287b-405b-bd8b-54a4a4d51c25",
name="HF Transformers Auto Seq2Seq LM",
Expand Down Expand Up @@ -110,7 +128,7 @@ def get_trainer(
# "generation_max_length": max_target_length,
}

return Seq2SeqTrainer(self._model, training_args, **trainer_arguments)
return LoggingTrainer(self._model, training_args, **trainer_arguments)

def _get_data_collator(self, **kwargs):
"""Function to return appropriate data collator based on resource.
Expand Down
57 changes: 57 additions & 0 deletions caikit_nlp/toolkit/trainer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains toolkit functionality for huggingface Trainer"""
# Standard
from datetime import datetime

# Third Party
import torch

# First Party
import alog

log = alog.use_channel("TRNR_UTILS")


def log_step(state, logs):
if state.epoch is not None:
logs["epoch"] = round(state.epoch, 2)

# Get Rank
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
else:
rank = 0

if "loss" in logs:
if state.epoch is not None:
logs["epoch"] = round(state.epoch, 2)

log.debug(
"process rank: {} loss: {} step: {}".format(
rank, float(logs["loss"]), state.global_step
)
)
output = {
"epoch": float(logs["epoch"]),
"step": state.global_step,
"value": float(logs["loss"]),
"timestamp": datetime.isoformat(datetime.now()),
}
state.log_history.append(output)
else:
output = {**logs, **{"step": state.global_step}}
state.log_history.append(output)

return state
10 changes: 5 additions & 5 deletions tests/modules/text_generation/test_peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def test_save_log_loss_file(causal_lm_dummy_model):
"""Ensure saving a model saves the log loss file"""
with tempfile.TemporaryDirectory() as model_dir:
causal_lm_dummy_model.save(model_dir, save_base_model=False)
assert os.path.isfile(
os.path.join(
model_dir,
caikit_nlp.modules.text_generation.peft_prompt_tuning.TRAINING_LOSS_LOG_FILENAME,
)
file_path = os.path.join(
model_dir,
caikit_nlp.modules.text_generation.peft_prompt_tuning.TRAINING_LOSS_LOG_FILENAME,
)

assert os.path.isfile(file_path)


def test_run_model(causal_lm_dummy_model):
"""Ensure that we can run a model and get the right type out."""
Expand Down

0 comments on commit ef283e4

Please sign in to comment.