Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 29, 2023
1 parent 0d1add2 commit 67ee5bd
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,

mlflow_logger_mock = MagicMock(spec=MLFlowLogger)
mlflow_logger_mock.state_dict = lambda *args, **kwargs: {}
mlflow_logger_mock.log_model = MagicMock()
mlflow_logger_mock.save_model = MagicMock()
mlflow_logger_mock.register_model = MagicMock()
mlflow_logger_mock.model_registry_prefix = ''
trainer = Trainer(
model=original_model,
device='gpu',
Expand All @@ -348,10 +350,13 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
trainer.fit()

if dist.get_global_rank() == 0:
assert mlflow_logger_mock.log_model.call_count == (1 if log_to_mlflow
assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow
else 0)
assert mlflow_logger_mock.register_model.call_count == (1 if log_to_mlflow
else 0)
else:
assert mlflow_logger_mock.log_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0

# summon full params to check equivalence
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand Down

0 comments on commit 67ee5bd

Please sign in to comment.