Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 28, 2023
1 parent c755ee4 commit b55c71d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 13 deletions.
24 changes: 20 additions & 4 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class HuggingFaceCheckpointer(Callback):
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
log_to_mlflow (bool): Whether to log and register the checkpoint to MLFlow. Default is ``False``.
mlflow_task (str): The MLFlow task to log the checkpoint under. Only used if ``log_to_mlflow`` is ``True``. Default is ``text-generation``.
mlflow_metadata (Optional[dict]): The MLFlow metadata to log the checkpoint with. Only used if ``log_to_mlflow`` is ``True``. Default is ``None``.
uc_prefix: (Optional[str]): Prefix to use for the MLFlow registered model. If specified, the model will be logged to UC rather than
the workspace model registry. If specified, the prefix must be of the form ``{catalog}.{schema}``
The model will be registered at ``{catalog}.{schema}.{model name}``. Only used if ``log_to_mlflow`` is ``True``. Default is ``None``.
"""

def __init__(
Expand All @@ -53,6 +58,7 @@ def __init__(
log_to_mlflow: bool = False,
mlflow_task: str = 'text-generation',
mlflow_metadata: Optional[dict] = None,
uc_prefix: Optional[str] = None,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
Expand Down Expand Up @@ -85,6 +91,15 @@ def __init__(
self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

self.uc_prefix = uc_prefix
if self.log_to_mlflow and uc_prefix is not None:
split_prefix = uc_prefix.split('.')
if len(split_prefix) != 2:
raise ValueError(
f'`uc_prefix` must be of the form `{{catalog}}.{{schema}}`. Got {uc_prefix} instead.'
)


def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
if state.get_elapsed_duration() is not None and self.check_interval(
Expand Down Expand Up @@ -114,8 +129,8 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
)

import mlflow
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'5GB')
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set("5GB")
mlflow.set_registry_uri('databricks-uc')

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused
Expand Down Expand Up @@ -222,14 +237,15 @@ def _save_checkpoint(self, state: State, logger: Logger):
components['tokenizer'] = new_tokenizer_instance

log.debug('Logging Hugging Face model to MLFlow')
registered_model_name = f'{state.run_name}_{os.path.basename(save_dir)}'
registered_model_name_full = f'{self.uc_prefix}.{registered_model_name}' if self.uc_prefix is not None else registered_model_name
for mlflow_logger in self.mlflow_loggers:
mlflow_logger.log_model(
flavor='transformers',
transformers_model=components,
artifact_path=os.path.basename(save_dir),
task=self.mlflow_task,
registered_model_name=
f'{state.run_name}_{os.path.basename(save_dir)}',
registered_model_name=registered_model_name_full,
metadata=self.mlflow_metadata,
await_registration_for=None,
)
28 changes: 19 additions & 9 deletions scripts/train/yamls/pretrain/mpt-125m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ max_seq_len: 2048
global_seed: 17

# Run Name
run_name: # If left blank, will be read from env var $RUN_NAME
run_name: test-mlflow-register-3

# Model
model:
Expand All @@ -31,7 +31,7 @@ train_loader:
dataset:
local: ${data_local}
remote: ${data_remote}
split: train
split: train_small
shuffle: true
max_seq_len: ${max_seq_len}
shuffle_seed: ${global_seed}
Expand All @@ -43,7 +43,7 @@ eval_loader:
dataset:
local: ${data_local}
remote: ${data_remote}
split: val
split: val_small
shuffle: false
max_seq_len: ${max_seq_len}
shuffle_seed: ${global_seed}
Expand All @@ -70,16 +70,16 @@ algorithms:
clipping_type: norm
clipping_threshold: 1.0

max_duration: 4800ba # ~ 2.5B tokens
max_duration: 10ba # ~ 2.5B tokens
eval_interval: 500ba
eval_first: false
eval_subset_num_batches: -1
global_train_batch_size: 256
eval_subset_num_batches: 2
global_train_batch_size: 2

# System
seed: ${global_seed}
device_eval_batch_size: 16
device_train_microbatch_size: 16
device_eval_batch_size: 1
device_train_microbatch_size: 1
# device_train_microbatch_size: auto
precision: amp_bf16

Expand All @@ -104,14 +104,24 @@ callbacks:
lr_monitor: {}
memory_monitor: {}
runtime_estimator: {}
hf_checkpointer:
save_interval: 10ba
precision: bfloat16
save_folder: ./{run_name}/checkpoints
log_to_mlflow: true
uc_prefix: main.danielking

loggers:
mlflow:
experiment_name: /Users/[email protected]/mlflow-logging-test

# loggers:
# wandb: {}

# Checkpoint to local filesystem or remote object store
# save_interval: 500ba
# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
# save_folder: ./{run_name}/checkpoints
save_folder: ./{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints

# Load from local filesystem or remote object store
Expand Down

0 comments on commit b55c71d

Please sign in to comment.