diff --git a/src/algs/adv/supmatch.py b/src/algs/adv/supmatch.py index 01acb426..e01c623a 100644 --- a/src/algs/adv/supmatch.py +++ b/src/algs/adv/supmatch.py @@ -33,7 +33,11 @@ def _get_data_iterators(self, dm: DataModule) -> tuple[IterTr, IterDep]: dl_tr = dm.train_dataloader(balance=True) # The batch size needs to be consistent for the aggregation layer in the setwise neural # discriminator - dl_dep = dm.deployment_dataloader(batch_size=dm.batch_size_tr) + dl_dep = dm.deployment_dataloader( + batch_size=dl_tr.batch_sampler.batch_size + if dm.deployment_ids is None + else dm.batch_size_tr + ) return iter(dl_tr), iter(dl_dep) @override