Skip to content

Commit

Permalink
Add text support to the Trainer's TensorBoard integration (#34418)
Browse files Browse the repository at this point in the history
* feat: add text support to TensorBoardCallback

* feat: ignore long strings in trainer progress

* docs: add docstring for max_str_len

* style: remove trailing whitespace

---------

Co-authored-by: Marc Sun <[email protected]>
  • Loading branch information
JacobLinCool and SunMarc authored Nov 4, 2024
1 parent 34927b0 commit 48831b7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,8 @@ def on_log(self, args, state, control, logs=None, **kwargs):
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
elif isinstance(v, str):
self.tb_writer.add_text(k, v, state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
Expand Down
20 changes: 18 additions & 2 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,11 +589,21 @@ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: Tr
class ProgressCallback(TrainerCallback):
"""
A [`TrainerCallback`] that displays the progress of training or evaluation.
You can modify `max_str_len` to control how long strings are truncated when logging.
"""

def __init__(self):
def __init__(self, max_str_len: int = 100):
"""
Initialize the callback with optional max_str_len parameter to control string truncation length.
Args:
max_str_len (`int`):
Maximum length of strings to display in logs.
Longer strings will be truncated with a message.
"""
self.training_bar = None
self.prediction_bar = None
self.max_str_len = max_str_len

def on_train_begin(self, args, state, control, **kwargs):
if state.is_world_process_zero:
Expand Down Expand Up @@ -631,7 +641,13 @@ def on_log(self, args, state, control, logs=None, **kwargs):
# but avoid doing any value pickling.
shallow_logs = {}
for k, v in logs.items():
shallow_logs[k] = v
if isinstance(v, str) and len(v) > self.max_str_len:
shallow_logs[k] = (
f"[String too long to display, length: {len(v)} > {self.max_str_len}. "
"Consider increasing `max_str_len` if needed.]"
)
else:
shallow_logs[k] = v
_ = shallow_logs.pop("total_flos", None)
# round numbers so that it looks better in console
if "epoch" in shallow_logs:
Expand Down

0 comments on commit 48831b7

Please sign in to comment.