From c725c4ba33c816d3f35531819abdbe4ea60ea08b Mon Sep 17 00:00:00 2001 From: Thomas MK Date: Tue, 27 Jun 2023 22:28:25 +0000 Subject: [PATCH] Prepare for label noise experiments --- src/algs/adv/supmatch.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/algs/adv/supmatch.py b/src/algs/adv/supmatch.py index 01acb426..e01c623a 100644 --- a/src/algs/adv/supmatch.py +++ b/src/algs/adv/supmatch.py @@ -33,7 +33,11 @@ def _get_data_iterators(self, dm: DataModule) -> tuple[IterTr, IterDep]: dl_tr = dm.train_dataloader(balance=True) # The batch size needs to be consistent for the aggregation layer in the setwise neural # discriminator - dl_dep = dm.deployment_dataloader(batch_size=dm.batch_size_tr) + dl_dep = dm.deployment_dataloader( + batch_size=dl_tr.batch_sampler.batch_size + if dm.deployment_ids is None + else dm.batch_size_tr + ) return iter(dl_tr), iter(dl_dep) @override