Skip to content

Commit

Permalink
Add exception when dummy batch is missing (fixes facebookresearch#1726)…
Browse files Browse the repository at this point in the history
… (facebookresearch#1735)

Summary: Pull Request resolved: facebookresearch#1735

Differential Revision: D20034557

Pulled By: myleott

fbshipit-source-id: a2ed5acd5e79b2cfd0b073bb8aabcb39172f7dc5
  • Loading branch information
myleott authored and facebook-github-bot committed Feb 24, 2020
1 parent b3ca86a commit ed4aa2c
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, args, task, model, criterion):
self._criterion = self._criterion.to(device=self.device)
self._model = self._model.to(device=self.device)

self._dummy_batch = None
self._dummy_batch = "DUMMY" # indicates we don't have a dummy batch at first
self._lr_scheduler = None
self._num_updates = 0
self._optim_history = None
Expand Down Expand Up @@ -420,6 +420,9 @@ def maybe_no_sync():
@metrics.aggregate("valid")
def valid_step(self, sample, raise_oom=False):
"""Do forward pass in evaluation mode."""
if self._dummy_batch is None:
self._dummy_batch = sample

with torch.no_grad():
self.model.eval()
self.criterion.eval()
Expand Down Expand Up @@ -542,6 +545,13 @@ def set_num_updates(self, num_updates):
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)

def _prepare_sample(self, sample):
if sample == "DUMMY":
raise Exception(
"Trying to use an uninitialized 'dummy' batch. This usually indicates "
"that the total number of batches is smaller than the number of "
"participating GPUs. Try reducing the batch size or using fewer GPUs."
)

if sample is None or len(sample) == 0:
return None

Expand Down

0 comments on commit ed4aa2c

Please sign in to comment.