diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 4f7cf3632fe549..a09116552c8e34 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -1218,6 +1218,8 @@ def setup(self, args, state, model): and other parameters are ignored. - **MLFLOW_FLATTEN_PARAMS** (`str`, *optional*, defaults to `False`): Whether to flatten the parameters dictionary before logging. + - **MLFLOW_MAX_LOG_PARAMS** (`int`, *optional*): + Set the maximum number of parameters to log in the run. """ self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES @@ -1225,6 +1227,7 @@ def setup(self, args, state, model): self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None) self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES self._run_id = os.getenv("MLFLOW_RUN_ID", None) + self._max_log_params = os.getenv("MLFLOW_MAX_LOG_PARAMS", None) # "synchronous" flag is only available with mlflow version >= 2.8.0 # https://github.com/mlflow/mlflow/pull/9705 @@ -1273,6 +1276,13 @@ def setup(self, args, state, model): del combined_dict[name] # MLflow cannot log more than 100 values in one go, so we have to split it combined_dict_items = list(combined_dict.items()) + if self._max_log_params and self._max_log_params.isdigit(): + max_log_params = int(self._max_log_params) + if max_log_params < len(combined_dict_items): + logger.debug( + f"Reducing the number of parameters to log from {len(combined_dict_items)} to {max_log_params}." + ) + combined_dict_items = combined_dict_items[:max_log_params] for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH): if self._async_log: self._ml_flow.log_params(