Skip to content

Commit

Permalink
Prepare for label noise experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Oct 2, 2023
1 parent 68394b9 commit c725c4b
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/algs/adv/supmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def _get_data_iterators(self, dm: DataModule) -> tuple[IterTr, IterDep]:
dl_tr = dm.train_dataloader(balance=True)
# The batch size needs to be consistent for the aggregation layer in the setwise neural
# discriminator
dl_dep = dm.deployment_dataloader(batch_size=dm.batch_size_tr)
dl_dep = dm.deployment_dataloader(
batch_size=dl_tr.batch_sampler.batch_size
if dm.deployment_ids is None
else dm.batch_size_tr
)
return iter(dl_tr), iter(dl_dep)

@override
Expand Down

0 comments on commit c725c4b

Please sign in to comment.