diff --git a/tests/test_randomizedsearch.py b/tests/test_randomizedsearch.py index 0240e98..45f7dfb 100644 --- a/tests/test_randomizedsearch.py +++ b/tests/test_randomizedsearch.py @@ -171,10 +171,6 @@ def test_multi_best_classification(self): scoring = ("accuracy", "f1_micro") search_methods = ["random", "bayesian", "hyperopt", "bohb", "optuna"] for search_method in search_methods: - if search_method == "bohb": - print("bobh test currently failing") - continue - tune_search = TuneSearchCV( model, parameter_grid, @@ -206,9 +202,6 @@ def test_multi_best_classification_scoring_dict(self): scoring = {"acc": "accuracy", "f1": "f1_micro"} search_methods = ["random", "bayesian", "hyperopt", "bohb", "optuna"] for search_method in search_methods: - if search_method == "bohb": - print("bobh test currently failing") - continue tune_search = TuneSearchCV( model, parameter_grid, @@ -239,9 +232,6 @@ def test_multi_best_regression(self): search_methods = ["random", "bayesian", "hyperopt", "bohb", "optuna"] for search_method in search_methods: - if search_method == "bohb": - print("bobh test currently failing") - continue tune_search = TuneSearchCV( model, parameter_grid, diff --git a/tune_sklearn/tune_search.py b/tune_sklearn/tune_search.py index 6fd35a6..caabbd2 100644 --- a/tune_sklearn/tune_search.py +++ b/tune_sklearn/tune_search.py @@ -383,6 +383,7 @@ def __init__(self, if early_stopping is False: raise ValueError( "early_stopping must not be False when using BOHB") + early_stopping = "HyperBandForBOHB" elif not isinstance(early_stopping, HyperBandForBOHB): if early_stopping != "HyperBandForBOHB": warnings.warn("Ignoring early_stopping value, "