diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 0c5f953aa7..16b1d3e8c9 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -234,6 +234,8 @@ def _save_checkpoint(self, state: State, logger: Logger): log.debug(f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{registered_model_name}') local_save_path = str( Path(temp_save_dir) / f'mlflow_save_{i}') + import mlflow + mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: "" mlflow_logger.save_model( flavor='transformers', transformers_model=components,