forked from AkariAsai/OpenScholar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
full_finetune_distributed.py
683 lines (566 loc) · 27.5 KB
/
full_finetune_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
import time
from functools import partial
from typing import Any, Dict, Optional, Tuple, Union
from warnings import warn
import torch
from omegaconf import DictConfig, ListConfig
from torch import nn
from torch.distributed import init_process_group
from torch.distributed.fsdp import (
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
StateDictType,
)
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.utils import DummyProfiler, PROFILER_KEY
from torchtune.utils.activations import apply_selective_activation_checkpointing
from tqdm import tqdm
log = utils.get_logger("DEBUG")
class FullFinetuneRecipeDistributed(FTRecipeInterface):
"""
Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports
distributed training and can be run on a single node (1 to 8 GPUs).
Features:
- FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU
is not supported.
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
activations in memory and instead recompute them during the backward pass. This is especially
helpful for larger batch sizes when you're memory constrained. But these savings in memory
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
loss in model quality (will depend on the model, training data and other settings). For
GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
precision are currently not supported.
- Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
controlled using the ``gradient_accumulation_steps`` flag.
Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
total batch size of 64.
Gradient accumulation is especially useful when you are memory constrained. In this case,
accumulating gradients might give you better training speed than enabling activation
checkpointing.
- Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are
only saved at the end of a given epoch and used in case of resuming training.
Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
currently not supported.
For more details on the checkpointer, please take a look at
our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html).
- Logging. Terminal, Disk, WandB and TensorBoard are all supported.
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.
Args:
cfg (DictConfig): OmegaConf object parsed from yaml file
Raises:
ValueError: If ``dtype`` is set to fp16.
"""
def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(cfg.dtype, device=self._device)
if self._dtype == torch.float16:
raise ValueError(
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)
if (
cfg.get("fsdp_cpu_offload", False)
and cfg.optimizer.get("fused", False)
and not utils.torch_version_ge("2.4.0")
):
raise RuntimeError(
"Using fused optimizer on CPU is only supported in PyTorch nightly."
)
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
# _is_rank_zero is used primarily for logging. In the future, the logger
# should directly take care of this
_, rank = utils.get_world_size_and_rank()
self._is_rank_zero = rank == 0
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = utils.set_seed(seed=cfg.seed)
self.epochs_run = 0
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. If resume_from_checkpoint
is True, this also includes the recipe state.
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()
if self._resume_from_checkpoint:
self._update_recipe_state(checkpoint_dict)
return checkpoint_dict
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
"""
Updates the recipe state from checkpoint.
"""
try:
self.epochs_run = ckpt_dict[utils.EPOCHS_KEY]
# on mismatch, warn the user and prevent the override
if self.seed != ckpt_dict[utils.SEED_KEY]:
warn(
message=(
"Config value for seed does not match the checkpoint value, "
f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}"
)
)
self.seed = ckpt_dict[utils.SEED_KEY]
if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]:
warn(
message=(
"Config value for max_steps_per_epoch does not match the checkpoint value, "
f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}"
)
)
self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY]
# on mismatch, warn the user but allow the override
if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]:
warn(
message=(
"Config value for total_epochs does not match the checkpoint value, "
f"using the config value: {self.total_epochs}"
)
)
except KeyError as e:
raise KeyError(
"Checkpoint does not contain the required keys needed for updating recipe state. "
"Are you sure you passed in the right recipe checkpoint?"
) from e
def setup(self, cfg: DictConfig) -> None:
"""
Sets up the recipe state correctly. This includes setting recipe attributes based
on the ``resume_from_checkpoint`` flag.
"""
if self._is_rank_zero:
self._metric_logger = config.instantiate(cfg.metric_logger)
# log config with parameter override
self._metric_logger.log_config(cfg)
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
# ``_setup_model`` handles initialization and loading the state dict. This method
# should be called before ``_setup_optimizer`` since transforming the optimizer
# state dict requires the model
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
model_state_dict=ckpt_dict[utils.MODEL_KEY],
ac_mode=cfg.get("ac_mode", None),
ac_option=cfg.get("ac_option", None),
)
self._tokenizer = config.instantiate(cfg.tokenizer)
# _setup_optimizer should take in ckpt_dict only if training is resumed from
# checkpoint. Transforming the opt state dict is handled by this method
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
opt_state_dict=ckpt_dict[utils.OPT_KEY]
if self._resume_from_checkpoint
else None,
)
self._loss_fn = config.instantiate(cfg.loss)
# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after both of these are initialized
self._sampler, self._dataloader = self._setup_data(
cfg_dataset=cfg.dataset,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
)
# Finally update the recipe state which can only be correctly set after all of the
# other components have been initialized and updated.
#
# Number of training steps in each epoch depends on the number of batches produced
# by the dataloader, the max_steps_per_epoch param set by the user and the
# gradient_accumulation_steps param. This value is used for logging and tracking
# training state. The computation should happen after the dataloader has been setup
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
def _setup_profiler(
self, cfg_profiler: Optional[DictConfig] = None
) -> Union[torch.profiler.profile, DummyProfiler]:
"""
Parses the `profiler` section of top-level `cfg` and sets up profiler
Args:
cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
`recipe.main`). Default None.
Returns:
profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
that the instrumented training loop does not need to be changed profiling is disabled.
The profiler config can be provided in configs under the `profiler` key with the following layout:
.. code-block:: yaml
profiler:
enabled: bool
#Output directory of trace artifacts
output_dir: str
#`torch.profiler.ProfilerActivity` types to trace
cpu: bool
cuda: bool
#Trace options
profile_memory: bool
with_stack: bool
record_shapes: bool
with_flops: bool
# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: int
warmup_steps: int
active_steps: int
num_cycles: int
"""
# Missing profiler section in config, assume disabled
if cfg_profiler is None:
cfg_profiler = DictConfig({"enabled": False})
# Check that component is included and set correctly
if cfg_profiler.get("_component_", None) is None:
cfg_profiler["_component_"] = "torchtune.utils.setup_torch_profiler"
else:
assert (
cfg_profiler.get("_component_")
== "torchtune.utils.setup_torch_profiler"
), "Only torch profiler supported currently: component must be `torchtune.utils.setup_torch_profiler`"
profiler, profiler_cfg = config.instantiate(cfg_profiler)
if self._is_rank_zero:
log.info(f" Profiler config after instantiation: {profiler_cfg}")
return profiler
def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
memory_efficient_fsdp_wrap: bool,
fsdp_cpu_offload: bool,
model_state_dict: Dict[str, Any],
ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
) -> nn.Module:
"""
Model initialization has some important considerations:
a. To minimize GPU peak memory, we load the model on CPU with the right
dtype. To ensure that we don't instantiate ``world_size`` number of models,
we initialize on meta_device for all ranks other than rank 0.
b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the
model weights from checkpoint.
c. While wrapping the model with FSDP, we set ``sync_module_states``
to TRUE and broadcast module params and buffers from rank 0.
d. The ``device_id`` param ensures that the FSDP initialization happens on
the correct device.
"""
if self._is_rank_zero:
log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...")
init_start = time.perf_counter()
with utils.set_default_dtype(self._dtype):
model = config.instantiate(cfg_model)
log.info(
f"Model instantiation took {time.perf_counter() - init_start:.2f} secs"
)
# Load both the model weights. This should happen only on Rank 0
model.load_state_dict(model_state_dict)
else:
# For non-zero ranks, load the model on meta device
with utils.set_default_dtype(self._dtype), torch.device("meta"):
model = config.instantiate(cfg_model)
if self._dtype == torch.bfloat16:
model = model.to(torch.bfloat16)
# We currently have two versions of activation checkpointing in this recipe
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
# the older version of AC and this behavior is unchanged
# ac_mode and ac_option together control selective AC. This is only enabled
# when these are set AND ``enable_activation_checkpointing`` is set to False
# We'll clean this up as soon as testing of AC is complete
ac_mode = ac_mode
ac_option = ac_option
if (not enable_activation_checkpointing) and (ac_mode is not None):
apply_selective_activation_checkpointing(
model,
ac_mode,
ac_option,
)
# Wrap the model with FSDP. This will ensure that the model is sharded
# across all available GPUs.
model = FSDP(
module=model,
auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy(
memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap,
modules_to_wrap={modules.TransformerDecoderLayer},
),
cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload),
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
device_id=self._device,
# this recipe does not currently support mixed precision training
mixed_precision=None,
# Ensure we broadcast params and buffers from rank 0
sync_module_states=True,
# Initialize empty modules on all non-zero ranks
param_init_fn=(
lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if not self._is_rank_zero
else None
),
)
# Ensure no params and buffers are on meta device
utils.validate_no_params_on_meta_device(model)
# original activation checkpointing (full) - flip the condition above
if enable_activation_checkpointing and ac_mode is None:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
if self._is_rank_zero:
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)
# synchronize before training begins
torch.distributed.barrier()
return model
def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
) -> Optimizer:
"""
Set up the optimizer. This method also handles transforing the state dict
for FSDP.
"""
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
opt_state_dict = FSDP.optim_state_dict_to_load(
self._model, optimizer, opt_state_dict
)
optimizer.load_state_dict(opt_state_dict)
if self._is_rank_zero:
log.info("Optimizer is initialized.")
return optimizer
def _setup_data(
self,
cfg_dataset: DictConfig,
shuffle: bool,
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
"""
All data related setup happens here. Currently this recipe only supports the
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)
sampler = DistributedSampler(
ds,
num_replicas=world_size,
rank=rank,
shuffle=shuffle,
seed=0,
)
dataloader = DataLoader(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=partial(
utils.padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None,
)
if self._is_rank_zero:
log.info("Dataset and Sampler are initialized.")
return sampler, dataloader
def save_checkpoint(self, epoch: int) -> None:
"""
Save state dict to file. The recipe save_checkpoint method is responsible for
correctly creating the checkpoint dict and passing to the checkpointer.
"""
checkpoint_dict = {}
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
with FSDP.state_dict_type(
self._model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
cpu_state_dict = self._model.state_dict()
opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer)
# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:
checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict})
# if training is in-progress, checkpoint the optimizer state as well
if epoch + 1 < self.total_epochs:
checkpoint_dict.update(
{
utils.OPT_KEY: opt_state_dict,
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
self._checkpointer.save_checkpoint(
checkpoint_dict,
epoch=epoch,
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
)
def train(self) -> None:
"""
The core training loop. Supports training on subsets of the dataset using the
``max_steps_per_epoch``.
"""
# clean up before training begins
utils.cleanup_before_training()
_, rank = utils.get_world_size_and_rank()
# zero out the gradients before starting training
self._optimizer.zero_grad()
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
self._profiler.start()
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)
pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
):
break
# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]
# Get the attention mask and position ids from the dataset if they
# exist. Currently, only sample packing in PackedDataset returns these
mask = batch.get("mask", None) # shape [b, s, s]
input_pos = batch.get("input_pos", None) # shape [b, s]
tokens = tokens.to(self._device)
num_tokens += tokens.numel()
labels = labels.to(self._device)
mask = mask.to(self._device) if mask is not None else None
input_pos = (
input_pos.to(self._device) if input_pos is not None else None
)
logits = self._model(tokens, mask=mask, input_pos=input_pos)
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
# Update the number of steps when the weights are updated
self.global_step += 1
loss_to_log = running_loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
# Log per-step metrics
if (
self.global_step % self._log_every_n_steps == 0
and self._is_rank_zero
):
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
if self._log_peak_memory_stats:
log_dict.update(utils.get_memory_stats(device=self._device))
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
)
# Reset running stats for the next step
running_loss = 0
num_tokens = 0
t0 = time.perf_counter()
# Step profiler
# Note that this is called within gradient accumulation block, hence
# will include multiple forward / backward passes if gradient accumulation > 1
self._profiler.step()
self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
self._profiler.stop()
def cleanup(self) -> None:
if self._is_rank_zero:
self._metric_logger.close()
torch.distributed.destroy_process_group()
@config.parse
def recipe_main(cfg: DictConfig) -> None:
"""
Entry point for the recipe.
Configurable parameters are read in the following order:
- Parameters specified in config (see available configs through ``tune ls``)
- Overwritten by arguments from the command-line
"""
if not utils.is_distributed():
raise RuntimeError(
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
utils.set_torch_num_threads()
config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)
recipe = FullFinetuneRecipeDistributed(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.train()
recipe.cleanup()
if __name__ == "__main__":
sys.exit(recipe_main())