Skip to content

Commit

Permalink
Hotfix/polytopesampler seed (#1968)
Browse files Browse the repository at this point in the history
Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
-->

## Motivation

During bugfixing, I came around the issue that setting a seed in `get_polytope_samples` is not leading to the same samples when called several times with the same seed. This PR fixes it.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes.

Pull Request resolved: #1968

Test Plan: Unit tests.

Reviewed By: saitcakmak

Differential Revision: D47993712

Pulled By: Balandat

fbshipit-source-id: 95b3e548e78609a5d6593addbba0cf7585a76120
  • Loading branch information
jduerholt authored and facebook-github-bot committed Aug 2, 2023
1 parent 9915f8a commit 5abab4d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
4 changes: 3 additions & 1 deletion botorch/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def sample_polytope(
# pre-sample samples from hypersphere
d = x0.size(0)
# uniform samples from unit ball in d dims
Rs = sample_hypersphere(d=d, n=n_tot, dtype=A.dtype, device=A.device).unsqueeze(-1)
Rs = sample_hypersphere(
d=d, n=n_tot, dtype=A.dtype, device=A.device, seed=seed
).unsqueeze(-1)

# compute matprods in batch
ARs = (A @ Rs).squeeze(-1)
Expand Down
23 changes: 23 additions & 0 deletions test/utils/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,29 @@ def test_sample_polytope(self):
self.assertTrue((more_samples <= bounds[1]).all())
self.assertTrue((more_samples >= bounds[0]).all())

def test_sample_polytope_with_seed(self):
for dtype in (torch.float, torch.double):
A = self.A.to(dtype)
b = self.b.to(dtype)
x0 = self.x0.to(dtype)
bounds = self.bounds.to(dtype)
for interior_point in [x0, None]:
sampler1 = self.sampler_class(
inequality_constraints=(A, b),
bounds=bounds,
interior_point=interior_point,
**self.sampler_kwargs,
)
sampler2 = self.sampler_class(
inequality_constraints=(A, b),
bounds=bounds,
interior_point=interior_point,
**self.sampler_kwargs,
)
samples1 = sampler1.draw(n=10, seed=42)
samples2 = sampler2.draw(n=10, seed=42)
self.assertTrue(torch.allclose(samples1, samples2))

def test_sample_polytope_with_eq_constraints(self):
for dtype in (torch.float, torch.double):
A = self.A.to(dtype)
Expand Down

0 comments on commit 5abab4d

Please sign in to comment.