Skip to content

Commit

Permalink
add support for CPU and MPS
Browse files Browse the repository at this point in the history
do not use distributed when not available, instead use CPU or MPS.

This entails a few changes:

--device is now a valid flag to the library since `ilab` can pass CPU, MPS, or default to cuda
when using CPU or MPS, do not initialize DS, instead put the model on the device and initialize `Adafactor` optimizer which is more efficient and than Adam based one
inside of `train` add logic for handling if torch.cuda.is_available and torch.distributed.is_initialized() we dont use distributed torch on consumer systems
the train loop needs some custom step and loss logic for a LlamaForCausalLM model, add that in
when using CPU or MPS we are always world_size == 1 and local_rank == 0

Signed-off-by: Charlie Doern <[email protected]>
  • Loading branch information
cdoern committed Aug 28, 2024
1 parent 0de1e36 commit ef8bbe5
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 74 deletions.
208 changes: 150 additions & 58 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@

# pylint: disable=no-name-in-module
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from torch import nn
from torch.distributed import ReduceOp, all_reduce
from tqdm import tqdm
from transformers import AutoModelForCausalLM, get_scheduler
from transformers import (
Adafactor,
AutoModelForCausalLM,
LlamaForCausalLM,
get_scheduler,
)
import deepspeed
import torch
import torch.distributed

# First Party
from instructlab.training import config
Expand Down Expand Up @@ -83,7 +90,7 @@ def get_ds_config(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption
return ds_config


def setup_model(args, tokenizer, train_loader, grad_accum):
def setup_model(args, tokenizer, train_loader, grad_accum, device):
bnb_config = None
if args.lora_r > 0 and args.lora_quant_bits == 4:
# Third Party
Expand Down Expand Up @@ -250,25 +257,35 @@ def make_inputs_require_grad(module, input, output):
)

# pylint: disable=unbalanced-tuple-unpacking
model, _, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
config=get_ds_config(
world_size=torch.distributed.get_world_size(),
samples_per_gpu=args.samples_per_gpu,
grad_accum=grad_accum,
opts=DeepSpeedOptions(
cpu_offload_optimizer=args.cpu_offload_optimizer,
cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio,
cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory,
save_samples=args.save_samples_ds,
optimizer = None
if device.type == "cuda":
model, _, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
config=get_ds_config(
world_size=torch.distributed.get_world_size(),
samples_per_gpu=args.samples_per_gpu,
grad_accum=grad_accum,
opts=DeepSpeedOptions(
cpu_offload_optimizer=args.cpu_offload_optimizer,
cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio,
cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory,
save_samples=args.save_samples_ds,
),
),
),
lr_scheduler=lr_scheduler,
dist_init_required=True,
)
# model = torch.compile(model)
return model
lr_scheduler=lr_scheduler,
dist_init_required=True,
)
else:
# If we are using CPU or MPS just place model on that device
# also, initialize Adafactor, a Transformers Optimizer designed to use less resources.
# if we use AdamW here most people will always run out of RAM
model = model.to(device)
optimizer = Adafactor(
model.parameters(), lr=1e-5, scale_parameter=True, relative_step=False
)
model.gradient_checkpointing_enable()
return model, optimizer


# this function is to check if the checkpoint provided can be resumed
Expand Down Expand Up @@ -331,7 +348,9 @@ def maybe_resume_training(args, model):
return model


def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
def train(
args, model, tokenizer, train_loader, grad_accum, metric_logger, device, optimizer
):
model.train()

global_step = 1
Expand Down Expand Up @@ -359,7 +378,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
)

for epoch in range(args.num_epochs):
torch.distributed.barrier()
if torch.cuda.is_available():
torch.distributed.barrier()
if args.sampler in ("multipack"):
train_loader.batch_sampler.set_epoch(epoch)
elif args.sampler in ("distributed"):
Expand All @@ -370,7 +390,12 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
if local_rank == 0:
inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}")

