From c43b074f559831b5e8a1863a9a469bf2978046a8 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Mon, 21 Aug 2023 11:36:10 -0700 Subject: [PATCH] Stop numerical tests from flaking; use assertRaisesRegex (#1991) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1991 see title Reviewed By: SebastianAment Differential Revision: D48523999 fbshipit-source-id: f2296f8f87080f89edcd1c4d8a68b36fbf90816f --- test/acquisition/test_objective.py | 27 ++++++++++++-------- test/utils/probability/test_utils.py | 37 +++++++++++++++++----------- 2 files changed, 40 insertions(+), 24 deletions(-) diff --git a/test/acquisition/test_objective.py b/test/acquisition/test_objective.py index b726e41ddb..ba2f8c9817 100644 --- a/test/acquisition/test_objective.py +++ b/test/acquisition/test_objective.py @@ -399,29 +399,36 @@ def test_identity_mc_objective(self): class TestLinearMCObjective(BotorchTestCase): - def test_linear_mc_objective(self): + def test_linear_mc_objective(self) -> None: + # Test passes for each seed + torch.manual_seed(torch.randint(high=1000, size=(1,))) for dtype in (torch.float, torch.double): weights = torch.rand(3, device=self.device, dtype=dtype) obj = LinearMCObjective(weights=weights) samples = torch.randn(4, 2, 3, device=self.device, dtype=dtype) - self.assertTrue( - torch.allclose(obj(samples), (samples * weights).sum(dim=-1)) - ) + atol = 1e-8 if dtype == torch.double else 3e-8 + rtol = 1e-5 if dtype == torch.double else 4e-5 + self.assertAllClose(obj(samples), samples @ weights, atol=atol, rtol=rtol) samples = torch.randn(5, 4, 2, 3, device=self.device, dtype=dtype) - self.assertTrue( - torch.allclose(obj(samples), (samples * weights).sum(dim=-1)) + self.assertAllClose( + obj(samples), + samples @ weights, + atol=atol, + rtol=rtol, ) # make sure this errors if sample output dimensions are incompatible - with self.assertRaises(RuntimeError): + shape_mismatch_msg = "Output shape of samples not equal to that of weights" + with self.assertRaisesRegex(RuntimeError, shape_mismatch_msg): obj(samples=torch.randn(2, device=self.device, dtype=dtype)) - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(RuntimeError, shape_mismatch_msg): obj(samples=torch.randn(1, device=self.device, dtype=dtype)) # make sure we can't construct objectives with multi-dim. weights - with self.assertRaises(ValueError): + weights_1d_msg = "weights must be a one-dimensional tensor." + with self.assertRaisesRegex(ValueError, expected_regex=weights_1d_msg): LinearMCObjective( weights=torch.rand(2, 3, device=self.device, dtype=dtype) ) - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, expected_regex=weights_1d_msg): LinearMCObjective( weights=torch.tensor(1.0, device=self.device, dtype=dtype) ) diff --git a/test/utils/probability/test_utils.py b/test/utils/probability/test_utils.py index 8c17fc28d4..2ef39717fd 100644 --- a/test/utils/probability/test_utils.py +++ b/test/utils/probability/test_utils.py @@ -153,7 +153,9 @@ def test_swap_along_dim_(self): with self.assertRaisesRegex(ValueError, "at most 1-dimensional"): utils.swap_along_dim_(values.view(-1), i=i_lidx, j=j, dim=0) - def test_gaussian_probabilities(self): + def test_gaussian_probabilities(self) -> None: + # test passes for each possible seed + torch.manual_seed(torch.randint(high=1000, size=(1,))) # testing Gaussian probability functions for dtype in (torch.float, torch.double): rtol = 1e-12 if dtype == torch.double else 1e-6 @@ -161,12 +163,8 @@ def test_gaussian_probabilities(self): n = 16 x = 3 * torch.randn(n, device=self.device, dtype=dtype) # first, test consistency between regular and log versions - self.assertTrue( - torch.allclose(phi(x), log_phi(x).exp(), atol=atol, rtol=rtol) - ) - self.assertTrue( - torch.allclose(ndtr(x), log_ndtr(x).exp(), atol=atol, rtol=rtol) - ) + self.assertAllClose(phi(x), log_phi(x).exp(), atol=atol, rtol=rtol) + self.assertAllClose(ndtr(x), log_ndtr(x).exp(), atol=atol, rtol=rtol) # test correctness of log_erfc and log_erfcx for special_f, custom_log_f in zip( @@ -291,10 +289,13 @@ def test_gaussian_probabilities(self): self.assertTrue((a.grad.diff() < 0).all()) # testing error raising for invalid inputs - with self.assertRaises(ValueError): - a = torch.randn(3, 4, dtype=dtype, device=self.device) - b = torch.randn(3, 4, dtype=dtype, device=self.device) - a[2, 3] = b[2, 3] + a = torch.randn(3, 4, dtype=dtype, device=self.device) + b = torch.randn(3, 4, dtype=dtype, device=self.device) + a[2, 3] = b[2, 3] + with self.assertRaisesRegex( + ValueError, + "Received input tensors a, b for which not all a < b.", + ): log_prob_normal_in(a, b) # testing gaussian hazard function @@ -303,12 +304,20 @@ def test_gaussian_probabilities(self): x = torch.cat((-x, x)) log_hx = standard_normal_log_hazard(x) expected_log_hx = log_phi(x) - log_ndtr(-x) - self.assertAllClose(expected_log_hx, log_hx) # correctness + self.assertAllClose( + expected_log_hx, + log_hx, + atol=1e-8 if dtype == torch.double else 1e-7, + ) # correctness # NOTE: Could extend tests here similarly to log_erfc(x) tests above, but # since the hazard functions are built on log_erfcx, not urgent. - with self.assertRaises(TypeError): + float16_msg = ( + "only supports torch.float32 and torch.float64 dtypes, but received " + "x.dtype = torch.float16." + ) + with self.assertRaisesRegex(TypeError, expected_regex=float16_msg): log_erfc(torch.tensor(1.0, dtype=torch.float16, device=self.device)) - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, expected_regex=float16_msg): log_ndtr(torch.tensor(1.0, dtype=torch.float16, device=self.device))