From 51b7ea597907e194094acd18c1b7545fac636124 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 11 Aug 2023 10:59:31 -0400 Subject: [PATCH 1/7] Adding three diffs Signed-off-by: Adam Li --- sklearn/tree/_classes.py | 199 ++- sklearn/tree/_tree.pxd | 2462 +++++++++++++++++++++++++++++++++++--- sklearn/tree/_tree.pyx | 303 ++++- 3 files changed, 2790 insertions(+), 174 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 26267a1355f6f..6234cb976aa08 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -11,12 +11,12 @@ # Joly Arnaud # Fares Hedayati # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause import copy import numbers -import warnings from abc import ABCMeta, abstractmethod from math import ceil from numbers import Integral, Real @@ -35,7 +35,10 @@ ) from sklearn.utils import Bunch, check_random_state, compute_sample_weight from sklearn.utils._param_validation import Hidden, Interval, RealNotInt, StrOptions -from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.multiclass import ( + _check_partial_fit_first_call, + check_classification_targets, +) from sklearn.utils.validation import ( _assert_all_finite_element_wise, _check_sample_weight, @@ -237,6 +240,7 @@ def _fit( self, X, y, + classes=None, sample_weight=None, check_input=True, missing_values_in_feature_mask=None, @@ -291,7 +295,6 @@ def _fit( is_classification = False if y is not None: is_classification = is_classifier(self) - y = np.atleast_1d(y) expanded_class_weight = None @@ -313,10 +316,28 @@ def _fit( y_original = np.copy(y) y_encoded = np.zeros(y.shape, dtype=int) - for k in range(self.n_outputs_): - classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) - self.classes_.append(classes_k) - self.n_classes_.append(classes_k.shape[0]) + if classes is not None: + classes = np.atleast_1d(classes) + if classes.ndim == 1: + classes = np.array([classes]) + + for k in classes: + self.classes_.append(np.array(k)) + self.n_classes_.append(np.array(k).shape[0]) + + for i in range(n_samples): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][ + 0 + ] + else: + for k in range(self.n_outputs_): + classes_k, y_encoded[:, k] = np.unique( + y[:, k], return_inverse=True + ) + self.classes_.append(classes_k) + self.n_classes_.append(classes_k.shape[0]) + y = y_encoded if self.class_weight is not None: @@ -355,24 +376,8 @@ def _fit( if self.max_features == "auto": if is_classification: max_features = max(1, int(np.sqrt(self.n_features_in_))) - warnings.warn( - ( - "`max_features='auto'` has been deprecated in 1.1 " - "and will be removed in 1.3. To keep the past behaviour, " - "explicitly set `max_features='sqrt'`." - ), - FutureWarning, - ) else: max_features = self.n_features_in_ - warnings.warn( - ( - "`max_features='auto'` has been deprecated in 1.1 " - "and will be removed in 1.3. To keep the past behaviour, " - "explicitly set `max_features=1.0'`." - ), - FutureWarning, - ) elif self.max_features == "sqrt": max_features = max(1, int(np.sqrt(self.n_features_in_))) elif self.max_features == "log2": @@ -538,7 +543,7 @@ def _build_tree( # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: - builder = DepthFirstTreeBuilder( + self.builder_ = DepthFirstTreeBuilder( splitter, min_samples_split, min_samples_leaf, @@ -548,7 +553,7 @@ def _build_tree( self.store_leaf_values, ) else: - builder = BestFirstTreeBuilder( + self.builder_ = BestFirstTreeBuilder( splitter, min_samples_split, min_samples_leaf, @@ -558,7 +563,7 @@ def _build_tree( self.min_impurity_decrease, self.store_leaf_values, ) - builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) + self.builder_.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) if self.n_outputs_ == 1 and is_classifier(self): self.n_classes_ = self.n_classes_[0] @@ -1119,6 +1124,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- DecisionTreeRegressor : A decision tree regressor. @@ -1209,7 +1217,14 @@ def __init__( ) @_fit_context(prefer_skip_nested_validation=True) - def fit(self, X, y, sample_weight=None, check_input=True): + def fit( + self, + X, + y, + sample_weight=None, + check_input=True, + classes=None, + ): """Build a decision tree classifier from the training set (X, y). Parameters @@ -1233,6 +1248,11 @@ def fit(self, X, y, sample_weight=None, check_input=True): Allow to bypass several input checking. Don't use this parameter unless you know what you're doing. + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + Returns ------- self : DecisionTreeClassifier @@ -1243,9 +1263,110 @@ def fit(self, X, y, sample_weight=None, check_input=True): y, sample_weight=sample_weight, check_input=check_input, + classes=classes, ) return self + def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): + """Update a decision tree classifier from the training set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you do. + + Returns + ------- + self : DecisionTreeClassifier + Fitted estimator. + """ + + first_call = _check_partial_fit_first_call(self, classes=classes) + + # Fit if no tree exists yet + if first_call: + self.fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + classes=classes, + ) + return self + + if check_input: + # Need to validate separately here. + # We can't pass multi_ouput=True because that would allow y to be + # csr. + check_X_params = dict(dtype=DTYPE, accept_sparse="csc") + check_y_params = dict(ensure_2d=False, dtype=None) + X, y = self._validate_data( + X, y, reset=False, validate_separately=(check_X_params, check_y_params) + ) + if issparse(X): + X.sort_indices() + + if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: + raise ValueError( + "No support for np.int64 index based sparse matrices" + ) + + if X.shape[1] != self.n_features_in_: + msg = "Number of features %d does not match previous data %d." + raise ValueError(msg % (X.shape[1], self.n_features_in_)) + + y = np.atleast_1d(y) + + if y.ndim == 1: + # reshape is necessary to preserve the data contiguity against vs + # [:, np.newaxis] that does not. + y = np.reshape(y, (-1, 1)) + + check_classification_targets(y) + y = np.copy(y) + + classes = self.classes_ + if self.n_outputs_ == 1: + classes = [classes] + + y_encoded = np.zeros(y.shape, dtype=int) + for i in range(X.shape[0]): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(classes[j] == y[i, j])[0][0] + y = y_encoded + + if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: + y = np.ascontiguousarray(y, dtype=DOUBLE) + + # Update tree + self.builder_.initialize_node_queue(self.tree_, X, y, sample_weight) + self.builder_.build(self.tree_, X, y, sample_weight) + + self._prune_tree() + + return self + def predict_proba(self, X, check_input=True): """Predict class probabilities of the input samples X. @@ -1518,6 +1639,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- DecisionTreeClassifier : A decision tree classifier. @@ -1600,7 +1724,14 @@ def __init__( ) @_fit_context(prefer_skip_nested_validation=True) - def fit(self, X, y, sample_weight=None, check_input=True): + def fit( + self, + X, + y, + sample_weight=None, + check_input=True, + classes=None, + ): """Build a decision tree regressor from the training set (X, y). Parameters @@ -1623,6 +1754,9 @@ def fit(self, X, y, sample_weight=None, check_input=True): Allow to bypass several input checking. Don't use this parameter unless you know what you're doing. + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Returns ------- self : DecisionTreeRegressor @@ -1634,6 +1768,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): y, sample_weight=sample_weight, check_input=check_input, + classes=classes, ) return self @@ -1885,6 +2020,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- ExtraTreeRegressor : An extremely randomized tree regressor. @@ -2147,6 +2285,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- ExtraTreeClassifier : An extremely randomized tree classifier. @@ -2214,4 +2355,4 @@ def __init__( ccp_alpha=ccp_alpha, store_leaf_values=store_leaf_values, monotonic_cst=monotonic_cst, - ) + ) \ No newline at end of file diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index dedd820c41e0f..9d9763810edb7 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -1,168 +1,2358 @@ +""" +This module gathers tree-based methods, including decision, regression and +randomized trees. Single and multi-output problems are both handled. +""" + # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt -# Joel Nothman -# Arnaud Joly -# Jacob Schreiber +# Noel Dawe +# Satrajit Gosh +# Joly Arnaud +# Fares Hedayati # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause -# See _tree.pyx for details. +import copy +import numbers +from abc import ABCMeta, abstractmethod +from math import ceil +from numbers import Integral, Real import numpy as np +from scipy.sparse import issparse -cimport numpy as cnp -from libcpp.unordered_map cimport unordered_map -from libcpp.vector cimport vector +from sklearn.base import ( + BaseEstimator, + ClassifierMixin, + MultiOutputMixin, + RegressorMixin, + _fit_context, + clone, + is_classifier, +) +from sklearn.utils import Bunch, check_random_state, compute_sample_weight +from sklearn.utils._param_validation import Hidden, Interval, RealNotInt, StrOptions +from sklearn.utils.multiclass import ( + _check_partial_fit_first_call, + check_classification_targets, +) +from sklearn.utils.validation import ( + _assert_all_finite_element_wise, + _check_sample_weight, + assert_all_finite, + check_is_fitted, +) -ctypedef cnp.npy_float32 DTYPE_t # Type of X -ctypedef cnp.npy_float64 DOUBLE_t # Type of y, sample_weight -ctypedef cnp.npy_intp SIZE_t # Type for indices and counters -ctypedef cnp.npy_int32 INT32_t # Signed 32 bit integer -ctypedef cnp.npy_uint32 UINT32_t # Unsigned 32 bit integer +from . import _criterion, _splitter, _tree +from ._criterion import BaseCriterion +from ._splitter import BaseSplitter +from ._tree import ( + BestFirstTreeBuilder, + DepthFirstTreeBuilder, + Tree, + _build_pruned_tree_ccp, + ccp_pruning_path, +) +from ._utils import _any_isnan_axis0 -from ._splitter cimport SplitRecord, Splitter +__all__ = [ + "DecisionTreeClassifier", + "DecisionTreeRegressor", + "ExtraTreeClassifier", + "ExtraTreeRegressor", +] -cdef struct Node: - # Base storage structure for the nodes in a Tree object +# ============================================================================= +# Types and constants +# ============================================================================= - SIZE_t left_child # id of the left child of the node - SIZE_t right_child # id of the right child of the node - SIZE_t feature # Feature used for splitting the node - DOUBLE_t threshold # Threshold value at the node - DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) - SIZE_t n_node_samples # Number of samples at the node - DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node - unsigned char missing_go_to_left # Whether features have missing values +DTYPE = _tree.DTYPE +DOUBLE = _tree.DOUBLE +CRITERIA_CLF = { + "gini": _criterion.Gini, + "log_loss": _criterion.Entropy, + "entropy": _criterion.Entropy, +} +CRITERIA_REG = { + "squared_error": _criterion.MSE, + "friedman_mse": _criterion.FriedmanMSE, + "absolute_error": _criterion.MAE, + "poisson": _criterion.Poisson, +} -cdef class BaseTree: - # Inner structures: values are stored separately from node structure, - # since size is determined at runtime. - cdef public SIZE_t max_depth # Max depth of the tree - cdef public SIZE_t node_count # Counter for node IDs - cdef public SIZE_t capacity # Capacity of tree, in terms of nodes - cdef Node* nodes # Array of nodes +DENSE_SPLITTERS = {"best": _splitter.BestSplitter, "random": _splitter.RandomSplitter} - cdef SIZE_t value_stride # The dimensionality of a vectorized output per sample - cdef double* value # Array of values prediction values for each node +SPARSE_SPLITTERS = { + "best": _splitter.BestSparseSplitter, + "random": _splitter.RandomSparseSplitter, +} - # Generic Methods: These are generic methods used by any tree. - cdef int _resize(self, SIZE_t capacity) except -1 nogil - cdef int _resize_c(self, SIZE_t capacity=*) except -1 nogil - cdef SIZE_t _add_node( - self, - SIZE_t parent, - bint is_left, - bint is_leaf, - SplitRecord* split_node, - double impurity, - SIZE_t n_node_samples, - double weighted_n_node_samples, - unsigned char missing_go_to_left - ) except -1 nogil - - # Python API methods: These are methods exposed to Python - cpdef cnp.ndarray apply(self, object X) - cdef cnp.ndarray _apply_dense(self, object X) - cdef cnp.ndarray _apply_sparse_csr(self, object X) - - cpdef object decision_path(self, object X) - cdef object _decision_path_dense(self, object X) - cdef object _decision_path_sparse_csr(self, object X) - - cpdef compute_node_depths(self) - cpdef compute_feature_importances(self, normalize=*) - - # Abstract methods: these functions must be implemented by any decision tree - cdef int _set_split_node( - self, - SplitRecord* split_node, - Node* node - ) except -1 nogil - cdef int _set_leaf_node( +# ============================================================================= +# Base decision tree +# ============================================================================= + + +class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): + """Base class for decision trees. + + Warning: This class should not be used directly. + Use derived classes instead. + """ + + _parameter_constraints: dict = { + "splitter": [StrOptions({"best", "random"})], + "max_depth": [Interval(Integral, 1, None, closed="left"), None], + "min_samples_split": [ + Interval(Integral, 2, None, closed="left"), + Interval(RealNotInt, 0.0, 1.0, closed="right"), + ], + "min_samples_leaf": [ + Interval(Integral, 1, None, closed="left"), + Interval(RealNotInt, 0.0, 1.0, closed="neither"), + ], + "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")], + "max_features": [ + Interval(Integral, 1, None, closed="left"), + Interval(RealNotInt, 0.0, 1.0, closed="right"), + StrOptions({"sqrt", "log2"}), + None, + ], + "random_state": ["random_state"], + "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None], + "min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")], + "ccp_alpha": [Interval(Real, 0.0, None, closed="left")], + "store_leaf_values": ["boolean"], + "monotonic_cst": ["array-like", None], + } + + @abstractmethod + def __init__( self, - SplitRecord* split_node, - Node* node - ) except -1 nogil - cdef DTYPE_t _compute_feature( + *, + criterion, + splitter, + max_depth, + min_samples_split, + min_samples_leaf, + min_weight_fraction_leaf, + max_features, + max_leaf_nodes, + random_state, + min_impurity_decrease, + class_weight=None, + ccp_alpha=0.0, + store_leaf_values=False, + monotonic_cst=None, + ): + self.criterion = criterion + self.splitter = splitter + self.max_depth = max_depth + self.min_samples_split = min_samples_split + self.min_samples_leaf = min_samples_leaf + self.min_weight_fraction_leaf = min_weight_fraction_leaf + self.max_features = max_features + self.max_leaf_nodes = max_leaf_nodes + self.random_state = random_state + self.min_impurity_decrease = min_impurity_decrease + self.class_weight = class_weight + self.ccp_alpha = ccp_alpha + self.store_leaf_values = store_leaf_values + self.monotonic_cst = monotonic_cst + + def get_depth(self): + """Return the depth of the decision tree. + + The depth of a tree is the maximum distance between the root + and any leaf. + + Returns + ------- + self.tree_.max_depth : int + The maximum depth of the tree. + """ + check_is_fitted(self) + return self.tree_.max_depth + + def get_n_leaves(self): + """Return the number of leaves of the decision tree. + + Returns + ------- + self.tree_.n_leaves : int + Number of leaves. + """ + check_is_fitted(self) + return self.tree_.n_leaves + + def _support_missing_values(self, X): + return ( + not issparse(X) + and self._get_tags()["allow_nan"] + and self.monotonic_cst is None + ) + + def _compute_missing_values_in_feature_mask(self, X, estimator_name=None): + """Return boolean mask denoting if there are missing values for each feature. + + This method also ensures that X is finite. + + Parameter + --------- + X : array-like of shape (n_samples, n_features), dtype=DOUBLE + Input data. + + estimator_name : str or None, default=None + Name to use when raising an error. Defaults to the class name. + + Returns + ------- + missing_values_in_feature_mask : ndarray of shape (n_features,), or None + Missing value mask. If missing values are not supported or there + are no missing values, return None. + """ + estimator_name = estimator_name or self.__class__.__name__ + common_kwargs = dict(estimator_name=estimator_name, input_name="X") + + if not self._support_missing_values(X): + assert_all_finite(X, **common_kwargs) + return None + + with np.errstate(over="ignore"): + overall_sum = np.sum(X) + + if not np.isfinite(overall_sum): + # Raise a ValueError in case of the presence of an infinite element. + _assert_all_finite_element_wise(X, xp=np, allow_nan=True, **common_kwargs) + + # If the sum is not nan, then there are no missing values + if not np.isnan(overall_sum): + return None + + missing_values_in_feature_mask = _any_isnan_axis0(X) + return missing_values_in_feature_mask + + def _fit( self, - const DTYPE_t[:, :] X_ndarray, - SIZE_t sample_index, - Node *node - ) noexcept nogil - cdef void _compute_feature_importances( + X, + y, + classes=None, + sample_weight=None, + check_input=True, + missing_values_in_feature_mask=None, + ): + random_state = check_random_state(self.random_state) + + if check_input: + # Need to validate separately here. + # We can't pass multi_output=True because that would allow y to be + # csr. + + # _compute_missing_values_in_feature_mask will check for finite values and + # compute the missing mask if the tree supports missing values + check_X_params = dict( + dtype=DTYPE, accept_sparse="csc", force_all_finite=False + ) + check_y_params = dict(ensure_2d=False, dtype=None) + if y is not None or self._get_tags()["requires_y"]: + X, y = self._validate_data( + X, y, validate_separately=(check_X_params, check_y_params) + ) + else: + X = self._validate_data(X, **check_X_params) + + missing_values_in_feature_mask = ( + self._compute_missing_values_in_feature_mask(X) + ) + if issparse(X): + X.sort_indices() + + if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: + raise ValueError( + "No support for np.int64 index based sparse matrices" + ) + + if y is not None and self.criterion == "poisson": + if np.any(y < 0): + raise ValueError( + "Some value(s) of y are negative which is" + " not allowed for Poisson regression." + ) + if np.sum(y) <= 0: + raise ValueError( + "Sum of y is not positive which is " + "necessary for Poisson regression." + ) + + # Determine output settings + n_samples, self.n_features_in_ = X.shape + + # Do preprocessing if 'y' is passed + is_classification = False + if y is not None: + is_classification = is_classifier(self) + y = np.atleast_1d(y) + expanded_class_weight = None + + if y.ndim == 1: + # reshape is necessary to preserve the data contiguity against vs + # [:, np.newaxis] that does not. + y = np.reshape(y, (-1, 1)) + + self.n_outputs_ = y.shape[1] + + if is_classification: + check_classification_targets(y) + y = np.copy(y) + + self.classes_ = [] + self.n_classes_ = [] + + if self.class_weight is not None: + y_original = np.copy(y) + + y_encoded = np.zeros(y.shape, dtype=int) + if classes is not None: + classes = np.atleast_1d(classes) + if classes.ndim == 1: + classes = np.array([classes]) + + for k in classes: + self.classes_.append(np.array(k)) + self.n_classes_.append(np.array(k).shape[0]) + + for i in range(n_samples): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][ + 0 + ] + else: + for k in range(self.n_outputs_): + classes_k, y_encoded[:, k] = np.unique( + y[:, k], return_inverse=True + ) + self.classes_.append(classes_k) + self.n_classes_.append(classes_k.shape[0]) + + y = y_encoded + + if self.class_weight is not None: + expanded_class_weight = compute_sample_weight( + self.class_weight, y_original + ) + + self.n_classes_ = np.array(self.n_classes_, dtype=np.intp) + + if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: + y = np.ascontiguousarray(y, dtype=DOUBLE) + + if len(y) != n_samples: + raise ValueError( + "Number of labels=%d does not match number of samples=%d" + % (len(y), n_samples) + ) + + # set decision-tree model parameters + max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth + + if isinstance(self.min_samples_leaf, numbers.Integral): + min_samples_leaf = self.min_samples_leaf + else: # float + min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples)) + + if isinstance(self.min_samples_split, numbers.Integral): + min_samples_split = self.min_samples_split + else: # float + min_samples_split = int(ceil(self.min_samples_split * n_samples)) + min_samples_split = max(2, min_samples_split) + + min_samples_split = max(min_samples_split, 2 * min_samples_leaf) + + if isinstance(self.max_features, str): + if self.max_features == "auto": + if is_classification: + max_features = max(1, int(np.sqrt(self.n_features_in_))) + else: + max_features = self.n_features_in_ + elif self.max_features == "sqrt": + max_features = max(1, int(np.sqrt(self.n_features_in_))) + elif self.max_features == "log2": + max_features = max(1, int(np.log2(self.n_features_in_))) + elif self.max_features is None: + max_features = self.n_features_in_ + elif isinstance(self.max_features, numbers.Integral): + max_features = self.max_features + else: # float + if self.max_features > 0.0: + max_features = max(1, int(self.max_features * self.n_features_in_)) + else: + max_features = 0 + + self.max_features_ = max_features + + max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes + + if sample_weight is not None: + sample_weight = _check_sample_weight(sample_weight, X, DOUBLE) + + if y is not None and expanded_class_weight is not None: + if sample_weight is not None: + sample_weight = sample_weight * expanded_class_weight + else: + sample_weight = expanded_class_weight + + # Set min_weight_leaf from min_weight_fraction_leaf + if sample_weight is None: + min_weight_leaf = self.min_weight_fraction_leaf * n_samples + else: + min_weight_leaf = self.min_weight_fraction_leaf * np.sum(sample_weight) + + # build the actual tree now with the parameters + self._build_tree( + X, + y, + sample_weight, + missing_values_in_feature_mask, + min_samples_leaf, + min_weight_leaf, + max_leaf_nodes, + min_samples_split, + max_depth, + random_state, + ) + + return self + + def _build_tree( self, - cnp.float64_t[:] importances, - Node* node, - ) noexcept nogil - -cdef class Tree(BaseTree): - # The Supervised Tree object is a binary tree structure constructed by the - # TreeBuilder. The tree structure is used for predictions and - # feature importances. - # - # Value of upstream properties: - # - value_stride = n_outputs * max_n_classes - # - value = (capacity, n_outputs, max_n_classes) array of values - - # Input/Output layout for supervised tree - cdef public SIZE_t n_features # Number of features in X - cdef SIZE_t* n_classes # Number of classes in y[:, k] - cdef public SIZE_t n_outputs # Number of outputs in y - cdef public SIZE_t max_n_classes # max(n_classes) - - # Enables the use of tree to store distributions of the output to allow - # arbitrary usage of the the leaves. This is used in the quantile - # estimators for example. - # for storing samples at each leaf node with leaf's node ID as the key and - # the sample values as the value - cdef unordered_map[SIZE_t, vector[vector[DOUBLE_t]]] value_samples - - # Methods - cdef cnp.ndarray _get_value_ndarray(self) - cdef cnp.ndarray _get_node_ndarray(self) - cdef cnp.ndarray _get_value_samples_ndarray(self, SIZE_t node_id) - cdef cnp.ndarray _get_value_samples_keys(self) - - cpdef cnp.ndarray predict(self, object X) + X, + y, + sample_weight, + missing_values_in_feature_mask, + min_samples_leaf, + min_weight_leaf, + max_leaf_nodes, + min_samples_split, + max_depth, + random_state, + ): + """Build the actual tree. + + Parameters + ---------- + X : Array-like + X dataset. + y : Array-like + Y targets. + sample_weight : Array-like + Sample weights + min_samples_leaf : float + Number of samples required to be a leaf. + min_weight_leaf : float + Weight of samples required to be a leaf. + max_leaf_nodes : float + Maximum number of leaf nodes allowed in tree. + min_samples_split : float + Minimum number of samples to split on. + max_depth : int + The maximum depth of any tree. + random_state : int + Random seed. + """ + + n_samples = X.shape[0] + + # Build tree + criterion = self.criterion + if not isinstance(criterion, BaseCriterion): + if is_classifier(self): + criterion = CRITERIA_CLF[self.criterion]( + self.n_outputs_, self.n_classes_ + ) + else: + criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) + else: + # Make a deepcopy in case the criterion has mutable attributes that + # might be shared and modified concurrently during parallel fitting + criterion = copy.deepcopy(criterion) + + SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS + + if self.monotonic_cst is None: + monotonic_cst = None + else: + if self.n_outputs_ > 1: + raise ValueError( + "Monotonicity constraints are not supported with multiple outputs." + ) + # Check to correct monotonicity constraint' specification, + # by applying element-wise logical conjunction + # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` + # straight away here so as to generate error messages for invalid + # values using the original values prior to any dtype related conversion. + monotonic_cst = np.asarray(self.monotonic_cst) + if monotonic_cst.shape[0] != X.shape[1]: + raise ValueError( + "monotonic_cst has shape {} but the input data " + "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) + ) + valid_constraints = np.isin(monotonic_cst, (-1, 0, 1)) + if not np.all(valid_constraints): + unique_constaints_value = np.unique(monotonic_cst) + raise ValueError( + "monotonic_cst must be None or an array-like of -1, 0 or 1, but" + f" got {unique_constaints_value}" + ) + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + if is_classifier(self): + if self.n_classes_[0] > 2: + raise ValueError( + "Monotonicity constraints are not supported with multiclass " + "classification" + ) + # Binary classification trees are built by constraining probabilities + # of the *negative class* in order to make the implementation similar + # to regression trees. + # Since self.monotonic_cst encodes constraints on probabilities of the + # *positive class*, all signs must be flipped. + monotonic_cst *= -1 + + if not isinstance(self.splitter, BaseSplitter): + splitter = SPLITTERS[self.splitter]( + criterion, + self.max_features_, + min_samples_leaf, + min_weight_leaf, + random_state, + monotonic_cst, + ) + + if is_classifier(self): + self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) + else: + self.tree_ = Tree( + self.n_features_in_, + # TODO: tree shouldn't need this in this case + np.array([1] * self.n_outputs_, dtype=np.intp), + self.n_outputs_, + ) + + # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise + if max_leaf_nodes < 0: + self.builder_ = DepthFirstTreeBuilder( + splitter, + min_samples_split, + min_samples_leaf, + min_weight_leaf, + max_depth, + self.min_impurity_decrease, + self.store_leaf_values, + ) + else: + self.builder_ = BestFirstTreeBuilder( + splitter, + min_samples_split, + min_samples_leaf, + min_weight_leaf, + max_depth, + max_leaf_nodes, + self.min_impurity_decrease, + self.store_leaf_values, + ) + self.builder_.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) + + if self.n_outputs_ == 1 and is_classifier(self): + self.n_classes_ = self.n_classes_[0] + self.classes_ = self.classes_[0] + + self._prune_tree() + + def _validate_X_predict(self, X, check_input): + """Validate the training data on predict (probabilities).""" + if check_input: + if self._support_missing_values(X): + force_all_finite = "allow-nan" + else: + force_all_finite = True + X = self._validate_data( + X, + dtype=DTYPE, + accept_sparse="csr", + reset=False, + force_all_finite=force_all_finite, + ) + if issparse(X) and ( + X.indices.dtype != np.intc or X.indptr.dtype != np.intc + ): + raise ValueError("No support for np.int64 index based sparse matrices") + else: + # The number of features is checked regardless of `check_input` + self._check_n_features(X, reset=False) + return X + + def predict(self, X, check_input=True): + """Predict class or regression value for X. + + For a classification model, the predicted class for each sample in X is + returned. For a regression model, the predicted value based on X is + returned. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you're doing. + + Returns + ------- + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The predicted classes, or the predict values. + """ + check_is_fitted(self) + X = self._validate_X_predict(X, check_input) + + # proba is a count matrix of leaves that fall into + # (n_samples, n_outputs, max_n_classes) array + proba = self.tree_.predict(X) + n_samples = X.shape[0] + + # Classification + if is_classifier(self): + if self.n_outputs_ == 1: + return self.classes_.take(np.argmax(proba, axis=1), axis=0) + + else: + class_type = self.classes_[0].dtype + predictions = np.zeros((n_samples, self.n_outputs_), dtype=class_type) + for k in range(self.n_outputs_): + predictions[:, k] = self.classes_[k].take( + np.argmax(proba[:, k], axis=1), axis=0 + ) + + return predictions + + # Regression + else: + if self.n_outputs_ == 1: + return proba[:, 0] + + else: + return proba[:, :, 0] + + def get_leaf_node_samples(self, X, check_input=True): + """For each datapoint x in X, get the training samples in the leaf node. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Dataset to apply the forest to. + check_input : bool, default=True + Allow to bypass several input checking. + + Returns + ------- + leaf_nodes_samples : a list of array-like of length (n_samples,) + Each sample is represented by the indices of the training samples that + reached the leaf node. The ``n_leaf_node_samples`` may vary between + samples, since the number of samples that fall in a leaf node is + variable. Each array has shape (n_leaf_node_samples, n_outputs). + """ + if not self.store_leaf_values: + raise RuntimeError( + "leaf node samples are not stored when store_leaf_values=False" + ) + + # get indices of leaves per sample (n_samples,) + X_leaves = self.apply(X, check_input=check_input) + n_samples = X_leaves.shape[0] + + # get array of samples per leaf (n_node_samples, n_outputs) + leaf_samples = self.tree_.leaf_nodes_samples + + leaf_nodes_samples = [] + for idx in range(n_samples): + leaf_id = X_leaves[idx] + leaf_nodes_samples.append(leaf_samples[leaf_id]) + return leaf_nodes_samples + + def predict_quantiles(self, X, quantiles=0.5, method="nearest", check_input=True): + """Predict class or regression value for X at given quantiles. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input data. + quantiles : float, optional + The quantiles at which to evaluate, by default 0.5 (median). + method : str, optional + The method to interpolate, by default 'linear'. Can be any keyword + argument accepted by :func:`~np.quantile`. + check_input : bool, optional + Whether or not to check input, by default True. + + Returns + ------- + predictions : array-like of shape (n_samples, n_outputs, len(quantiles)) + The predicted quantiles. + """ + if not self.store_leaf_values: + raise RuntimeError( + "Predicting quantiles requires that the tree stores leaf node samples." + ) + + check_is_fitted(self) + + # Check data + X = self._validate_X_predict(X, check_input) + + if not isinstance(quantiles, (np.ndarray, list)): + quantiles = np.array([quantiles]) + + # get indices of leaves per sample + X_leaves = self.apply(X) + + # get array of samples per leaf (n_node_samples, n_outputs) + leaf_samples = self.tree_.leaf_nodes_samples + + # compute quantiles (n_samples, n_quantiles, n_outputs) + n_samples = X.shape[0] + n_quantiles = len(quantiles) + proba = np.zeros((n_samples, n_quantiles, self.n_outputs_)) + for idx, leaf_id in enumerate(X_leaves): + # predict by taking the quantile across the samples in the leaf for + # each output + try: + proba[idx, ...] = np.quantile( + leaf_samples[leaf_id], quantiles, axis=0, method=method + ) + except TypeError: + proba[idx, ...] = np.quantile( + leaf_samples[leaf_id], quantiles, axis=0, interpolation=method + ) + + # Classification + if is_classifier(self): + if self.n_outputs_ == 1: + # return the class with the highest probability for each quantile + # (n_samples, n_quantiles) + class_preds = np.zeros( + (n_samples, n_quantiles), dtype=self.classes_.dtype + ) + for i in range(n_quantiles): + class_pred_per_sample = ( + proba[:, i, :].squeeze().astype(self.classes_.dtype) + ) + class_preds[:, i] = self.classes_.take( + class_pred_per_sample, axis=0 + ) + return class_preds + else: + class_type = self.classes_[0].dtype + predictions = np.zeros( + (n_samples, n_quantiles, self.n_outputs_), dtype=class_type + ) + for k in range(self.n_outputs_): + for i in range(n_quantiles): + class_pred_per_sample = proba[:, i, k].squeeze().astype(int) + predictions[:, i, k] = self.classes_[k].take( + class_pred_per_sample, axis=0 + ) + + return predictions + # Regression + else: + if self.n_outputs_ == 1: + return proba[:, :, 0] + + else: + return proba + + def apply(self, X, check_input=True): + """Return the index of the leaf that each sample is predicted as. + + .. versionadded:: 0.17 + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you're doing. + + Returns + ------- + X_leaves : array-like of shape (n_samples,) + For each datapoint x in X, return the index of the leaf x + ends up in. Leaves are numbered within + ``[0; self.tree_.node_count)``, possibly with gaps in the + numbering. + """ + check_is_fitted(self) + X = self._validate_X_predict(X, check_input) + return self.tree_.apply(X) + + def decision_path(self, X, check_input=True): + """Return the decision path in the tree. + + .. versionadded:: 0.18 + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you're doing. + + Returns + ------- + indicator : sparse matrix of shape (n_samples, n_nodes) + Return a node indicator CSR matrix where non zero elements + indicates that the samples goes through the nodes. + """ + X = self._validate_X_predict(X, check_input) + return self.tree_.decision_path(X) + + def _prune_tree(self): + """Prune tree using Minimal Cost-Complexity Pruning.""" + check_is_fitted(self) + + if self.ccp_alpha == 0.0: + return + + # build pruned tree + if is_classifier(self): + n_classes = np.atleast_1d(self.n_classes_) + pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_) + else: + pruned_tree = Tree( + self.n_features_in_, + # TODO: the tree shouldn't need this param + np.array([1] * self.n_outputs_, dtype=np.intp), + self.n_outputs_, + ) + _build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha) + + self.tree_ = pruned_tree + + def cost_complexity_pruning_path(self, X, y, sample_weight=None): + """Compute the pruning path during Minimal Cost-Complexity Pruning. + + See :ref:`minimal_cost_complexity_pruning` for details on the pruning + process. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + Returns + ------- + ccp_path : :class:`~sklearn.utils.Bunch` + Dictionary-like object, with the following attributes. + + ccp_alphas : ndarray + Effective alphas of subtree during pruning. + + impurities : ndarray + Sum of the impurities of the subtree leaves for the + corresponding alpha value in ``ccp_alphas``. + """ + est = clone(self).set_params(ccp_alpha=0.0) + est.fit(X, y, sample_weight=sample_weight) + return Bunch(**ccp_pruning_path(est.tree_)) + + @property + def feature_importances_(self): + """Return the feature importances. + + The importance of a feature is computed as the (normalized) total + reduction of the criterion brought by that feature. + It is also known as the Gini importance. + + Warning: impurity-based feature importances can be misleading for + high cardinality features (many unique values). See + :func:`sklearn.inspection.permutation_importance` as an alternative. + + Returns + ------- + feature_importances_ : ndarray of shape (n_features,) + Normalized total reduction of criteria by feature + (Gini importance). + """ + check_is_fitted(self) + + return self.tree_.compute_feature_importances() + # ============================================================================= -# Tree builder +# Public estimators # ============================================================================= -cdef class TreeBuilder: - # The TreeBuilder recursively builds a Tree object from training samples, - # using a Splitter object for splitting internal nodes and assigning - # values to leaves. - # - # This class controls the various stopping criteria and the node splitting - # evaluation order, e.g. depth-first or best-first. - cdef Splitter splitter # Splitting algorithm +class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): + """A decision tree classifier. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + criterion : {"gini", "entropy", "log_loss"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity and "log_loss" and "entropy" both for the + Shannon information gain, see :ref:`tree_mathematical_formulation`. + + splitter : {"best", "random"}, default="best" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split and "random" to choose + the best random split. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int or float, default=2 + The minimum number of samples required to split an internal node: + + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a fraction and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_samples_leaf : int or float, default=1 + The minimum number of samples required to be at a leaf node. + A split point at any depth will only be considered if it leaves at + least ``min_samples_leaf`` training samples in each of the left and + right branches. This may have the effect of smoothing the model, + especially in regression. + + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a fraction and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_weight_fraction_leaf : float, default=0.0 + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_features : int, float or {"auto", "sqrt", "log2"}, default=None + The number of features to consider when looking for the best split: + + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a fraction and + `max(1, int(max_features * n_features_in_))` features are considered at + each split. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + random_state : int, RandomState instance or None, default=None + Controls the randomness of the estimator. The features are always + randomly permuted at each split, even if ``splitter`` is set to + ``"best"``. When ``max_features < n_features``, the algorithm will + select ``max_features`` at random at each split before finding the best + split among them. But the best found split may vary across different + runs, even if ``max_features=n_features``. That is the case, if the + improvement of the criterion is identical for several splits and one + split has to be selected at random. To obtain a deterministic behaviour + during fitting, ``random_state`` has to be fixed to an integer. + See :term:`Glossary ` for details. + + max_leaf_nodes : int, default=None + Grow a tree with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + min_impurity_decrease : float, default=0.0 + A node will be split if this split induces a decrease of the impurity + greater than or equal to this value. + + The weighted impurity decrease equation is the following:: + + N_t / N * (impurity - N_t_R / N_t * right_impurity + - N_t_L / N_t * left_impurity) + + where ``N`` is the total number of samples, ``N_t`` is the number of + samples at the current node, ``N_t_L`` is the number of samples in the + left child, and ``N_t_R`` is the number of samples in the right child. + + ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, + if ``sample_weight`` is passed. + + .. versionadded:: 0.19 + + class_weight : dict, list of dict or "balanced", default=None + Weights associated with classes in the form ``{class_label: weight}``. + If None, all classes are supposed to have weight one. For + multi-output problems, a list of dicts can be provided in the same + order as the columns of y. + + Note that for multioutput (including multilabel) weights should be + defined for each class of every column in its own dict. For example, + for four-class multilabel classification weights should be + [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of + [{1:1}, {2:5}, {3:1}, {4:1}]. + + The "balanced" mode uses the values of y to automatically adjust + weights inversely proportional to class frequencies in the input data + as ``n_samples / (n_classes * np.bincount(y))`` + + For multi-output, the weights of each column of y will be multiplied. + + Note that these weights will be multiplied with sample_weight (passed + through the fit method) if sample_weight is specified. + + ccp_alpha : non-negative float, default=0.0 + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + + store_leaf_values : bool, default=False + Whether to store the samples that fall into leaves in the ``tree_`` attribute. + Each leaf will store a 2D array corresponding to the samples that fall into it + keyed by node_id. + + XXX: This is currently experimental and may change without notice. + Moreover, it can be improved upon since storing the samples twice is not ideal. + One could instead store the indices in ``y_train`` that fall into each leaf, + which would lower RAM/diskspace usage. + + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + + Attributes + ---------- + classes_ : ndarray of shape (n_classes,) or list of ndarray + The classes labels (single output problem), + or a list of arrays of class labels (multi-output problem). + + feature_importances_ : ndarray of shape (n_features,) + The impurity-based feature importances. + The higher, the more important the feature. + The importance of a feature is computed as the (normalized) + total reduction of the criterion brought by that feature. It is also + known as the Gini importance [4]_. + + Warning: impurity-based feature importances can be misleading for + high cardinality features (many unique values). See + :func:`sklearn.inspection.permutation_importance` as an alternative. + + max_features_ : int + The inferred value of max_features. + + n_classes_ : int or list of int + The number of classes (for single output problems), + or a list containing the number of classes for each + output (for multi-output problems). + + n_features_in_ : int + Number of features seen during :term:`fit`. + + .. versionadded:: 0.24 + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Defined only when `X` + has feature names that are all strings. + + .. versionadded:: 1.0 + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + tree_ : Tree instance + The underlying Tree object. Please refer to + ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object and + :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` + for basic usage of these attributes. + + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + + See Also + -------- + DecisionTreeRegressor : A decision tree regressor. + + Notes + ----- + The default values for the parameters controlling the size of the trees + (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and + unpruned trees which can potentially be very large on some data sets. To + reduce memory consumption, the complexity and size of the trees should be + controlled by setting those parameter values. + + The :meth:`predict` method operates using the :func:`numpy.argmax` + function on the outputs of :meth:`predict_proba`. This means that in + case the highest predicted probabilities are tied, the classifier will + predict the tied class with the lowest index in :term:`classes_`. + + References + ---------- + + .. [1] https://en.wikipedia.org/wiki/Decision_tree_learning + + .. [2] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification + and Regression Trees", Wadsworth, Belmont, CA, 1984. + + .. [3] T. Hastie, R. Tibshirani and J. Friedman. "Elements of Statistical + Learning", Springer, 2009. + + .. [4] L. Breiman, and A. Cutler, "Random Forests", + https://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm + + Examples + -------- + >>> from sklearn.datasets import load_iris + >>> from sklearn.model_selection import cross_val_score + >>> from sklearn.tree import DecisionTreeClassifier + >>> clf = DecisionTreeClassifier(random_state=0) + >>> iris = load_iris() + >>> cross_val_score(clf, iris.data, iris.target, cv=10) + ... # doctest: +SKIP + ... + array([ 1. , 0.93..., 0.86..., 0.93..., 0.93..., + 0.93..., 0.93..., 1. , 0.93..., 1. ]) + """ + + _parameter_constraints: dict = { + **BaseDecisionTree._parameter_constraints, + "criterion": [ + StrOptions({"gini", "entropy", "log_loss"}), + Hidden(BaseCriterion), + ], + "class_weight": [dict, list, StrOptions({"balanced"}), None], + } + + def __init__( + self, + *, + criterion="gini", + splitter="best", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features=None, + random_state=None, + max_leaf_nodes=None, + min_impurity_decrease=0.0, + class_weight=None, + ccp_alpha=0.0, + store_leaf_values=False, + monotonic_cst=None, + ): + super().__init__( + criterion=criterion, + splitter=splitter, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + min_weight_fraction_leaf=min_weight_fraction_leaf, + max_features=max_features, + max_leaf_nodes=max_leaf_nodes, + class_weight=class_weight, + random_state=random_state, + min_impurity_decrease=min_impurity_decrease, + monotonic_cst=monotonic_cst, + ccp_alpha=ccp_alpha, + store_leaf_values=store_leaf_values, + ) + + @_fit_context(prefer_skip_nested_validation=True) + def fit( + self, + X, + y, + sample_weight=None, + check_input=True, + classes=None, + ): + """Build a decision tree classifier from the training set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you're doing. + + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + + Returns + ------- + self : DecisionTreeClassifier + Fitted estimator. + """ + super()._fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + classes=classes, + ) + return self + + def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): + """Update a decision tree classifier from the training set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you do. + + Returns + ------- + self : DecisionTreeClassifier + Fitted estimator. + """ + + first_call = _check_partial_fit_first_call(self, classes=classes) + + # Fit if no tree exists yet + if first_call: + self.fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + classes=classes, + ) + return self + + if check_input: + # Need to validate separately here. + # We can't pass multi_ouput=True because that would allow y to be + # csr. + check_X_params = dict(dtype=DTYPE, accept_sparse="csc") + check_y_params = dict(ensure_2d=False, dtype=None) + X, y = self._validate_data( + X, y, reset=False, validate_separately=(check_X_params, check_y_params) + ) + if issparse(X): + X.sort_indices() - cdef SIZE_t min_samples_split # Minimum number of samples in an internal node - cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf - cdef double min_weight_leaf # Minimum weight in a leaf - cdef SIZE_t max_depth # Maximal tree depth - cdef double min_impurity_decrease # Impurity threshold for early stopping + if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: + raise ValueError( + "No support for np.int64 index based sparse matrices" + ) - cdef unsigned char store_leaf_values # Whether to store leaf values + if X.shape[1] != self.n_features_in_: + msg = "Number of features %d does not match previous data %d." + raise ValueError(msg % (X.shape[1], self.n_features_in_)) - cpdef build( + y = np.atleast_1d(y) + + if y.ndim == 1: + # reshape is necessary to preserve the data contiguity against vs + # [:, np.newaxis] that does not. + y = np.reshape(y, (-1, 1)) + + check_classification_targets(y) + y = np.copy(y) + + classes = self.classes_ + if self.n_outputs_ == 1: + classes = [classes] + + y_encoded = np.zeros(y.shape, dtype=int) + for i in range(X.shape[0]): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(classes[j] == y[i, j])[0][0] + y = y_encoded + + if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: + y = np.ascontiguousarray(y, dtype=DOUBLE) + + # Update tree + self.builder_.initialize_node_queue(self.tree_, X, y, sample_weight) + self.builder_.build(self.tree_, X, y, sample_weight) + + self._prune_tree() + + return self + + def predict_proba(self, X, check_input=True): + """Predict class probabilities of the input samples X. + + The predicted class probability is the fraction of samples of the same + class in a leaf. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you're doing. + + Returns + ------- + proba : ndarray of shape (n_samples, n_classes) or list of n_outputs \ + such arrays if n_outputs > 1 + The class probabilities of the input samples. The order of the + classes corresponds to that in the attribute :term:`classes_`. + """ + check_is_fitted(self) + X = self._validate_X_predict(X, check_input) + proba = self.tree_.predict(X) + + if self.n_outputs_ == 1: + proba = proba[:, : self.n_classes_] + normalizer = proba.sum(axis=1)[:, np.newaxis] + normalizer[normalizer == 0.0] = 1.0 + proba /= normalizer + + return proba + + else: + all_proba = [] + + for k in range(self.n_outputs_): + proba_k = proba[:, k, : self.n_classes_[k]] + normalizer = proba_k.sum(axis=1)[:, np.newaxis] + normalizer[normalizer == 0.0] = 1.0 + proba_k /= normalizer + all_proba.append(proba_k) + + return all_proba + + def predict_log_proba(self, X): + """Predict class log-probabilities of the input samples X. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + Returns + ------- + proba : ndarray of shape (n_samples, n_classes) or list of n_outputs \ + such arrays if n_outputs > 1 + The class log-probabilities of the input samples. The order of the + classes corresponds to that in the attribute :term:`classes_`. + """ + proba = self.predict_proba(X) + + if self.n_outputs_ == 1: + return np.log(proba) + + else: + for k in range(self.n_outputs_): + proba[k] = np.log(proba[k]) + + return proba + + def _more_tags(self): + # XXX: nan is only support for dense arrays, but we set this for common test to + # pass, specifically: check_estimators_nan_inf + allow_nan = self.splitter == "best" and self.criterion in { + "gini", + "log_loss", + "entropy", + } + return {"multilabel": True, "allow_nan": allow_nan} + + +class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): + """A decision tree regressor. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + criterion : {"squared_error", "friedman_mse", "absolute_error", \ + "poisson"}, default="squared_error" + The function to measure the quality of a split. Supported criteria + are "squared_error" for the mean squared error, which is equal to + variance reduction as feature selection criterion and minimizes the L2 + loss using the mean of each terminal node, "friedman_mse", which uses + mean squared error with Friedman's improvement score for potential + splits, "absolute_error" for the mean absolute error, which minimizes + the L1 loss using the median of each terminal node, and "poisson" which + uses reduction in Poisson deviance to find splits. + + .. versionadded:: 0.18 + Mean Absolute Error (MAE) criterion. + + .. versionadded:: 0.24 + Poisson deviance criterion. + + splitter : {"best", "random"}, default="best" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split and "random" to choose + the best random split. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int or float, default=2 + The minimum number of samples required to split an internal node: + + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a fraction and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_samples_leaf : int or float, default=1 + The minimum number of samples required to be at a leaf node. + A split point at any depth will only be considered if it leaves at + least ``min_samples_leaf`` training samples in each of the left and + right branches. This may have the effect of smoothing the model, + especially in regression. + + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a fraction and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_weight_fraction_leaf : float, default=0.0 + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_features : int, float or {"auto", "sqrt", "log2"}, default=None + The number of features to consider when looking for the best split: + + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a fraction and + `max(1, int(max_features * n_features_in_))` features are considered at each + split. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + random_state : int, RandomState instance or None, default=None + Controls the randomness of the estimator. The features are always + randomly permuted at each split, even if ``splitter`` is set to + ``"best"``. When ``max_features < n_features``, the algorithm will + select ``max_features`` at random at each split before finding the best + split among them. But the best found split may vary across different + runs, even if ``max_features=n_features``. That is the case, if the + improvement of the criterion is identical for several splits and one + split has to be selected at random. To obtain a deterministic behaviour + during fitting, ``random_state`` has to be fixed to an integer. + See :term:`Glossary ` for details. + + max_leaf_nodes : int, default=None + Grow a tree with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + min_impurity_decrease : float, default=0.0 + A node will be split if this split induces a decrease of the impurity + greater than or equal to this value. + + The weighted impurity decrease equation is the following:: + + N_t / N * (impurity - N_t_R / N_t * right_impurity + - N_t_L / N_t * left_impurity) + + where ``N`` is the total number of samples, ``N_t`` is the number of + samples at the current node, ``N_t_L`` is the number of samples in the + left child, and ``N_t_R`` is the number of samples in the right child. + + ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, + if ``sample_weight`` is passed. + + .. versionadded:: 0.19 + + ccp_alpha : non-negative float, default=0.0 + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + + store_leaf_values : bool, default=False + Whether to store the samples that fall into leaves in the ``tree_`` attribute. + Each leaf will store a 2D array corresponding to the samples that fall into it + keyed by node_id. + + XXX: This is currently experimental and may change without notice. + Moreover, it can be improved upon since storing the samples twice is not ideal. + One could instead store the indices in ``y_train`` that fall into each leaf, + which would lower RAM/diskspace usage. + + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multioutput regressions (i.e. when `n_outputs_ > 1`), + - regressions trained on data with missing values. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + + Attributes + ---------- + feature_importances_ : ndarray of shape (n_features,) + The feature importances. + The higher, the more important the feature. + The importance of a feature is computed as the + (normalized) total reduction of the criterion brought + by that feature. It is also known as the Gini importance [4]_. + + Warning: impurity-based feature importances can be misleading for + high cardinality features (many unique values). See + :func:`sklearn.inspection.permutation_importance` as an alternative. + + max_features_ : int + The inferred value of max_features. + + n_features_in_ : int + Number of features seen during :term:`fit`. + + .. versionadded:: 0.24 + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Defined only when `X` + has feature names that are all strings. + + .. versionadded:: 1.0 + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + tree_ : Tree instance + The underlying Tree object. Please refer to + ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object and + :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` + for basic usage of these attributes. + + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + + See Also + -------- + DecisionTreeClassifier : A decision tree classifier. + + Notes + ----- + The default values for the parameters controlling the size of the trees + (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and + unpruned trees which can potentially be very large on some data sets. To + reduce memory consumption, the complexity and size of the trees should be + controlled by setting those parameter values. + + References + ---------- + + .. [1] https://en.wikipedia.org/wiki/Decision_tree_learning + + .. [2] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification + and Regression Trees", Wadsworth, Belmont, CA, 1984. + + .. [3] T. Hastie, R. Tibshirani and J. Friedman. "Elements of Statistical + Learning", Springer, 2009. + + .. [4] L. Breiman, and A. Cutler, "Random Forests", + https://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm + + Examples + -------- + >>> from sklearn.datasets import load_diabetes + >>> from sklearn.model_selection import cross_val_score + >>> from sklearn.tree import DecisionTreeRegressor + >>> X, y = load_diabetes(return_X_y=True) + >>> regressor = DecisionTreeRegressor(random_state=0) + >>> cross_val_score(regressor, X, y, cv=10) + ... # doctest: +SKIP + ... + array([-0.39..., -0.46..., 0.02..., 0.06..., -0.50..., + 0.16..., 0.11..., -0.73..., -0.30..., -0.00...]) + """ + + _parameter_constraints: dict = { + **BaseDecisionTree._parameter_constraints, + "criterion": [ + StrOptions({"squared_error", "friedman_mse", "absolute_error", "poisson"}), + Hidden(BaseCriterion), + ], + } + + def __init__( + self, + *, + criterion="squared_error", + splitter="best", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features=None, + random_state=None, + max_leaf_nodes=None, + min_impurity_decrease=0.0, + ccp_alpha=0.0, + store_leaf_values=False, + monotonic_cst=None, + ): + super().__init__( + criterion=criterion, + splitter=splitter, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + min_weight_fraction_leaf=min_weight_fraction_leaf, + max_features=max_features, + max_leaf_nodes=max_leaf_nodes, + random_state=random_state, + min_impurity_decrease=min_impurity_decrease, + ccp_alpha=ccp_alpha, + store_leaf_values=store_leaf_values, + monotonic_cst=monotonic_cst, + ) + + @_fit_context(prefer_skip_nested_validation=True) + def fit( self, - Tree tree, - object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight=*, - const unsigned char[::1] missing_values_in_feature_mask=*, - ) - - cdef _check_input( + X, + y, + sample_weight=None, + check_input=True, + classes=None, + ): + """Build a decision tree regressor from the training set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (real numbers). Use ``dtype=np.float64`` and + ``order='C'`` for maximum efficiency. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you're doing. + + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + + Returns + ------- + self : DecisionTreeRegressor + Fitted estimator. + """ + + super()._fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + classes=classes, + ) + return self + + def _compute_partial_dependence_recursion(self, grid, target_features): + """Fast partial dependence computation. + + Parameters + ---------- + grid : ndarray of shape (n_samples, n_target_features) + The grid points on which the partial dependence should be + evaluated. + target_features : ndarray of shape (n_target_features) + The set of target features for which the partial dependence + should be evaluated. + + Returns + ------- + averaged_predictions : ndarray of shape (n_samples,) + The value of the partial dependence function on each grid point. + """ + grid = np.asarray(grid, dtype=DTYPE, order="C") + averaged_predictions = np.zeros( + shape=grid.shape[0], dtype=np.float64, order="C" + ) + + self.tree_.compute_partial_dependence( + grid, target_features, averaged_predictions + ) + return averaged_predictions + + def _more_tags(self): + # XXX: nan is only support for dense arrays, but we set this for common test to + # pass, specifically: check_estimators_nan_inf + allow_nan = self.splitter == "best" and self.criterion in { + "squared_error", + "friedman_mse", + "poisson", + } + return {"allow_nan": allow_nan} + + +class ExtraTreeClassifier(DecisionTreeClassifier): + """An extremely randomized tree classifier. + + Extra-trees differ from classic decision trees in the way they are built. + When looking for the best split to separate the samples of a node into two + groups, random splits are drawn for each of the `max_features` randomly + selected features and the best split among those is chosen. When + `max_features` is set 1, this amounts to building a totally random + decision tree. + + Warning: Extra-trees should only be used within ensemble methods. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + criterion : {"gini", "entropy", "log_loss"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity and "log_loss" and "entropy" both for the + Shannon information gain, see :ref:`tree_mathematical_formulation`. + + splitter : {"random", "best"}, default="random" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split and "random" to choose + the best random split. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int or float, default=2 + The minimum number of samples required to split an internal node: + + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a fraction and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_samples_leaf : int or float, default=1 + The minimum number of samples required to be at a leaf node. + A split point at any depth will only be considered if it leaves at + least ``min_samples_leaf`` training samples in each of the left and + right branches. This may have the effect of smoothing the model, + especially in regression. + + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a fraction and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_weight_fraction_leaf : float, default=0.0 + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_features : int, float, {"auto", "sqrt", "log2"} or None, default="sqrt" + The number of features to consider when looking for the best split: + + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a fraction and + `max(1, int(max_features * n_features_in_))` features are considered at + each split. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + + .. versionchanged:: 1.1 + The default of `max_features` changed from `"auto"` to `"sqrt"`. + + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + random_state : int, RandomState instance or None, default=None + Used to pick randomly the `max_features` used at each split. + See :term:`Glossary ` for details. + + max_leaf_nodes : int, default=None + Grow a tree with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + min_impurity_decrease : float, default=0.0 + A node will be split if this split induces a decrease of the impurity + greater than or equal to this value. + + The weighted impurity decrease equation is the following:: + + N_t / N * (impurity - N_t_R / N_t * right_impurity + - N_t_L / N_t * left_impurity) + + where ``N`` is the total number of samples, ``N_t`` is the number of + samples at the current node, ``N_t_L`` is the number of samples in the + left child, and ``N_t_R`` is the number of samples in the right child. + + ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, + if ``sample_weight`` is passed. + + .. versionadded:: 0.19 + + class_weight : dict, list of dict or "balanced", default=None + Weights associated with classes in the form ``{class_label: weight}``. + If None, all classes are supposed to have weight one. For + multi-output problems, a list of dicts can be provided in the same + order as the columns of y. + + Note that for multioutput (including multilabel) weights should be + defined for each class of every column in its own dict. For example, + for four-class multilabel classification weights should be + [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of + [{1:1}, {2:5}, {3:1}, {4:1}]. + + The "balanced" mode uses the values of y to automatically adjust + weights inversely proportional to class frequencies in the input data + as ``n_samples / (n_classes * np.bincount(y))`` + + For multi-output, the weights of each column of y will be multiplied. + + Note that these weights will be multiplied with sample_weight (passed + through the fit method) if sample_weight is specified. + + ccp_alpha : non-negative float, default=0.0 + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + + store_leaf_values : bool, default=False + Whether to store the samples that fall into leaves in the ``tree_`` attribute. + Each leaf will store a 2D array corresponding to the samples that fall into it + keyed by node_id. + + XXX: This is currently experimental and may change without notice. + Moreover, it can be improved upon since storing the samples twice is not ideal. + One could instead store the indices in ``y_train`` that fall into each leaf, + which would lower RAM/diskspace usage. + + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multiclass classifications (i.e. when `n_classes > 2`), + - multioutput classifications (i.e. when `n_outputs_ > 1`), + - classifications trained on data with missing values. + + The constraints hold over the probability of the positive class. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + + Attributes + ---------- + classes_ : ndarray of shape (n_classes,) or list of ndarray + The classes labels (single output problem), + or a list of arrays of class labels (multi-output problem). + + max_features_ : int + The inferred value of max_features. + + n_classes_ : int or list of int + The number of classes (for single output problems), + or a list containing the number of classes for each + output (for multi-output problems). + + feature_importances_ : ndarray of shape (n_features,) + The impurity-based feature importances. + The higher, the more important the feature. + The importance of a feature is computed as the (normalized) + total reduction of the criterion brought by that feature. It is also + known as the Gini importance. + + Warning: impurity-based feature importances can be misleading for + high cardinality features (many unique values). See + :func:`sklearn.inspection.permutation_importance` as an alternative. + + n_features_in_ : int + Number of features seen during :term:`fit`. + + .. versionadded:: 0.24 + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Defined only when `X` + has feature names that are all strings. + + .. versionadded:: 1.0 + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + tree_ : Tree instance + The underlying Tree object. Please refer to + ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object and + :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` + for basic usage of these attributes. + + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + + See Also + -------- + ExtraTreeRegressor : An extremely randomized tree regressor. + sklearn.ensemble.ExtraTreesClassifier : An extra-trees classifier. + sklearn.ensemble.ExtraTreesRegressor : An extra-trees regressor. + sklearn.ensemble.RandomForestClassifier : A random forest classifier. + sklearn.ensemble.RandomForestRegressor : A random forest regressor. + sklearn.ensemble.RandomTreesEmbedding : An ensemble of + totally random trees. + + Notes + ----- + The default values for the parameters controlling the size of the trees + (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and + unpruned trees which can potentially be very large on some data sets. To + reduce memory consumption, the complexity and size of the trees should be + controlled by setting those parameter values. + + References + ---------- + + .. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized trees", + Machine Learning, 63(1), 3-42, 2006. + + Examples + -------- + >>> from sklearn.datasets import load_iris + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.ensemble import BaggingClassifier + >>> from sklearn.tree import ExtraTreeClassifier + >>> X, y = load_iris(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, random_state=0) + >>> extra_tree = ExtraTreeClassifier(random_state=0) + >>> cls = BaggingClassifier(extra_tree, random_state=0).fit( + ... X_train, y_train) + >>> cls.score(X_test, y_test) + 0.8947... + """ + + def __init__( + self, + *, + criterion="gini", + splitter="random", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features="sqrt", + random_state=None, + max_leaf_nodes=None, + min_impurity_decrease=0.0, + class_weight=None, + ccp_alpha=0.0, + store_leaf_values=False, + monotonic_cst=None, + ): + super().__init__( + criterion=criterion, + splitter=splitter, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + min_weight_fraction_leaf=min_weight_fraction_leaf, + max_features=max_features, + max_leaf_nodes=max_leaf_nodes, + class_weight=class_weight, + min_impurity_decrease=min_impurity_decrease, + random_state=random_state, + ccp_alpha=ccp_alpha, + store_leaf_values=store_leaf_values, + monotonic_cst=monotonic_cst, + ) + + +class ExtraTreeRegressor(DecisionTreeRegressor): + """An extremely randomized tree regressor. + + Extra-trees differ from classic decision trees in the way they are built. + When looking for the best split to separate the samples of a node into two + groups, random splits are drawn for each of the `max_features` randomly + selected features and the best split among those is chosen. When + `max_features` is set 1, this amounts to building a totally random + decision tree. + + Warning: Extra-trees should only be used within ensemble methods. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + criterion : {"squared_error", "friedman_mse", "absolute_error", "poisson"}, \ + default="squared_error" + The function to measure the quality of a split. Supported criteria + are "squared_error" for the mean squared error, which is equal to + variance reduction as feature selection criterion and minimizes the L2 + loss using the mean of each terminal node, "friedman_mse", which uses + mean squared error with Friedman's improvement score for potential + splits, "absolute_error" for the mean absolute error, which minimizes + the L1 loss using the median of each terminal node, and "poisson" which + uses reduction in Poisson deviance to find splits. + + .. versionadded:: 0.18 + Mean Absolute Error (MAE) criterion. + + .. versionadded:: 0.24 + Poisson deviance criterion. + + splitter : {"random", "best"}, default="random" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split and "random" to choose + the best random split. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int or float, default=2 + The minimum number of samples required to split an internal node: + + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a fraction and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_samples_leaf : int or float, default=1 + The minimum number of samples required to be at a leaf node. + A split point at any depth will only be considered if it leaves at + least ``min_samples_leaf`` training samples in each of the left and + right branches. This may have the effect of smoothing the model, + especially in regression. + + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a fraction and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + + .. versionchanged:: 0.18 + Added float values for fractions. + + min_weight_fraction_leaf : float, default=0.0 + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_features : int, float, {"auto", "sqrt", "log2"} or None, default=1.0 + The number of features to consider when looking for the best split: + + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a fraction and + `max(1, int(max_features * n_features_in_))` features are considered at each + split. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + + .. versionchanged:: 1.1 + The default of `max_features` changed from `"auto"` to `1.0`. + + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + random_state : int, RandomState instance or None, default=None + Used to pick randomly the `max_features` used at each split. + See :term:`Glossary ` for details. + + min_impurity_decrease : float, default=0.0 + A node will be split if this split induces a decrease of the impurity + greater than or equal to this value. + + The weighted impurity decrease equation is the following:: + + N_t / N * (impurity - N_t_R / N_t * right_impurity + - N_t_L / N_t * left_impurity) + + where ``N`` is the total number of samples, ``N_t`` is the number of + samples at the current node, ``N_t_L`` is the number of samples in the + left child, and ``N_t_R`` is the number of samples in the right child. + + ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, + if ``sample_weight`` is passed. + + .. versionadded:: 0.19 + + max_leaf_nodes : int, default=None + Grow a tree with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + ccp_alpha : non-negative float, default=0.0 + Complexity parameter used for Minimal Cost-Complexity Pruning. The + subtree with the largest cost complexity that is smaller than + ``ccp_alpha`` will be chosen. By default, no pruning is performed. See + :ref:`minimal_cost_complexity_pruning` for details. + + .. versionadded:: 0.22 + + store_leaf_values : bool, default=False + Whether to store the samples that fall into leaves in the ``tree_`` attribute. + Each leaf will store a 2D array corresponding to the samples that fall into it + keyed by node_id. + + XXX: This is currently experimental and may change without notice. + Moreover, it can be improved upon since storing the samples twice is not ideal. + One could instead store the indices in ``y_train`` that fall into each leaf, + which would lower RAM/diskspace usage. + + monotonic_cst : array-like of int of shape (n_features), default=None + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If monotonic_cst is None, no constraints are applied. + + Monotonicity constraints are not supported for: + - multioutput regressions (i.e. when `n_outputs_ > 1`), + - regressions trained on data with missing values. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.4 + + Attributes + ---------- + max_features_ : int + The inferred value of max_features. + + n_features_in_ : int + Number of features seen during :term:`fit`. + + .. versionadded:: 0.24 + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Defined only when `X` + has feature names that are all strings. + + .. versionadded:: 1.0 + + feature_importances_ : ndarray of shape (n_features,) + Return impurity-based feature importances (the higher, the more + important the feature). + + Warning: impurity-based feature importances can be misleading for + high cardinality features (many unique values). See + :func:`sklearn.inspection.permutation_importance` as an alternative. + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + tree_ : Tree instance + The underlying Tree object. Please refer to + ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object and + :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` + for basic usage of these attributes. + + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + + See Also + -------- + ExtraTreeClassifier : An extremely randomized tree classifier. + sklearn.ensemble.ExtraTreesClassifier : An extra-trees classifier. + sklearn.ensemble.ExtraTreesRegressor : An extra-trees regressor. + + Notes + ----- + The default values for the parameters controlling the size of the trees + (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and + unpruned trees which can potentially be very large on some data sets. To + reduce memory consumption, the complexity and size of the trees should be + controlled by setting those parameter values. + + References + ---------- + + .. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized trees", + Machine Learning, 63(1), 3-42, 2006. + + Examples + -------- + >>> from sklearn.datasets import load_diabetes + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.ensemble import BaggingRegressor + >>> from sklearn.tree import ExtraTreeRegressor + >>> X, y = load_diabetes(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, random_state=0) + >>> extra_tree = ExtraTreeRegressor(random_state=0) + >>> reg = BaggingRegressor(extra_tree, random_state=0).fit( + ... X_train, y_train) + >>> reg.score(X_test, y_test) + 0.33... + """ + + def __init__( self, - object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, - ) + *, + criterion="squared_error", + splitter="random", + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features=1.0, + random_state=None, + min_impurity_decrease=0.0, + max_leaf_nodes=None, + ccp_alpha=0.0, + store_leaf_values=False, + monotonic_cst=None, + ): + super().__init__( + criterion=criterion, + splitter=splitter, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + min_weight_fraction_leaf=min_weight_fraction_leaf, + max_features=max_features, + max_leaf_nodes=max_leaf_nodes, + min_impurity_decrease=min_impurity_decrease, + random_state=random_state, + ccp_alpha=ccp_alpha, + store_leaf_values=store_leaf_values, + monotonic_cst=monotonic_cst, + ) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 492b5219fa18e..e307a53b850a5 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -12,6 +12,7 @@ # Fares Hedayati # Jacob Schreiber # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause @@ -91,6 +92,17 @@ NODE_DTYPE = np.asarray((&dummy)).dtype cdef class TreeBuilder: """Interface for different tree building strategies.""" + cpdef initialize_node_queue( + self, + Tree tree, + object X, + const DOUBLE_t[:, ::1] y, + const DOUBLE_t[:] sample_weight=None, + const unsigned char[::1] missing_values_in_feature_mask=None, + ): + """Build a decision tree from the training set (X, y).""" + pass + cpdef build( self, Tree tree, @@ -165,7 +177,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): double min_weight_leaf, SIZE_t max_depth, double min_impurity_decrease, - unsigned char store_leaf_values=False + unsigned char store_leaf_values=False, + cnp.ndarray initial_roots=None, ): self.splitter = splitter self.min_samples_split = min_samples_split @@ -174,6 +187,71 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): self.max_depth = max_depth self.min_impurity_decrease = min_impurity_decrease self.store_leaf_values = store_leaf_values + self.initial_roots = initial_roots + + def __reduce__(self): + """Reduce re-implementation, for pickling.""" + return(DepthFirstTreeBuilder, (self.splitter, self.min_samples_split, + self.min_samples_leaf, + self.min_weight_leaf, + self.max_depth, + self.min_impurity_decrease, + self.store_leaf_values, + self.initial_roots)) + + cpdef initialize_node_queue( + self, + Tree tree, + object X, + const DOUBLE_t[:, ::1] y, + const DOUBLE_t[:] sample_weight=None, + const unsigned char[::1] missing_values_in_feature_mask=None, + ): + """Initialize a list of roots""" + X, y, sample_weight = self._check_input(X, y, sample_weight) + + # organize samples by decision paths + paths = tree.decision_path(X) + cdef int PARENT + cdef int CHILD + false_roots = {} + X_copy = {} + y_copy = {} + for i in range(X.shape[0]): + depth_i = paths[i].indices.shape[0] - 1 + PARENT = depth_i - 1 + CHILD = depth_i + + if PARENT < 0: + parent_i = 0 + else: + parent_i = paths[i].indices[PARENT] + child_i = paths[i].indices[CHILD] + left = 0 + if tree.children_left[parent_i] == child_i: + left = 1 + + if (parent_i, left) in false_roots: + false_roots[(parent_i, left)][0] += 1 + X_copy[(parent_i, left)].append(X[i]) + y_copy[(parent_i, left)].append(y[i]) + else: + false_roots[(parent_i, left)] = [1, depth_i] + X_copy[(parent_i, left)] = [X[i]] + y_copy[(parent_i, left)] = [y[i]] + + X_list = [] + y_list = [] + for key, value in reversed(sorted(X_copy.items())): + X_list = X_list + value + y_list = y_list + y_copy[key] + cdef object X_new = np.array(X_list) + cdef cnp.ndarray y_new = np.array(y_list) + + cdef Splitter splitter = self.splitter + splitter.init(X_new, y_new, sample_weight, missing_values_in_feature_mask) + + self.initial_roots = false_roots cpdef build( self, @@ -206,11 +284,14 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split cdef double min_impurity_decrease = self.min_impurity_decrease - # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight, missing_values_in_feature_mask) + cdef bint first = 0 + if self.initial_roots is None: + # Recursive partition (without actual recursion) + splitter.init(X, y, sample_weight, missing_values_in_feature_mask) + first = 1 - cdef SIZE_t start - cdef SIZE_t end + cdef SIZE_t start = 0 + cdef SIZE_t end = 0 cdef SIZE_t depth cdef SIZE_t parent cdef bint is_left @@ -227,14 +308,33 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double middle_value cdef SIZE_t n_constant_features cdef bint is_leaf - cdef bint first = 1 - cdef SIZE_t max_depth_seen = -1 + cdef SIZE_t max_depth_seen = -1 if first else tree.max_depth cdef int rc = 0 cdef stack[StackRecord] builder_stack + cdef stack[StackRecord] update_stack cdef StackRecord stack_record - with nogil: + if not first: + # push reached leaf nodes onto stack + for key, value in reversed(sorted(self.initial_roots.items())): + end += value[0] + update_stack.push({ + "start": start, + "end": end, + "depth": value[1], + "parent": key[0], + "is_left": key[1], + "impurity": tree.impurity[key[0]], + "n_constant_features": 0, + "lower_bound": -INFINITY, + "upper_bound": INFINITY, + }) + start += value[0] + if rc == -1: + # got return code -1 - out-of-memory + raise MemoryError() + else: # push root node onto stack builder_stack.push({ "start": 0, @@ -247,6 +347,135 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): "lower_bound": -INFINITY, "upper_bound": INFINITY, }) + if rc == -1: + # got return code -1 - out-of-memory + raise MemoryError() + + with nogil: + while not update_stack.empty(): + stack_record = update_stack.top() + update_stack.pop() + + start = stack_record.start + end = stack_record.end + depth = stack_record.depth + parent = stack_record.parent + is_left = stack_record.is_left + impurity = stack_record.impurity + n_constant_features = stack_record.n_constant_features + lower_bound = stack_record.lower_bound + upper_bound = stack_record.upper_bound + + n_node_samples = end - start + splitter.node_reset(start, end, &weighted_n_node_samples) + + is_leaf = (depth >= max_depth or + n_node_samples < min_samples_split or + n_node_samples < 2 * min_samples_leaf or + weighted_n_node_samples < 2 * min_weight_leaf) + + # impurity == 0 with tolerance due to rounding errors + is_leaf = is_leaf or impurity <= EPSILON + + if not is_leaf: + splitter.node_split( + impurity, + split_ptr, + &n_constant_features, + lower_bound, + upper_bound + ) + + # assign local copy of SplitRecord to assign + # pos, improvement, and impurity scores + split = deref(split_ptr) + + # If EPSILON=0 in the below comparison, float precision + # issues stop splitting, producing trees that are + # dissimilar to v0.18 + is_leaf = (is_leaf or split.pos >= end or + (split.improvement + EPSILON < + min_impurity_decrease)) + + node_id = tree._update_node(parent, is_left, is_leaf, + split_ptr, impurity, n_node_samples, + weighted_n_node_samples, + split.missing_go_to_left) + + if node_id == INTPTR_MAX: + rc = -1 + break + + # Store value for all nodes, to facilitate tree/model + # inspection and interpretation + splitter.node_value(tree.value + node_id * tree.value_stride) + if splitter.with_monotonic_cst: + splitter.clip_node_value(tree.value + node_id * tree.value_stride, lower_bound, upper_bound) + + if not is_leaf: + if ( + not splitter.with_monotonic_cst or + splitter.monotonic_cst[split.feature] == 0 + ): + # Split on a feature with no monotonicity constraint + + # Current bounds must always be propagated to both children. + # If a monotonic constraint is active, bounds are used in + # node value clipping. + left_child_min = right_child_min = lower_bound + left_child_max = right_child_max = upper_bound + elif splitter.monotonic_cst[split.feature] == 1: + # Split on a feature with monotonic increase constraint + left_child_min = lower_bound + right_child_max = upper_bound + + # Lower bound for right child and upper bound for left child + # are set to the same value. + middle_value = splitter.criterion.middle_value() + right_child_min = middle_value + left_child_max = middle_value + else: # i.e. splitter.monotonic_cst[split.feature] == -1 + # Split on a feature with monotonic decrease constraint + right_child_min = lower_bound + left_child_max = upper_bound + + # Lower bound for left child and upper bound for right child + # are set to the same value. + middle_value = splitter.criterion.middle_value() + left_child_min = middle_value + right_child_max = middle_value + + # Push right child on stack + builder_stack.push({ + "start": split.pos, + "end": end, + "depth": depth + 1, + "parent": node_id, + "is_left": 0, + "impurity": split.impurity_right, + "n_constant_features": n_constant_features, + "lower_bound": right_child_min, + "upper_bound": right_child_max, + }) + + # Push left child on stack + builder_stack.push({ + "start": start, + "end": split.pos, + "depth": depth + 1, + "parent": node_id, + "is_left": 1, + "impurity": split.impurity_left, + "n_constant_features": n_constant_features, + "lower_bound": left_child_min, + "upper_bound": left_child_max, + }) + elif self.store_leaf_values and is_leaf: + # copy leaf values to leaf_values array + splitter.node_samples(tree.value_samples[node_id]) + + if depth > max_depth_seen: + max_depth_seen = depth while not builder_stack.empty(): stack_record = builder_stack.top() @@ -272,7 +501,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if first: impurity = splitter.node_impurity() - first = 0 + first=0 # impurity == 0 with tolerance due to rounding errors is_leaf = is_leaf or impurity <= EPSILON @@ -388,6 +617,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if rc == -1: raise MemoryError() + self.initial_roots = None # Best first builder ---------------------------------------------------------- cdef struct FrontierRecord: @@ -441,6 +671,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): SIZE_t max_leaf_nodes, double min_impurity_decrease, unsigned char store_leaf_values=False, + cnp.ndarray initial_roots=None, ): self.splitter = splitter self.min_samples_split = min_samples_split @@ -450,6 +681,17 @@ cdef class BestFirstTreeBuilder(TreeBuilder): self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.store_leaf_values = store_leaf_values + self.initial_roots = initial_roots + + def __reduce__(self): + """Reduce re-implementation, for pickling.""" + return(BestFirstTreeBuilder, (self.splitter, self.min_samples_split, + self.min_samples_leaf, + self.min_weight_leaf, self.max_depth, + self.max_leaf_nodes, + self.min_impurity_decrease, + self.store_leaf_values, + self.initial_roots)) cpdef build( self, @@ -896,6 +1138,49 @@ cdef class BaseTree: return node_id + cdef SIZE_t _update_node( + self, + SIZE_t parent, + bint is_left, + bint is_leaf, + SplitRecord* split_node, + double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples, + unsigned char missing_go_to_left + ) except -1 nogil: + """Update a node on the tree. + + The updated node remains on the same position. + Returns (size_t)(-1) on error. + """ + cdef SIZE_t node_id + if is_left: + node_id = self.nodes[parent].left_child + else: + node_id = self.nodes[parent].right_child + + if node_id >= self.capacity: + if self._resize_c() != 0: + return INTPTR_MAX + + cdef Node* node = &self.nodes[node_id] + node.impurity = impurity + node.n_node_samples = n_node_samples + node.weighted_n_node_samples = weighted_n_node_samples + + if is_leaf: + if self._set_leaf_node(split_node, node) != 1: + with gil: + raise RuntimeError + else: + if self._set_split_node(split_node, node) != 1: + with gil: + raise RuntimeError + node.missing_go_to_left = missing_go_to_left + + return node_id + cpdef cnp.ndarray apply(self, object X): """Finds the terminal region (=leaf node) for each sample in X.""" if issparse(X): From 4e6d809e07a22843bf8376001b5204b4fcb29f01 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 11 Aug 2023 11:01:08 -0400 Subject: [PATCH 2/7] Fix Signed-off-by: Adam Li --- sklearn/tree/_tree.pxd | 2489 +++------------------------------------- 1 file changed, 162 insertions(+), 2327 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 9d9763810edb7..7bb8c50ee38d8 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -1,2358 +1,193 @@ -""" -This module gathers tree-based methods, including decision, regression and -randomized trees. Single and multi-output problems are both handled. -""" - # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt -# Noel Dawe -# Satrajit Gosh -# Joly Arnaud -# Fares Hedayati +# Joel Nothman +# Arnaud Joly +# Jacob Schreiber # Nelson Liu # Haoyin Xu # # License: BSD 3 clause -import copy -import numbers -from abc import ABCMeta, abstractmethod -from math import ceil -from numbers import Integral, Real +# See _tree.pyx for details. import numpy as np -from scipy.sparse import issparse - -from sklearn.base import ( - BaseEstimator, - ClassifierMixin, - MultiOutputMixin, - RegressorMixin, - _fit_context, - clone, - is_classifier, -) -from sklearn.utils import Bunch, check_random_state, compute_sample_weight -from sklearn.utils._param_validation import Hidden, Interval, RealNotInt, StrOptions -from sklearn.utils.multiclass import ( - _check_partial_fit_first_call, - check_classification_targets, -) -from sklearn.utils.validation import ( - _assert_all_finite_element_wise, - _check_sample_weight, - assert_all_finite, - check_is_fitted, -) - -from . import _criterion, _splitter, _tree -from ._criterion import BaseCriterion -from ._splitter import BaseSplitter -from ._tree import ( - BestFirstTreeBuilder, - DepthFirstTreeBuilder, - Tree, - _build_pruned_tree_ccp, - ccp_pruning_path, -) -from ._utils import _any_isnan_axis0 - -__all__ = [ - "DecisionTreeClassifier", - "DecisionTreeRegressor", - "ExtraTreeClassifier", - "ExtraTreeRegressor", -] +cimport numpy as cnp +from libcpp.unordered_map cimport unordered_map +from libcpp.vector cimport vector -# ============================================================================= -# Types and constants -# ============================================================================= - -DTYPE = _tree.DTYPE -DOUBLE = _tree.DOUBLE +ctypedef cnp.npy_float32 DTYPE_t # Type of X +ctypedef cnp.npy_float64 DOUBLE_t # Type of y, sample_weight +ctypedef cnp.npy_intp SIZE_t # Type for indices and counters +ctypedef cnp.npy_int32 INT32_t # Signed 32 bit integer +ctypedef cnp.npy_uint32 UINT32_t # Unsigned 32 bit integer -CRITERIA_CLF = { - "gini": _criterion.Gini, - "log_loss": _criterion.Entropy, - "entropy": _criterion.Entropy, -} -CRITERIA_REG = { - "squared_error": _criterion.MSE, - "friedman_mse": _criterion.FriedmanMSE, - "absolute_error": _criterion.MAE, - "poisson": _criterion.Poisson, -} +from ._splitter cimport SplitRecord, Splitter -DENSE_SPLITTERS = {"best": _splitter.BestSplitter, "random": _splitter.RandomSplitter} -SPARSE_SPLITTERS = { - "best": _splitter.BestSparseSplitter, - "random": _splitter.RandomSparseSplitter, -} +cdef struct Node: + # Base storage structure for the nodes in a Tree object -# ============================================================================= -# Base decision tree -# ============================================================================= + SIZE_t left_child # id of the left child of the node + SIZE_t right_child # id of the right child of the node + SIZE_t feature # Feature used for splitting the node + DOUBLE_t threshold # Threshold value at the node + DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) + SIZE_t n_node_samples # Number of samples at the node + DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node + unsigned char missing_go_to_left # Whether features have missing values -class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): - """Base class for decision trees. +cdef class BaseTree: + # Inner structures: values are stored separately from node structure, + # since size is determined at runtime. + cdef public SIZE_t max_depth # Max depth of the tree + cdef public SIZE_t node_count # Counter for node IDs + cdef public SIZE_t capacity # Capacity of tree, in terms of nodes + cdef Node* nodes # Array of nodes - Warning: This class should not be used directly. - Use derived classes instead. - """ + cdef SIZE_t value_stride # The dimensionality of a vectorized output per sample + cdef double* value # Array of values prediction values for each node - _parameter_constraints: dict = { - "splitter": [StrOptions({"best", "random"})], - "max_depth": [Interval(Integral, 1, None, closed="left"), None], - "min_samples_split": [ - Interval(Integral, 2, None, closed="left"), - Interval(RealNotInt, 0.0, 1.0, closed="right"), - ], - "min_samples_leaf": [ - Interval(Integral, 1, None, closed="left"), - Interval(RealNotInt, 0.0, 1.0, closed="neither"), - ], - "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")], - "max_features": [ - Interval(Integral, 1, None, closed="left"), - Interval(RealNotInt, 0.0, 1.0, closed="right"), - StrOptions({"sqrt", "log2"}), - None, - ], - "random_state": ["random_state"], - "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None], - "min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")], - "ccp_alpha": [Interval(Real, 0.0, None, closed="left")], - "store_leaf_values": ["boolean"], - "monotonic_cst": ["array-like", None], - } + # Generic Methods: These are generic methods used by any tree. + cdef int _resize(self, SIZE_t capacity) except -1 nogil + cdef int _resize_c(self, SIZE_t capacity=*) except -1 nogil - @abstractmethod - def __init__( + cdef SIZE_t _add_node( self, - *, - criterion, - splitter, - max_depth, - min_samples_split, - min_samples_leaf, - min_weight_fraction_leaf, - max_features, - max_leaf_nodes, - random_state, - min_impurity_decrease, - class_weight=None, - ccp_alpha=0.0, - store_leaf_values=False, - monotonic_cst=None, - ): - self.criterion = criterion - self.splitter = splitter - self.max_depth = max_depth - self.min_samples_split = min_samples_split - self.min_samples_leaf = min_samples_leaf - self.min_weight_fraction_leaf = min_weight_fraction_leaf - self.max_features = max_features - self.max_leaf_nodes = max_leaf_nodes - self.random_state = random_state - self.min_impurity_decrease = min_impurity_decrease - self.class_weight = class_weight - self.ccp_alpha = ccp_alpha - self.store_leaf_values = store_leaf_values - self.monotonic_cst = monotonic_cst - - def get_depth(self): - """Return the depth of the decision tree. - - The depth of a tree is the maximum distance between the root - and any leaf. - - Returns - ------- - self.tree_.max_depth : int - The maximum depth of the tree. - """ - check_is_fitted(self) - return self.tree_.max_depth - - def get_n_leaves(self): - """Return the number of leaves of the decision tree. - - Returns - ------- - self.tree_.n_leaves : int - Number of leaves. - """ - check_is_fitted(self) - return self.tree_.n_leaves - - def _support_missing_values(self, X): - return ( - not issparse(X) - and self._get_tags()["allow_nan"] - and self.monotonic_cst is None - ) - - def _compute_missing_values_in_feature_mask(self, X, estimator_name=None): - """Return boolean mask denoting if there are missing values for each feature. - - This method also ensures that X is finite. - - Parameter - --------- - X : array-like of shape (n_samples, n_features), dtype=DOUBLE - Input data. - - estimator_name : str or None, default=None - Name to use when raising an error. Defaults to the class name. - - Returns - ------- - missing_values_in_feature_mask : ndarray of shape (n_features,), or None - Missing value mask. If missing values are not supported or there - are no missing values, return None. - """ - estimator_name = estimator_name or self.__class__.__name__ - common_kwargs = dict(estimator_name=estimator_name, input_name="X") - - if not self._support_missing_values(X): - assert_all_finite(X, **common_kwargs) - return None - - with np.errstate(over="ignore"): - overall_sum = np.sum(X) - - if not np.isfinite(overall_sum): - # Raise a ValueError in case of the presence of an infinite element. - _assert_all_finite_element_wise(X, xp=np, allow_nan=True, **common_kwargs) - - # If the sum is not nan, then there are no missing values - if not np.isnan(overall_sum): - return None - - missing_values_in_feature_mask = _any_isnan_axis0(X) - return missing_values_in_feature_mask - - def _fit( + SIZE_t parent, + bint is_left, + bint is_leaf, + SplitRecord* split_node, + double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples, + unsigned char missing_go_to_left + ) except -1 nogil + cdef SIZE_t _update_node( self, - X, - y, - classes=None, - sample_weight=None, - check_input=True, - missing_values_in_feature_mask=None, - ): - random_state = check_random_state(self.random_state) - - if check_input: - # Need to validate separately here. - # We can't pass multi_output=True because that would allow y to be - # csr. - - # _compute_missing_values_in_feature_mask will check for finite values and - # compute the missing mask if the tree supports missing values - check_X_params = dict( - dtype=DTYPE, accept_sparse="csc", force_all_finite=False - ) - check_y_params = dict(ensure_2d=False, dtype=None) - if y is not None or self._get_tags()["requires_y"]: - X, y = self._validate_data( - X, y, validate_separately=(check_X_params, check_y_params) - ) - else: - X = self._validate_data(X, **check_X_params) - - missing_values_in_feature_mask = ( - self._compute_missing_values_in_feature_mask(X) - ) - if issparse(X): - X.sort_indices() - - if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: - raise ValueError( - "No support for np.int64 index based sparse matrices" - ) - - if y is not None and self.criterion == "poisson": - if np.any(y < 0): - raise ValueError( - "Some value(s) of y are negative which is" - " not allowed for Poisson regression." - ) - if np.sum(y) <= 0: - raise ValueError( - "Sum of y is not positive which is " - "necessary for Poisson regression." - ) - - # Determine output settings - n_samples, self.n_features_in_ = X.shape - - # Do preprocessing if 'y' is passed - is_classification = False - if y is not None: - is_classification = is_classifier(self) - y = np.atleast_1d(y) - expanded_class_weight = None - - if y.ndim == 1: - # reshape is necessary to preserve the data contiguity against vs - # [:, np.newaxis] that does not. - y = np.reshape(y, (-1, 1)) - - self.n_outputs_ = y.shape[1] - - if is_classification: - check_classification_targets(y) - y = np.copy(y) - - self.classes_ = [] - self.n_classes_ = [] - - if self.class_weight is not None: - y_original = np.copy(y) - - y_encoded = np.zeros(y.shape, dtype=int) - if classes is not None: - classes = np.atleast_1d(classes) - if classes.ndim == 1: - classes = np.array([classes]) - - for k in classes: - self.classes_.append(np.array(k)) - self.n_classes_.append(np.array(k).shape[0]) - - for i in range(n_samples): - for j in range(self.n_outputs_): - y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][ - 0 - ] - else: - for k in range(self.n_outputs_): - classes_k, y_encoded[:, k] = np.unique( - y[:, k], return_inverse=True - ) - self.classes_.append(classes_k) - self.n_classes_.append(classes_k.shape[0]) - - y = y_encoded - - if self.class_weight is not None: - expanded_class_weight = compute_sample_weight( - self.class_weight, y_original - ) - - self.n_classes_ = np.array(self.n_classes_, dtype=np.intp) - - if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: - y = np.ascontiguousarray(y, dtype=DOUBLE) - - if len(y) != n_samples: - raise ValueError( - "Number of labels=%d does not match number of samples=%d" - % (len(y), n_samples) - ) - - # set decision-tree model parameters - max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth - - if isinstance(self.min_samples_leaf, numbers.Integral): - min_samples_leaf = self.min_samples_leaf - else: # float - min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples)) - - if isinstance(self.min_samples_split, numbers.Integral): - min_samples_split = self.min_samples_split - else: # float - min_samples_split = int(ceil(self.min_samples_split * n_samples)) - min_samples_split = max(2, min_samples_split) - - min_samples_split = max(min_samples_split, 2 * min_samples_leaf) - - if isinstance(self.max_features, str): - if self.max_features == "auto": - if is_classification: - max_features = max(1, int(np.sqrt(self.n_features_in_))) - else: - max_features = self.n_features_in_ - elif self.max_features == "sqrt": - max_features = max(1, int(np.sqrt(self.n_features_in_))) - elif self.max_features == "log2": - max_features = max(1, int(np.log2(self.n_features_in_))) - elif self.max_features is None: - max_features = self.n_features_in_ - elif isinstance(self.max_features, numbers.Integral): - max_features = self.max_features - else: # float - if self.max_features > 0.0: - max_features = max(1, int(self.max_features * self.n_features_in_)) - else: - max_features = 0 - - self.max_features_ = max_features - - max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes - - if sample_weight is not None: - sample_weight = _check_sample_weight(sample_weight, X, DOUBLE) - - if y is not None and expanded_class_weight is not None: - if sample_weight is not None: - sample_weight = sample_weight * expanded_class_weight - else: - sample_weight = expanded_class_weight - - # Set min_weight_leaf from min_weight_fraction_leaf - if sample_weight is None: - min_weight_leaf = self.min_weight_fraction_leaf * n_samples - else: - min_weight_leaf = self.min_weight_fraction_leaf * np.sum(sample_weight) - - # build the actual tree now with the parameters - self._build_tree( - X, - y, - sample_weight, - missing_values_in_feature_mask, - min_samples_leaf, - min_weight_leaf, - max_leaf_nodes, - min_samples_split, - max_depth, - random_state, - ) - - return self - - def _build_tree( - self, - X, - y, - sample_weight, - missing_values_in_feature_mask, - min_samples_leaf, - min_weight_leaf, - max_leaf_nodes, - min_samples_split, - max_depth, - random_state, - ): - """Build the actual tree. - - Parameters - ---------- - X : Array-like - X dataset. - y : Array-like - Y targets. - sample_weight : Array-like - Sample weights - min_samples_leaf : float - Number of samples required to be a leaf. - min_weight_leaf : float - Weight of samples required to be a leaf. - max_leaf_nodes : float - Maximum number of leaf nodes allowed in tree. - min_samples_split : float - Minimum number of samples to split on. - max_depth : int - The maximum depth of any tree. - random_state : int - Random seed. - """ - - n_samples = X.shape[0] - - # Build tree - criterion = self.criterion - if not isinstance(criterion, BaseCriterion): - if is_classifier(self): - criterion = CRITERIA_CLF[self.criterion]( - self.n_outputs_, self.n_classes_ - ) - else: - criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) - else: - # Make a deepcopy in case the criterion has mutable attributes that - # might be shared and modified concurrently during parallel fitting - criterion = copy.deepcopy(criterion) - - SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS - - if self.monotonic_cst is None: - monotonic_cst = None - else: - if self.n_outputs_ > 1: - raise ValueError( - "Monotonicity constraints are not supported with multiple outputs." - ) - # Check to correct monotonicity constraint' specification, - # by applying element-wise logical conjunction - # Note: we do not cast `np.asarray(self.monotonic_cst, dtype=np.int8)` - # straight away here so as to generate error messages for invalid - # values using the original values prior to any dtype related conversion. - monotonic_cst = np.asarray(self.monotonic_cst) - if monotonic_cst.shape[0] != X.shape[1]: - raise ValueError( - "monotonic_cst has shape {} but the input data " - "X has {} features.".format(monotonic_cst.shape[0], X.shape[1]) - ) - valid_constraints = np.isin(monotonic_cst, (-1, 0, 1)) - if not np.all(valid_constraints): - unique_constaints_value = np.unique(monotonic_cst) - raise ValueError( - "monotonic_cst must be None or an array-like of -1, 0 or 1, but" - f" got {unique_constaints_value}" - ) - monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) - if is_classifier(self): - if self.n_classes_[0] > 2: - raise ValueError( - "Monotonicity constraints are not supported with multiclass " - "classification" - ) - # Binary classification trees are built by constraining probabilities - # of the *negative class* in order to make the implementation similar - # to regression trees. - # Since self.monotonic_cst encodes constraints on probabilities of the - # *positive class*, all signs must be flipped. - monotonic_cst *= -1 - - if not isinstance(self.splitter, BaseSplitter): - splitter = SPLITTERS[self.splitter]( - criterion, - self.max_features_, - min_samples_leaf, - min_weight_leaf, - random_state, - monotonic_cst, - ) - - if is_classifier(self): - self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) - else: - self.tree_ = Tree( - self.n_features_in_, - # TODO: tree shouldn't need this in this case - np.array([1] * self.n_outputs_, dtype=np.intp), - self.n_outputs_, - ) - - # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise - if max_leaf_nodes < 0: - self.builder_ = DepthFirstTreeBuilder( - splitter, - min_samples_split, - min_samples_leaf, - min_weight_leaf, - max_depth, - self.min_impurity_decrease, - self.store_leaf_values, - ) - else: - self.builder_ = BestFirstTreeBuilder( - splitter, - min_samples_split, - min_samples_leaf, - min_weight_leaf, - max_depth, - max_leaf_nodes, - self.min_impurity_decrease, - self.store_leaf_values, - ) - self.builder_.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) - - if self.n_outputs_ == 1 and is_classifier(self): - self.n_classes_ = self.n_classes_[0] - self.classes_ = self.classes_[0] - - self._prune_tree() - - def _validate_X_predict(self, X, check_input): - """Validate the training data on predict (probabilities).""" - if check_input: - if self._support_missing_values(X): - force_all_finite = "allow-nan" - else: - force_all_finite = True - X = self._validate_data( - X, - dtype=DTYPE, - accept_sparse="csr", - reset=False, - force_all_finite=force_all_finite, - ) - if issparse(X) and ( - X.indices.dtype != np.intc or X.indptr.dtype != np.intc - ): - raise ValueError("No support for np.int64 index based sparse matrices") - else: - # The number of features is checked regardless of `check_input` - self._check_n_features(X, reset=False) - return X - - def predict(self, X, check_input=True): - """Predict class or regression value for X. - - For a classification model, the predicted class for each sample in X is - returned. For a regression model, the predicted value based on X is - returned. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csr_matrix``. - - check_input : bool, default=True - Allow to bypass several input checking. - Don't use this parameter unless you know what you're doing. - - Returns - ------- - y : array-like of shape (n_samples,) or (n_samples, n_outputs) - The predicted classes, or the predict values. - """ - check_is_fitted(self) - X = self._validate_X_predict(X, check_input) - - # proba is a count matrix of leaves that fall into - # (n_samples, n_outputs, max_n_classes) array - proba = self.tree_.predict(X) - n_samples = X.shape[0] - - # Classification - if is_classifier(self): - if self.n_outputs_ == 1: - return self.classes_.take(np.argmax(proba, axis=1), axis=0) - - else: - class_type = self.classes_[0].dtype - predictions = np.zeros((n_samples, self.n_outputs_), dtype=class_type) - for k in range(self.n_outputs_): - predictions[:, k] = self.classes_[k].take( - np.argmax(proba[:, k], axis=1), axis=0 - ) - - return predictions - - # Regression - else: - if self.n_outputs_ == 1: - return proba[:, 0] - - else: - return proba[:, :, 0] - - def get_leaf_node_samples(self, X, check_input=True): - """For each datapoint x in X, get the training samples in the leaf node. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Dataset to apply the forest to. - check_input : bool, default=True - Allow to bypass several input checking. - - Returns - ------- - leaf_nodes_samples : a list of array-like of length (n_samples,) - Each sample is represented by the indices of the training samples that - reached the leaf node. The ``n_leaf_node_samples`` may vary between - samples, since the number of samples that fall in a leaf node is - variable. Each array has shape (n_leaf_node_samples, n_outputs). - """ - if not self.store_leaf_values: - raise RuntimeError( - "leaf node samples are not stored when store_leaf_values=False" - ) - - # get indices of leaves per sample (n_samples,) - X_leaves = self.apply(X, check_input=check_input) - n_samples = X_leaves.shape[0] - - # get array of samples per leaf (n_node_samples, n_outputs) - leaf_samples = self.tree_.leaf_nodes_samples - - leaf_nodes_samples = [] - for idx in range(n_samples): - leaf_id = X_leaves[idx] - leaf_nodes_samples.append(leaf_samples[leaf_id]) - return leaf_nodes_samples - - def predict_quantiles(self, X, quantiles=0.5, method="nearest", check_input=True): - """Predict class or regression value for X at given quantiles. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Input data. - quantiles : float, optional - The quantiles at which to evaluate, by default 0.5 (median). - method : str, optional - The method to interpolate, by default 'linear'. Can be any keyword - argument accepted by :func:`~np.quantile`. - check_input : bool, optional - Whether or not to check input, by default True. - - Returns - ------- - predictions : array-like of shape (n_samples, n_outputs, len(quantiles)) - The predicted quantiles. - """ - if not self.store_leaf_values: - raise RuntimeError( - "Predicting quantiles requires that the tree stores leaf node samples." - ) - - check_is_fitted(self) - - # Check data - X = self._validate_X_predict(X, check_input) - - if not isinstance(quantiles, (np.ndarray, list)): - quantiles = np.array([quantiles]) - - # get indices of leaves per sample - X_leaves = self.apply(X) - - # get array of samples per leaf (n_node_samples, n_outputs) - leaf_samples = self.tree_.leaf_nodes_samples - - # compute quantiles (n_samples, n_quantiles, n_outputs) - n_samples = X.shape[0] - n_quantiles = len(quantiles) - proba = np.zeros((n_samples, n_quantiles, self.n_outputs_)) - for idx, leaf_id in enumerate(X_leaves): - # predict by taking the quantile across the samples in the leaf for - # each output - try: - proba[idx, ...] = np.quantile( - leaf_samples[leaf_id], quantiles, axis=0, method=method - ) - except TypeError: - proba[idx, ...] = np.quantile( - leaf_samples[leaf_id], quantiles, axis=0, interpolation=method - ) - - # Classification - if is_classifier(self): - if self.n_outputs_ == 1: - # return the class with the highest probability for each quantile - # (n_samples, n_quantiles) - class_preds = np.zeros( - (n_samples, n_quantiles), dtype=self.classes_.dtype - ) - for i in range(n_quantiles): - class_pred_per_sample = ( - proba[:, i, :].squeeze().astype(self.classes_.dtype) - ) - class_preds[:, i] = self.classes_.take( - class_pred_per_sample, axis=0 - ) - return class_preds - else: - class_type = self.classes_[0].dtype - predictions = np.zeros( - (n_samples, n_quantiles, self.n_outputs_), dtype=class_type - ) - for k in range(self.n_outputs_): - for i in range(n_quantiles): - class_pred_per_sample = proba[:, i, k].squeeze().astype(int) - predictions[:, i, k] = self.classes_[k].take( - class_pred_per_sample, axis=0 - ) - - return predictions - # Regression - else: - if self.n_outputs_ == 1: - return proba[:, :, 0] - - else: - return proba - - def apply(self, X, check_input=True): - """Return the index of the leaf that each sample is predicted as. - - .. versionadded:: 0.17 - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csr_matrix``. - - check_input : bool, default=True - Allow to bypass several input checking. - Don't use this parameter unless you know what you're doing. - - Returns - ------- - X_leaves : array-like of shape (n_samples,) - For each datapoint x in X, return the index of the leaf x - ends up in. Leaves are numbered within - ``[0; self.tree_.node_count)``, possibly with gaps in the - numbering. - """ - check_is_fitted(self) - X = self._validate_X_predict(X, check_input) - return self.tree_.apply(X) - - def decision_path(self, X, check_input=True): - """Return the decision path in the tree. - - .. versionadded:: 0.18 - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csr_matrix``. - - check_input : bool, default=True - Allow to bypass several input checking. - Don't use this parameter unless you know what you're doing. - - Returns - ------- - indicator : sparse matrix of shape (n_samples, n_nodes) - Return a node indicator CSR matrix where non zero elements - indicates that the samples goes through the nodes. - """ - X = self._validate_X_predict(X, check_input) - return self.tree_.decision_path(X) - - def _prune_tree(self): - """Prune tree using Minimal Cost-Complexity Pruning.""" - check_is_fitted(self) - - if self.ccp_alpha == 0.0: - return - - # build pruned tree - if is_classifier(self): - n_classes = np.atleast_1d(self.n_classes_) - pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_) - else: - pruned_tree = Tree( - self.n_features_in_, - # TODO: the tree shouldn't need this param - np.array([1] * self.n_outputs_, dtype=np.intp), - self.n_outputs_, - ) - _build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha) - - self.tree_ = pruned_tree - - def cost_complexity_pruning_path(self, X, y, sample_weight=None): - """Compute the pruning path during Minimal Cost-Complexity Pruning. - - See :ref:`minimal_cost_complexity_pruning` for details on the pruning - process. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The training input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csc_matrix``. - - y : array-like of shape (n_samples,) or (n_samples, n_outputs) - The target values (class labels) as integers or strings. - - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. If None, then samples are equally weighted. Splits - that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. Splits are also - ignored if they would result in any single class carrying a - negative weight in either child node. - - Returns - ------- - ccp_path : :class:`~sklearn.utils.Bunch` - Dictionary-like object, with the following attributes. - - ccp_alphas : ndarray - Effective alphas of subtree during pruning. - - impurities : ndarray - Sum of the impurities of the subtree leaves for the - corresponding alpha value in ``ccp_alphas``. - """ - est = clone(self).set_params(ccp_alpha=0.0) - est.fit(X, y, sample_weight=sample_weight) - return Bunch(**ccp_pruning_path(est.tree_)) - - @property - def feature_importances_(self): - """Return the feature importances. - - The importance of a feature is computed as the (normalized) total - reduction of the criterion brought by that feature. - It is also known as the Gini importance. - - Warning: impurity-based feature importances can be misleading for - high cardinality features (many unique values). See - :func:`sklearn.inspection.permutation_importance` as an alternative. - - Returns - ------- - feature_importances_ : ndarray of shape (n_features,) - Normalized total reduction of criteria by feature - (Gini importance). - """ - check_is_fitted(self) - - return self.tree_.compute_feature_importances() - - -# ============================================================================= -# Public estimators -# ============================================================================= - - -class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): - """A decision tree classifier. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - criterion : {"gini", "entropy", "log_loss"}, default="gini" - The function to measure the quality of a split. Supported criteria are - "gini" for the Gini impurity and "log_loss" and "entropy" both for the - Shannon information gain, see :ref:`tree_mathematical_formulation`. - - splitter : {"best", "random"}, default="best" - The strategy used to choose the split at each node. Supported - strategies are "best" to choose the best split and "random" to choose - the best random split. - - max_depth : int, default=None - The maximum depth of the tree. If None, then nodes are expanded until - all leaves are pure or until all leaves contain less than - min_samples_split samples. - - min_samples_split : int or float, default=2 - The minimum number of samples required to split an internal node: - - - If int, then consider `min_samples_split` as the minimum number. - - If float, then `min_samples_split` is a fraction and - `ceil(min_samples_split * n_samples)` are the minimum - number of samples for each split. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_samples_leaf : int or float, default=1 - The minimum number of samples required to be at a leaf node. - A split point at any depth will only be considered if it leaves at - least ``min_samples_leaf`` training samples in each of the left and - right branches. This may have the effect of smoothing the model, - especially in regression. - - - If int, then consider `min_samples_leaf` as the minimum number. - - If float, then `min_samples_leaf` is a fraction and - `ceil(min_samples_leaf * n_samples)` are the minimum - number of samples for each node. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_weight_fraction_leaf : float, default=0.0 - The minimum weighted fraction of the sum total of weights (of all - the input samples) required to be at a leaf node. Samples have - equal weight when sample_weight is not provided. - - max_features : int, float or {"auto", "sqrt", "log2"}, default=None - The number of features to consider when looking for the best split: - - - If int, then consider `max_features` features at each split. - - If float, then `max_features` is a fraction and - `max(1, int(max_features * n_features_in_))` features are considered at - each split. - - If "sqrt", then `max_features=sqrt(n_features)`. - - If "log2", then `max_features=log2(n_features)`. - - If None, then `max_features=n_features`. - - Note: the search for a split does not stop until at least one - valid partition of the node samples is found, even if it requires to - effectively inspect more than ``max_features`` features. - - random_state : int, RandomState instance or None, default=None - Controls the randomness of the estimator. The features are always - randomly permuted at each split, even if ``splitter`` is set to - ``"best"``. When ``max_features < n_features``, the algorithm will - select ``max_features`` at random at each split before finding the best - split among them. But the best found split may vary across different - runs, even if ``max_features=n_features``. That is the case, if the - improvement of the criterion is identical for several splits and one - split has to be selected at random. To obtain a deterministic behaviour - during fitting, ``random_state`` has to be fixed to an integer. - See :term:`Glossary ` for details. - - max_leaf_nodes : int, default=None - Grow a tree with ``max_leaf_nodes`` in best-first fashion. - Best nodes are defined as relative reduction in impurity. - If None then unlimited number of leaf nodes. - - min_impurity_decrease : float, default=0.0 - A node will be split if this split induces a decrease of the impurity - greater than or equal to this value. - - The weighted impurity decrease equation is the following:: - - N_t / N * (impurity - N_t_R / N_t * right_impurity - - N_t_L / N_t * left_impurity) - - where ``N`` is the total number of samples, ``N_t`` is the number of - samples at the current node, ``N_t_L`` is the number of samples in the - left child, and ``N_t_R`` is the number of samples in the right child. - - ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, - if ``sample_weight`` is passed. - - .. versionadded:: 0.19 - - class_weight : dict, list of dict or "balanced", default=None - Weights associated with classes in the form ``{class_label: weight}``. - If None, all classes are supposed to have weight one. For - multi-output problems, a list of dicts can be provided in the same - order as the columns of y. - - Note that for multioutput (including multilabel) weights should be - defined for each class of every column in its own dict. For example, - for four-class multilabel classification weights should be - [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of - [{1:1}, {2:5}, {3:1}, {4:1}]. - - The "balanced" mode uses the values of y to automatically adjust - weights inversely proportional to class frequencies in the input data - as ``n_samples / (n_classes * np.bincount(y))`` - - For multi-output, the weights of each column of y will be multiplied. - - Note that these weights will be multiplied with sample_weight (passed - through the fit method) if sample_weight is specified. - - ccp_alpha : non-negative float, default=0.0 - Complexity parameter used for Minimal Cost-Complexity Pruning. The - subtree with the largest cost complexity that is smaller than - ``ccp_alpha`` will be chosen. By default, no pruning is performed. See - :ref:`minimal_cost_complexity_pruning` for details. - - .. versionadded:: 0.22 - - store_leaf_values : bool, default=False - Whether to store the samples that fall into leaves in the ``tree_`` attribute. - Each leaf will store a 2D array corresponding to the samples that fall into it - keyed by node_id. - - XXX: This is currently experimental and may change without notice. - Moreover, it can be improved upon since storing the samples twice is not ideal. - One could instead store the indices in ``y_train`` that fall into each leaf, - which would lower RAM/diskspace usage. - - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonicity constraint to enforce on each feature. - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease - - If monotonic_cst is None, no constraints are applied. - - Monotonicity constraints are not supported for: - - multiclass classifications (i.e. when `n_classes > 2`), - - multioutput classifications (i.e. when `n_outputs_ > 1`), - - classifications trained on data with missing values. - - The constraints hold over the probability of the positive class. - - Read more in the :ref:`User Guide `. - - .. versionadded:: 1.4 - - Attributes - ---------- - classes_ : ndarray of shape (n_classes,) or list of ndarray - The classes labels (single output problem), - or a list of arrays of class labels (multi-output problem). - - feature_importances_ : ndarray of shape (n_features,) - The impurity-based feature importances. - The higher, the more important the feature. - The importance of a feature is computed as the (normalized) - total reduction of the criterion brought by that feature. It is also - known as the Gini importance [4]_. - - Warning: impurity-based feature importances can be misleading for - high cardinality features (many unique values). See - :func:`sklearn.inspection.permutation_importance` as an alternative. - - max_features_ : int - The inferred value of max_features. - - n_classes_ : int or list of int - The number of classes (for single output problems), - or a list containing the number of classes for each - output (for multi-output problems). - - n_features_in_ : int - Number of features seen during :term:`fit`. - - .. versionadded:: 0.24 - - feature_names_in_ : ndarray of shape (`n_features_in_`,) - Names of features seen during :term:`fit`. Defined only when `X` - has feature names that are all strings. - - .. versionadded:: 1.0 - - n_outputs_ : int - The number of outputs when ``fit`` is performed. - - tree_ : Tree instance - The underlying Tree object. Please refer to - ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object and - :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` - for basic usage of these attributes. - - builder_ : TreeBuilder instance - The underlying TreeBuilder object. - - See Also - -------- - DecisionTreeRegressor : A decision tree regressor. - - Notes - ----- - The default values for the parameters controlling the size of the trees - (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and - unpruned trees which can potentially be very large on some data sets. To - reduce memory consumption, the complexity and size of the trees should be - controlled by setting those parameter values. - - The :meth:`predict` method operates using the :func:`numpy.argmax` - function on the outputs of :meth:`predict_proba`. This means that in - case the highest predicted probabilities are tied, the classifier will - predict the tied class with the lowest index in :term:`classes_`. - - References - ---------- - - .. [1] https://en.wikipedia.org/wiki/Decision_tree_learning - - .. [2] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification - and Regression Trees", Wadsworth, Belmont, CA, 1984. - - .. [3] T. Hastie, R. Tibshirani and J. Friedman. "Elements of Statistical - Learning", Springer, 2009. - - .. [4] L. Breiman, and A. Cutler, "Random Forests", - https://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm - - Examples - -------- - >>> from sklearn.datasets import load_iris - >>> from sklearn.model_selection import cross_val_score - >>> from sklearn.tree import DecisionTreeClassifier - >>> clf = DecisionTreeClassifier(random_state=0) - >>> iris = load_iris() - >>> cross_val_score(clf, iris.data, iris.target, cv=10) - ... # doctest: +SKIP - ... - array([ 1. , 0.93..., 0.86..., 0.93..., 0.93..., - 0.93..., 0.93..., 1. , 0.93..., 1. ]) - """ - - _parameter_constraints: dict = { - **BaseDecisionTree._parameter_constraints, - "criterion": [ - StrOptions({"gini", "entropy", "log_loss"}), - Hidden(BaseCriterion), - ], - "class_weight": [dict, list, StrOptions({"balanced"}), None], - } - - def __init__( + SIZE_t parent, + bint is_left, + bint is_leaf, + SplitRecord* split_node, + double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples, + unsigned char missing_go_to_left + ) except -1 nogil + + # Python API methods: These are methods exposed to Python + cpdef cnp.ndarray apply(self, object X) + cdef cnp.ndarray _apply_dense(self, object X) + cdef cnp.ndarray _apply_sparse_csr(self, object X) + + cpdef object decision_path(self, object X) + cdef object _decision_path_dense(self, object X) + cdef object _decision_path_sparse_csr(self, object X) + + cpdef compute_node_depths(self) + cpdef compute_feature_importances(self, normalize=*) + + # Abstract methods: these functions must be implemented by any decision tree + cdef int _set_split_node( self, - *, - criterion="gini", - splitter="best", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0.0, - max_features=None, - random_state=None, - max_leaf_nodes=None, - min_impurity_decrease=0.0, - class_weight=None, - ccp_alpha=0.0, - store_leaf_values=False, - monotonic_cst=None, - ): - super().__init__( - criterion=criterion, - splitter=splitter, - max_depth=max_depth, - min_samples_split=min_samples_split, - min_samples_leaf=min_samples_leaf, - min_weight_fraction_leaf=min_weight_fraction_leaf, - max_features=max_features, - max_leaf_nodes=max_leaf_nodes, - class_weight=class_weight, - random_state=random_state, - min_impurity_decrease=min_impurity_decrease, - monotonic_cst=monotonic_cst, - ccp_alpha=ccp_alpha, - store_leaf_values=store_leaf_values, - ) - - @_fit_context(prefer_skip_nested_validation=True) - def fit( + SplitRecord* split_node, + Node* node + ) except -1 nogil + cdef int _set_leaf_node( self, - X, - y, - sample_weight=None, - check_input=True, - classes=None, - ): - """Build a decision tree classifier from the training set (X, y). - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The training input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csc_matrix``. - - y : array-like of shape (n_samples,) or (n_samples, n_outputs) - The target values (class labels) as integers or strings. - - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. If None, then samples are equally weighted. Splits - that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. Splits are also - ignored if they would result in any single class carrying a - negative weight in either child node. - - check_input : bool, default=True - Allow to bypass several input checking. - Don't use this parameter unless you know what you're doing. - - classes : array-like of shape (n_classes,), default=None - List of all the classes that can possibly appear in the y vector. - Must be provided at the first call to partial_fit, can be omitted - in subsequent calls. - - Returns - ------- - self : DecisionTreeClassifier - Fitted estimator. - """ - super()._fit( - X, - y, - sample_weight=sample_weight, - check_input=check_input, - classes=classes, - ) - return self - - def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): - """Update a decision tree classifier from the training set (X, y). - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The training input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csc_matrix``. - - y : array-like of shape (n_samples,) or (n_samples, n_outputs) - The target values (class labels) as integers or strings. - - classes : array-like of shape (n_classes,), default=None - List of all the classes that can possibly appear in the y vector. - Must be provided at the first call to partial_fit, can be omitted - in subsequent calls. - - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. If None, then samples are equally weighted. Splits - that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. Splits are also - ignored if they would result in any single class carrying a - negative weight in either child node. - - check_input : bool, default=True - Allow to bypass several input checking. - Don't use this parameter unless you know what you do. - - Returns - ------- - self : DecisionTreeClassifier - Fitted estimator. - """ - - first_call = _check_partial_fit_first_call(self, classes=classes) - - # Fit if no tree exists yet - if first_call: - self.fit( - X, - y, - sample_weight=sample_weight, - check_input=check_input, - classes=classes, - ) - return self - - if check_input: - # Need to validate separately here. - # We can't pass multi_ouput=True because that would allow y to be - # csr. - check_X_params = dict(dtype=DTYPE, accept_sparse="csc") - check_y_params = dict(ensure_2d=False, dtype=None) - X, y = self._validate_data( - X, y, reset=False, validate_separately=(check_X_params, check_y_params) - ) - if issparse(X): - X.sort_indices() - - if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: - raise ValueError( - "No support for np.int64 index based sparse matrices" - ) - - if X.shape[1] != self.n_features_in_: - msg = "Number of features %d does not match previous data %d." - raise ValueError(msg % (X.shape[1], self.n_features_in_)) - - y = np.atleast_1d(y) - - if y.ndim == 1: - # reshape is necessary to preserve the data contiguity against vs - # [:, np.newaxis] that does not. - y = np.reshape(y, (-1, 1)) - - check_classification_targets(y) - y = np.copy(y) - - classes = self.classes_ - if self.n_outputs_ == 1: - classes = [classes] - - y_encoded = np.zeros(y.shape, dtype=int) - for i in range(X.shape[0]): - for j in range(self.n_outputs_): - y_encoded[i, j] = np.where(classes[j] == y[i, j])[0][0] - y = y_encoded - - if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: - y = np.ascontiguousarray(y, dtype=DOUBLE) - - # Update tree - self.builder_.initialize_node_queue(self.tree_, X, y, sample_weight) - self.builder_.build(self.tree_, X, y, sample_weight) - - self._prune_tree() - - return self - - def predict_proba(self, X, check_input=True): - """Predict class probabilities of the input samples X. - - The predicted class probability is the fraction of samples of the same - class in a leaf. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csr_matrix``. - - check_input : bool, default=True - Allow to bypass several input checking. - Don't use this parameter unless you know what you're doing. - - Returns - ------- - proba : ndarray of shape (n_samples, n_classes) or list of n_outputs \ - such arrays if n_outputs > 1 - The class probabilities of the input samples. The order of the - classes corresponds to that in the attribute :term:`classes_`. - """ - check_is_fitted(self) - X = self._validate_X_predict(X, check_input) - proba = self.tree_.predict(X) - - if self.n_outputs_ == 1: - proba = proba[:, : self.n_classes_] - normalizer = proba.sum(axis=1)[:, np.newaxis] - normalizer[normalizer == 0.0] = 1.0 - proba /= normalizer - - return proba - - else: - all_proba = [] - - for k in range(self.n_outputs_): - proba_k = proba[:, k, : self.n_classes_[k]] - normalizer = proba_k.sum(axis=1)[:, np.newaxis] - normalizer[normalizer == 0.0] = 1.0 - proba_k /= normalizer - all_proba.append(proba_k) - - return all_proba - - def predict_log_proba(self, X): - """Predict class log-probabilities of the input samples X. - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csr_matrix``. - - Returns - ------- - proba : ndarray of shape (n_samples, n_classes) or list of n_outputs \ - such arrays if n_outputs > 1 - The class log-probabilities of the input samples. The order of the - classes corresponds to that in the attribute :term:`classes_`. - """ - proba = self.predict_proba(X) - - if self.n_outputs_ == 1: - return np.log(proba) - - else: - for k in range(self.n_outputs_): - proba[k] = np.log(proba[k]) - - return proba - - def _more_tags(self): - # XXX: nan is only support for dense arrays, but we set this for common test to - # pass, specifically: check_estimators_nan_inf - allow_nan = self.splitter == "best" and self.criterion in { - "gini", - "log_loss", - "entropy", - } - return {"multilabel": True, "allow_nan": allow_nan} - - -class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): - """A decision tree regressor. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - criterion : {"squared_error", "friedman_mse", "absolute_error", \ - "poisson"}, default="squared_error" - The function to measure the quality of a split. Supported criteria - are "squared_error" for the mean squared error, which is equal to - variance reduction as feature selection criterion and minimizes the L2 - loss using the mean of each terminal node, "friedman_mse", which uses - mean squared error with Friedman's improvement score for potential - splits, "absolute_error" for the mean absolute error, which minimizes - the L1 loss using the median of each terminal node, and "poisson" which - uses reduction in Poisson deviance to find splits. - - .. versionadded:: 0.18 - Mean Absolute Error (MAE) criterion. - - .. versionadded:: 0.24 - Poisson deviance criterion. - - splitter : {"best", "random"}, default="best" - The strategy used to choose the split at each node. Supported - strategies are "best" to choose the best split and "random" to choose - the best random split. - - max_depth : int, default=None - The maximum depth of the tree. If None, then nodes are expanded until - all leaves are pure or until all leaves contain less than - min_samples_split samples. - - min_samples_split : int or float, default=2 - The minimum number of samples required to split an internal node: - - - If int, then consider `min_samples_split` as the minimum number. - - If float, then `min_samples_split` is a fraction and - `ceil(min_samples_split * n_samples)` are the minimum - number of samples for each split. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_samples_leaf : int or float, default=1 - The minimum number of samples required to be at a leaf node. - A split point at any depth will only be considered if it leaves at - least ``min_samples_leaf`` training samples in each of the left and - right branches. This may have the effect of smoothing the model, - especially in regression. - - - If int, then consider `min_samples_leaf` as the minimum number. - - If float, then `min_samples_leaf` is a fraction and - `ceil(min_samples_leaf * n_samples)` are the minimum - number of samples for each node. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_weight_fraction_leaf : float, default=0.0 - The minimum weighted fraction of the sum total of weights (of all - the input samples) required to be at a leaf node. Samples have - equal weight when sample_weight is not provided. - - max_features : int, float or {"auto", "sqrt", "log2"}, default=None - The number of features to consider when looking for the best split: - - - If int, then consider `max_features` features at each split. - - If float, then `max_features` is a fraction and - `max(1, int(max_features * n_features_in_))` features are considered at each - split. - - If "sqrt", then `max_features=sqrt(n_features)`. - - If "log2", then `max_features=log2(n_features)`. - - If None, then `max_features=n_features`. - - Note: the search for a split does not stop until at least one - valid partition of the node samples is found, even if it requires to - effectively inspect more than ``max_features`` features. - - random_state : int, RandomState instance or None, default=None - Controls the randomness of the estimator. The features are always - randomly permuted at each split, even if ``splitter`` is set to - ``"best"``. When ``max_features < n_features``, the algorithm will - select ``max_features`` at random at each split before finding the best - split among them. But the best found split may vary across different - runs, even if ``max_features=n_features``. That is the case, if the - improvement of the criterion is identical for several splits and one - split has to be selected at random. To obtain a deterministic behaviour - during fitting, ``random_state`` has to be fixed to an integer. - See :term:`Glossary ` for details. - - max_leaf_nodes : int, default=None - Grow a tree with ``max_leaf_nodes`` in best-first fashion. - Best nodes are defined as relative reduction in impurity. - If None then unlimited number of leaf nodes. - - min_impurity_decrease : float, default=0.0 - A node will be split if this split induces a decrease of the impurity - greater than or equal to this value. - - The weighted impurity decrease equation is the following:: - - N_t / N * (impurity - N_t_R / N_t * right_impurity - - N_t_L / N_t * left_impurity) - - where ``N`` is the total number of samples, ``N_t`` is the number of - samples at the current node, ``N_t_L`` is the number of samples in the - left child, and ``N_t_R`` is the number of samples in the right child. - - ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, - if ``sample_weight`` is passed. - - .. versionadded:: 0.19 - - ccp_alpha : non-negative float, default=0.0 - Complexity parameter used for Minimal Cost-Complexity Pruning. The - subtree with the largest cost complexity that is smaller than - ``ccp_alpha`` will be chosen. By default, no pruning is performed. See - :ref:`minimal_cost_complexity_pruning` for details. - - .. versionadded:: 0.22 - - store_leaf_values : bool, default=False - Whether to store the samples that fall into leaves in the ``tree_`` attribute. - Each leaf will store a 2D array corresponding to the samples that fall into it - keyed by node_id. - - XXX: This is currently experimental and may change without notice. - Moreover, it can be improved upon since storing the samples twice is not ideal. - One could instead store the indices in ``y_train`` that fall into each leaf, - which would lower RAM/diskspace usage. - - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonicity constraint to enforce on each feature. - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease - - If monotonic_cst is None, no constraints are applied. - - Monotonicity constraints are not supported for: - - multioutput regressions (i.e. when `n_outputs_ > 1`), - - regressions trained on data with missing values. - - Read more in the :ref:`User Guide `. - - .. versionadded:: 1.4 - - Attributes - ---------- - feature_importances_ : ndarray of shape (n_features,) - The feature importances. - The higher, the more important the feature. - The importance of a feature is computed as the - (normalized) total reduction of the criterion brought - by that feature. It is also known as the Gini importance [4]_. - - Warning: impurity-based feature importances can be misleading for - high cardinality features (many unique values). See - :func:`sklearn.inspection.permutation_importance` as an alternative. - - max_features_ : int - The inferred value of max_features. - - n_features_in_ : int - Number of features seen during :term:`fit`. - - .. versionadded:: 0.24 - - feature_names_in_ : ndarray of shape (`n_features_in_`,) - Names of features seen during :term:`fit`. Defined only when `X` - has feature names that are all strings. - - .. versionadded:: 1.0 - - n_outputs_ : int - The number of outputs when ``fit`` is performed. - - tree_ : Tree instance - The underlying Tree object. Please refer to - ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object and - :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` - for basic usage of these attributes. - - builder_ : TreeBuilder instance - The underlying TreeBuilder object. - - See Also - -------- - DecisionTreeClassifier : A decision tree classifier. - - Notes - ----- - The default values for the parameters controlling the size of the trees - (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and - unpruned trees which can potentially be very large on some data sets. To - reduce memory consumption, the complexity and size of the trees should be - controlled by setting those parameter values. - - References - ---------- - - .. [1] https://en.wikipedia.org/wiki/Decision_tree_learning - - .. [2] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification - and Regression Trees", Wadsworth, Belmont, CA, 1984. - - .. [3] T. Hastie, R. Tibshirani and J. Friedman. "Elements of Statistical - Learning", Springer, 2009. - - .. [4] L. Breiman, and A. Cutler, "Random Forests", - https://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm - - Examples - -------- - >>> from sklearn.datasets import load_diabetes - >>> from sklearn.model_selection import cross_val_score - >>> from sklearn.tree import DecisionTreeRegressor - >>> X, y = load_diabetes(return_X_y=True) - >>> regressor = DecisionTreeRegressor(random_state=0) - >>> cross_val_score(regressor, X, y, cv=10) - ... # doctest: +SKIP - ... - array([-0.39..., -0.46..., 0.02..., 0.06..., -0.50..., - 0.16..., 0.11..., -0.73..., -0.30..., -0.00...]) - """ - - _parameter_constraints: dict = { - **BaseDecisionTree._parameter_constraints, - "criterion": [ - StrOptions({"squared_error", "friedman_mse", "absolute_error", "poisson"}), - Hidden(BaseCriterion), - ], - } - - def __init__( + SplitRecord* split_node, + Node* node + ) except -1 nogil + cdef DTYPE_t _compute_feature( self, - *, - criterion="squared_error", - splitter="best", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0.0, - max_features=None, - random_state=None, - max_leaf_nodes=None, - min_impurity_decrease=0.0, - ccp_alpha=0.0, - store_leaf_values=False, - monotonic_cst=None, - ): - super().__init__( - criterion=criterion, - splitter=splitter, - max_depth=max_depth, - min_samples_split=min_samples_split, - min_samples_leaf=min_samples_leaf, - min_weight_fraction_leaf=min_weight_fraction_leaf, - max_features=max_features, - max_leaf_nodes=max_leaf_nodes, - random_state=random_state, - min_impurity_decrease=min_impurity_decrease, - ccp_alpha=ccp_alpha, - store_leaf_values=store_leaf_values, - monotonic_cst=monotonic_cst, - ) - - @_fit_context(prefer_skip_nested_validation=True) - def fit( + const DTYPE_t[:, :] X_ndarray, + SIZE_t sample_index, + Node *node + ) noexcept nogil + cdef void _compute_feature_importances( self, - X, - y, - sample_weight=None, - check_input=True, - classes=None, - ): - """Build a decision tree regressor from the training set (X, y). - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The training input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csc_matrix``. - - y : array-like of shape (n_samples,) or (n_samples, n_outputs) - The target values (real numbers). Use ``dtype=np.float64`` and - ``order='C'`` for maximum efficiency. - - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. If None, then samples are equally weighted. Splits - that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. - - check_input : bool, default=True - Allow to bypass several input checking. - Don't use this parameter unless you know what you're doing. - - classes : array-like of shape (n_classes,), default=None - List of all the classes that can possibly appear in the y vector. - - Returns - ------- - self : DecisionTreeRegressor - Fitted estimator. - """ - - super()._fit( - X, - y, - sample_weight=sample_weight, - check_input=check_input, - classes=classes, - ) - return self - - def _compute_partial_dependence_recursion(self, grid, target_features): - """Fast partial dependence computation. - - Parameters - ---------- - grid : ndarray of shape (n_samples, n_target_features) - The grid points on which the partial dependence should be - evaluated. - target_features : ndarray of shape (n_target_features) - The set of target features for which the partial dependence - should be evaluated. - - Returns - ------- - averaged_predictions : ndarray of shape (n_samples,) - The value of the partial dependence function on each grid point. - """ - grid = np.asarray(grid, dtype=DTYPE, order="C") - averaged_predictions = np.zeros( - shape=grid.shape[0], dtype=np.float64, order="C" - ) - - self.tree_.compute_partial_dependence( - grid, target_features, averaged_predictions - ) - return averaged_predictions - - def _more_tags(self): - # XXX: nan is only support for dense arrays, but we set this for common test to - # pass, specifically: check_estimators_nan_inf - allow_nan = self.splitter == "best" and self.criterion in { - "squared_error", - "friedman_mse", - "poisson", - } - return {"allow_nan": allow_nan} - - -class ExtraTreeClassifier(DecisionTreeClassifier): - """An extremely randomized tree classifier. - - Extra-trees differ from classic decision trees in the way they are built. - When looking for the best split to separate the samples of a node into two - groups, random splits are drawn for each of the `max_features` randomly - selected features and the best split among those is chosen. When - `max_features` is set 1, this amounts to building a totally random - decision tree. - - Warning: Extra-trees should only be used within ensemble methods. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - criterion : {"gini", "entropy", "log_loss"}, default="gini" - The function to measure the quality of a split. Supported criteria are - "gini" for the Gini impurity and "log_loss" and "entropy" both for the - Shannon information gain, see :ref:`tree_mathematical_formulation`. - - splitter : {"random", "best"}, default="random" - The strategy used to choose the split at each node. Supported - strategies are "best" to choose the best split and "random" to choose - the best random split. - - max_depth : int, default=None - The maximum depth of the tree. If None, then nodes are expanded until - all leaves are pure or until all leaves contain less than - min_samples_split samples. - - min_samples_split : int or float, default=2 - The minimum number of samples required to split an internal node: + cnp.float64_t[:] importances, + Node* node, + ) noexcept nogil + +cdef class Tree(BaseTree): + # The Supervised Tree object is a binary tree structure constructed by the + # TreeBuilder. The tree structure is used for predictions and + # feature importances. + # + # Value of upstream properties: + # - value_stride = n_outputs * max_n_classes + # - value = (capacity, n_outputs, max_n_classes) array of values + + # Input/Output layout for supervised tree + cdef public SIZE_t n_features # Number of features in X + cdef SIZE_t* n_classes # Number of classes in y[:, k] + cdef public SIZE_t n_outputs # Number of outputs in y + cdef public SIZE_t max_n_classes # max(n_classes) + + # Enables the use of tree to store distributions of the output to allow + # arbitrary usage of the the leaves. This is used in the quantile + # estimators for example. + # for storing samples at each leaf node with leaf's node ID as the key and + # the sample values as the value + cdef unordered_map[SIZE_t, vector[vector[DOUBLE_t]]] value_samples + + # Methods + cdef cnp.ndarray _get_value_ndarray(self) + cdef cnp.ndarray _get_node_ndarray(self) + cdef cnp.ndarray _get_value_samples_ndarray(self, SIZE_t node_id) + cdef cnp.ndarray _get_value_samples_keys(self) + + cpdef cnp.ndarray predict(self, object X) - - If int, then consider `min_samples_split` as the minimum number. - - If float, then `min_samples_split` is a fraction and - `ceil(min_samples_split * n_samples)` are the minimum - number of samples for each split. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_samples_leaf : int or float, default=1 - The minimum number of samples required to be at a leaf node. - A split point at any depth will only be considered if it leaves at - least ``min_samples_leaf`` training samples in each of the left and - right branches. This may have the effect of smoothing the model, - especially in regression. - - - If int, then consider `min_samples_leaf` as the minimum number. - - If float, then `min_samples_leaf` is a fraction and - `ceil(min_samples_leaf * n_samples)` are the minimum - number of samples for each node. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_weight_fraction_leaf : float, default=0.0 - The minimum weighted fraction of the sum total of weights (of all - the input samples) required to be at a leaf node. Samples have - equal weight when sample_weight is not provided. - - max_features : int, float, {"auto", "sqrt", "log2"} or None, default="sqrt" - The number of features to consider when looking for the best split: - - - If int, then consider `max_features` features at each split. - - If float, then `max_features` is a fraction and - `max(1, int(max_features * n_features_in_))` features are considered at - each split. - - If "sqrt", then `max_features=sqrt(n_features)`. - - If "log2", then `max_features=log2(n_features)`. - - If None, then `max_features=n_features`. - - .. versionchanged:: 1.1 - The default of `max_features` changed from `"auto"` to `"sqrt"`. - - Note: the search for a split does not stop until at least one - valid partition of the node samples is found, even if it requires to - effectively inspect more than ``max_features`` features. - - random_state : int, RandomState instance or None, default=None - Used to pick randomly the `max_features` used at each split. - See :term:`Glossary ` for details. - - max_leaf_nodes : int, default=None - Grow a tree with ``max_leaf_nodes`` in best-first fashion. - Best nodes are defined as relative reduction in impurity. - If None then unlimited number of leaf nodes. - - min_impurity_decrease : float, default=0.0 - A node will be split if this split induces a decrease of the impurity - greater than or equal to this value. - - The weighted impurity decrease equation is the following:: - - N_t / N * (impurity - N_t_R / N_t * right_impurity - - N_t_L / N_t * left_impurity) - - where ``N`` is the total number of samples, ``N_t`` is the number of - samples at the current node, ``N_t_L`` is the number of samples in the - left child, and ``N_t_R`` is the number of samples in the right child. - - ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, - if ``sample_weight`` is passed. - - .. versionadded:: 0.19 - - class_weight : dict, list of dict or "balanced", default=None - Weights associated with classes in the form ``{class_label: weight}``. - If None, all classes are supposed to have weight one. For - multi-output problems, a list of dicts can be provided in the same - order as the columns of y. - - Note that for multioutput (including multilabel) weights should be - defined for each class of every column in its own dict. For example, - for four-class multilabel classification weights should be - [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of - [{1:1}, {2:5}, {3:1}, {4:1}]. - - The "balanced" mode uses the values of y to automatically adjust - weights inversely proportional to class frequencies in the input data - as ``n_samples / (n_classes * np.bincount(y))`` - - For multi-output, the weights of each column of y will be multiplied. - - Note that these weights will be multiplied with sample_weight (passed - through the fit method) if sample_weight is specified. - - ccp_alpha : non-negative float, default=0.0 - Complexity parameter used for Minimal Cost-Complexity Pruning. The - subtree with the largest cost complexity that is smaller than - ``ccp_alpha`` will be chosen. By default, no pruning is performed. See - :ref:`minimal_cost_complexity_pruning` for details. - - .. versionadded:: 0.22 - - store_leaf_values : bool, default=False - Whether to store the samples that fall into leaves in the ``tree_`` attribute. - Each leaf will store a 2D array corresponding to the samples that fall into it - keyed by node_id. - - XXX: This is currently experimental and may change without notice. - Moreover, it can be improved upon since storing the samples twice is not ideal. - One could instead store the indices in ``y_train`` that fall into each leaf, - which would lower RAM/diskspace usage. - - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonicity constraint to enforce on each feature. - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease - - If monotonic_cst is None, no constraints are applied. - - Monotonicity constraints are not supported for: - - multiclass classifications (i.e. when `n_classes > 2`), - - multioutput classifications (i.e. when `n_outputs_ > 1`), - - classifications trained on data with missing values. - - The constraints hold over the probability of the positive class. - - Read more in the :ref:`User Guide `. - - .. versionadded:: 1.4 - - Attributes - ---------- - classes_ : ndarray of shape (n_classes,) or list of ndarray - The classes labels (single output problem), - or a list of arrays of class labels (multi-output problem). - - max_features_ : int - The inferred value of max_features. - - n_classes_ : int or list of int - The number of classes (for single output problems), - or a list containing the number of classes for each - output (for multi-output problems). - - feature_importances_ : ndarray of shape (n_features,) - The impurity-based feature importances. - The higher, the more important the feature. - The importance of a feature is computed as the (normalized) - total reduction of the criterion brought by that feature. It is also - known as the Gini importance. - - Warning: impurity-based feature importances can be misleading for - high cardinality features (many unique values). See - :func:`sklearn.inspection.permutation_importance` as an alternative. - - n_features_in_ : int - Number of features seen during :term:`fit`. - - .. versionadded:: 0.24 - - feature_names_in_ : ndarray of shape (`n_features_in_`,) - Names of features seen during :term:`fit`. Defined only when `X` - has feature names that are all strings. - - .. versionadded:: 1.0 - - n_outputs_ : int - The number of outputs when ``fit`` is performed. - - tree_ : Tree instance - The underlying Tree object. Please refer to - ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object and - :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` - for basic usage of these attributes. - - builder_ : TreeBuilder instance - The underlying TreeBuilder object. - - See Also - -------- - ExtraTreeRegressor : An extremely randomized tree regressor. - sklearn.ensemble.ExtraTreesClassifier : An extra-trees classifier. - sklearn.ensemble.ExtraTreesRegressor : An extra-trees regressor. - sklearn.ensemble.RandomForestClassifier : A random forest classifier. - sklearn.ensemble.RandomForestRegressor : A random forest regressor. - sklearn.ensemble.RandomTreesEmbedding : An ensemble of - totally random trees. - - Notes - ----- - The default values for the parameters controlling the size of the trees - (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and - unpruned trees which can potentially be very large on some data sets. To - reduce memory consumption, the complexity and size of the trees should be - controlled by setting those parameter values. - - References - ---------- - - .. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized trees", - Machine Learning, 63(1), 3-42, 2006. - - Examples - -------- - >>> from sklearn.datasets import load_iris - >>> from sklearn.model_selection import train_test_split - >>> from sklearn.ensemble import BaggingClassifier - >>> from sklearn.tree import ExtraTreeClassifier - >>> X, y = load_iris(return_X_y=True) - >>> X_train, X_test, y_train, y_test = train_test_split( - ... X, y, random_state=0) - >>> extra_tree = ExtraTreeClassifier(random_state=0) - >>> cls = BaggingClassifier(extra_tree, random_state=0).fit( - ... X_train, y_train) - >>> cls.score(X_test, y_test) - 0.8947... - """ +# ============================================================================= +# Tree builder +# ============================================================================= - def __init__( +cdef class TreeBuilder: + # The TreeBuilder recursively builds a Tree object from training samples, + # using a Splitter object for splitting internal nodes and assigning + # values to leaves. + # + # This class controls the various stopping criteria and the node splitting + # evaluation order, e.g. depth-first or best-first. + + cdef Splitter splitter # Splitting algorithm + + cdef SIZE_t min_samples_split # Minimum number of samples in an internal node + cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf + cdef double min_weight_leaf # Minimum weight in a leaf + cdef SIZE_t max_depth # Maximal tree depth + cdef double min_impurity_decrease # Impurity threshold for early stopping + cdef object initial_roots # Leaf nodes for streaming updates + + cdef unsigned char store_leaf_values # Whether to store leaf values + + cpdef initialize_node_queue( + self, + Tree tree, + object X, + const DOUBLE_t[:, ::1] y, + const DOUBLE_t[:] sample_weight=*, + const unsigned char[::1] missing_values_in_feature_mask=*, + ) + + cdef unsigned char store_leaf_values # Whether to store leaf values + + cpdef build( self, - *, - criterion="gini", - splitter="random", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0.0, - max_features="sqrt", - random_state=None, - max_leaf_nodes=None, - min_impurity_decrease=0.0, - class_weight=None, - ccp_alpha=0.0, - store_leaf_values=False, - monotonic_cst=None, - ): - super().__init__( - criterion=criterion, - splitter=splitter, - max_depth=max_depth, - min_samples_split=min_samples_split, - min_samples_leaf=min_samples_leaf, - min_weight_fraction_leaf=min_weight_fraction_leaf, - max_features=max_features, - max_leaf_nodes=max_leaf_nodes, - class_weight=class_weight, - min_impurity_decrease=min_impurity_decrease, - random_state=random_state, - ccp_alpha=ccp_alpha, - store_leaf_values=store_leaf_values, - monotonic_cst=monotonic_cst, - ) - - -class ExtraTreeRegressor(DecisionTreeRegressor): - """An extremely randomized tree regressor. - - Extra-trees differ from classic decision trees in the way they are built. - When looking for the best split to separate the samples of a node into two - groups, random splits are drawn for each of the `max_features` randomly - selected features and the best split among those is chosen. When - `max_features` is set 1, this amounts to building a totally random - decision tree. - - Warning: Extra-trees should only be used within ensemble methods. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - criterion : {"squared_error", "friedman_mse", "absolute_error", "poisson"}, \ - default="squared_error" - The function to measure the quality of a split. Supported criteria - are "squared_error" for the mean squared error, which is equal to - variance reduction as feature selection criterion and minimizes the L2 - loss using the mean of each terminal node, "friedman_mse", which uses - mean squared error with Friedman's improvement score for potential - splits, "absolute_error" for the mean absolute error, which minimizes - the L1 loss using the median of each terminal node, and "poisson" which - uses reduction in Poisson deviance to find splits. - - .. versionadded:: 0.18 - Mean Absolute Error (MAE) criterion. - - .. versionadded:: 0.24 - Poisson deviance criterion. - - splitter : {"random", "best"}, default="random" - The strategy used to choose the split at each node. Supported - strategies are "best" to choose the best split and "random" to choose - the best random split. - - max_depth : int, default=None - The maximum depth of the tree. If None, then nodes are expanded until - all leaves are pure or until all leaves contain less than - min_samples_split samples. - - min_samples_split : int or float, default=2 - The minimum number of samples required to split an internal node: - - - If int, then consider `min_samples_split` as the minimum number. - - If float, then `min_samples_split` is a fraction and - `ceil(min_samples_split * n_samples)` are the minimum - number of samples for each split. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_samples_leaf : int or float, default=1 - The minimum number of samples required to be at a leaf node. - A split point at any depth will only be considered if it leaves at - least ``min_samples_leaf`` training samples in each of the left and - right branches. This may have the effect of smoothing the model, - especially in regression. - - - If int, then consider `min_samples_leaf` as the minimum number. - - If float, then `min_samples_leaf` is a fraction and - `ceil(min_samples_leaf * n_samples)` are the minimum - number of samples for each node. - - .. versionchanged:: 0.18 - Added float values for fractions. - - min_weight_fraction_leaf : float, default=0.0 - The minimum weighted fraction of the sum total of weights (of all - the input samples) required to be at a leaf node. Samples have - equal weight when sample_weight is not provided. - - max_features : int, float, {"auto", "sqrt", "log2"} or None, default=1.0 - The number of features to consider when looking for the best split: - - - If int, then consider `max_features` features at each split. - - If float, then `max_features` is a fraction and - `max(1, int(max_features * n_features_in_))` features are considered at each - split. - - If "sqrt", then `max_features=sqrt(n_features)`. - - If "log2", then `max_features=log2(n_features)`. - - If None, then `max_features=n_features`. - - .. versionchanged:: 1.1 - The default of `max_features` changed from `"auto"` to `1.0`. - - Note: the search for a split does not stop until at least one - valid partition of the node samples is found, even if it requires to - effectively inspect more than ``max_features`` features. - - random_state : int, RandomState instance or None, default=None - Used to pick randomly the `max_features` used at each split. - See :term:`Glossary ` for details. - - min_impurity_decrease : float, default=0.0 - A node will be split if this split induces a decrease of the impurity - greater than or equal to this value. - - The weighted impurity decrease equation is the following:: - - N_t / N * (impurity - N_t_R / N_t * right_impurity - - N_t_L / N_t * left_impurity) - - where ``N`` is the total number of samples, ``N_t`` is the number of - samples at the current node, ``N_t_L`` is the number of samples in the - left child, and ``N_t_R`` is the number of samples in the right child. - - ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, - if ``sample_weight`` is passed. - - .. versionadded:: 0.19 - - max_leaf_nodes : int, default=None - Grow a tree with ``max_leaf_nodes`` in best-first fashion. - Best nodes are defined as relative reduction in impurity. - If None then unlimited number of leaf nodes. - - ccp_alpha : non-negative float, default=0.0 - Complexity parameter used for Minimal Cost-Complexity Pruning. The - subtree with the largest cost complexity that is smaller than - ``ccp_alpha`` will be chosen. By default, no pruning is performed. See - :ref:`minimal_cost_complexity_pruning` for details. - - .. versionadded:: 0.22 - - store_leaf_values : bool, default=False - Whether to store the samples that fall into leaves in the ``tree_`` attribute. - Each leaf will store a 2D array corresponding to the samples that fall into it - keyed by node_id. - - XXX: This is currently experimental and may change without notice. - Moreover, it can be improved upon since storing the samples twice is not ideal. - One could instead store the indices in ``y_train`` that fall into each leaf, - which would lower RAM/diskspace usage. - - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonicity constraint to enforce on each feature. - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease - - If monotonic_cst is None, no constraints are applied. - - Monotonicity constraints are not supported for: - - multioutput regressions (i.e. when `n_outputs_ > 1`), - - regressions trained on data with missing values. - - Read more in the :ref:`User Guide `. - - .. versionadded:: 1.4 - - Attributes - ---------- - max_features_ : int - The inferred value of max_features. - - n_features_in_ : int - Number of features seen during :term:`fit`. - - .. versionadded:: 0.24 - - feature_names_in_ : ndarray of shape (`n_features_in_`,) - Names of features seen during :term:`fit`. Defined only when `X` - has feature names that are all strings. - - .. versionadded:: 1.0 - - feature_importances_ : ndarray of shape (n_features,) - Return impurity-based feature importances (the higher, the more - important the feature). - - Warning: impurity-based feature importances can be misleading for - high cardinality features (many unique values). See - :func:`sklearn.inspection.permutation_importance` as an alternative. - - n_outputs_ : int - The number of outputs when ``fit`` is performed. - - tree_ : Tree instance - The underlying Tree object. Please refer to - ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object and - :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` - for basic usage of these attributes. - - builder_ : TreeBuilder instance - The underlying TreeBuilder object. - - See Also - -------- - ExtraTreeClassifier : An extremely randomized tree classifier. - sklearn.ensemble.ExtraTreesClassifier : An extra-trees classifier. - sklearn.ensemble.ExtraTreesRegressor : An extra-trees regressor. - - Notes - ----- - The default values for the parameters controlling the size of the trees - (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and - unpruned trees which can potentially be very large on some data sets. To - reduce memory consumption, the complexity and size of the trees should be - controlled by setting those parameter values. - - References - ---------- - - .. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized trees", - Machine Learning, 63(1), 3-42, 2006. - - Examples - -------- - >>> from sklearn.datasets import load_diabetes - >>> from sklearn.model_selection import train_test_split - >>> from sklearn.ensemble import BaggingRegressor - >>> from sklearn.tree import ExtraTreeRegressor - >>> X, y = load_diabetes(return_X_y=True) - >>> X_train, X_test, y_train, y_test = train_test_split( - ... X, y, random_state=0) - >>> extra_tree = ExtraTreeRegressor(random_state=0) - >>> reg = BaggingRegressor(extra_tree, random_state=0).fit( - ... X_train, y_train) - >>> reg.score(X_test, y_test) - 0.33... - """ - - def __init__( + Tree tree, + object X, + const DOUBLE_t[:, ::1] y, + const DOUBLE_t[:] sample_weight=*, + const unsigned char[::1] missing_values_in_feature_mask=*, + ) + + cdef _check_input( self, - *, - criterion="squared_error", - splitter="random", - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0.0, - max_features=1.0, - random_state=None, - min_impurity_decrease=0.0, - max_leaf_nodes=None, - ccp_alpha=0.0, - store_leaf_values=False, - monotonic_cst=None, - ): - super().__init__( - criterion=criterion, - splitter=splitter, - max_depth=max_depth, - min_samples_split=min_samples_split, - min_samples_leaf=min_samples_leaf, - min_weight_fraction_leaf=min_weight_fraction_leaf, - max_features=max_features, - max_leaf_nodes=max_leaf_nodes, - min_impurity_decrease=min_impurity_decrease, - random_state=random_state, - ccp_alpha=ccp_alpha, - store_leaf_values=store_leaf_values, - monotonic_cst=monotonic_cst, - ) + object X, + const DOUBLE_t[:, ::1] y, + const DOUBLE_t[:] sample_weight, + ) \ No newline at end of file From 5f06b27e5bfbd413b001c7426d9df44a0e9c5507 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Fri, 11 Aug 2023 11:17:03 -0400 Subject: [PATCH 3/7] FIX correct linting --- sklearn/tree/_classes.py | 6 ++++-- sklearn/tree/_tree.pxd | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 6234cb976aa08..98ee31d88c35b 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -563,7 +563,9 @@ def _build_tree( self.min_impurity_decrease, self.store_leaf_values, ) - self.builder_.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) + self.builder_.build( + self.tree_, X, y, sample_weight, missing_values_in_feature_mask + ) if self.n_outputs_ == 1 and is_classifier(self): self.n_classes_ = self.n_classes_[0] @@ -2355,4 +2357,4 @@ def __init__( ccp_alpha=ccp_alpha, store_leaf_values=store_leaf_values, monotonic_cst=monotonic_cst, - ) \ No newline at end of file + ) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 7bb8c50ee38d8..b4cb8444ecb69 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -190,4 +190,4 @@ cdef class TreeBuilder: object X, const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight, - ) \ No newline at end of file + ) From b1f75cf4d3e422e5af8eb309c6810ef23bbcc0c5 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Fri, 11 Aug 2023 11:24:53 -0400 Subject: [PATCH 4/7] FIX remove duplicate variable --- sklearn/tree/_tree.pxd | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index b4cb8444ecb69..507140cf52dce 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -174,8 +174,6 @@ cdef class TreeBuilder: const unsigned char[::1] missing_values_in_feature_mask=*, ) - cdef unsigned char store_leaf_values # Whether to store leaf values - cpdef build( self, Tree tree, From 63637f7da3469195e04339dc0a9c5aeae99407ed Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 11 Aug 2023 15:35:00 -0400 Subject: [PATCH 5/7] Fix unit tests Signed-off-by: Adam Li --- sklearn/tree/_classes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 98ee31d88c35b..b3fac9d433297 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -1303,7 +1303,9 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): self : DecisionTreeClassifier Fitted estimator. """ - + self._validate_params() + + # validate input parameters first_call = _check_partial_fit_first_call(self, classes=classes) # Fit if no tree exists yet From b6050760c520dddc64f1e6702290fbb9a2f0e0b7 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Fri, 11 Aug 2023 15:39:53 -0400 Subject: [PATCH 6/7] Fix lint Signed-off-by: Adam Li --- sklearn/tree/_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index b3fac9d433297..8783d45d0bfd7 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -1304,7 +1304,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): Fitted estimator. """ self._validate_params() - + # validate input parameters first_call = _check_partial_fit_first_call(self, classes=classes) From 368df7a381b14952f9cee477a7e81b3768dcd2d0 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Mon, 14 Aug 2023 12:59:08 -0400 Subject: [PATCH 7/7] Fixed sklearn set split node Signed-off-by: Adam Li --- sklearn/tree/_tree.pxd | 6 ++++-- sklearn/tree/_tree.pyx | 22 ++++++++++++++-------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 507140cf52dce..3f95ab2abfd6a 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -93,12 +93,14 @@ cdef class BaseTree: cdef int _set_split_node( self, SplitRecord* split_node, - Node* node + Node* node, + SIZE_t node_id, ) except -1 nogil cdef int _set_leaf_node( self, SplitRecord* split_node, - Node* node + Node* node, + SIZE_t node_id, ) except -1 nogil cdef DTYPE_t _compute_feature( self, diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index e307a53b850a5..afa6a1b8b040b 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1018,10 +1018,11 @@ cdef class BaseTree: self.capacity = capacity return 0 - cdef int _set_split_node( + cdef inline int _set_split_node( self, SplitRecord* split_node, - Node* node + Node* node, + SIZE_t node_id, ) except -1 nogil: """Set split node data. @@ -1031,16 +1032,19 @@ cdef class BaseTree: The pointer to the record of the split node data. node : Node* The pointer to the node that will hold the split node. + node_id : SIZE_t + The index of the node. """ # left_child and right_child will be set later for a split node node.feature = split_node.feature node.threshold = split_node.threshold return 1 - cdef int _set_leaf_node( + cdef inline int _set_leaf_node( self, SplitRecord* split_node, - Node* node + Node* node, + SIZE_t node_id, ) except -1 nogil: """Set leaf node data. @@ -1050,6 +1054,8 @@ cdef class BaseTree: The pointer to the record of the leaf node data. node : Node* The pointer to the node that will hold the leaf node. + node_id : SIZE_t + The index of the node. """ node.left_child = _TREE_LEAF node.right_child = _TREE_LEAF @@ -1125,11 +1131,11 @@ cdef class BaseTree: self.nodes[parent].right_child = node_id if is_leaf: - if self._set_leaf_node(split_node, node) != 1: + if self._set_leaf_node(split_node, node, node_id) != 1: with gil: raise RuntimeError else: - if self._set_split_node(split_node, node) != 1: + if self._set_split_node(split_node, node, node_id) != 1: with gil: raise RuntimeError node.missing_go_to_left = missing_go_to_left @@ -1170,11 +1176,11 @@ cdef class BaseTree: node.weighted_n_node_samples = weighted_n_node_samples if is_leaf: - if self._set_leaf_node(split_node, node) != 1: + if self._set_leaf_node(split_node, node, node_id) != 1: with gil: raise RuntimeError else: - if self._set_split_node(split_node, node) != 1: + if self._set_split_node(split_node, node, node_id) != 1: with gil: raise RuntimeError node.missing_go_to_left = missing_go_to_left