Skip to content

Commit

Permalink
Add precision arg for pretraining (#1353)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
rasbt and carmocca authored Apr 25, 2024
1 parent 43c4432 commit b9ddd8b
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 2 deletions.
3 changes: 3 additions & 0 deletions config_hub/pretrain/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ model_config:
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
out_dir: out/pretrain/debug

# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed

# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:
Expand Down
3 changes: 3 additions & 0 deletions config_hub/pretrain/tinyllama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ model_config:
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
out_dir: out/pretrain/tiny-llama

# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed

# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:
Expand Down
3 changes: 3 additions & 0 deletions config_hub/pretrain/tinystories.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ model_config:
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
out_dir: out/pretrain/stories15M

# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed

# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:
Expand Down
9 changes: 7 additions & 2 deletions litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
choose_logger,
chunked_cross_entropy,
copy_config_files,
get_default_supported_precision,
init_out_dir,
num_parameters,
parse_devices,
Expand All @@ -42,6 +43,7 @@ def setup(
model_name: Optional[str] = None,
model_config: Optional[Config] = None,
out_dir: Path = Path("out/pretrain"),
precision: Literal["bf16-true", "bf16-mixed", "32-true", None] = None,
initial_checkpoint_dir: Optional[Path] = None,
resume: Union[bool, Path] = False,
data: Optional[DataModule] = None,
Expand Down Expand Up @@ -75,6 +77,7 @@ def setup(
``model_config``.
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
/teamspace/jobs/<job-name>/share.
precision: The precision to use for finetuning. Determines a compatible precision setting by default.
initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from.
Useful for continued pretraining. Mutually exclusive with ``resume``.
resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
Expand All @@ -96,6 +99,7 @@ def setup(
available_models = "\n".join(sorted(name_to_config))
raise ValueError(f"Please specify --model_name <model_name>. Available values:\n{available_models}")
config = Config.from_name(model_name) if model_config is None else model_config
precision = precision or get_default_supported_precision(training=True)
devices = parse_devices(devices)
out_dir = init_out_dir(out_dir)
# in case the dataset requires the Tokenizer
Expand All @@ -109,7 +113,7 @@ def setup(
strategy = FSDPStrategy(auto_wrap_policy={Block}, state_dict_type="full", sharding_strategy="HYBRID_SHARD")
else:
strategy = "auto"
fabric = L.Fabric(devices=devices, strategy=strategy, precision="bf16-mixed", loggers=[logger])
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger])
fabric.launch()

fabric.print(pprint.pformat(hparams))
Expand Down Expand Up @@ -169,12 +173,13 @@ def main(

model = torch.compile(model)
model = fabric.setup(model)

optimizer = torch.optim.AdamW(
model.parameters(),
lr=train.learning_rate,
weight_decay=train.weight_decay,
betas=(train.beta1, train.beta2),
fused=True,
fused=fabric.device.type == "cuda",
)
optimizer = fabric.setup_optimizers(optimizer)

Expand Down

0 comments on commit b9ddd8b

Please sign in to comment.