From 80959211c228bc50e928ffefe30ff2457d7814e9 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Sun, 8 Sep 2024 21:06:19 -0400 Subject: [PATCH] Fixed Signed-off-by: Adam Li --- sklearn/ensemble/_forest.py | 28 ++++++++++++++++------------ sklearn/tree/_classes.py | 6 +++++- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 3f6ed9d5040ab..b01a27f14462d 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -36,10 +36,10 @@ class calls the ``fit`` method of each sub-estimator on random samples # SPDX-License-Identifier: BSD-3-Clause -from time import time import threading from abc import ABCMeta, abstractmethod from numbers import Integral, Real +from time import time from warnings import catch_warnings, simplefilter, warn import numpy as np @@ -54,22 +54,20 @@ class calls the ``fit`` method of each sub-estimator on random samples _fit_context, is_classifier, ) +from sklearn.ensemble._base import BaseEnsemble, _partition_estimators +from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper from sklearn.exceptions import DataConversionWarning from sklearn.metrics import accuracy_score, r2_score from sklearn.preprocessing import OneHotEncoder -from ..tree import ( - BaseDecisionTree, - DecisionTreeClassifier, - DecisionTreeRegressor, - ExtraTreeClassifier, - ExtraTreeRegressor, -) -from ..tree._tree import DOUBLE, DTYPE from sklearn.utils import check_random_state, compute_sample_weight from sklearn.utils._openmp_helpers import _openmp_effective_n_threads from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions from sklearn.utils._tags import get_tags -from sklearn.utils.multiclass import check_classification_targets, type_of_target +from sklearn.utils.multiclass import ( + _check_partial_fit_first_call, + check_classification_targets, + type_of_target, +) from sklearn.utils.parallel import Parallel, delayed from sklearn.utils.validation import ( _check_feature_names_in, @@ -78,9 +76,15 @@ class calls the ``fit`` method of each sub-estimator on random samples check_is_fitted, validate_data, ) -from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper -from ._base import BaseEnsemble, _partition_estimators +from ..tree import ( + BaseDecisionTree, + DecisionTreeClassifier, + DecisionTreeRegressor, + ExtraTreeClassifier, + ExtraTreeRegressor, +) +from ..tree._tree import DOUBLE, DTYPE __all__ = [ "RandomForestClassifier", diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index e0f30bf864010..2e792e768c17d 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -26,7 +26,10 @@ ) from sklearn.utils import Bunch, check_random_state, compute_sample_weight from sklearn.utils._param_validation import Hidden, Interval, RealNotInt, StrOptions -from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.multiclass import ( + _check_partial_fit_first_call, + check_classification_targets, +) from sklearn.utils.validation import ( _assert_all_finite_element_wise, _check_n_features, @@ -35,6 +38,7 @@ check_is_fitted, validate_data, ) + from . import _criterion, _splitter, _tree from ._criterion import BaseCriterion from ._splitter import BaseSplitter