Skip to content

Commit

Permalink
Add log_model to MLFlowLogger (mosaicml#2541)
Browse files Browse the repository at this point in the history
* tie module run name to client run name
  • Loading branch information
dakinggg authored Sep 26, 2023
1 parent d3d3a4e commit 5cb02db
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 5 deletions.
22 changes: 22 additions & 0 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
self._flush_interval = flush_interval
if self._enabled:
self.tracking_uri = str(tracking_uri or mlflow.get_tracking_uri())
mlflow.set_tracking_uri(self.tracking_uri)
# Set up MLflow state
self._run_id = None
if self.experiment_name is None:
Expand Down Expand Up @@ -112,6 +113,7 @@ def init(self, state: State, logger: Logger) -> None:
run_name=self.run_name,
)
self._run_id = new_run.info.run_id
mlflow.start_run(run_id=self._run_id)

def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None:
if self._enabled:
Expand Down Expand Up @@ -150,6 +152,23 @@ def log_hyperparameters(self, hyperparameters: Dict[str, Any]):
)
self._optimized_mlflow_client.flush(synchronous=False)

def log_model(self, flavor: str, **kwargs):
"""Log a model to MLFlow.
Args:
flavor (str): The MLFlow model flavor to use. Currently only ``'transformers'`` is supported.
**kwargs: Keyword arguments to pass to the MLFlow model logging function.
Raises:
NotImplementedError: If ``flavor`` is not ``'transformers'``.
"""
if self._enabled:
import mlflow
if flavor == 'transformers':
mlflow.transformers.log_model(**kwargs,)
else:
raise NotImplementedError(f'flavor {flavor} not supported.')

def log_images(
self,
images: Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]],
Expand All @@ -176,10 +195,13 @@ def log_images(

def post_close(self):
if self._enabled:
import mlflow

# We use MlflowClient for run termination because MlflowAutologgingQueueingClient's
# run termination relies on scheduling Python futures, which is not supported within
# the Python atexit handler in which post_close() is called
self._mlflow_client.set_terminated(self._run_id)
mlflow.end_run()

def _flush(self):
"""Test-only method to synchronously flush all queued metrics."""
Expand Down
70 changes: 65 additions & 5 deletions tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from tests.common.datasets import RandomImageDataset
from tests.common.markers import device
from tests.common.models import SimpleConvModel
from tests.models.test_hf_model import check_hf_model_equivalence, check_hf_tokenizer_equivalence


def _get_latest_mlflow_run(experiment_name, tracking_uri=None):
pytest.importorskip('mlflow')
from mlflow import MlflowClient

# NB: Convert tracking URI to string because MlflowClient doesn't support non-string
Expand All @@ -43,9 +45,12 @@ def test_mlflow_experiment_init_unspecified(monkeypatch):
This mocks the mlflow library to check that the correct calls are made to set up the experiment
"""
import mlflow
mlflow = pytest.importorskip('mlflow')
from mlflow import MlflowClient

monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'start_run', MagicMock())

mock_state = MagicMock()
mock_state.run_name = 'dummy-run-name'

Expand All @@ -63,13 +68,17 @@ def test_mlflow_experiment_init_unspecified(monkeypatch):
).info.run_name == unspecified.run_name)


def test_mlflow_experiment_init_specified():
def test_mlflow_experiment_init_specified(monkeypatch):
""" Test that MLFlow experiment is set up correctly when all parameters are specified
This mocks the mlflow library to check that the correct calls are made to set up the experiment
"""
mlflow = pytest.importorskip('mlflow')
from mlflow import MlflowClient

monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'start_run', MagicMock())

mock_state = MagicMock()
mock_state.run_name = 'dummy-run-name' # Not used

Expand Down Expand Up @@ -103,7 +112,7 @@ def test_mlflow_experiment_init_ids(monkeypatch):
This mocks the mlflow library to check that the correct calls are made to set up the experiment
"""
import mlflow
mlflow = pytest.importorskip('mlflow')

monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'set_experiment', MagicMock())
Expand All @@ -123,7 +132,7 @@ def test_mlflow_experiment_init_ids(monkeypatch):

assert id_logger.run_name == 'dummy-run-name' # Defaults are set, but we don't use them
assert id_logger.experiment_name == 'my-mlflow-experiment'
assert mlflow.set_tracking_uri.call_count == 0
assert mlflow.set_tracking_uri.call_count == 1 # We call this once in the init
assert mlflow.set_experiment.called_with(experiment_id=mlflow_exp_id)
assert mlflow.start_run.called_with(run_id=mlflow_run_id)

