Skip to content

Commit

Permalink
Add profiler support in llm foundry (#678)
Browse files Browse the repository at this point in the history
  • Loading branch information
j316chuck authored Oct 18, 2023
1 parent f11483f commit 92bd673
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 0 deletions.
30 changes: 30 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from composer import Trainer
from composer.core import Evaluator
from composer.core.callback import Callback
from composer.profiler import (JSONTraceHandler, Profiler, TraceHandler,
cyclic_schedule)
from composer.utils import dist, get_device, reproducibility
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
Expand Down Expand Up @@ -458,6 +460,33 @@ def main(cfg: DictConfig) -> Trainer:
for name, logger_cfg in logger_configs.items()
] if logger_configs else None

# Profiling
profiler: Optional[Profiler] = None
profiler_cfg: Optional[DictConfig] = pop_config(cfg,
'profiler',
must_exist=False,
convert=False,
default_value=None)
if profiler_cfg:
profiler_schedule_cfg: Dict = pop_config(profiler_cfg,
'schedule',
must_exist=True,
convert=True)
profiler_schedule = cyclic_schedule(**profiler_schedule_cfg)
# Only support json trace handler
profiler_trace_handlers: List[TraceHandler] = []
profiler_trace_cfg: Optional[Dict] = pop_config(profiler_cfg,
'json_trace_handler',
must_exist=False,
default_value=None,
convert=True)
if profiler_trace_cfg:
profiler_trace_handlers.append(
JSONTraceHandler(**profiler_trace_cfg))
profiler = Profiler(**profiler_cfg,
trace_handlers=profiler_trace_handlers,
schedule=profiler_schedule)

# Callbacks
callbacks: List[Callback] = [
build_callback(str(name), callback_cfg)
Expand Down Expand Up @@ -576,6 +605,7 @@ def main(cfg: DictConfig) -> Trainer:
autoresume=autoresume,
python_log_level=python_log_level,
dist_timeout=dist_timeout,
profiler=profiler,
)

print('Logging config')
Expand Down
119 changes: 119 additions & 0 deletions scripts/train/yamls/pretrain/mpt-small-cpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
data_local: ./my-copy-c4
data_remote: # If blank, files must be present in data_local
max_seq_len: 128
global_seed: 17

# Run Name
run_name: mpt_causal_lm_cpu # If left blank, will be read from env var $RUN_NAME

# Model
model:
name: mpt_causal_lm
init_device: cpu
d_model: 16
n_heads: 4
n_layers: 4
expansion_ratio: 5
max_seq_len: ${max_seq_len}
vocab_size: 50368
attn_config:
attn_impl: torch
loss_fn: torch_crossentropy

# Tokenizer
tokenizer:
name: EleutherAI/gpt-neox-20b
kwargs:
model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split: train
shuffle: true
max_seq_len: ${max_seq_len}
shuffle_seed: ${global_seed}
drop_last: true
num_workers: 2

eval_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split: val
shuffle: false
max_seq_len: ${max_seq_len}
shuffle_seed: ${global_seed}
drop_last: false
num_workers: 2

# Optimization
scheduler:
name: cosine_with_warmup
t_warmup: 100ba
alpha_f: 0.1

optimizer:
name: decoupled_adamw
lr: 6.0e-4
betas:
- 0.9
- 0.95
eps: 1.0e-08
weight_decay: 0.0

algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1.0

max_duration: 10ba
eval_interval: 5ba
eval_first: false
eval_subset_num_batches: 5
global_train_batch_size: 256
autoresume: false

# System
seed: ${global_seed}
device_eval_batch_size: 16
device_train_microbatch_size: 16
# device_train_microbatch_size: auto
precision: fp32

# FSDP
fsdp_config:
sharding_strategy: FULL_SHARD
mixed_precision: PURE
activation_checkpointing: false
activation_checkpointing_reentrant: false
activation_cpu_offload: false
limit_all_gathers: true
verbose: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
speed_monitor:
window_size: 10
lr_monitor: {}
memory_monitor: {}
runtime_estimator: {}

# Checkpoint to local filesystem or remote object store
save_overwrite: true
save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK
# save_interval: 500ba
# save_folder: ./{run_name}/checkpoints
# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints

# Load from local filesystem or remote object store
# load_path: ./gpt-125m/checkpoints/latest-rank{rank}.pt
# load_path: s3://my-bucket/my-folder/gpt-125m/checkpoints/latest-rank{rank}.pt

0 comments on commit 92bd673

Please sign in to comment.