diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index 54d276e3b9..755fb4af30 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -249,7 +249,9 @@ def sample_polytope( # pre-sample samples from hypersphere d = x0.size(0) # uniform samples from unit ball in d dims - Rs = sample_hypersphere(d=d, n=n_tot, dtype=A.dtype, device=A.device).unsqueeze(-1) + Rs = sample_hypersphere( + d=d, n=n_tot, dtype=A.dtype, device=A.device, seed=seed + ).unsqueeze(-1) # compute matprods in batch ARs = (A @ Rs).squeeze(-1) diff --git a/test/utils/test_sampling.py b/test/utils/test_sampling.py index 4876a3ae16..535d09d062 100644 --- a/test/utils/test_sampling.py +++ b/test/utils/test_sampling.py @@ -348,6 +348,29 @@ def test_sample_polytope(self): self.assertTrue((more_samples <= bounds[1]).all()) self.assertTrue((more_samples >= bounds[0]).all()) + def test_sample_polytope_with_seed(self): + for dtype in (torch.float, torch.double): + A = self.A.to(dtype) + b = self.b.to(dtype) + x0 = self.x0.to(dtype) + bounds = self.bounds.to(dtype) + for interior_point in [x0, None]: + sampler1 = self.sampler_class( + inequality_constraints=(A, b), + bounds=bounds, + interior_point=interior_point, + **self.sampler_kwargs, + ) + sampler2 = self.sampler_class( + inequality_constraints=(A, b), + bounds=bounds, + interior_point=interior_point, + **self.sampler_kwargs, + ) + samples1 = sampler1.draw(n=10, seed=42) + samples2 = sampler2.draw(n=10, seed=42) + self.assertTrue(torch.allclose(samples1, samples2)) + def test_sample_polytope_with_eq_constraints(self): for dtype in (torch.float, torch.double): A = self.A.to(dtype)