From fb64e43c67b146cc9f28b35facd59a44c1e6031b Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 24 Nov 2021 07:49:55 -0800 Subject: [PATCH] skip remainder batch (#2464) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2464 Reviewed By: myleott Differential Revision: D31742871 Pulled By: sshleifer fbshipit-source-id: e5d29ca9d594abd92212eb24b60c991f2840a4e8 --- examples/MMPT/mmpt/tasks/fairseqmmtask.py | 29 +++-- examples/laser/laser_src/laser_task.py | 1 + .../tasks/speech_text_joint.py | 23 ++-- .../truncated_bptt/truncated_bptt_lm_task.py | 18 +-- fairseq/data/iterators.py | 62 +++++++--- fairseq/dataclass/configs.py | 54 +++++---- fairseq/tasks/fairseq_task.py | 15 ++- .../tasks/translation_multi_simple_epoch.py | 10 +- fairseq/trainer.py | 114 ++++++++---------- fairseq_cli/train.py | 60 ++++++--- tests/test_iterators.py | 59 ++++++++- 11 files changed, 285 insertions(+), 160 deletions(-) diff --git a/examples/MMPT/mmpt/tasks/fairseqmmtask.py b/examples/MMPT/mmpt/tasks/fairseqmmtask.py index 78ef7ba17c..f6b6115a39 100644 --- a/examples/MMPT/mmpt/tasks/fairseqmmtask.py +++ b/examples/MMPT/mmpt/tasks/fairseqmmtask.py @@ -25,9 +25,7 @@ def add_args(parser): parser.add_argument( "taskconfig", metavar="FILE", - help=( - "taskconfig to load all configurations" - "outside fairseq parser."), + help=("taskconfig to load all configurations" "outside fairseq parser."), ) @classmethod @@ -68,23 +66,34 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + skip_remainder_batch=False, grouped_shuffling=False, update_epoch_batch_itr=False, ): random.seed(epoch) - if dataset.mmdataset.split == "train" \ - and isinstance(self.mmtask, RetriTask): + if dataset.mmdataset.split == "train" and isinstance(self.mmtask, RetriTask): if epoch >= self.mmtask.config.retri_epoch: if not hasattr(self.mmtask, "retri_dataloader"): self.mmtask.build_dataloader() self.mmtask.retrive_candidates(epoch) return super().get_batch_iterator( - dataset, max_tokens, max_sentences, max_positions, - ignore_invalid_inputs, required_batch_size_multiple, - seed, num_shards, shard_id, num_workers, epoch, - data_buffer_size, disable_iterator_cache, - grouped_shuffling, update_epoch_batch_itr) + dataset, + max_tokens, + max_sentences, + max_positions, + ignore_invalid_inputs, + required_batch_size_multiple, + seed, + num_shards, + shard_id, + num_workers, + epoch, + data_buffer_size, + disable_iterator_cache, + grouped_shuffling, + update_epoch_batch_itr, + ) @property def source_dictionary(self): diff --git a/examples/laser/laser_src/laser_task.py b/examples/laser/laser_src/laser_task.py index 43416e0a0d..72d069fe8e 100644 --- a/examples/laser/laser_src/laser_task.py +++ b/examples/laser/laser_src/laser_task.py @@ -284,6 +284,7 @@ def get_batch_iterator( disable_iterator_cache=False, grouped_shuffling=False, update_epoch_batch_itr=False, + **kwargs, ): assert isinstance(dataset, OrderedDict) diff --git a/examples/speech_text_joint_to_text/tasks/speech_text_joint.py b/examples/speech_text_joint_to_text/tasks/speech_text_joint.py index 800ccd782a..cd9aabd583 100644 --- a/examples/speech_text_joint_to_text/tasks/speech_text_joint.py +++ b/examples/speech_text_joint_to_text/tasks/speech_text_joint.py @@ -21,7 +21,10 @@ LangPairMaskDataset, ModalityDatasetItem, ) -from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset, SpeechToTextDatasetCreator +from fairseq.data.audio.speech_to_text_dataset import ( + SpeechToTextDataset, + SpeechToTextDatasetCreator, +) from fairseq.data.audio.speech_to_text_joint_dataset import ( S2TJointDataConfig, SpeechToTextJointDatasetCreator, @@ -89,9 +92,7 @@ def add_args(cls, parser): help="use mixed data in one update when update-freq > 1", ) parser.add_argument( - "--load-speech-only", - action="store_true", - help="load speech data only", + "--load-speech-only", action="store_true", help="load speech data only", ) parser.add_argument( "--mask-text-ratio", @@ -160,7 +161,9 @@ def setup_task(cls, args, **kwargs): assert infer_tgt_lang_id != tgt_dict.unk() return cls(args, src_dict, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id) - def load_langpair_dataset(self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0): + def load_langpair_dataset( + self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, epoch=0 + ): lang_pairs = [] text_dataset = None split = "train" @@ -200,9 +203,7 @@ def load_langpair_dataset(self, prepend_tgt_lang_tag=False, sampling_alpha=1.0, alpha=sampling_alpha, ) lang_pairs = [ - ResamplingDataset( - d, size_ratio=r, epoch=epoch, replace=(r >= 1.0) - ) + ResamplingDataset(d, size_ratio=r, epoch=epoch, replace=(r >= 1.0)) for d, r in zip(lang_pairs, size_ratios) ] return ConcatDataset(lang_pairs) @@ -257,9 +258,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): text_dataset = None if self.args.parallel_text_data != "" and is_train_split: text_dataset = self.load_langpair_dataset( - self.data_cfg.prepend_tgt_lang_tag_no_change, - 1.0, - epoch=epoch, + self.data_cfg.prepend_tgt_lang_tag_no_change, 1.0, epoch=epoch, ) if self.args.mask_text_ratio > 0: # add mask @@ -326,6 +325,7 @@ def get_batch_iterator( epoch=0, data_buffer_size=0, disable_iterator_cache=False, + skip_remainder_batch=False, grouped_shuffling=False, update_epoch_batch_itr=False, ): @@ -345,6 +345,7 @@ def get_batch_iterator( epoch, data_buffer_size, disable_iterator_cache, + skip_remainder_batch=skip_remainder_batch, update_epoch_batch_itr=update_epoch_batch_itr, ) diff --git a/examples/truncated_bptt/truncated_bptt_lm_task.py b/examples/truncated_bptt/truncated_bptt_lm_task.py index 02be0e7fb4..9978481b6d 100644 --- a/examples/truncated_bptt/truncated_bptt_lm_task.py +++ b/examples/truncated_bptt/truncated_bptt_lm_task.py @@ -29,8 +29,7 @@ class TruncatedBPTTLMConfig(FairseqDataclass): data: str = field(default="???", metadata={"help": "path to data directory"}) tokens_per_sample: int = field( - default=1024, - metadata={"help": "max number of tokens per sequence"}, + default=1024, metadata={"help": "max number of tokens per sequence"}, ) batch_size: int = II("dataset.batch_size") # Some models use *max_target_positions* to know how many positional @@ -103,7 +102,13 @@ def dataset(self, split): return self.datasets[split] def get_batch_iterator( - self, dataset, num_workers=0, epoch=1, data_buffer_size=0, **kwargs + self, + dataset, + num_workers=0, + epoch=1, + data_buffer_size=0, + skip_remainder_batch=False, + **kwargs ): return iterators.EpochBatchIterator( dataset=dataset, @@ -115,6 +120,7 @@ def get_batch_iterator( # instead every item in *dataset* is a whole batch batch_sampler=[[i] for i in range(len(dataset))], disable_shuffling=True, + skip_remainder_batch=skip_remainder_batch, ) def _collate_fn(self, items: List[List[torch.Tensor]]): @@ -134,10 +140,8 @@ def _collate_fn(self, items: List[List[torch.Tensor]]): # fairseq expects batches to have the following structure return { - "id": torch.tensor([id]*item.size(0)), - "net_input": { - "src_tokens": item, - }, + "id": torch.tensor([id] * item.size(0)), + "net_input": {"src_tokens": item,}, "target": target, "nsentences": item.size(0), "ntokens": item.numel(), diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 14b4f83330..81b5f56547 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -41,7 +41,7 @@ class CountingIterator(object): def __init__(self, iterable, start=None, total=None): self._itr = iter(iterable) self.n = start or getattr(iterable, "n", 0) - self.total = total or self.n + len(iterable) + self.total = total if total is not None else self.n + len(iterable) def __len__(self): return self.total @@ -265,6 +265,9 @@ class EpochBatchIterator(EpochBatchIterating): from workers. Should always be non-negative (default: ``0``). disable_shuffling (bool, optional): force disable shuffling (default: ``False``). + skip_remainder_batch (bool, optional): if set, discard the last batch in an epoch + for the sake of training stability, as the last batch is usually smaller than + local_batch_size * distributed_word_size (default: ``False``). grouped_shuffling (bool, optional): enable shuffling batches in groups of num_shards. Ensures that each GPU receives similar length sequences when batches are sorted by length. @@ -283,6 +286,7 @@ def __init__( buffer_size=0, timeout=0, disable_shuffling=False, + skip_remainder_batch=False, grouped_shuffling=False, ): assert isinstance(dataset, torch.utils.data.Dataset) @@ -301,6 +305,7 @@ def __init__( self.buffer_size = min(buffer_size, 20) self.timeout = timeout self.disable_shuffling = disable_shuffling + self.skip_remainder_batch = skip_remainder_batch self.grouped_shuffling = grouped_shuffling self.epoch = max(epoch, 1) # we use 1-based indexing for epochs @@ -375,9 +380,7 @@ def next_epoch_itr( # reset _frozen_batches to refresh the next epoch self._frozen_batches = None self._cur_epoch_itr = self._get_iterator_for_epoch( - self.epoch, - shuffle, - fix_batches_to_gpus=fix_batches_to_gpus, + self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus, ) self.shuffle = shuffle return self._cur_epoch_itr @@ -418,9 +421,7 @@ def load_state_dict(self, state_dict): if itr_pos > 0: # fast-forward epoch iterator self._next_epoch_itr = self._get_iterator_for_epoch( - self.epoch, - shuffle=state_dict.get("shuffle", True), - offset=itr_pos, + self.epoch, shuffle=state_dict.get("shuffle", True), offset=itr_pos, ) if self._next_epoch_itr is None: if version == 1: @@ -497,6 +498,14 @@ def shuffle_batches(batches, seed): # Wrap with CountingIterator itr = CountingIterator(itr, start=offset) + + if self.skip_remainder_batch: + # TODO: Below is a lazy implementation which discard the final batch regardless + # of whether it is a full batch or not. + total_num_itrs = len(batches) - 1 + itr.take(total_num_itrs) + logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}") + return itr @@ -506,29 +515,47 @@ class GroupedIterator(CountingIterator): Args: iterable (iterable): iterable to wrap chunk_size (int): size of each chunk - + skip_remainder_batch (bool, optional): if set, discard the last grouped batch in + each training epoch, as the last grouped batch is usually smaller than + local_batch_size * distributed_word_size * chunk_size (default: ``False``). Attributes: n (int): number of elements consumed from this iterator """ - def __init__(self, iterable, chunk_size): - itr = _chunk_iterator(iterable, chunk_size) + def __init__(self, iterable, chunk_size, skip_remainder_batch=False): + if skip_remainder_batch: + total_num_itrs = int(math.floor(len(iterable) / float(chunk_size))) + logger.info( + f"skip final residual batch, grouped total_num_itrs = {total_num_itrs}" + ) + else: + total_num_itrs = int(math.ceil(len(iterable) / float(chunk_size))) + logger.info(f"grouped total_num_itrs = {total_num_itrs}") + + itr = _chunk_iterator(iterable, chunk_size, skip_remainder_batch) super().__init__( itr, start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))), - total=int(math.ceil(len(iterable) / float(chunk_size))), + total=total_num_itrs, ) self.chunk_size = chunk_size + if skip_remainder_batch: + self.take(total_num_itrs) + # TODO: [Hack] Here the grouped iterator modifies the base iterator size so that + # training can move into the next epoch once the grouped iterator is exhausted. + # Double-check this implementation in case unexpected behavior occurs. + iterable.take(total_num_itrs * chunk_size) -def _chunk_iterator(itr, chunk_size): + +def _chunk_iterator(itr, chunk_size, skip_remainder_batch=False): chunk = [] for x in itr: chunk.append(x) if len(chunk) == chunk_size: yield chunk chunk = [] - if len(chunk) > 0: + if not skip_remainder_batch and len(chunk) > 0: yield chunk @@ -546,7 +573,12 @@ class ShardedIterator(CountingIterator): n (int): number of elements consumed from this iterator """ - def __init__(self, iterable, num_shards, shard_id, fill_value=None): + def __init__( + self, iterable, num_shards, shard_id, fill_value=None, skip_remainder_batch=None + ): + """ + Args: + skip_remainder_batch: ignored""" if shard_id < 0 or shard_id >= num_shards: raise ValueError("shard_id must be between 0 and num_shards") sharded_len = int(math.ceil(len(iterable) / float(num_shards))) @@ -611,7 +643,7 @@ def _create_consumer(self): self._queue, self._iterable, self.total, - torch.cuda.current_device() if torch.cuda.is_available() else None + torch.cuda.current_device() if torch.cuda.is_available() else None, ) self._consumer.daemon = True self._consumer.start() diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 452deb1a19..b081e6cabf 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -95,7 +95,6 @@ def from_namespace(cls, args): return config - @dataclass class CommonConfig(FairseqDataclass): # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were @@ -169,11 +168,13 @@ class CommonConfig(FairseqDataclass): metadata={ "help": "if set, the floating point conversion to fp16/bf16 runs on CPU. " "This reduces bus transfer time and GPU memory usage." - } + }, ) min_loss_scale: float = field( default=1e-4, - metadata={"help": "minimum FP16/AMP loss scale, after which training is stopped"}, + metadata={ + "help": "minimum FP16/AMP loss scale, after which training is stopped" + }, ) threshold_loss_scale: Optional[float] = field( default=None, metadata={"help": "threshold FP16 loss scale from below"} @@ -181,7 +182,9 @@ class CommonConfig(FairseqDataclass): amp: bool = field(default=False, metadata={"help": "use automatic mixed precision"}) amp_batch_retries: int = field( default=2, - metadata={"help": "number of retries of same batch after reducing loss scale with AMP"}, + metadata={ + "help": "number of retries of same batch after reducing loss scale with AMP" + }, ) amp_init_scale: int = field( default=2 ** 7, metadata={"help": "default AMP loss scale"} @@ -223,7 +226,7 @@ class CommonConfig(FairseqDataclass): default=False, metadata={ "help": "suppress crashes when training with the hydra_train entry point so that the " - "main method can return a value (useful for sweeps)" + "main method can return a value (useful for sweeps)" }, ) use_plasma_view: bool = field( @@ -440,6 +443,7 @@ class DistributedTrainingConfig(FairseqDataclass): default=False, metadata={"help": "not flatten parameter param for fsdp"}, ) + @dataclass class DatasetConfig(FairseqDataclass): num_workers: int = field( @@ -489,7 +493,7 @@ class DatasetConfig(FairseqDataclass): default=None, metadata={ "help": "comma separated list of data subsets to use for validation" - " (e.g. train, valid, test)", + " (e.g. train, valid, test)", "argparse_alias": "--combine-val", }, ) @@ -527,8 +531,10 @@ class DatasetConfig(FairseqDataclass): "argparse_alias": "--max-sentences-valid", }, ) - max_valid_steps: Optional[int] = field(default=None, metadata={'help': 'How many batches to evaluate', - "argparse_alias": "--nval"}) + max_valid_steps: Optional[int] = field( + default=None, + metadata={"help": "How many batches to evaluate", "argparse_alias": "--nval"}, + ) curriculum: int = field( default=0, metadata={"help": "don't shuffle batches for first N epochs"} ) @@ -558,7 +564,7 @@ class DatasetConfig(FairseqDataclass): default=False, metadata={ "help": "if true then increment seed with epoch for getting batch iterators, defautls to False.", - } + }, ) @@ -607,6 +613,13 @@ class OptimizationConfig(FairseqDataclass): "help": "specify global optimizer for syncing models on different GPUs/shards" }, ) + skip_remainder_batch: Optional[bool] = field( + default=False, + metadata={ + "help": "if set, include the last (partial) batch of each epoch in training" + " (default is to skip it)." + }, + ) @dataclass @@ -669,8 +682,8 @@ class CheckpointConfig(FairseqDataclass): default=-1, metadata={ "help": "when used with --keep-interval-updates, skips deleting " - "any checkpoints with update X where " - "X %% keep_interval_updates_pattern == 0" + "any checkpoints with update X where " + "X %% keep_interval_updates_pattern == 0" }, ) keep_last_epochs: int = field( @@ -1020,25 +1033,22 @@ class InteractiveConfig(FairseqDataclass): @dataclass class EMAConfig(FairseqDataclass): store_ema: bool = field( - default=False, metadata={ - help: "store exponential moving average shadow model" - } + default=False, metadata={help: "store exponential moving average shadow model"} ) ema_decay: float = field( - default=0.9999, metadata={ - "help": 'decay for exponential moving average model' - } + default=0.9999, metadata={"help": "decay for exponential moving average model"} ) - ema_start_update : int = field( + ema_start_update: int = field( default=0, metadata={"help": "start EMA update after this many model updates"} ) - ema_seed_model : Optional[str] = field( - default=None, metadata={ + ema_seed_model: Optional[str] = field( + default=None, + metadata={ "help": "Seed to load EMA model from. " "Used to load EMA model separately from the actual model." - } + }, ) - ema_update_freq : int = field( + ema_update_freq: int = field( default=1, metadata={"help": "Do EMA update every this many model updates"} ) ema_fp32: bool = field( diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index ec62464f03..9cbdf89e6f 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -22,7 +22,6 @@ class StatefulContainer(object): - def __init__(self): self._state = dict() self._factories = dict() @@ -135,7 +134,7 @@ def load_dataset( split: str, combine: bool = False, task_cfg: FairseqDataclass = None, - **kwargs + **kwargs, ): """Load a given dataset split. @@ -220,6 +219,7 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + skip_remainder_batch=False, grouped_shuffling=False, update_epoch_batch_itr=False, ): @@ -254,6 +254,9 @@ def get_batch_iterator( disable_iterator_cache (bool, optional): don't cache the EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) (default: False). + skip_remainder_batch (bool, optional): if set, discard the last + batch in each training epoch, as the last batch is often smaller than + local_batch_size * distributed_word_size (default: ``True``). grouped_shuffling (bool, optional): group batches with each groups containing num_shards batches and shuffle groups. Reduces difference between sequence lengths among workers for batches sorted by length. @@ -307,6 +310,7 @@ def get_batch_iterator( num_workers=num_workers, epoch=epoch, buffer_size=data_buffer_size, + skip_remainder_batch=skip_remainder_batch, grouped_shuffling=grouped_shuffling, ) @@ -348,7 +352,12 @@ def build_criterion(self, cfg: DictConfig): return criterions.build_criterion(cfg, self) def build_generator( - self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + prefix_allowed_tokens_fn=None, ): """ Build a :class:`~fairseq.SequenceGenerator` instance for this diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index f4797f5676..e64ab9a687 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -125,7 +125,7 @@ def check_dicts(self, dicts, source_langs, target_langs): @classmethod def setup_task(cls, args, **kwargs): langs, dicts, training = MultilingualDatasetManager.prepare( - cls.load_dictionary, args, **kwargs + cls.load_dictionary, args, **kwargs ) return cls(args, langs, dicts, training) @@ -197,11 +197,7 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None) return dataset def build_generator( - self, - models, - args, - seq_gen_cls=None, - extra_gen_cls_kwargs=None, + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, ): if not getattr(args, "keep_inference_langtok", False): _, tgt_langtok_spec = self.args.langtoks["main"] @@ -349,6 +345,7 @@ def get_batch_iterator( epoch=1, data_buffer_size=0, disable_iterator_cache=False, + skip_remainder_batch=False, grouped_shuffling=False, update_epoch_batch_itr=False, ): @@ -412,6 +409,7 @@ def get_batch_iterator( epoch=epoch, data_buffer_size=data_buffer_size, disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=skip_remainder_batch, update_epoch_batch_itr=update_epoch_batch_itr, ) self.dataset_to_epoch_iter[dataset] = batch_iter diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 6413411604..30e12dcc98 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -64,6 +64,7 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): if self.is_fsdp: import fairscale + if self.cfg.common.bf16: raise ValueError( "FullyShardedDataParallel is not compatible with --bf16 or " @@ -74,7 +75,10 @@ def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): "FullyShardedDataParallel is not compatible with --zero-sharding " "option (it's already built in)" ) - if max(self.cfg.optimization.update_freq) > 1 and fairscale.__version__ < "0.4.0": + if ( + max(self.cfg.optimization.update_freq) > 1 + and fairscale.__version__ < "0.4.0" + ): raise RuntimeError( "Please update to fairscale 0.4.0 or newer when combining " "--update-freq with FullyShardedDataParallel" @@ -198,9 +202,7 @@ def is_data_parallel_master(self): def use_distributed_wrapper(self) -> bool: return ( self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf - ) or ( - self.is_fsdp and self.cfg.distributed_training.cpu_offload - ) + ) or (self.is_fsdp and self.cfg.distributed_training.cpu_offload) @property def should_save_checkpoint_on_current_rank(self) -> bool: @@ -267,9 +269,7 @@ def ema(self): def _build_ema(self): if self.cfg.ema.store_ema: self._ema = build_ema(self._model, self.cfg.ema, self.device) - logger.info( - "Exponential Moving Average Shadow Model is initialized." - ) + logger.info("Exponential Moving Average Shadow Model is initialized.") @property def optimizer(self): @@ -320,7 +320,9 @@ def _build_optimizer(self): self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params) else: if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: - logger.info("NOTE: your device may support faster training with --fp16 or --amp") + logger.info( + "NOTE: your device may support faster training with --fp16 or --amp" + ) self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) if self.is_fsdp: @@ -335,10 +337,7 @@ def _build_optimizer(self): ) if self.cfg.optimization.use_bmuf: - self._optimizer = optim.FairseqBMUF( - self.cfg.bmuf, - self._optimizer, - ) + self._optimizer = optim.FairseqBMUF(self.cfg.bmuf, self._optimizer,) if self.cfg.distributed_training.zero_sharding == "os": if ( @@ -356,8 +355,7 @@ def _build_optimizer(self): # We should initialize the learning rate scheduler immediately after # building the optimizer, so that the initial learning rate is set. self._lr_scheduler = lr_scheduler.build_lr_scheduler( - self.cfg.lr_scheduler, - self.optimizer, + self.cfg.lr_scheduler, self.optimizer, ) self._lr_scheduler.step_update(0) @@ -579,18 +577,16 @@ def load_checkpoint( "EMA not found in checkpoint. But store_ema is True. " "EMA is re-initialized from checkpoint." ) - self.ema.restore(state["model"], build_fp32_params=self.cfg.ema.ema_fp32) - else: - logger.info( - "Loading EMA from checkpoint" + self.ema.restore( + state["model"], build_fp32_params=self.cfg.ema.ema_fp32 ) + else: + logger.info("Loading EMA from checkpoint") self.ema.restore(extra_state["ema"], build_fp32_params=False) if self.cfg.ema.ema_fp32: if "ema_fp32_params" in extra_state: - logger.info( - "Loading EMA fp32 params from checkpoint" - ) + logger.info("Loading EMA fp32 params from checkpoint") self.ema.build_fp32_params(extra_state["ema_fp32_params"]) else: logger.info( @@ -648,6 +644,7 @@ def get_train_iterator( epoch=epoch, data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=self.cfg.optimization.skip_remainder_batch, grouped_shuffling=self.cfg.dataset.grouped_shuffling, update_epoch_batch_itr=self.cfg.dataset.update_epoch_batch_itr, ) @@ -655,9 +652,7 @@ def get_train_iterator( return batch_iterator def get_valid_iterator( - self, - subset, - disable_iterator_cache=False, + self, subset, disable_iterator_cache=False, ): """Return an EpochBatchIterator over given validation subset for a given epoch.""" batch_iterator = self.task.get_batch_iterator( @@ -665,8 +660,7 @@ def get_valid_iterator( max_tokens=self.cfg.dataset.max_tokens_valid, max_sentences=self.cfg.dataset.batch_size_valid, max_positions=utils.resolve_max_positions( - self.task.max_positions(), - self.model.max_positions(), + self.task.max_positions(), self.model.max_positions(), ), ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, @@ -679,6 +673,7 @@ def get_valid_iterator( epoch=1, data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=False, ) self.reset_dummy_batch(batch_iterator.first_batch) return batch_iterator @@ -812,10 +807,9 @@ def maybe_no_sync(): # gather logging outputs from all replicas if self._sync_stats(): train_time = self._local_cumulative_training_time() - logging_outputs, ( - sample_size, - ooms, - total_train_time, + ( + logging_outputs, + (sample_size, ooms, total_train_time,), ) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch ) @@ -882,7 +876,9 @@ def maybe_no_sync(): self._amp_retries = 0 else: self._amp_retries += 1 - return self.train_step(samples, raise_oom) # recursion to feed in same batch + return self.train_step( + samples, raise_oom + ) # recursion to feed in same batch except FloatingPointError: # re-run the forward and backward pass with hooks attached to print @@ -918,8 +914,7 @@ def maybe_no_sync(): # after the step if hasattr(self.model, "perform_slowmo"): self.model.perform_slowmo( - self.optimizer.optimizer, - getattr(self.optimizer, "fp32_params", None) + self.optimizer.optimizer, getattr(self.optimizer, "fp32_params", None) ) logging_output = None @@ -929,8 +924,7 @@ def maybe_no_sync(): if self.cfg.ema.store_ema: # Step EMA forward with new model. self.ema.step( - self.get_model(), - self.get_num_updates(), + self.get_model(), self.get_num_updates(), ) metrics.log_scalar( "ema_decay", @@ -1064,9 +1058,7 @@ def valid_step(self, sample, raise_oom=False): # gather logging outputs from all replicas if self.data_parallel_world_size > 1: logging_outputs, (sample_size,) = self._aggregate_logging_outputs( - logging_outputs, - sample_size, - ignore=is_dummy_batch, + logging_outputs, sample_size, ignore=is_dummy_batch, ) # log validation stats @@ -1175,12 +1167,9 @@ def agg_norm_fn(total_norm): ) return total_norm ** 0.5 - should_agg_norm = ( - self.is_fsdp - and ( - self.data_parallel_process_group is not None - or torch.distributed.is_initialized() - ) + should_agg_norm = self.is_fsdp and ( + self.data_parallel_process_group is not None + or torch.distributed.is_initialized() ) return self.optimizer.clip_grad_norm( clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None @@ -1240,8 +1229,10 @@ def _prepare_sample(self, sample, is_dummy=False): if self.cuda: if self.pipeline_model_parallel: - if 'target' in sample: - sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device) + if "target" in sample: + sample["target"] = utils.move_to_cuda( + sample["target"], device=self.last_device + ) else: sample = utils.move_to_cuda(sample) elif self.tpu and is_dummy: @@ -1269,10 +1260,9 @@ def _sync_stats(self): return False elif self.cfg.optimization.use_bmuf: return ( - self.get_num_updates() + 1 - ) % self.cfg.bmuf.global_sync_iter == 0 and ( - self.get_num_updates() + 1 - ) > self.cfg.bmuf.warmup_iterations + (self.get_num_updates() + 1) % self.cfg.bmuf.global_sync_iter == 0 + and (self.get_num_updates() + 1) > self.cfg.bmuf.warmup_iterations + ) else: return True @@ -1285,10 +1275,7 @@ def _log_oom(self, exc): sys.stderr.flush() def _aggregate_logging_outputs( - self, - logging_outputs: List[Dict[str, Any]], - *extra_stats_to_sum, - ignore=False, + self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, ): if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()): return self._fast_stat_sync_sum( @@ -1300,10 +1287,7 @@ def _aggregate_logging_outputs( ) def _all_gather_list_sync( - self, - logging_outputs: List[Dict[str, Any]], - *extra_stats_to_sum, - ignore=False, + self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, ): """ Sync logging outputs across workers. all_gather_list_sync is @@ -1328,10 +1312,7 @@ def _all_gather_list_sync( return logging_outputs, extra_stats_to_sum def _fast_stat_sync_sum( - self, - logging_outputs: List[Dict[str, Any]], - *extra_stats_to_sum, - ignore=False, + self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, ): """ Sync logging outputs across workers. fast_stat_sync_sum is @@ -1379,10 +1360,11 @@ def _check_grad_norms(self, grad_norm): def is_consistent(tensor): max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) return ( - (torch.isfinite(tensor).all() - and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all()) - or - (self.cfg.common.amp and not torch.isfinite(tensor).all()) + ( + torch.isfinite(tensor).all() + and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + ) + or (self.cfg.common.amp and not torch.isfinite(tensor).all()) # in case of amp non-finite grads are fine ) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 8347587313..369a8a82c5 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -44,15 +44,16 @@ from omegaconf import DictConfig, OmegaConf - - def main(cfg: FairseqConfig) -> None: if isinstance(cfg, argparse.Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) - if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: + if ( + distributed_utils.is_master(cfg.distributed_training) + and "job_logging_cfg" in cfg + ): # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) @@ -102,15 +103,25 @@ def main(cfg: FairseqConfig) -> None: logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( "num. shared model params: {:,} (num. trained: {:,})".format( - sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)), - sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad) + sum( + p.numel() for p in model.parameters() if not getattr(p, "expert", False) + ), + sum( + p.numel() + for p in model.parameters() + if not getattr(p, "expert", False) and p.requires_grad + ), ) ) logger.info( "num. expert model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)), - sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad), + sum( + p.numel() + for p in model.parameters() + if getattr(p, "expert", False) and p.requires_grad + ), ) ) @@ -145,8 +156,7 @@ def main(cfg: FairseqConfig) -> None: ) logger.info( "max tokens per device = {} and max sentences per device = {}".format( - cfg.dataset.max_tokens, - cfg.dataset.batch_size, + cfg.dataset.max_tokens, cfg.dataset.batch_size, ) ) @@ -160,6 +170,7 @@ def main(cfg: FairseqConfig) -> None: ) if cfg.common.tpu: import torch_xla.core.xla_model as xm + xm.rendezvous("load_checkpoint") # wait for all workers max_epoch = cfg.optimization.max_epoch or math.inf @@ -247,7 +258,9 @@ def train( if epoch_itr.epoch <= len(cfg.optimization.update_freq) else cfg.optimization.update_freq[-1] ) - itr = iterators.GroupedIterator(itr, update_freq) + itr = iterators.GroupedIterator( + itr, update_freq, skip_remainder_batch=cfg.optimization.skip_remainder_batch, + ) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( @@ -376,15 +389,19 @@ def validate_and_save( ) ) do_validate = ( - (not end_of_epoch and do_save) # validate during mid-epoch saves - or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) - or should_stop - or ( - cfg.dataset.validate_interval_updates > 0 - and num_updates > 0 - and num_updates % cfg.dataset.validate_interval_updates == 0 + ( + (not end_of_epoch and do_save) # validate during mid-epoch saves + or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) + or should_stop + or ( + cfg.dataset.validate_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.dataset.validate_interval_updates == 0 + ) ) - ) and not cfg.dataset.disable_validation and num_updates >= cfg.dataset.validate_after_updates + and not cfg.dataset.disable_validation + and num_updates >= cfg.dataset.validate_after_updates + ) # Validate valid_losses = [None] @@ -457,7 +474,10 @@ def validate( # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for i, sample in enumerate(progress): - if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps: + if ( + cfg.dataset.max_valid_steps is not None + and i > cfg.dataset.max_valid_steps + ): break trainer.valid_step(sample) @@ -497,7 +517,9 @@ def cli_main( if cfg.common.use_plasma_view: server = PlasmaStore(path=cfg.common.plasma_path) - logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}") + logger.info( + f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}" + ) if args.profile: with torch.cuda.profiler.profile(): diff --git a/tests/test_iterators.py b/tests/test_iterators.py index 7b3dd48485..2e2eb2f0a8 100644 --- a/tests/test_iterators.py +++ b/tests/test_iterators.py @@ -5,7 +5,7 @@ import unittest -from fairseq.data import iterators +from fairseq.data import iterators, ListDataset class TestIterators(unittest.TestCase): @@ -132,6 +132,63 @@ def test_counting_iterator_buffered_iterator_take(self): self.assertFalse(itr.has_next()) self.assertRaises(StopIteration, next, buffered_itr) + def test_epoch_batch_iterator_skip_remainder_batch(self): + reference = [1, 2, 3] + itr1 = _get_epoch_batch_itr(reference, 2, True) + self.assertEqual(len(itr1), 1) + itr2 = _get_epoch_batch_itr(reference, 2, False) + self.assertEqual(len(itr2), 2) + itr3 = _get_epoch_batch_itr(reference, 1, True) + self.assertEqual(len(itr3), 2) + itr4 = _get_epoch_batch_itr(reference, 1, False) + self.assertEqual(len(itr4), 3) + itr5 = _get_epoch_batch_itr(reference, 4, True) + self.assertEqual(len(itr5), 0) + self.assertFalse(itr5.has_next()) + itr6 = _get_epoch_batch_itr(reference, 4, False) + self.assertEqual(len(itr6), 1) + + def test_grouped_iterator_skip_remainder_batch(self): + reference = [1, 2, 3, 4, 5, 6, 7, 8, 9] + itr1 = _get_epoch_batch_itr(reference, 3, False) + grouped_itr1 = iterators.GroupedIterator(itr1, 2, True) + self.assertEqual(len(grouped_itr1), 1) + + itr2 = _get_epoch_batch_itr(reference, 3, False) + grouped_itr2 = iterators.GroupedIterator(itr2, 2, False) + self.assertEqual(len(grouped_itr2), 2) + + itr3 = _get_epoch_batch_itr(reference, 3, True) + grouped_itr3 = iterators.GroupedIterator(itr3, 2, True) + self.assertEqual(len(grouped_itr3), 1) + + itr4 = _get_epoch_batch_itr(reference, 3, True) + grouped_itr4 = iterators.GroupedIterator(itr4, 2, False) + self.assertEqual(len(grouped_itr4), 1) + + itr5 = _get_epoch_batch_itr(reference, 5, True) + grouped_itr5 = iterators.GroupedIterator(itr5, 2, True) + self.assertEqual(len(grouped_itr5), 0) + + itr6 = _get_epoch_batch_itr(reference, 5, True) + grouped_itr6 = iterators.GroupedIterator(itr6, 2, False) + self.assertEqual(len(grouped_itr6), 1) + + +def _get_epoch_batch_itr(ref, bsz, skip_remainder_batch): + dsz = len(ref) + indices = range(dsz) + starts = indices[::bsz] + batch_sampler = [indices[s : s + bsz] for s in starts] + dataset = ListDataset(ref) + itr = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=batch_sampler, + skip_remainder_batch=skip_remainder_batch, + ) + return itr.next_epoch_itr() + if __name__ == "__main__": unittest.main()