Skip to content

Commit

Permalink
Implementing HF Padding-Free and GraniteLM Support (instructlab#257)
Browse files Browse the repository at this point in the history
Updating the data collator for models with HF padding-free support, adding support for upcoming Granite HF model class, and updating flags/interface accordingly.

------------------------------------------------

* only compute lengths in the token dataset when it's not already present in the dataset

Signed-off-by: aldo pareja-cardona <[email protected]>

* Refactor padding function to support position_ids for FlashAttention

- Added `supports_flash_attention` function to check GPU compatibility for FlashAttention.
- Updated `make_collate_fn` to return `position_ids` instead of `attention_mask` when FlashAttention is supported.
- Integrated the new padding logic into `setup_dataloader` to ensure compatibility with both Granite and non-Granite configurations.
- Ensured backward compatibility by maintaining the original padding logic for GPUs that do not support FlashAttention.
- Updated `main_ds.py` to use the new `supports_flash_attention` check for determining padding strategy.

Signed-off-by: aldo pareja-cardona <[email protected]>

* logging the global gradnorm now

Signed-off-by: aldo pareja-cardona <[email protected]>

* fixing deepspeed because it's not working with the scheduler we want

Signed-off-by: aldo pareja-cardona <[email protected]>

* fixing accelerate lr_scheduler

Signed-off-by: aldo pareja-cardona <[email protected]>

* fixing accelerate lr_scheduler

Signed-off-by: aldo pareja-cardona <[email protected]>

* samples seen was broken because now the samples are a single line

Signed-off-by: aldo pareja-cardona <[email protected]>

* find packing is wrong because when flash attention is supported padding should not be used when building the buckets

Signed-off-by: aldo pareja-cardona <[email protected]>

* black formatting

Signed-off-by: aldo pareja-cardona <[email protected]>

* it should not fail on granite 8b models anymore

Signed-off-by: aldo pareja-cardona <[email protected]>

* linting

Signed-off-by: aldo pareja-cardona <[email protected]>

* linting

Signed-off-by: aldo pareja-cardona <[email protected]>

* bug on padding when creating the multipack sampler

Signed-off-by: aldo pareja-cardona <[email protected]>

* linter

Signed-off-by: aldo pareja-cardona <[email protected]>

* linter

Signed-off-by: aldo pareja-cardona <[email protected]>

* Change old padding-free and granite flags to use_dolomite

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Add safeguards and checks for flash attention when enabled/disabled

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Rework flash attention checks for better modularity

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Fix arg name

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Update transformers to a version with Granite model class

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Adding stateguards for dolomite and granite and model path check

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Missing update

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Clean up early validation checks and move to utils

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Fix spelling mistake

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Include AMD in flash attn check

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Red-add is_padding_free with deprecation warning

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Make use_dolomite default false

Signed-off-by: Mustafa Eyceoz <[email protected]>

* this is needed because the tag <MASK> is too common and some datasets will fail

Signed-off-by: Mustafa Eyceoz <[email protected]>

* added a warning in case the special tokens used for data processing are present in the dataset

Signed-off-by: Mustafa Eyceoz <[email protected]>

* added a warning in case the special tokens used for data processing are present in the dataset

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Update valid data filter

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Fix ruff formatting

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Apply review feedback

Signed-off-by: Mustafa Eyceoz <[email protected]>

* Added comments

Signed-off-by: Mustafa Eyceoz <[email protected]>

---------

Signed-off-by: aldo pareja-cardona <[email protected]>
Signed-off-by: Mustafa Eyceoz <[email protected]>
Co-authored-by: aldo pareja-cardona <[email protected]>
Co-authored-by: Mustafa Eyceoz <[email protected]>
  • Loading branch information
3 people authored Oct 25, 2024
1 parent ed8d6e2 commit 03d1b62
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 107 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ py-cpuinfo
# we set this to be above 0a0 so that it doesn't
# replace custom pytorch images with the 2.3.0
torch>=2.3.0a0
transformers>=4.41.2
transformers>=4.45.2
accelerate>=0.34.2
datasets>=2.15.0
numba
Expand Down
3 changes: 2 additions & 1 deletion src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ class TrainingArgs(BaseModel):
save_samples: int
learning_rate: float
warmup_steps: int
is_padding_free: bool
random_seed: int = 42
use_dolomite: bool = False
is_padding_free: bool = False # TODO: deprecate
checkpoint_at_epoch: bool = True
accelerate_full_state_at_epoch: bool = True

Expand Down
25 changes: 21 additions & 4 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def print_masked_samples(data, tokenizer, is_pretrain, num_proc):
def get_masked_and_orig_text(sample):
labels = sample["labels"]
input_ids = sample["input_ids"]
mask_id = get_sp_token(tokenizer, "<MASK>")[0]
mask_id = get_sp_token(tokenizer, "<|MASK|>")[0]
label = [mask_id if tk == -100 else tk for tk in labels]
text = tokenizer.decode(label)
orig_text = tokenizer.decode(input_ids)
Expand Down Expand Up @@ -239,7 +239,7 @@ def main(args: DataProcessArgs):

# Adding after tokenizer setup as these are temp tokens, not to be saved
tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>", "<MASK>"]}
{"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>", "<|MASK|>"]}
)

