Skip to content

Commit

Permalink
ClearMLCallback enhancements: support multiple runs and handle loggin…
Browse files Browse the repository at this point in the history
…g better (#28559)

* add clearml tracker

* support multiple train runs

* remove bad code

* add UI entries for config/hparams overrides

* handle models in different tasks

* run ruff format

* tidy code based on code review

---------

Co-authored-by: Eugen Ajechiloae <[email protected]>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent 838d489 commit 2a49074
Showing 1 changed file with 153 additions and 21 deletions.
174 changes: 153 additions & 21 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import shutil
import sys
import tempfile
from dataclasses import asdict
from dataclasses import asdict, fields
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union

Expand Down Expand Up @@ -1438,6 +1438,24 @@ class ClearMLCallback(TrainerCallback):
Whether to log models as artifacts during training.
"""

log_suffix = ""

_hparams_section = "Transformers"
_model_config_section = "Model Configuration"
_ignore_hparams_overrides = "_ignore_hparams_ui_overrides_"
_ignoge_model_config_overrides = "_ignore_model_config_ui_overrides_"
_model_config_description = "The configuration of model number {}."
_model_config_description_note = (
"Note that, when cloning this task and running it remotely,"
" the configuration might be applied to another model instead of this one."
" To avoid this, initialize the task externally by calling `Task.init`"
" before the `ClearMLCallback` is instantiated."
)
_train_run_counter = 0
_model_connect_counter = 0
_task_created_in_callback = False
_should_close_on_train_end = None

def __init__(self):
if is_clearml_available():
import clearml
Expand All @@ -1447,25 +1465,38 @@ def __init__(self):
raise RuntimeError("ClearMLCallback requires 'clearml' to be installed. Run `pip install clearml`.")

self._initialized = False
self._initialized_externally = False
self._clearml_task = None

self._log_model = os.getenv("CLEARML_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
self._log_model = False
self._checkpoints_saved = []

def setup(self, args, state, model, tokenizer, **kwargs):
if self._clearml is None:
return
if self._initialized:
return
ClearMLCallback._train_run_counter += 1
ClearMLCallback._model_connect_counter += 1
ClearMLCallback.log_suffix = (
"" if ClearMLCallback._train_run_counter == 1 else "_" + str(ClearMLCallback._train_run_counter)
)
if state.is_world_process_zero:
logger.info("Automatic ClearML logging enabled.")
if self._clearml_task is None:
if ClearMLCallback._should_close_on_train_end is None:
if not self._clearml.Task.running_locally() or self._clearml.Task.current_task():
ClearMLCallback._should_close_on_train_end = False
else:
ClearMLCallback._should_close_on_train_end = True

# This might happen when running inside of a pipeline, where the task is already initialized
# from outside of Hugging Face
if self._clearml.Task.current_task():
if self._clearml.Task.running_locally() and self._clearml.Task.current_task():
self._clearml_task = self._clearml.Task.current_task()
self._initialized = True
self._initialized_externally = True
self._log_model = os.getenv(
"CLEARML_LOG_MODEL",
"FALSE" if not ClearMLCallback._task_created_in_callback else "TRUE",
).upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
logger.info("External ClearML Task has been connected.")
else:
self._clearml_task = self._clearml.Task.init(
Expand All @@ -1474,27 +1505,83 @@ def setup(self, args, state, model, tokenizer, **kwargs):
auto_connect_frameworks={"tensorboard": False, "pytorch": False},
output_uri=True,
)
self._initialized = True
self._log_model = os.getenv("CLEARML_LOG_MODEL", "TRUE").upper() in ENV_VARS_TRUE_VALUES.union(
{"TRUE"}
)
ClearMLCallback._task_created_in_callback = True
logger.info("ClearML Task has been initialized.")
self._initialized = True

suffixed_hparams_section = ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
ignore_hparams_config_section = suffixed_hparams_section + "/" + ClearMLCallback._ignore_hparams_overrides
if self._clearml.Task.running_locally():
self._copy_training_args_as_hparams(args, suffixed_hparams_section)
self._clearml_task.set_parameter(
name=ignore_hparams_config_section,
value=True,
value_type=bool,
description=(
"If True, ignore Transformers hyperparameters overrides done in the UI/backend "
+ "when running remotely. Otherwise, the overrides will be applied when running remotely"
),
)
elif not self._clearml_task.get_parameter(ignore_hparams_config_section, default=True, cast=True):
self._clearml_task.connect(args, suffixed_hparams_section)
else:
self._copy_training_args_as_hparams(
args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
)

self._clearml_task.connect(args, "Args")
if hasattr(model, "config") and model.config is not None:
self._clearml_task.connect(model.config, "Model Configuration")
if getattr(model, "config", None) is not None:
ignore_model_config_section = (
suffixed_hparams_section + "/" + ClearMLCallback._ignoge_model_config_overrides
)
configuration_object_description = ClearMLCallback._model_config_description.format(
ClearMLCallback._model_connect_counter
)
if ClearMLCallback._model_connect_counter != ClearMLCallback._train_run_counter:
configuration_object_description += " " + ClearMLCallback._model_config_description_note
if self._clearml.Task.running_locally():
self._clearml_task.set_parameter(
name=ignore_model_config_section,
value=True,
value_type=bool,
description=(
"If True, ignore Transformers model configuration overrides done in the UI/backend "
+ "when running remotely. Otherwise, the overrides will be applied when running remotely"
),
)
self._clearml_task.set_configuration_object(
name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
config_dict=model.config.to_dict(),
description=configuration_object_description,
)
elif not self._clearml_task.get_parameter(ignore_model_config_section, default=True, cast=True):
model.config = model.config.from_dict(
self._clearml_task.get_configuration_object_as_dict(
ClearMLCallback._model_config_section + ClearMLCallback.log_suffix
)
)
else:
self._clearml_task.set_configuration_object(
name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
config_dict=model.config.to_dict(),
description=configuration_object_description,
)

def on_train_begin(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._clearml is None:
return
self._checkpoints_saved = []
if state.is_hyper_param_search:
self._initialized = False
if not self._initialized:
self.setup(args, state, model, tokenizer, **kwargs)

def on_train_end(self, args, state, control, model=None, tokenizer=None, metrics=None, logs=None, **kwargs):
if self._clearml is None:
return
if self._clearml_task and state.is_world_process_zero and not self._initialized_externally:
# Close ClearML Task at the end end of training
def on_train_end(self, args, state, control, **kwargs):
if ClearMLCallback._should_close_on_train_end:
self._clearml_task.close()
ClearMLCallback._train_run_counter = 0

def on_log(self, args, state, control, model=None, tokenizer=None, logs=None, **kwargs):
if self._clearml is None:
Expand All @@ -1517,18 +1604,29 @@ def on_log(self, args, state, control, model=None, tokenizer=None, logs=None, **
for k, v in logs.items():
if isinstance(v, (int, float)):
if k in single_value_scalars:
self._clearml_task.get_logger().report_single_value(name=k, value=v)
self._clearml_task.get_logger().report_single_value(
name=k + ClearMLCallback.log_suffix, value=v
)
elif k.startswith(eval_prefix):
self._clearml_task.get_logger().report_scalar(
title=k[eval_prefix_len:], series="eval", value=v, iteration=state.global_step
title="eval" + ClearMLCallback.log_suffix,
series=k[eval_prefix_len:],
value=v,
iteration=state.global_step,
)
elif k.startswith(test_prefix):
self._clearml_task.get_logger().report_scalar(
title=k[test_prefix_len:], series="test", value=v, iteration=state.global_step
title="test" + ClearMLCallback.log_suffix,
series=k[test_prefix_len:],
value=v,
iteration=state.global_step,
)
else:
self._clearml_task.get_logger().report_scalar(
title=k, series="train", value=v, iteration=state.global_step
title="train" + ClearMLCallback.log_suffix,
series=k,
value=v,
iteration=state.global_step,
)
else:
logger.warning(
Expand All @@ -1542,8 +1640,42 @@ def on_save(self, args, state, control, **kwargs):
if self._log_model and self._clearml_task and state.is_world_process_zero:
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
self._clearml_task.update_output_model(artifact_path, iteration=state.global_step, auto_delete_file=False)
name = ckpt_dir + ClearMLCallback.log_suffix
logger.info(f"Logging checkpoint artifact `{name}`. This may take some time.")
output_model = self._clearml.OutputModel(task=self._clearml_task, name=name)
output_model.connect(task=self._clearml_task, name=name)
output_model.update_weights_package(
weights_path=artifact_path,
target_filename=ckpt_dir,
iteration=state.global_step,
auto_delete_file=False,
)
self._checkpoints_saved.append(output_model)
while args.save_total_limit and args.save_total_limit < len(self._checkpoints_saved):
try:
self._clearml.model.Model.remove(
self._checkpoints_saved[0],
delete_weights_file=True,
force=True,
raise_on_errors=True,
)
except Exception as e:
logger.warning(
"Could not remove checkpoint `{}` after going over the `save_total_limit`. Error is: {}".format(
self._checkpoints_saved[0].name, e
)
)
break
self._checkpoints_saved = self._checkpoints_saved[1:]

def _copy_training_args_as_hparams(self, training_args, prefix):
as_dict = {
field.name: getattr(training_args, field.name)
for field in fields(training_args)
if field.init and not field.name.endswith("_token")
}
flat_dict = {str(k): v for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(as_dict).items()}
self._clearml_task._arguments.copy_from_dict(flat_dict, prefix=prefix)


class FlyteCallback(TrainerCallback):
Expand Down

0 comments on commit 2a49074

Please sign in to comment.