Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiple evaluation compat #336

Open
wants to merge 13 commits into
base: t0loading
Choose a base branch
from
59 changes: 45 additions & 14 deletions finetune_t0.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import torch

from pretrain_gpt import get_batch_pipe as get_batch_pipe_gpt
from megatron import get_args, get_tokenizer, print_rank_0, mpu
from megatron.data.gpt_dataset import build_dataset_group as build_dataset_group_gpt
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.enums import PositionEmbeddingType, AttnMaskType
from megatron.model import GPTModelPipe
Expand Down Expand Up @@ -48,6 +50,14 @@ def model_provider(pre_process=True, post_process=True):
return model


def fast_normalize(loss_mask: torch.Tensor):
"""
Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3]
"""
_, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True)
counts = torch.gather(dim=0, index=inverse_indices, input=counts)
return loss_mask / counts

def get_batch_pipe(data):
"""
Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion
Expand All @@ -57,6 +67,9 @@ def get_batch_pipe(data):
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]
"""
if 'text' in data:
return get_batch_pipe_gpt(data)

args = get_args()
tokenizer = get_tokenizer()

Expand Down Expand Up @@ -95,6 +108,10 @@ def get_batch_pipe(data):
segment_ids=segment_ids.long(),
)

if args.norm_target_loss:
loss_mask = loss_mask.view(-1)
loss_mask = fast_normalize(loss_mask)

if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")

Expand Down Expand Up @@ -142,20 +159,34 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
eval(f"args.{s}_weighted_split_splits"),
eval(f"args.{s}_weighted_split_names"))
for paths, weights, splits, name in data_groups:
d = build_dataset_group(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
if "merged-meg-ds_v3_pii" in paths[0]:
d = build_dataset_group_gpt(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
else:
d = build_dataset_group(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
eval(f"{s}_ds").append(d)
else:
raise NotImplementedError("No dataloading argument passed")
Expand Down
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,8 @@ def __call__(self, parser, args, values, option_string=None):
help='Mask loss for the end of document tokens.')
group.add_argument('--loss-on-targets-only', action='store_true',
help='Mask loss on input sequence.')
group.add_argument('--norm-target-loss', action='store_true',
help='Normalize the loss per target. Used for multi-task finetuning with packing.')
group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true",
help='Some objectives require us to sample loss_mask. This might introduce bias towards '
'specific positions. This option tries to un-bias the loss by reweighting loss on specific '
Expand Down
5 changes: 4 additions & 1 deletion megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

"""GPT-2 model."""

from functools import partial
import torch

from megatron import get_args
Expand Down Expand Up @@ -186,6 +185,10 @@ def CrossEntropy(output, labels):
else:
average_tokens_per_sample = sequence_length
expected_number_of_tokens = average_tokens_per_sample * micro_batch_size
elif args.norm_target_loss and (loss_mask.dim() == 1):
expected_num_of_target_seqs = loss_mask.sum()
loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs
return loss
else:
expected_number_of_tokens = loss_mask.sum()

Expand Down
5 changes: 3 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def pretrain(train_valid_test_dataset_provider,
timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
print_rank_0('training ...')

iteration = 0
iteration = args.iteration
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
Expand All @@ -199,7 +199,8 @@ def pretrain(train_valid_test_dataset_provider,
iterator, model,
iteration, False, data_group_name=name)

if args.save and iteration != 0:
# Do not save if the iteration has not changed
if args.save and iteration != args.iteration:
save_checkpoint(iteration, model, optimizer, lr_scheduler)

if args.do_test:
Expand Down