Skip to content

Commit

Permalink
Stop numerical tests from flaking; use assertRaisesRegex (#1991)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1991

see title

Reviewed By: SebastianAment

Differential Revision: D48523999

fbshipit-source-id: f2296f8f87080f89edcd1c4d8a68b36fbf90816f
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 21, 2023
1 parent b576f8d commit c43b074
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 24 deletions.
27 changes: 17 additions & 10 deletions test/acquisition/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,29 +399,36 @@ def test_identity_mc_objective(self):


class TestLinearMCObjective(BotorchTestCase):
def test_linear_mc_objective(self):
def test_linear_mc_objective(self) -> None:
# Test passes for each seed
torch.manual_seed(torch.randint(high=1000, size=(1,)))
for dtype in (torch.float, torch.double):
weights = torch.rand(3, device=self.device, dtype=dtype)
obj = LinearMCObjective(weights=weights)
samples = torch.randn(4, 2, 3, device=self.device, dtype=dtype)
self.assertTrue(
torch.allclose(obj(samples), (samples * weights).sum(dim=-1))
)
atol = 1e-8 if dtype == torch.double else 3e-8
rtol = 1e-5 if dtype == torch.double else 4e-5
self.assertAllClose(obj(samples), samples @ weights, atol=atol, rtol=rtol)
samples = torch.randn(5, 4, 2, 3, device=self.device, dtype=dtype)
self.assertTrue(
torch.allclose(obj(samples), (samples * weights).sum(dim=-1))
self.assertAllClose(
obj(samples),
samples @ weights,
atol=atol,
rtol=rtol,
)
# make sure this errors if sample output dimensions are incompatible
with self.assertRaises(RuntimeError):
shape_mismatch_msg = "Output shape of samples not equal to that of weights"
with self.assertRaisesRegex(RuntimeError, shape_mismatch_msg):
obj(samples=torch.randn(2, device=self.device, dtype=dtype))
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, shape_mismatch_msg):
obj(samples=torch.randn(1, device=self.device, dtype=dtype))
# make sure we can't construct objectives with multi-dim. weights
with self.assertRaises(ValueError):
weights_1d_msg = "weights must be a one-dimensional tensor."
with self.assertRaisesRegex(ValueError, expected_regex=weights_1d_msg):
LinearMCObjective(
weights=torch.rand(2, 3, device=self.device, dtype=dtype)
)
with self.assertRaises(ValueError):
with self.assertRaisesRegex(ValueError, expected_regex=weights_1d_msg):
LinearMCObjective(
weights=torch.tensor(1.0, device=self.device, dtype=dtype)
)
Expand Down
37 changes: 23 additions & 14 deletions test/utils/probability/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,18 @@ def test_swap_along_dim_(self):
with self.assertRaisesRegex(ValueError, "at most 1-dimensional"):
utils.swap_along_dim_(values.view(-1), i=i_lidx, j=j, dim=0)

def test_gaussian_probabilities(self):
def test_gaussian_probabilities(self) -> None:
# test passes for each possible seed
torch.manual_seed(torch.randint(high=1000, size=(1,)))
# testing Gaussian probability functions
for dtype in (torch.float, torch.double):
rtol = 1e-12 if dtype == torch.double else 1e-6
atol = rtol
n = 16
x = 3 * torch.randn(n, device=self.device, dtype=dtype)
# first, test consistency between regular and log versions
self.assertTrue(
torch.allclose(phi(x), log_phi(x).exp(), atol=atol, rtol=rtol)
)
self.assertTrue(
torch.allclose(ndtr(x), log_ndtr(x).exp(), atol=atol, rtol=rtol)
)
self.assertAllClose(phi(x), log_phi(x).exp(), atol=atol, rtol=rtol)
self.assertAllClose(ndtr(x), log_ndtr(x).exp(), atol=atol, rtol=rtol)

# test correctness of log_erfc and log_erfcx
for special_f, custom_log_f in zip(
Expand Down Expand Up @@ -291,10 +289,13 @@ def test_gaussian_probabilities(self):
self.assertTrue((a.grad.diff() < 0).all())

# testing error raising for invalid inputs
with self.assertRaises(ValueError):
a = torch.randn(3, 4, dtype=dtype, device=self.device)
b = torch.randn(3, 4, dtype=dtype, device=self.device)
a[2, 3] = b[2, 3]
a = torch.randn(3, 4, dtype=dtype, device=self.device)
b = torch.randn(3, 4, dtype=dtype, device=self.device)
a[2, 3] = b[2, 3]
with self.assertRaisesRegex(
ValueError,
"Received input tensors a, b for which not all a < b.",
):
log_prob_normal_in(a, b)

# testing gaussian hazard function
Expand All @@ -303,12 +304,20 @@ def test_gaussian_probabilities(self):
x = torch.cat((-x, x))
log_hx = standard_normal_log_hazard(x)
expected_log_hx = log_phi(x) - log_ndtr(-x)
self.assertAllClose(expected_log_hx, log_hx) # correctness
self.assertAllClose(
expected_log_hx,
log_hx,
atol=1e-8 if dtype == torch.double else 1e-7,
) # correctness
# NOTE: Could extend tests here similarly to log_erfc(x) tests above, but
# since the hazard functions are built on log_erfcx, not urgent.

with self.assertRaises(TypeError):
float16_msg = (
"only supports torch.float32 and torch.float64 dtypes, but received "
"x.dtype = torch.float16."
)
with self.assertRaisesRegex(TypeError, expected_regex=float16_msg):
log_erfc(torch.tensor(1.0, dtype=torch.float16, device=self.device))

with self.assertRaises(TypeError):
with self.assertRaisesRegex(TypeError, expected_regex=float16_msg):
log_ndtr(torch.tensor(1.0, dtype=torch.float16, device=self.device))

0 comments on commit c43b074

Please sign in to comment.