Skip to content

Commit

Permalink
Replaced inefficient sums in batch sampler.
Browse files Browse the repository at this point in the history
  • Loading branch information
joelloo committed Apr 30, 2023
1 parent fcb40ea commit a3afb3f
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,17 +354,18 @@ def __iter__(self):
# self.train_set.init_dataset() # bugs here? Batch intention messed up. Manual init preferred.
self.forward, self.left, self.right, self.elevator = self.group_samples()
if self.shuffle_on:
self.shuffle()
all_groups = []
self.shuffle()
batch_lists = []
for group in [self.forward, self.left, self.right, self.elevator]:
# for each group. easy samples at first when no shuffle
for value in group.values():
all_groups.append(chunk_by_max_len(value, self.batch_size, drop_last=self.drop_last))
all = sum(all_groups, [])
batch_by_seq_len = chunk_by_max_len(value, self.batch_size, drop_last=self.drop_last)
for batch_list in batch_by_seq_len:
batch_lists.append(batch_list)
if self.shuffle_on:
random.shuffle(all)
all = sum(all, [])
return iter(all)
random.shuffle(batch_lists)
flattened = [idx for batch_list in batch_lists for idx in batch_list]
return iter(flattened)

def __len__(self):
return self.length

0 comments on commit a3afb3f

Please sign in to comment.