From 03342e33478a5ecc94f010745844393de36a1ec9 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 5 Oct 2022 19:54:18 +0200 Subject: [PATCH] Fix `early_stopping=True` with BOHB not setting the correct scheduler (#254) * Fix `early_stopping=True` with BOHB Signed-off-by: Antoni Baum * Enable BOHB tests Signed-off-by: Antoni Baum Signed-off-by: Antoni Baum --- tests/test_randomizedsearch.py | 10 ---------- tune_sklearn/tune_search.py | 1 + 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/test_randomizedsearch.py b/tests/test_randomizedsearch.py index 0240e98..45f7dfb 100644 --- a/tests/test_randomizedsearch.py +++ b/tests/test_randomizedsearch.py @@ -171,10 +171,6 @@ def test_multi_best_classification(self): scoring = ("accuracy", "f1_micro") search_methods = ["random", "bayesian", "hyperopt", "bohb", "optuna"] for search_method in search_methods: - if search_method == "bohb": - print("bobh test currently failing") - continue - tune_search = TuneSearchCV( model, parameter_grid, @@ -206,9 +202,6 @@ def test_multi_best_classification_scoring_dict(self): scoring = {"acc": "accuracy", "f1": "f1_micro"} search_methods = ["random", "bayesian", "hyperopt", "bohb", "optuna"] for search_method in search_methods: - if search_method == "bohb": - print("bobh test currently failing") - continue tune_search = TuneSearchCV( model, parameter_grid, @@ -239,9 +232,6 @@ def test_multi_best_regression(self): search_methods = ["random", "bayesian", "hyperopt", "bohb", "optuna"] for search_method in search_methods: - if search_method == "bohb": - print("bobh test currently failing") - continue tune_search = TuneSearchCV( model, parameter_grid, diff --git a/tune_sklearn/tune_search.py b/tune_sklearn/tune_search.py index 6fd35a6..caabbd2 100644 --- a/tune_sklearn/tune_search.py +++ b/tune_sklearn/tune_search.py @@ -383,6 +383,7 @@ def __init__(self, if early_stopping is False: raise ValueError( "early_stopping must not be False when using BOHB") + early_stopping = "HyperBandForBOHB" elif not isinstance(early_stopping, HyperBandForBOHB): if early_stopping != "HyperBandForBOHB": warnings.warn("Ignoring early_stopping value, "