Skip to content

Commit

Permalink
skip remainder batch (facebookresearch#2464)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#2464

Reviewed By: myleott

Differential Revision: D31742871

Pulled By: sshleifer

fbshipit-source-id: e5d29ca9d594abd92212eb24b60c991f2840a4e8
  • Loading branch information
sshleifer authored and facebook-github-bot committed Nov 24, 2021
1 parent 7f5ec30 commit fb64e43
Show file tree
Hide file tree
Showing 11 changed files with 285 additions and 160 deletions.
29 changes: 19 additions & 10 deletions examples/MMPT/mmpt/tasks/fairseqmmtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def add_args(parser):
parser.add_argument(
"taskconfig",
metavar="FILE",
help=(
"taskconfig to load all configurations"
"outside fairseq parser."),
help=("taskconfig to load all configurations" "outside fairseq parser."),
)

@classmethod
Expand Down Expand Up @@ -68,23 +66,34 @@ def get_batch_iterator(
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
skip_remainder_batch=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
random.seed(epoch)
if dataset.mmdataset.split == "train" \
and isinstance(self.mmtask, RetriTask):
if dataset.mmdataset.split == "train" and isinstance(self.mmtask, RetriTask):
if epoch >= self.mmtask.config.retri_epoch:
if not hasattr(self.mmtask, "retri_dataloader"):
self.mmtask.build_dataloader()
self.mmtask.retrive_candidates(epoch)

return super().get_batch_iterator(
dataset, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, required_batch_size_multiple,
seed, num_shards, shard_id, num_workers, epoch,
data_buffer_size, disable_iterator_cache,
grouped_shuffling, update_epoch_batch_itr)
dataset,
max_tokens,
max_sentences,
max_positions,
ignore_invalid_inputs,
required_batch_size_multiple,
seed,
num_shards,
shard_id,
num_workers,
epoch,
data_buffer_size,
disable_iterator_cache,
grouped_shuffling,
update_epoch_batch_itr,
)

@property
def source_dictionary(self):
Expand Down
1 change: 1 addition & 0 deletions examples/laser/laser_src/laser_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def get_batch_iterator(
disable_iterator_cache=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
**kwargs,
):

assert isinstance(dataset, OrderedDict)
Expand Down
23 changes: 12 additions & 11 deletions examples/speech_text_joint_to_text/tasks/speech_text_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
LangPairMaskDataset,
ModalityDatasetItem,
)
from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset, SpeechToTextDatasetCreator
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset,
SpeechToTextDatasetCreator,
)
from fairseq.data.audio.speech_to_text_joint_dataset import (
S2TJointDataConfig,
SpeechToTextJointDatasetCreator,
Expand Down Expand Up @@ -89,9 +92,7 @@ def add_args(cls, parser):
help="use mixed data in one update when update-freq > 1",
)
parser.add_argument(
"--load-speech-only",
action="store_true",
help="load speech data only",
"--load-speech-only", action="store_true", help="load speech data only",
)
parser.add_argument(
"--mask-text-ratio",
Expand Down Expand Up @@ -160,7 +161,9 @@ def setup_task(cls, args, **kwargs):
assert infer_tgt_lang_id != tgt_dict.unk()
return cls(args, src_dict, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id)

def load_langpair_dataset(self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0):
def load_langpair_dataset(
self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0
):
lang_pairs = []
text_dataset = None
split = "train"
Expand Down Expand Up @@ -200,9 +203,7 @@ def load_langpair_dataset(self, prepend_tgt_lang_tag=False, sampling_alpha=1.0,
alpha=sampling_alpha,
)
lang_pairs = [
ResamplingDataset(
d, size_ratio=r, epoch=epoch, replace=(r >= 1.0)
)
ResamplingDataset(d, size_ratio=r, epoch=epoch, replace=(r >= 1.0))
for d, r in zip(lang_pairs, size_ratios)
]
return ConcatDataset(lang_pairs)
Expand Down Expand Up @@ -257,9 +258,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
text_dataset = None
if self.args.parallel_text_data != "" and is_train_split:
text_dataset = self.load_langpair_dataset(
self.data_cfg.prepend_tgt_lang_tag_no_change,
1.0,
epoch=epoch,
self.data_cfg.prepend_tgt_lang_tag_no_change, 1.0, epoch=epoch,
)
if self.args.mask_text_ratio > 0:
# add mask
Expand Down Expand Up @@ -326,6 +325,7 @@ def get_batch_iterator(
epoch=0,
data_buffer_size=0,
disable_iterator_cache=False,
skip_remainder_batch=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
Expand All @@ -345,6 +345,7 @@ def get_batch_iterator(
epoch,
data_buffer_size,
disable_iterator_cache,
skip_remainder_batch=skip_remainder_batch,
update_epoch_batch_itr=update_epoch_batch_itr,
)

Expand Down
18 changes: 11 additions & 7 deletions examples/truncated_bptt/truncated_bptt_lm_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
class TruncatedBPTTLMConfig(FairseqDataclass):
data: str = field(default="???", metadata={"help": "path to data directory"})
tokens_per_sample: int = field(
default=1024,
metadata={"help": "max number of tokens per sequence"},
default=1024, metadata={"help": "max number of tokens per sequence"},
)
batch_size: int = II("dataset.batch_size")
# Some models use *max_target_positions* to know how many positional
Expand Down Expand Up @@ -103,7 +102,13 @@ def dataset(self, split):
return self.datasets[split]

def get_batch_iterator(
self, dataset, num_workers=0, epoch=1, data_buffer_size=0, **kwargs
self,
dataset,
num_workers=0,
epoch=1,
data_buffer_size=0,
skip_remainder_batch=False,
**kwargs
):
return iterators.EpochBatchIterator(
dataset=dataset,
Expand All @@ -115,6 +120,7 @@ def get_batch_iterator(
# instead every item in *dataset* is a whole batch
batch_sampler=[[i] for i in range(len(dataset))],
disable_shuffling=True,
skip_remainder_batch=skip_remainder_batch,
)

def _collate_fn(self, items: List[List[torch.Tensor]]):
Expand All @@ -134,10 +140,8 @@ def _collate_fn(self, items: List[List[torch.Tensor]]):

# fairseq expects batches to have the following structure
return {
"id": torch.tensor([id]*item.size(0)),
"net_input": {
"src_tokens": item,
},
"id": torch.tensor([id] * item.size(0)),
"net_input": {"src_tokens": item,},
"target": target,
"nsentences": item.size(0),
"ntokens": item.numel(),
Expand Down
62 changes: 47 additions & 15 deletions fairseq/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CountingIterator(object):
def __init__(self, iterable, start=None, total=None):
self._itr = iter(iterable)
self.n = start or getattr(iterable, "n", 0)
self.total = total or self.n + len(iterable)
self.total = total if total is not None else self.n + len(iterable)

def __len__(self):
return self.total
Expand Down Expand Up @@ -265,6 +265,9 @@ class EpochBatchIterator(EpochBatchIterating):
from workers. Should always be non-negative (default: ``0``).
disable_shuffling (bool, optional): force disable shuffling
(default: ``False``).
skip_remainder_batch (bool, optional): if set, discard the last batch in an epoch
for the sake of training stability, as the last batch is usually smaller than
local_batch_size * distributed_word_size (default: ``False``).
grouped_shuffling (bool, optional): enable shuffling batches in groups
of num_shards. Ensures that each GPU receives similar length sequences when
batches are sorted by length.
Expand All @@ -283,6 +286,7 @@ def __init__(
buffer_size=0,
timeout=0,
disable_shuffling=False,
skip_remainder_batch=False,
grouped_shuffling=False,
):
assert isinstance(dataset, torch.utils.data.Dataset)
Expand All @@ -301,6 +305,7 @@ def __init__(
self.buffer_size = min(buffer_size, 20)
self.timeout = timeout
self.disable_shuffling = disable_shuffling
self.skip_remainder_batch = skip_remainder_batch
self.grouped_shuffling = grouped_shuffling

self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
Expand Down Expand Up @@ -375,9 +380,7 @@ def next_epoch_itr(
# reset _frozen_batches to refresh the next epoch
self._frozen_batches = None
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch,
shuffle,
fix_batches_to_gpus=fix_batches_to_gpus,
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
self.shuffle = shuffle
return self._cur_epoch_itr
Expand Down Expand Up @@ -418,9 +421,7 @@ def load_state_dict(self, state_dict):
if itr_pos > 0:
# fast-forward epoch iterator
self._next_epoch_itr = self._get_iterator_for_epoch(
self.epoch,
shuffle=state_dict.get("shuffle", True),
offset=itr_pos,
self.epoch, shuffle=state_dict.get("shuffle", True), offset=itr_pos,
)
if self._next_epoch_itr is None:
if version == 1:
Expand Down Expand Up @@ -497,6 +498,14 @@ def shuffle_batches(batches, seed):

# Wrap with CountingIterator
itr = CountingIterator(itr, start=offset)

if self.skip_remainder_batch:
# TODO: Below is a lazy implementation which discard the final batch regardless
# of whether it is a full batch or not.
total_num_itrs = len(batches) - 1
itr.take(total_num_itrs)
logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}")

return itr


Expand All @@ -506,29 +515,47 @@ class GroupedIterator(CountingIterator):
Args:
iterable (iterable): iterable to wrap
chunk_size (int): size of each chunk
skip_remainder_batch (bool, optional): if set, discard the last grouped batch in
each training epoch, as the last grouped batch is usually smaller than
local_batch_size * distributed_word_size * chunk_size (default: ``False``).
Attributes:
n (int): number of elements consumed from this iterator
"""

def __init__(self, iterable, chunk_size):
itr = _chunk_iterator(iterable, chunk_size)
def __init__(self, iterable, chunk_size, skip_remainder_batch=False):
if skip_remainder_batch:
total_num_itrs = int(math.floor(len(iterable) / float(chunk_size)))
logger.info(
f"skip final residual batch, grouped total_num_itrs = {total_num_itrs}"
)
else:
total_num_itrs = int(math.ceil(len(iterable) / float(chunk_size)))
logger.info(f"grouped total_num_itrs = {total_num_itrs}")

itr = _chunk_iterator(iterable, chunk_size, skip_remainder_batch)
super().__init__(
itr,
start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))),
total=int(math.ceil(len(iterable) / float(chunk_size))),
total=total_num_itrs,
)
self.chunk_size = chunk_size

if skip_remainder_batch:
self.take(total_num_itrs)
# TODO: [Hack] Here the grouped iterator modifies the base iterator size so that
# training can move into the next epoch once the grouped iterator is exhausted.
# Double-check this implementation in case unexpected behavior occurs.
iterable.take(total_num_itrs * chunk_size)

def _chunk_iterator(itr, chunk_size):

def _chunk_iterator(itr, chunk_size, skip_remainder_batch=False):
chunk = []
for x in itr:
chunk.append(x)
if len(chunk) == chunk_size:
yield chunk
chunk = []
if len(chunk) > 0:
if not skip_remainder_batch and len(chunk) > 0:
yield chunk


Expand All @@ -546,7 +573,12 @@ class ShardedIterator(CountingIterator):
n (int): number of elements consumed from this iterator
"""

def __init__(self, iterable, num_shards, shard_id, fill_value=None):
def __init__(
self, iterable, num_shards, shard_id, fill_value=None, skip_remainder_batch=None
):
"""
Args:
skip_remainder_batch: ignored"""
if shard_id < 0 or shard_id >= num_shards:
raise ValueError("shard_id must be between 0 and num_shards")
sharded_len = int(math.ceil(len(iterable) / float(num_shards)))
Expand Down Expand Up @@ -611,7 +643,7 @@ def _create_consumer(self):
self._queue,
self._iterable,
self.total,
torch.cuda.current_device() if torch.cuda.is_available() else None
torch.cuda.current_device() if torch.cuda.is_available() else None,
)
self._consumer.daemon = True
self._consumer.start()
Expand Down
Loading

0 comments on commit fb64e43

Please sign in to comment.