aggregated_values = torch.zeros(3, dtype=torch.float32).to(local_rank)
if not torch.cuda.is_available():
aggregated_values = torch.zeros(3, dtype=torch.float32, device=device).to(
device=device
)
else:
aggregated_values = torch.zeros(3, dtype=torch.float16).to(local_rank)
for batch in train_loader:
if global_step <= args.last_step:
# in the case of resuming, last_step > 0
Expand All @@ -384,7 +409,10 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
aggregated_values[1] = len(batch["input_ids"])
if not args.is_granite:
for k in batch:
batch[k] = batch[k].to(local_rank)
if torch.cuda.is_available():
batch[k] = batch[k].to(local_rank)
else:
batch[k] = batch[k].to(device="cpu")

output = model(
**batch,
Expand All @@ -394,7 +422,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):

aggregated_values[2] = loss.item()

all_reduce(aggregated_values, op=ReduceOp.SUM)
if torch.cuda.is_available() and torch.distributed.is_initialized():
all_reduce(aggregated_values, op=ReduceOp.SUM)

num_loss_counted_tokens = aggregated_values[0]
loss = (
Expand All @@ -404,32 +433,65 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
print(
f"\033[93mPer-token loss scaled by world size: {(loss/num_loss_counted_tokens) * world_size}\033[0m"
)
print(
f"Epoch: {epoch}, Step: {global_step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
)

model.backward(loss)
model.step()
if torch.cuda.is_available():
rank = torch.distributed.get_rank()
else:
rank = 0
print(f"Epoch: {epoch}, Step: {global_step}, Rank: {rank}, loss = {loss}")

# If using a LlamaForCausalLM model (single device CPU, GPU, or MPS) then we cannot use the DS .backward, .step from the model_engine
# instead, use the AdaFactor Optimizer's zero_grad, the loss.backward() and step the optimizer itself.
if torch.cuda.is_available():
model.backward(loss)
model.step()
else:
optimizer.zero_grad()
loss.backward()
optimizer.step()

if local_rank == 0:
elapsed_time = time.time() - start
overall_throughput = args.samples_per_gpu * world_size / elapsed_time
current_lr = model.lr_scheduler.get_last_lr()[0]
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
global_grad_norm = model.get_global_grad_norm()
cuda_malloc_retries = 0
cuda_mem_allocated = 0
if torch.cuda.is_available():
cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
norm = None
if not isinstance(model, LlamaForCausalLM):
global_grad_norm = model.get_global_grad_norm()
norm = model.optimizer.single_partition_of_fp32_groups[0].norm()
current_lr = model.lr_scheduler.get_last_lr()[0]
else:
global_grad_norm = nn.utils.clip_grad_norm_(
model.parameters(), max_norm=float("inf")
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=10, gamma=0.1
)
fp32_params = [
param.data
for param in model.parameters()
if param.requires_grad
]
norm = torch.norm(fp32_params[0])
# for name, param in model.named_parameters():
# if param.requires_grad:
# fp32_weights = param.data
# fp32_norm = torch.norm(fp32_weights)
# print(f"Norm of {name}: {fp32_norm.item()}")
current_lr = lr_scheduler.get_last_lr()[0]
global_grad_norm = (
float(global_grad_norm) if global_grad_norm is not None else None
)
weight_norm = float(
model.optimizer.single_partition_of_fp32_groups[0].norm()
)

weight_norm = float(norm)

metric_logger.log_sync(
{
"epoch": epoch,
"step": global_step,
"rank": torch.distributed.get_rank(),
"rank": rank,
"loss": loss.item(),
"overall_throughput": overall_throughput,
"lr": current_lr,
Expand Down Expand Up @@ -470,7 +532,8 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger):
global_step += 1
if local_rank == 0:
inner_pb.update(1)
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()

if args.checkpoint_at_epoch:
save_hf_format_ds(
Expand Down Expand Up @@ -507,13 +570,28 @@ def main(args):
# device = torch.device("cuda", args.local_rank)

#### distributed init #####
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
args.local_rank = int(os.environ["LOCAL_RANK"])
deepspeed.init_distributed(timeout=timedelta(minutes=30))
args.global_rank = torch.distributed.get_rank()
tensor = torch.ByteTensor([False]).cuda()
torch.distributed.all_reduce(tensor)
torch.distributed.barrier()
world_size = 1
device = None
if not torch.cuda.is_available():
if (
args.device == "mps"
and torch.backends.mps.is_available()
and torch.backend.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
args.local_rank = 0
args.global_rank = 0
elif torch.distributed.is_available():
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
args.local_rank = int(os.environ["LOCAL_RANK"])
deepspeed.init_distributed(timeout=timedelta(minutes=10))
args.global_rank = torch.distributed.get_rank()
tensor = torch.ByteTensor([False]).cuda()
torch.distributed.all_reduce(tensor)
torch.distributed.barrier()

dataset = setup_dataset(
args.data_path,
Expand All @@ -523,7 +601,7 @@ def main(args):

try:
packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum(
num_gpus=torch.distributed.get_world_size(),
num_gpus=world_size,
avg_sample_len=dataset.get_lengths().mean(),
effective_batch_size=args.effective_batch_size,
max_batch_len_per_gpu=args.max_batch_len,
Expand All @@ -542,9 +620,7 @@ def main(args):
grad_accum = 1
args.sampler = "distributed"

args.samples_per_gpu = (
args.effective_batch_size // grad_accum // torch.distributed.get_world_size()
)
args.samples_per_gpu = args.effective_batch_size // grad_accum // world_size

train_loader = setup_dataloader(
dataset,
Expand Down Expand Up @@ -580,7 +656,7 @@ def main(args):
if args.local_rank == 0:
metric_logger.log_sync(
{
"num_gpus": torch.distributed.get_world_size(),
"num_gpus": world_size,
"avg_sample_len": dataset.get_lengths().mean(),
"effective_batch_size": args.effective_batch_size,
"max_batch_len_per_gpu": args.max_batch_len,
Expand All @@ -592,13 +668,24 @@ def main(args):
}
)

model = setup_model(args, tokenizer, train_loader, grad_accum)
model = maybe_resume_training(args, model)

train(args, model, tokenizer, train_loader, grad_accum, metric_logger)
model, optimizer = setup_model(args, tokenizer, train_loader, grad_accum, device)
if device.type == "cuda":
model = maybe_resume_training(args, model)

train(
args,
model,
tokenizer,
train_loader,
grad_accum,
metric_logger,
device,
optimizer,
)

torch.distributed.barrier()
torch.distributed.destroy_process_group()
if torch.cuda.is_available() and torch.distributed.is_available():
torch.distributed.barrier()
torch.distributed.destroy_process_group()


# public API
Expand Down Expand Up @@ -705,6 +792,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.deepspeed_options.cpu_offload_optimizer_pin_memory:
command.append("--cpu_offload_optimizer_pin_memory")

if torch_args.nproc_per_node == 1:
command.append("--standalone")

print(f"\033[92mRunning command: {' '.join(command)}\033[0m")
process = None
try:
Expand Down Expand Up @@ -831,6 +921,8 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
),
)
parser.add_argument("--disable_flash_attn", action="store_true")
parser.add_argument("--standalone", action="store_true")
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
set_random_seed(args.seed)
main(args)
Expand Down
10 changes: 8 additions & 2 deletions src/instructlab/training/multipack_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import List, Optional

# Third Party
import torch.distributed
from torch.utils.data import Sampler
import numba
import numpy as np
Expand Down Expand Up @@ -67,11 +68,16 @@ def get_effective_samples_per_minibatch(num_tokens_per_gpu):
The function creates a sampler using the MultipackDistributedBatchSampler class, generates batches using the sampler, and then returns the ratio of the dataset size to the number of batches.
"""
num_replicas = 1
rank = 0
if torch.distributed.is_initialized():
num_replicas = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
sampler = MultipackDistributedBatchSampler(
batch_max_length=num_tokens_per_gpu,
lengths=dataset.get_lengths(),
num_replicas=torch.distributed.get_world_size(),
rank=torch.distributed.get_rank(),
num_replicas=num_replicas,
rank=rank,
seed=seed,
padding=True,
)
Expand Down
Loading

0 comments on commit ef8bbe5

Please sign in to comment.