try:
Expand Down Expand Up @@ -347,9 +347,26 @@ def main(args: DataProcessArgs):
)

# extract only labels and messages formatted into a new dataset
data_with_labels = data_with_labels.select_columns(["labels", "input_ids"])
data_with_labels = data_with_labels.map(
lambda x: {
"len": len(x["input_ids"]),
},
num_proc=NUM_PROC,
)
data_with_labels = data_with_labels.select_columns(["labels", "input_ids", "len"])
# MASK and both pretrain tokens should not be in the final tokens, those are special tokens added only for data processing purposes.
max_id = len(tokenizer) - 3
final_valid_data = data_with_labels.filter(
lambda x: all(tk < max_id for tk in x["labels"]), num_proc=NUM_PROC
)
# Dropping samples that could break training due to oob ids
if len(final_valid_data) < len(data_with_labels):
dropped_samples = len(data_with_labels) - len(final_valid_data)
print(
f"\033[93mWarning: {dropped_samples} samples were dropped because they contained token IDs greater than or equal to {max_id}.\033[0m"
)
# use path to get the stem of the file
data_with_labels.to_json(Path(args.data_output_path) / f"data.jsonl")
final_valid_data.to_json(Path(args.data_output_path) / "data.jsonl")


if __name__ == "__main__":
Expand Down
64 changes: 31 additions & 33 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@
StreamablePopen,
add_noisy_embeddings,
apply_gradient_checkpointing,
check_flash_attn_enabled,
check_valid_train_args,
convert_loss_to_reduce_sum,
ensure_loadable_granite_checkpoint,
ensure_loadable_dolomite_checkpoint,
get_projection_layer_names,
load_latest_full_state,
prepare_peft_model,
Expand Down Expand Up @@ -84,7 +86,7 @@ def setup_optimizer(args, model):
return optimizer


def setup_model(args, tokenizer, train_loader, grad_accum):
def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled):
bnb_config = None
if args.lora_r > 0 and args.lora_quant_bits == 4:
# Third Party
Expand All @@ -102,15 +104,11 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
"torch_dtype": torch.bfloat16,
"quantization_config": bnb_config,
}
if not args.disable_flash_attn:
if flash_enabled:
base_model_args["attn_implementation"] = "flash_attention_2"
elif args.is_granite:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)

if args.is_granite:
with ensure_loadable_granite_checkpoint(
if args.use_dolomite:
with ensure_loadable_dolomite_checkpoint(
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
Expand Down Expand Up @@ -165,9 +163,10 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
"Starcoder2ForCausalLM",
"GemmaForCausalLM",
"MixtralForCausalLM",
"GraniteForCausalLM",
], f"Model class name: {model.__class__.__name__} is not supported."

model = convert_loss_to_reduce_sum(model, is_granite=args.is_granite)
model = convert_loss_to_reduce_sum(model, use_dolomite=args.use_dolomite)
model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha)

# handling of gradient checkpointing
Expand Down Expand Up @@ -212,15 +211,15 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
target_modules=args.lora_target_modules,
)
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.is_granite
model, peft_config, gradient_checkpointing=not args.use_dolomite
)

elif not args.is_granite:
elif not args.use_dolomite:
model.gradient_checkpointing_enable()

# granite gradient checkpointing is handled uniformly
# for both lora and full here
if args.is_granite:
if args.use_dolomite:
block_name = model._no_split_modules[0]
apply_gradient_checkpointing(
model,
Expand Down Expand Up @@ -252,6 +251,9 @@ def make_inputs_require_grad(module, input, output):
deepcopy(train_loader),
lr_scheduler,
)
# Necessary so that Accelerate does not step once per GPU
# see https://github.com/huggingface/accelerate/blob/127818fc27ebe5cb236357fff59ff1748326d643/src/accelerate/scheduler.py#L69
lr_scheduler.split_batches = True
return model, lr_scheduler, optimizer, accelerator