Expand All @@ -133,7 +142,7 @@ def test_mlflow_experiment_init_experiment_name(monkeypatch):
This mocks the mlflow library to check that the correct calls are made to set up the experiment
"""
import mlflow
mlflow = pytest.importorskip('mlflow')

monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'set_experiment', MagicMock())
Expand All @@ -151,6 +160,8 @@ def test_mlflow_experiment_init_experiment_name(monkeypatch):
assert id_logger.experiment_name == exp_name
assert mlflow.set_experiment.called_with(experiment_name=exp_name)

id_logger.post_close()


def test_mlflow_experiment_set_up(tmp_path):
""" Test that MLFlow experiment is set up correctly within mlflow
Expand Down Expand Up @@ -207,6 +218,8 @@ def test_mlflow_experiment_set_up(tmp_path):


def test_mlflow_log_table(tmp_path):
pytest.importorskip('mlflow')

mlflow_uri = tmp_path / Path('my-test-mlflow-uri')
mlflow_exp_name = 'test-log-table-exp-name'
test_mlflow_logger = MLFlowLogger(
Expand Down Expand Up @@ -245,8 +258,52 @@ def test_mlflow_log_table(tmp_path):
assert table['data'] == rows


@pytest.mark.filterwarnings("ignore:.*The 'transformers' MLflow Models integration.*:FutureWarning")
def test_mlflow_log_model(tmp_path, tiny_gpt2_model, tiny_gpt2_tokenizer):
mlflow = pytest.importorskip('mlflow')

mlflow_uri = tmp_path / Path('my-test-mlflow-uri')
mlflow_exp_name = 'test-log-model-exp-name'
test_mlflow_logger = MLFlowLogger(
tracking_uri=mlflow_uri,
experiment_name=mlflow_exp_name,
)

mock_state = MagicMock()
mock_state.run_name = 'dummy-run-name' # this run name should be unused.
mock_logger = MagicMock()

test_mlflow_logger.init(state=mock_state, logger=mock_logger)
test_mlflow_logger.log_model(
flavor='transformers',
transformers_model={
'model': tiny_gpt2_model,
'tokenizer': tiny_gpt2_tokenizer,
},
artifact_path='my_model',
metadata={'task': 'llm/v1/completions'},
task='text-generation',
)
test_mlflow_logger._flush()
test_mlflow_logger.post_close()

run = _get_latest_mlflow_run(mlflow_exp_name, tracking_uri=mlflow_uri)
run_info = run.info
run_id = run_info.run_id
experiment_id = run_info.experiment_id
run_file_path = mlflow_uri / Path(experiment_id) / Path(run_id)

model_directory = run_file_path / Path('artifacts') / Path('my_model')
loaded_model = mlflow.transformers.load_model(model_directory, return_type='components')

check_hf_model_equivalence(loaded_model['model'], tiny_gpt2_model)
check_hf_tokenizer_equivalence(loaded_model['tokenizer'], tiny_gpt2_tokenizer)


@device('cpu')
def test_mlflow_logging_works(tmp_path, device):
pytest.importorskip('mlflow')

mlflow_uri = tmp_path / Path('my-test-mlflow-uri')
experiment_name = 'mlflow_logging_test'
test_mlflow_logger = MLFlowLogger(
Expand All @@ -268,6 +325,7 @@ def test_mlflow_logging_works(tmp_path, device):
device=device)
trainer.fit()
test_mlflow_logger._flush()
test_mlflow_logger.post_close()

run = _get_latest_mlflow_run(
experiment_name=experiment_name,
Expand Down Expand Up @@ -300,6 +358,7 @@ def test_mlflow_logging_works(tmp_path, device):

@device('cpu')
def test_mlflow_log_image_works(tmp_path, device):
pytest.importorskip('mlflow')

class ImageLogger(Callback):

Expand Down Expand Up @@ -335,6 +394,7 @@ def before_forward(self, state: State, logger: Logger):

trainer.fit()
test_mlflow_logger._flush()
test_mlflow_logger.post_close()

run = _get_latest_mlflow_run(
experiment_name=experiment_name,
Expand Down
4 changes: 4 additions & 0 deletions tests/models/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ def check_hf_tokenizer_equivalence(tokenizer1, tokenizer2):
def check_hf_model_equivalence(model1, model2):
expected_model_config_dict = model1.config.to_dict()
new_model_config_dict = model2.config.to_dict()

# _name_or_path is different depending on where the model was loaded from, so don't compare it
expected_model_config_dict.pop('_name_or_path')
new_model_config_dict.pop('_name_or_path')
assert expected_model_config_dict == new_model_config_dict
assert sum(p.numel() for p in model1.parameters()) == sum(p.numel() for p in model2.parameters())
assert all(type(module1) == type(module2) for module1, module2 in zip(model1.modules(), model2.modules()))
Expand Down

0 comments on commit 5cb02db

Please sign in to comment.