Skip to content

Commit

Permalink
Stop some warnings in unit tests (#1992)
Browse files Browse the repository at this point in the history
Summary:
## Motivation

Warning output from unit tests sometimes indicates a serious problem and sometimes merely clutters the output so we can't notice the serious problems. These warnings are the latter:
*  InputDataWarning: Input data is not contained to the unit cube. Please consider min-max scaling the input data (occurred 194 times, now 0)
* BadInitialCandidatesWarning: Unable to find non-zero acquisition function values - initial conditions are being selected randomly. (occurred 40 times, now 0)
*  The first positional argument of samplers, `num_samples`, has been deprecated and replaced with `sample_shape`, which expects a `torch.Size` object.' (occurred 35 times, now 0)

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #1992

Test Plan:
Units

## Related PRs

#1792, #1539

Reviewed By: saitcakmak

Differential Revision: D48530764

Pulled By: esantorella

fbshipit-source-id: c5b1898ce8156a6f02550acb09bd5eba0c157c5e
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 22, 2023
1 parent c43b074 commit 4cc5ed5
Show file tree
Hide file tree
Showing 20 changed files with 99 additions and 64 deletions.
38 changes: 22 additions & 16 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,31 @@ class BotorchTestCase(TestCase):

device = torch.device("cpu")

def setUp(self):
def setUp(self, suppress_input_warnings: bool = True) -> None:
warnings.resetwarnings()
settings.debug._set_state(False)
warnings.simplefilter("always", append=True)
warnings.filterwarnings(
"ignore",
message="The model inputs are of type",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message="Non-strict enforcement of botorch tensor conventions.",
category=BotorchTensorDimensionWarning,
)
warnings.filterwarnings(
"ignore",
message="Input data is not standardized.",
category=InputDataWarning,
)
if suppress_input_warnings:
warnings.filterwarnings(
"ignore",
message="The model inputs are of type",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message="Non-strict enforcement of botorch tensor conventions.",
category=BotorchTensorDimensionWarning,
)
warnings.filterwarnings(
"ignore",
message="Input data is not standardized.",
category=InputDataWarning,
)
warnings.filterwarnings(
"ignore",
message="Input data is not contained to the unit cube.",
category=InputDataWarning,
)

def assertAllClose(
self,
Expand Down
1 change: 1 addition & 0 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class DummyAcquisitionFunction(AcquisitionFunction):

class InputConstructorBaseTestCase:
def setUp(self) -> None:
super().setUp()
self.mock_model = MockModel(
posterior=MockPosterior(mean=None, variance=None, base_shape=(1,))
)
Expand Down
2 changes: 1 addition & 1 deletion test/acquisition/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def test_cache_root(self):
"prune_baseline": False,
"cache_root": True,
"posterior_transform": ScalarizedPosteriorTransform(weights=torch.ones(m)),
"sampler": SobolQMCNormalSampler(5),
"sampler": SobolQMCNormalSampler(sample_shape=torch.Size([5])),
}
acqf = qNoisyExpectedImprovement(**nei_args)
X = torch.randn_like(X_baseline)
Expand Down
3 changes: 2 additions & 1 deletion test/acquisition/test_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@


class TestPreferenceAcquisitionFunctions(BotorchTestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.twargs = {"dtype": torch.double}
self.X_dim = 3
self.Y_dim = 2
Expand Down
1 change: 1 addition & 0 deletions test/acquisition/test_prior_guided.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def get_weighted_val(ei_val, prob, exponent, use_log):

class TestPriorGuidedAcquisitionFunction(BotorchTestCase):
def setUp(self):
super().setUp()
self.prior = DummyPrior()
self.train_X = torch.rand(5, 3, dtype=torch.double, device=self.device)
self.train_Y = self.train_X.norm(dim=-1, keepdim=True)
Expand Down
2 changes: 1 addition & 1 deletion test/models/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test_fantasize_flag(self):
self.assertFalse(model.last_fantasize_flag)
model.posterior(test_X)
self.assertFalse(model.last_fantasize_flag)
model.fantasize(test_X, SobolQMCNormalSampler(2))
model.fantasize(test_X, SobolQMCNormalSampler(sample_shape=torch.Size([2])))
self.assertTrue(model.last_fantasize_flag)
model.last_fantasize_flag = False
with fantasize():
Expand Down
16 changes: 9 additions & 7 deletions test/models/test_model_list_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,9 @@ def test_fantasize(self):
m1 = SingleTaskGP(torch.rand(5, 2), torch.rand(5, 1)).eval()
m2 = SingleTaskGP(torch.rand(5, 2), torch.rand(5, 1)).eval()
modellist = ModelListGP(m1, m2)
fm = modellist.fantasize(torch.rand(3, 2), sampler=IIDNormalSampler(2))
fm = modellist.fantasize(
torch.rand(3, 2), sampler=IIDNormalSampler(sample_shape=torch.Size([2]))
)
self.assertIsInstance(fm, ModelListGP)
for i in range(2):
fm_i = fm.models[i]
Expand All @@ -391,8 +393,8 @@ def test_fantasize(self):
self.assertEqual(fm_i.train_targets.shape, torch.Size([2, 8]))

# test decoupled
sampler1 = IIDNormalSampler(2)
sampler2 = IIDNormalSampler(2)
sampler1 = IIDNormalSampler(sample_shape=torch.Size([2]))
sampler2 = IIDNormalSampler(sample_shape=torch.Size([2]))
eval_mask = torch.tensor(
[[1, 0], [0, 1], [1, 0]],
dtype=torch.bool,
Expand Down Expand Up @@ -457,7 +459,7 @@ def _get_fant_mean(
return fant.posterior(target_x).mean.mean(dim=(-2, -3))

# ~0
sampler = IIDNormalSampler(10, seed=0)
sampler = IIDNormalSampler(sample_shape=torch.Size([10]), seed=0)
fant_mean_with_manual_transform = _get_fant_mean(
model_manually_transformed, sampler=sampler
)
Expand Down Expand Up @@ -490,8 +492,8 @@ def _get_fant_mean(
)
# test decoupled
sampler = ListSampler(
IIDNormalSampler(10, seed=0),
IIDNormalSampler(10, seed=0),
IIDNormalSampler(sample_shape=torch.Size([10]), seed=0),
IIDNormalSampler(sample_shape=torch.Size([10]), seed=0),
)
fant_mean_with_manual_transform = _get_fant_mean(
model_manually_transformed,
Expand Down Expand Up @@ -539,7 +541,7 @@ def test_fantasize_with_outcome_transform_fixed_noise(self) -> None:
100 at x=0. If transforms are not properly applied, we'll get answers
on the order of ~1. Answers between 99 and 101 are acceptable.
"""
n_fants = 20
n_fants = torch.Size([20])
y_at_low_x = 100.0
y_at_high_x = -40.0

Expand Down
13 changes: 8 additions & 5 deletions test/models/utils/test_assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def test_add_output_dim(self):


class TestInputDataChecks(BotorchTestCase):
def setUp(self) -> None:
# The super class usually disables input data warnings in unit tests.
# Don't do that here.
super().setUp(suppress_input_warnings=False)

def test_check_no_nans(self):
check_no_nans(torch.tensor([1.0, 2.0]))
with self.assertRaises(InputDataError):
Expand All @@ -87,12 +92,10 @@ def test_check_min_max_scaling(self):
any(issubclass(w.category, InputDataWarning) for w in ws)
)
check_min_max_scaling(X=X, raise_on_fail=True)
with warnings.catch_warnings(record=True) as ws:
with self.assertWarnsRegex(
expected_warning=InputDataWarning, expected_regex="not scaled"
):
check_min_max_scaling(X=X, strict=True)
self.assertTrue(
any(issubclass(w.category, InputDataWarning) for w in ws)
)
self.assertTrue(any("not scaled" in str(w.message) for w in ws))
with self.assertRaises(InputDataError):
check_min_max_scaling(X=X, strict=True, raise_on_fail=True)
# check proper input
Expand Down
10 changes: 7 additions & 3 deletions test/optim/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@


class TestFitGPyTorchMLLScipy(BotorchTestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.mlls = {}
with torch.random.fork_rng():
torch.manual_seed(0)
Expand Down Expand Up @@ -172,7 +173,8 @@ def _assert_np_array_is_float64_type(array) -> bool:


class TestFitGPyTorchMLLTorch(BotorchTestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.mlls = {}
with torch.random.fork_rng():
torch.manual_seed(0)
Expand Down Expand Up @@ -236,7 +238,8 @@ def _test_fit_gpytorch_mll_torch(self, mll):


class TestFitGPyTorchScipy(BotorchTestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.mlls = {}
with torch.random.fork_rng():
torch.manual_seed(0)
Expand Down Expand Up @@ -372,6 +375,7 @@ def _test_fit_gpytorch_scipy(self, mll):

class TestFitGPyTorchTorch(BotorchTestCase):
def setUp(self):
super().setUp()
self.mlls = {}
with torch.random.fork_rng():
torch.manual_seed(0)
Expand Down
50 changes: 28 additions & 22 deletions test/optim/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def test_gen_batch_initial_conditions(self):
MockAcquisitionFunction,
"__call__",
wraps=mock_acqf.__call__,
) as mock_acqf_call:
) as mock_acqf_call, warnings.catch_warnings():
warnings.simplefilter(
"ignore", category=BadInitialCandidatesWarning
)
batch_initial_conditions = gen_batch_initial_conditions(
acq_function=mock_acqf,
bounds=bounds,
Expand Down Expand Up @@ -248,6 +251,9 @@ def test_gen_batch_initial_conditions_highdim(self):
[True, False], [None, 1234], [None, ffs_map], [True, False]
):
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
warnings.simplefilter(
"ignore", category=BadInitialCandidatesWarning
)
batch_initial_conditions = gen_batch_initial_conditions(
acq_function=MockAcquisitionFunction(),
bounds=bounds,
Expand Down Expand Up @@ -279,19 +285,17 @@ def test_gen_batch_initial_conditions_highdim(self):
torch.all(batch_initial_conditions[..., idx] == val)
)

def test_gen_batch_initial_conditions_warning(self):
def test_gen_batch_initial_conditions_warning(self) -> None:
for dtype in (torch.float, torch.double):
bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype)
samples = torch.zeros(10, 1, 2, device=self.device, dtype=dtype)
with ExitStack() as es:
ws = es.enter_context(warnings.catch_warnings(record=True))
es.enter_context(settings.debug(True))
es.enter_context(
mock.patch(
"botorch.optim.initializers.draw_sobol_samples",
return_value=samples,
)
)
with self.assertWarnsRegex(
expected_warning=BadInitialCandidatesWarning,
expected_regex="Unable to find non-zero acquisition",
), mock.patch(
"botorch.optim.initializers.draw_sobol_samples",
return_value=samples,
):
batch_initial_conditions = gen_batch_initial_conditions(
acq_function=MockAcquisitionFunction(),
bounds=bounds,
Expand All @@ -300,16 +304,12 @@ def test_gen_batch_initial_conditions_warning(self):
raw_samples=10,
options={"seed": 1234},
)
self.assertEqual(len(ws), 1)
self.assertTrue(
any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws)
)
self.assertTrue(
torch.equal(
batch_initial_conditions,
torch.zeros(2, 1, 2, device=self.device, dtype=dtype),
)
self.assertTrue(
torch.equal(
batch_initial_conditions,
torch.zeros(2, 1, 2, device=self.device, dtype=dtype),
)
)

def test_gen_batch_initial_conditions_transform_intra_point_constraint(self):
for dtype in (torch.float, torch.double):
Expand Down Expand Up @@ -549,7 +549,10 @@ def test_gen_batch_initial_conditions_constraints(self):
MockAcquisitionFunction,
"__call__",
wraps=mock_acqf.__call__,
) as mock_acqf_call:
) as mock_acqf_call, warnings.catch_warnings():
warnings.simplefilter(
"ignore", category=BadInitialCandidatesWarning
)
batch_initial_conditions = gen_batch_initial_conditions(
acq_function=mock_acqf,
bounds=bounds,
Expand Down Expand Up @@ -723,7 +726,10 @@ def generator(n: int, q: int, seed: int):
MockAcquisitionFunction,
"__call__",
wraps=mock_acqf.__call__,
):
), warnings.catch_warnings():
warnings.simplefilter(
"ignore", category=BadInitialCandidatesWarning
)
batch_initial_conditions = gen_batch_initial_conditions(
acq_function=mock_acqf,
bounds=bounds,
Expand Down
3 changes: 2 additions & 1 deletion test/optim/test_numpy_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def test_set_parameters(self):


class TestScipyObjectiveAndGrad(BotorchTestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
Expand Down
3 changes: 2 additions & 1 deletion test/optim/utils/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_get_data_loader(self):

class TestGetParameters(BotorchTestCase):
def setUp(self):
super().setUp()
self.module = GaussianLikelihood(
noise_constraint=GreaterThan(1e-6, initial_value=0.123),
)
Expand All @@ -124,7 +125,7 @@ def test_get_parameters(self):
self.assertEqual(0, len(get_parameters(self.module, requires_grad=False)))

params = get_parameters(self.module)
self.assertTrue(1 == len(params))
self.assertEqual(1, len(params))
self.assertEqual(next(iter(params)), "noise_covar.raw_noise")
self.assertTrue(
self.module.noise_covar.raw_noise.equal(next(iter(params.values())))
Expand Down
9 changes: 6 additions & 3 deletions test/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __call__(self, mll, closure: Optional[Callable] = None):
class TestFitAPI(BotorchTestCase):
r"""Unit tests for general fitting API"""

def setUp(self):
def setUp(self) -> None:
super().setUp()
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
Expand Down Expand Up @@ -172,7 +173,8 @@ def mock_fit_gpytorch_mll(*args, **kwargs):


class TestFitFallback(BotorchTestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
Expand Down Expand Up @@ -377,7 +379,8 @@ def _test_exceptions(self, mll, ckpt):


class TestFitFallbackAppoximate(BotorchTestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _get_single_cell(self):

class TestBoxDecomposition(BotorchTestCase):
def setUp(self):
super().setUp()
self.ref_point_raw = torch.zeros(3, device=self.device)
self.Y_raw = torch.tensor(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ class TestFastPartitioningUtils(BotorchTestCase):
"""

def setUp(self):
super().setUp()
self.ref_point = -torch.tensor([10.0, 10.0, 10.0], device=self.device)
self.U = -self.ref_point.clone().view(1, -1)
self.Z = torch.empty(1, 3, 3, device=self.device)
Expand Down
3 changes: 2 additions & 1 deletion test/utils/probability/test_truncated_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def setUp(
upper_quantile_min: float = 0.1, # MC methods will not produce any samples.
num_log_probs: int = 4,
seed: int = 1,
):
) -> None:
super().setUp()
self.seed_generator = count(seed)
self.num_log_probs = num_log_probs

Expand Down
1 change: 1 addition & 0 deletions test/utils/test_context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

class TestContextManagers(BotorchTestCase):
def setUp(self):
super().setUp()
module = self.module = Module()
for i, name in enumerate(ascii_lowercase[:3], start=1):
values = torch.rand(2).to(torch.float16)
Expand Down
Loading

0 comments on commit 4cc5ed5

Please sign in to comment.