Skip to content

Commit

Permalink
Fix mlflow model logging bug (#692)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Oct 25, 2023
1 parent d72902a commit bc687b7
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 16 deletions.
29 changes: 26 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import contextlib
import copy
import logging
import math
import os
import tempfile
from pathlib import Path
from typing import Optional, Union

import torch
from composer.core import Callback, Event, State, Time
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
Expand Down Expand Up @@ -83,6 +84,13 @@ def __init__(

self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)

if isinstance(save_interval, str):
save_interval = Time.from_timestring(save_interval)
if isinstance(save_interval, int):
save_interval = Time(save_interval, TimeUnit.EPOCH)

self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.upload_to_object_store = (self.backend != '')
Expand Down Expand Up @@ -128,6 +136,21 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'5GB')

def _is_last_batch(self, state: State):
elapsed_duration = state.get_elapsed_duration()
if elapsed_duration is not None and elapsed_duration >= 1.0:
return True

assert state.max_duration is not None # for pyright
# If the save interval is specified as 1dur, and the max duration is in epoch units
# we need a special case to identify we are on the last batch and should write the mlflow checkpoint
if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH:
assert state.dataloader_len is not None # for pyright
return int(state.timestamp.batch) % math.ceil(
state.max_duration.value * state.dataloader_len) == 0

return False

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

Expand Down Expand Up @@ -224,8 +247,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
overwrite=self.overwrite,
)

elapsed_duration = state.get_elapsed_duration()
if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:
if self.mlflow_registered_model_name and self._is_last_batch(
state):
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
Expand Down
29 changes: 16 additions & 13 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,25 +251,30 @@ def test_callback_inits_with_defaults():
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)])
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: Optional[str],
log_to_mlflow: bool):
log_to_mlflow: bool,
hf_save_interval: str,
save_interval: str, max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))

max_seq_len = 16
save_interval_batches = 2
huggingface_save_interval_batches = 3
device_batch_size = 1
dataset_size = 14
max_duration_batches = 7
precision_str = 'bfloat16'
precision = torch.bfloat16
batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2))

checkpointer_callback = HuggingFaceCheckpointer(
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{huggingface_save_interval_batches}ba',
save_interval=hf_save_interval,
precision=precision_str,
mlflow_registered_model_name='dummy-registered-name'
if log_to_mlflow else None,
Expand Down Expand Up @@ -405,8 +410,8 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{save_interval_batches}ba',
max_duration=f'{max_duration_batches}ba',
save_interval=save_interval,
max_duration=max_duration,
callbacks=[checkpointer_callback],
loggers=[mlflow_logger_mock] if log_to_mlflow else [],
optimizers=optimizer,
Expand Down Expand Up @@ -442,15 +447,13 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
name for name in os.listdir(
os.path.join(tmp_path, 'checkpoints', 'huggingface'))
]
assert len(normal_checkpoints) == math.ceil(max_duration_batches /
save_interval_batches)
assert len(huggingface_checkpoints) == math.ceil(
max_duration_batches / huggingface_save_interval_batches)
assert len(normal_checkpoints) == expected_normal_checkpoints
assert len(huggingface_checkpoints) == expected_hf_checkpoints

# Load the last huggingface checkpoint
loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{max_duration_batches}'),
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)

Expand All @@ -471,7 +474,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,

loaded_tokenizer = transformers.AutoTokenizer.from_pretrained(
os.path.join(tmp_path, 'checkpoints', 'huggingface',
f'ba{max_duration_batches}'),
f'ba{batches_per_epoch}'),
trust_remote_code=True,
)

Expand Down

0 comments on commit bc687b7

Please sign in to comment.