Expand Down Expand Up @@ -381,8 +383,8 @@ def train(
num_loss_counted_tokens = float(
torch.tensor([batch.pop("num_loss_counted_tokens")])
)
micro_batch_size = float(len(batch["input_ids"]))
if not args.is_granite:
micro_batch_size = float(torch.tensor([batch.pop("num_samples")]))
if not args.use_dolomite:
for k in batch:
batch[k] = batch[k].to(local_rank)
output = model(
Expand Down Expand Up @@ -453,7 +455,7 @@ def train(
"batch_size": int(micro_batch_size),
"total_loss": float(log_loss / num_loss_counted_tokens),
"samples_seen": samples_seen,
# "gradnorm": global_grad_norm,
"gradnorm": global_grad_norm,
# "weight_norm": weight_norm,
}
)
Expand Down Expand Up @@ -535,6 +537,8 @@ def main(args):
torch.distributed.all_reduce(tensor)
torch.distributed.barrier()

flash_enabled = check_flash_attn_enabled(args.disable_flash_attn, args.use_dolomite)

dataset = setup_dataset(
args.data_path,
mock=args.mock_data,
Expand All @@ -547,7 +551,7 @@ def main(args):
avg_sample_len=dataset.get_lengths().mean(),
effective_batch_size=args.effective_batch_size,
max_batch_len_per_gpu=args.max_batch_len,
is_padding=not args.is_granite,
is_padding=not (args.use_dolomite or flash_enabled),
dataset=dataset,
seed=args.seed,
)
Expand All @@ -570,7 +574,8 @@ def main(args):
dataset,
tokenizer.pad_token_id,
num_workers=8,
is_granite=args.is_granite,
use_dolomite=args.use_dolomite,
flash_enabled=flash_enabled,
max_batch_len=args.max_batch_len,
packing_max_batch_len=packing_max_batch_len,
samples_per_gpu=args.samples_per_gpu,
Expand All @@ -589,7 +594,8 @@ def main(args):
dataset,
tokenizer.pad_token_id,
num_workers=8,
is_granite=args.is_granite,
use_dolomite=args.use_dolomite,
flash_enabled=flash_enabled,
max_batch_len=args.max_batch_len,
packing_max_batch_len=packing_max_batch_len,
samples_per_gpu=args.samples_per_gpu,
Expand All @@ -613,7 +619,7 @@ def main(args):
)

model, lr_scheduler, optimizer, accelerator = setup_model(
args, tokenizer, train_loader, grad_accum
args, tokenizer, train_loader, grad_accum, flash_enabled
)

load_latest_full_state(args=args, accelerator=accelerator)
Expand All @@ -639,11 +645,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
# early validation logic here
if train_args.max_batch_len < train_args.max_seq_len:
raise ValueError(
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)
check_valid_train_args(train_args)

if train_args.process_data:
dp.main(
Expand Down Expand Up @@ -697,14 +699,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.mock_len:
command.append(f"--mock_len={train_args.mock_len}")

if train_args.is_padding_free:
command.append("--is_granite")
if train_args.use_dolomite:
command.append("--use_dolomite")

if train_args.disable_flash_attn:
if train_args.is_padding_free:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)
command.append("--disable_flash_attn")

if train_args.lora:
Expand Down Expand Up @@ -888,7 +886,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
default="SHARD_GRAD_OP",
help="Sharding strategy to be used for FSDP distributed training.",
)
parser.add_argument("--is_granite", action="store_true")
parser.add_argument("--use_dolomite", action="store_true")
parser.add_argument("--lora_r", type=int, default=0) # set to > 0 to activate lora
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.1)
Expand Down Expand Up @@ -977,7 +975,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
--save_samples=250000 \
--log_level="INFO" \
--fsdp_sharding_strategy="SHARD_GRAD_OP" \
--is_granite \
--use_dolomite \
--max_batch_len 70000 \
--seed=42
"""
25 changes: 16 additions & 9 deletions src/instructlab/training/token_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
class TokenDataset(Dataset):
def __init__(self, data_path):
self.data = load_dataset("json", data_files=data_path, split="train")
self.lengths = np.array(
self.data.map(
lambda x: {"len": len(x["input_ids"])},
num_proc=8,
)["len"]
)
if "len" not in self.data.column_names:
self.lengths = np.array(
self.data.map(
lambda x: {"len": len(x["input_ids"])},
num_proc=8,
)["len"]
)
else:
self.lengths = np.array(self.data["len"])

def __len__(self):
return len(self.data)
Expand Down Expand Up @@ -87,15 +90,19 @@ def setup_dataloader(
dataset: Dataset,
pad_token_id: int,
num_workers: int = 8,
is_granite=False,
use_dolomite=False,
flash_enabled=True,
max_batch_len=60000,
packing_max_batch_len=60000,
samples_per_gpu=None,
sampler="multipack",
seed=47,
) -> DataLoader:
collate_fn = make_collate_fn(
pad_token_id, is_granite=is_granite, max_batch_len=max_batch_len
pad_token_id,
use_dolomite=use_dolomite,
flash_enabled=flash_enabled,
max_batch_len=max_batch_len,
)
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
Expand All @@ -108,7 +115,7 @@ def setup_dataloader(
num_replicas=world_size,
rank=rank,
seed=seed,
padding=not is_granite,
padding=not flash_enabled,
)
sampler = {"batch_sampler": sampler}
elif sampler == "distributed":
Expand Down
Loading

0 comments on commit 03d1b62

Please sign in to comment.