Skip to content

Commit

Permalink
Feature: Add MLFLOW_MAX_LOG_PARAMS to MLflowCallback (huggingface…
Browse files Browse the repository at this point in the history
  • Loading branch information
cecheta authored and BernardZach committed Dec 6, 2024
1 parent a790e3a commit e1a6b5f
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,13 +1218,16 @@ def setup(self, args, state, model):
and other parameters are ignored.
- **MLFLOW_FLATTEN_PARAMS** (`str`, *optional*, defaults to `False`):
Whether to flatten the parameters dictionary before logging.
- **MLFLOW_MAX_LOG_PARAMS** (`int`, *optional*):
Set the maximum number of parameters to log in the run.
"""
self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._tracking_uri = os.getenv("MLFLOW_TRACKING_URI", None)
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
self._max_log_params = os.getenv("MLFLOW_MAX_LOG_PARAMS", None)

# "synchronous" flag is only available with mlflow version >= 2.8.0
# https://github.com/mlflow/mlflow/pull/9705
Expand Down Expand Up @@ -1273,6 +1276,13 @@ def setup(self, args, state, model):
del combined_dict[name]
# MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items())
if self._max_log_params and self._max_log_params.isdigit():
max_log_params = int(self._max_log_params)
if max_log_params < len(combined_dict_items):
logger.debug(
f"Reducing the number of parameters to log from {len(combined_dict_items)} to {max_log_params}."
)
combined_dict_items = combined_dict_items[:max_log_params]
for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
if self._async_log:
self._ml_flow.log_params(
Expand Down

0 comments on commit e1a6b5f

Please sign in to comment.