Skip to content

Commit

Permalink
[fix] Fix different batches per epoch in NoDuplicatesBatchSampler (#…
Browse files Browse the repository at this point in the history
…3073)

* Fix different batches per epoch in NoDuplicatesBatchSampler

* Use dict.fromkeys instead of a dict comprehension
  • Loading branch information
tomaarsen authored Nov 20, 2024
1 parent 0434450 commit 8fabce0
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sentence_transformers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,10 @@ def __iter__(self) -> Iterator[list[int]]:
if self.generator and self.seed:
self.generator.manual_seed(self.seed + self.epoch)

remaining_indices = set(torch.randperm(len(self.dataset), generator=self.generator).tolist())
# We create a dictionary to None because we need a data structure that:
# 1. Allows for cheap removal of elements
# 2. Preserves the order of elements, i.e. remains random
remaining_indices = dict.fromkeys(torch.randperm(len(self.dataset), generator=self.generator).tolist())
while remaining_indices:
batch_values = set()
batch_indices = []
Expand All @@ -209,7 +212,8 @@ def __iter__(self) -> Iterator[list[int]]:
if not self.drop_last:
yield batch_indices

remaining_indices -= set(batch_indices)
for index in batch_indices:
del remaining_indices[index]

def __len__(self) -> int:
if self.drop_last:
Expand Down

0 comments on commit 8fabce0

Please sign in to comment.