From 48831b7d1189f16a1fca9d4dad723bf4e828c041 Mon Sep 17 00:00:00 2001 From: JacobLinCool Date: Tue, 5 Nov 2024 00:36:27 +0800 Subject: [PATCH] Add text support to the Trainer's TensorBoard integration (#34418) * feat: add text support to TensorBoardCallback * feat: ignore long strings in trainer progress * docs: add docstring for max_str_len * style: remove trailing whitespace --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- .../integrations/integration_utils.py | 2 ++ src/transformers/trainer_callback.py | 20 +++++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index a09116552c8e34..be9a4aff3c7e7f 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -697,6 +697,8 @@ def on_log(self, args, state, control, logs=None, **kwargs): for k, v in logs.items(): if isinstance(v, (int, float)): self.tb_writer.add_scalar(k, v, state.global_step) + elif isinstance(v, str): + self.tb_writer.add_text(k, v, state.global_step) else: logger.warning( "Trainer is attempting to log a value of " diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index ce9f2a26732c2e..cf9a83aa188a30 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -589,11 +589,21 @@ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: Tr class ProgressCallback(TrainerCallback): """ A [`TrainerCallback`] that displays the progress of training or evaluation. + You can modify `max_str_len` to control how long strings are truncated when logging. """ - def __init__(self): + def __init__(self, max_str_len: int = 100): + """ + Initialize the callback with optional max_str_len parameter to control string truncation length. + + Args: + max_str_len (`int`): + Maximum length of strings to display in logs. + Longer strings will be truncated with a message. + """ self.training_bar = None self.prediction_bar = None + self.max_str_len = max_str_len def on_train_begin(self, args, state, control, **kwargs): if state.is_world_process_zero: @@ -631,7 +641,13 @@ def on_log(self, args, state, control, logs=None, **kwargs): # but avoid doing any value pickling. shallow_logs = {} for k, v in logs.items(): - shallow_logs[k] = v + if isinstance(v, str) and len(v) > self.max_str_len: + shallow_logs[k] = ( + f"[String too long to display, length: {len(v)} > {self.max_str_len}. " + "Consider increasing `max_str_len` if needed.]" + ) + else: + shallow_logs[k] = v _ = shallow_logs.pop("total_flos", None) # round numbers so that it looks better in console if "epoch" in shallow_logs: