diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index b5426935d9..3e8ff05179 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -68,6 +68,7 @@ def __init__( self._flush_interval = flush_interval if self._enabled: self.tracking_uri = str(tracking_uri or mlflow.get_tracking_uri()) + mlflow.set_tracking_uri(self.tracking_uri) # Set up MLflow state self._run_id = None if self.experiment_name is None: @@ -112,6 +113,7 @@ def init(self, state: State, logger: Logger) -> None: run_name=self.run_name, ) self._run_id = new_run.info.run_id + mlflow.start_run(run_id=self._run_id) def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None: if self._enabled: @@ -150,6 +152,23 @@ def log_hyperparameters(self, hyperparameters: Dict[str, Any]): ) self._optimized_mlflow_client.flush(synchronous=False) + def log_model(self, flavor: str, **kwargs): + """Log a model to MLFlow. + + Args: + flavor (str): The MLFlow model flavor to use. Currently only ``'transformers'`` is supported. + **kwargs: Keyword arguments to pass to the MLFlow model logging function. + + Raises: + NotImplementedError: If ``flavor`` is not ``'transformers'``. + """ + if self._enabled: + import mlflow + if flavor == 'transformers': + mlflow.transformers.log_model(**kwargs,) + else: + raise NotImplementedError(f'flavor {flavor} not supported.') + def log_images( self, images: Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]], @@ -176,10 +195,13 @@ def log_images( def post_close(self): if self._enabled: + import mlflow + # We use MlflowClient for run termination because MlflowAutologgingQueueingClient's # run termination relies on scheduling Python futures, which is not supported within # the Python atexit handler in which post_close() is called self._mlflow_client.set_terminated(self._run_id) + mlflow.end_run() def _flush(self): """Test-only method to synchronously flush all queued metrics.""" diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index 9ba79a42f6..ff2b49ff55 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -18,9 +18,11 @@ from tests.common.datasets import RandomImageDataset from tests.common.markers import device from tests.common.models import SimpleConvModel +from tests.models.test_hf_model import check_hf_model_equivalence, check_hf_tokenizer_equivalence def _get_latest_mlflow_run(experiment_name, tracking_uri=None): + pytest.importorskip('mlflow') from mlflow import MlflowClient # NB: Convert tracking URI to string because MlflowClient doesn't support non-string @@ -43,9 +45,12 @@ def test_mlflow_experiment_init_unspecified(monkeypatch): This mocks the mlflow library to check that the correct calls are made to set up the experiment """ - import mlflow + mlflow = pytest.importorskip('mlflow') from mlflow import MlflowClient + monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) + monkeypatch.setattr(mlflow, 'start_run', MagicMock()) + mock_state = MagicMock() mock_state.run_name = 'dummy-run-name' @@ -63,13 +68,17 @@ def test_mlflow_experiment_init_unspecified(monkeypatch): ).info.run_name == unspecified.run_name) -def test_mlflow_experiment_init_specified(): +def test_mlflow_experiment_init_specified(monkeypatch): """ Test that MLFlow experiment is set up correctly when all parameters are specified This mocks the mlflow library to check that the correct calls are made to set up the experiment """ + mlflow = pytest.importorskip('mlflow') from mlflow import MlflowClient + monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) + monkeypatch.setattr(mlflow, 'start_run', MagicMock()) + mock_state = MagicMock() mock_state.run_name = 'dummy-run-name' # Not used @@ -103,7 +112,7 @@ def test_mlflow_experiment_init_ids(monkeypatch): This mocks the mlflow library to check that the correct calls are made to set up the experiment """ - import mlflow + mlflow = pytest.importorskip('mlflow') monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) monkeypatch.setattr(mlflow, 'set_experiment', MagicMock()) @@ -123,7 +132,7 @@ def test_mlflow_experiment_init_ids(monkeypatch): assert id_logger.run_name == 'dummy-run-name' # Defaults are set, but we don't use them assert id_logger.experiment_name == 'my-mlflow-experiment' - assert mlflow.set_tracking_uri.call_count == 0 + assert mlflow.set_tracking_uri.call_count == 1 # We call this once in the init assert mlflow.set_experiment.called_with(experiment_id=mlflow_exp_id) assert mlflow.start_run.called_with(run_id=mlflow_run_id) @@ -133,7 +142,7 @@ def test_mlflow_experiment_init_experiment_name(monkeypatch): This mocks the mlflow library to check that the correct calls are made to set up the experiment """ - import mlflow + mlflow = pytest.importorskip('mlflow') monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) monkeypatch.setattr(mlflow, 'set_experiment', MagicMock()) @@ -151,6 +160,8 @@ def test_mlflow_experiment_init_experiment_name(monkeypatch): assert id_logger.experiment_name == exp_name assert mlflow.set_experiment.called_with(experiment_name=exp_name) + id_logger.post_close() + def test_mlflow_experiment_set_up(tmp_path): """ Test that MLFlow experiment is set up correctly within mlflow @@ -207,6 +218,8 @@ def test_mlflow_experiment_set_up(tmp_path): def test_mlflow_log_table(tmp_path): + pytest.importorskip('mlflow') + mlflow_uri = tmp_path / Path('my-test-mlflow-uri') mlflow_exp_name = 'test-log-table-exp-name' test_mlflow_logger = MLFlowLogger( @@ -245,8 +258,52 @@ def test_mlflow_log_table(tmp_path): assert table['data'] == rows +@pytest.mark.filterwarnings("ignore:.*The 'transformers' MLflow Models integration.*:FutureWarning") +def test_mlflow_log_model(tmp_path, tiny_gpt2_model, tiny_gpt2_tokenizer): + mlflow = pytest.importorskip('mlflow') + + mlflow_uri = tmp_path / Path('my-test-mlflow-uri') + mlflow_exp_name = 'test-log-model-exp-name' + test_mlflow_logger = MLFlowLogger( + tracking_uri=mlflow_uri, + experiment_name=mlflow_exp_name, + ) + + mock_state = MagicMock() + mock_state.run_name = 'dummy-run-name' # this run name should be unused. + mock_logger = MagicMock() + + test_mlflow_logger.init(state=mock_state, logger=mock_logger) + test_mlflow_logger.log_model( + flavor='transformers', + transformers_model={ + 'model': tiny_gpt2_model, + 'tokenizer': tiny_gpt2_tokenizer, + }, + artifact_path='my_model', + metadata={'task': 'llm/v1/completions'}, + task='text-generation', + ) + test_mlflow_logger._flush() + test_mlflow_logger.post_close() + + run = _get_latest_mlflow_run(mlflow_exp_name, tracking_uri=mlflow_uri) + run_info = run.info + run_id = run_info.run_id + experiment_id = run_info.experiment_id + run_file_path = mlflow_uri / Path(experiment_id) / Path(run_id) + + model_directory = run_file_path / Path('artifacts') / Path('my_model') + loaded_model = mlflow.transformers.load_model(model_directory, return_type='components') + + check_hf_model_equivalence(loaded_model['model'], tiny_gpt2_model) + check_hf_tokenizer_equivalence(loaded_model['tokenizer'], tiny_gpt2_tokenizer) + + @device('cpu') def test_mlflow_logging_works(tmp_path, device): + pytest.importorskip('mlflow') + mlflow_uri = tmp_path / Path('my-test-mlflow-uri') experiment_name = 'mlflow_logging_test' test_mlflow_logger = MLFlowLogger( @@ -268,6 +325,7 @@ def test_mlflow_logging_works(tmp_path, device): device=device) trainer.fit() test_mlflow_logger._flush() + test_mlflow_logger.post_close() run = _get_latest_mlflow_run( experiment_name=experiment_name, @@ -300,6 +358,7 @@ def test_mlflow_logging_works(tmp_path, device): @device('cpu') def test_mlflow_log_image_works(tmp_path, device): + pytest.importorskip('mlflow') class ImageLogger(Callback): @@ -335,6 +394,7 @@ def before_forward(self, state: State, logger: Logger): trainer.fit() test_mlflow_logger._flush() + test_mlflow_logger.post_close() run = _get_latest_mlflow_run( experiment_name=experiment_name, diff --git a/tests/models/test_hf_model.py b/tests/models/test_hf_model.py index 2ca6e83bae..00e851227c 100644 --- a/tests/models/test_hf_model.py +++ b/tests/models/test_hf_model.py @@ -265,6 +265,10 @@ def check_hf_tokenizer_equivalence(tokenizer1, tokenizer2): def check_hf_model_equivalence(model1, model2): expected_model_config_dict = model1.config.to_dict() new_model_config_dict = model2.config.to_dict() + + # _name_or_path is different depending on where the model was loaded from, so don't compare it + expected_model_config_dict.pop('_name_or_path') + new_model_config_dict.pop('_name_or_path') assert expected_model_config_dict == new_model_config_dict assert sum(p.numel() for p in model1.parameters()) == sum(p.numel() for p in model2.parameters()) assert all(type(module1) == type(module2) for module1, module2 in zip(model1.modules(), model2.modules()))