From 351f14d94c9c94abc62374e0bdb308039088ee2f Mon Sep 17 00:00:00 2001 From: Adam Li Date: Mon, 14 Aug 2023 13:35:48 -0400 Subject: [PATCH] [ENH v2] Add partial fit to the correct branch for decisiontreeclassifier (#54) Supersedes: #50 Implements partial_fit API for all classification decision trees. --------- Signed-off-by: Adam Li Co-authored-by: Haoyin Xu --- sklearn/tree/_classes.py | 201 ++++++++++++++++++++---- sklearn/tree/_tree.pxd | 29 +++- sklearn/tree/_tree.pyx | 321 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 506 insertions(+), 45 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 26267a1355f6f..8783d45d0bfd7 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -11,12 +11,12 @@ # Joly Arnaud # Fares Hedayati # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause import copy import numbers -import warnings from abc import ABCMeta, abstractmethod from math import ceil from numbers import Integral, Real @@ -35,7 +35,10 @@ ) from sklearn.utils import Bunch, check_random_state, compute_sample_weight from sklearn.utils._param_validation import Hidden, Interval, RealNotInt, StrOptions -from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.multiclass import ( + _check_partial_fit_first_call, + check_classification_targets, +) from sklearn.utils.validation import ( _assert_all_finite_element_wise, _check_sample_weight, @@ -237,6 +240,7 @@ def _fit( self, X, y, + classes=None, sample_weight=None, check_input=True, missing_values_in_feature_mask=None, @@ -291,7 +295,6 @@ def _fit( is_classification = False if y is not None: is_classification = is_classifier(self) - y = np.atleast_1d(y) expanded_class_weight = None @@ -313,10 +316,28 @@ def _fit( y_original = np.copy(y) y_encoded = np.zeros(y.shape, dtype=int) - for k in range(self.n_outputs_): - classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) - self.classes_.append(classes_k) - self.n_classes_.append(classes_k.shape[0]) + if classes is not None: + classes = np.atleast_1d(classes) + if classes.ndim == 1: + classes = np.array([classes]) + + for k in classes: + self.classes_.append(np.array(k)) + self.n_classes_.append(np.array(k).shape[0]) + + for i in range(n_samples): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][ + 0 + ] + else: + for k in range(self.n_outputs_): + classes_k, y_encoded[:, k] = np.unique( + y[:, k], return_inverse=True + ) + self.classes_.append(classes_k) + self.n_classes_.append(classes_k.shape[0]) + y = y_encoded if self.class_weight is not None: @@ -355,24 +376,8 @@ def _fit( if self.max_features == "auto": if is_classification: max_features = max(1, int(np.sqrt(self.n_features_in_))) - warnings.warn( - ( - "`max_features='auto'` has been deprecated in 1.1 " - "and will be removed in 1.3. To keep the past behaviour, " - "explicitly set `max_features='sqrt'`." - ), - FutureWarning, - ) else: max_features = self.n_features_in_ - warnings.warn( - ( - "`max_features='auto'` has been deprecated in 1.1 " - "and will be removed in 1.3. To keep the past behaviour, " - "explicitly set `max_features=1.0'`." - ), - FutureWarning, - ) elif self.max_features == "sqrt": max_features = max(1, int(np.sqrt(self.n_features_in_))) elif self.max_features == "log2": @@ -538,7 +543,7 @@ def _build_tree( # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: - builder = DepthFirstTreeBuilder( + self.builder_ = DepthFirstTreeBuilder( splitter, min_samples_split, min_samples_leaf, @@ -548,7 +553,7 @@ def _build_tree( self.store_leaf_values, ) else: - builder = BestFirstTreeBuilder( + self.builder_ = BestFirstTreeBuilder( splitter, min_samples_split, min_samples_leaf, @@ -558,7 +563,9 @@ def _build_tree( self.min_impurity_decrease, self.store_leaf_values, ) - builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) + self.builder_.build( + self.tree_, X, y, sample_weight, missing_values_in_feature_mask + ) if self.n_outputs_ == 1 and is_classifier(self): self.n_classes_ = self.n_classes_[0] @@ -1119,6 +1126,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- DecisionTreeRegressor : A decision tree regressor. @@ -1209,7 +1219,14 @@ def __init__( ) @_fit_context(prefer_skip_nested_validation=True) - def fit(self, X, y, sample_weight=None, check_input=True): + def fit( + self, + X, + y, + sample_weight=None, + check_input=True, + classes=None, + ): """Build a decision tree classifier from the training set (X, y). Parameters @@ -1233,6 +1250,11 @@ def fit(self, X, y, sample_weight=None, check_input=True): Allow to bypass several input checking. Don't use this parameter unless you know what you're doing. + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + Returns ------- self : DecisionTreeClassifier @@ -1243,9 +1265,112 @@ def fit(self, X, y, sample_weight=None, check_input=True): y, sample_weight=sample_weight, check_input=check_input, + classes=classes, ) return self + def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): + """Update a decision tree classifier from the training set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you do. + + Returns + ------- + self : DecisionTreeClassifier + Fitted estimator. + """ + self._validate_params() + + # validate input parameters + first_call = _check_partial_fit_first_call(self, classes=classes) + + # Fit if no tree exists yet + if first_call: + self.fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + classes=classes, + ) + return self + + if check_input: + # Need to validate separately here. + # We can't pass multi_ouput=True because that would allow y to be + # csr. + check_X_params = dict(dtype=DTYPE, accept_sparse="csc") + check_y_params = dict(ensure_2d=False, dtype=None) + X, y = self._validate_data( + X, y, reset=False, validate_separately=(check_X_params, check_y_params) + ) + if issparse(X): + X.sort_indices() + + if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: + raise ValueError( + "No support for np.int64 index based sparse matrices" + ) + + if X.shape[1] != self.n_features_in_: + msg = "Number of features %d does not match previous data %d." + raise ValueError(msg % (X.shape[1], self.n_features_in_)) + + y = np.atleast_1d(y) + + if y.ndim == 1: + # reshape is necessary to preserve the data contiguity against vs + # [:, np.newaxis] that does not. + y = np.reshape(y, (-1, 1)) + + check_classification_targets(y) + y = np.copy(y) + + classes = self.classes_ + if self.n_outputs_ == 1: + classes = [classes] + + y_encoded = np.zeros(y.shape, dtype=int) + for i in range(X.shape[0]): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(classes[j] == y[i, j])[0][0] + y = y_encoded + + if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: + y = np.ascontiguousarray(y, dtype=DOUBLE) + + # Update tree + self.builder_.initialize_node_queue(self.tree_, X, y, sample_weight) + self.builder_.build(self.tree_, X, y, sample_weight) + + self._prune_tree() + + return self + def predict_proba(self, X, check_input=True): """Predict class probabilities of the input samples X. @@ -1518,6 +1643,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- DecisionTreeClassifier : A decision tree classifier. @@ -1600,7 +1728,14 @@ def __init__( ) @_fit_context(prefer_skip_nested_validation=True) - def fit(self, X, y, sample_weight=None, check_input=True): + def fit( + self, + X, + y, + sample_weight=None, + check_input=True, + classes=None, + ): """Build a decision tree regressor from the training set (X, y). Parameters @@ -1623,6 +1758,9 @@ def fit(self, X, y, sample_weight=None, check_input=True): Allow to bypass several input checking. Don't use this parameter unless you know what you're doing. + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Returns ------- self : DecisionTreeRegressor @@ -1634,6 +1772,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): y, sample_weight=sample_weight, check_input=check_input, + classes=classes, ) return self @@ -1885,6 +2024,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- ExtraTreeRegressor : An extremely randomized tree regressor. @@ -2147,6 +2289,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- ExtraTreeClassifier : An extremely randomized tree classifier. diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index dedd820c41e0f..3f95ab2abfd6a 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -5,6 +5,7 @@ # Arnaud Joly # Jacob Schreiber # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause @@ -52,6 +53,7 @@ cdef class BaseTree: # Generic Methods: These are generic methods used by any tree. cdef int _resize(self, SIZE_t capacity) except -1 nogil cdef int _resize_c(self, SIZE_t capacity=*) except -1 nogil + cdef SIZE_t _add_node( self, SIZE_t parent, @@ -63,6 +65,17 @@ cdef class BaseTree: double weighted_n_node_samples, unsigned char missing_go_to_left ) except -1 nogil + cdef SIZE_t _update_node( + self, + SIZE_t parent, + bint is_left, + bint is_leaf, + SplitRecord* split_node, + double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples, + unsigned char missing_go_to_left + ) except -1 nogil # Python API methods: These are methods exposed to Python cpdef cnp.ndarray apply(self, object X) @@ -80,12 +93,14 @@ cdef class BaseTree: cdef int _set_split_node( self, SplitRecord* split_node, - Node* node + Node* node, + SIZE_t node_id, ) except -1 nogil cdef int _set_leaf_node( self, SplitRecord* split_node, - Node* node + Node* node, + SIZE_t node_id, ) except -1 nogil cdef DTYPE_t _compute_feature( self, @@ -148,9 +163,19 @@ cdef class TreeBuilder: cdef double min_weight_leaf # Minimum weight in a leaf cdef SIZE_t max_depth # Maximal tree depth cdef double min_impurity_decrease # Impurity threshold for early stopping + cdef object initial_roots # Leaf nodes for streaming updates cdef unsigned char store_leaf_values # Whether to store leaf values + cpdef initialize_node_queue( + self, + Tree tree, + object X, + const DOUBLE_t[:, ::1] y, + const DOUBLE_t[:] sample_weight=*, + const unsigned char[::1] missing_values_in_feature_mask=*, + ) + cpdef build( self, Tree tree, diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 492b5219fa18e..afa6a1b8b040b 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -12,6 +12,7 @@ # Fares Hedayati # Jacob Schreiber # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause @@ -91,6 +92,17 @@ NODE_DTYPE = np.asarray((&dummy)).dtype cdef class TreeBuilder: """Interface for different tree building strategies.""" + cpdef initialize_node_queue( + self, + Tree tree, + object X, + const DOUBLE_t[:, ::1] y, + const DOUBLE_t[:] sample_weight=None, + const unsigned char[::1] missing_values_in_feature_mask=None, + ): + """Build a decision tree from the training set (X, y).""" + pass + cpdef build( self, Tree tree, @@ -165,7 +177,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): double min_weight_leaf, SIZE_t max_depth, double min_impurity_decrease, - unsigned char store_leaf_values=False + unsigned char store_leaf_values=False, + cnp.ndarray initial_roots=None, ): self.splitter = splitter self.min_samples_split = min_samples_split @@ -174,6 +187,71 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): self.max_depth = max_depth self.min_impurity_decrease = min_impurity_decrease self.store_leaf_values = store_leaf_values + self.initial_roots = initial_roots + + def __reduce__(self): + """Reduce re-implementation, for pickling.""" + return(DepthFirstTreeBuilder, (self.splitter, self.min_samples_split, + self.min_samples_leaf, + self.min_weight_leaf, + self.max_depth, + self.min_impurity_decrease, + self.store_leaf_values, + self.initial_roots)) + + cpdef initialize_node_queue( + self, + Tree tree, + object X, + const DOUBLE_t[:, ::1] y, + const DOUBLE_t[:] sample_weight=None, + const unsigned char[::1] missing_values_in_feature_mask=None, + ): + """Initialize a list of roots""" + X, y, sample_weight = self._check_input(X, y, sample_weight) + + # organize samples by decision paths + paths = tree.decision_path(X) + cdef int PARENT + cdef int CHILD + false_roots = {} + X_copy = {} + y_copy = {} + for i in range(X.shape[0]): + depth_i = paths[i].indices.shape[0] - 1 + PARENT = depth_i - 1 + CHILD = depth_i + + if PARENT < 0: + parent_i = 0 + else: + parent_i = paths[i].indices[PARENT] + child_i = paths[i].indices[CHILD] + left = 0 + if tree.children_left[parent_i] == child_i: + left = 1 + + if (parent_i, left) in false_roots: + false_roots[(parent_i, left)][0] += 1 + X_copy[(parent_i, left)].append(X[i]) + y_copy[(parent_i, left)].append(y[i]) + else: + false_roots[(parent_i, left)] = [1, depth_i] + X_copy[(parent_i, left)] = [X[i]] + y_copy[(parent_i, left)] = [y[i]] + + X_list = [] + y_list = [] + for key, value in reversed(sorted(X_copy.items())): + X_list = X_list + value + y_list = y_list + y_copy[key] + cdef object X_new = np.array(X_list) + cdef cnp.ndarray y_new = np.array(y_list) + + cdef Splitter splitter = self.splitter + splitter.init(X_new, y_new, sample_weight, missing_values_in_feature_mask) + + self.initial_roots = false_roots cpdef build( self, @@ -206,11 +284,14 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split cdef double min_impurity_decrease = self.min_impurity_decrease - # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight, missing_values_in_feature_mask) + cdef bint first = 0 + if self.initial_roots is None: + # Recursive partition (without actual recursion) + splitter.init(X, y, sample_weight, missing_values_in_feature_mask) + first = 1 - cdef SIZE_t start - cdef SIZE_t end + cdef SIZE_t start = 0 + cdef SIZE_t end = 0 cdef SIZE_t depth cdef SIZE_t parent cdef bint is_left @@ -227,14 +308,33 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double middle_value cdef SIZE_t n_constant_features cdef bint is_leaf - cdef bint first = 1 - cdef SIZE_t max_depth_seen = -1 + cdef SIZE_t max_depth_seen = -1 if first else tree.max_depth cdef int rc = 0 cdef stack[StackRecord] builder_stack + cdef stack[StackRecord] update_stack cdef StackRecord stack_record - with nogil: + if not first: + # push reached leaf nodes onto stack + for key, value in reversed(sorted(self.initial_roots.items())): + end += value[0] + update_stack.push({ + "start": start, + "end": end, + "depth": value[1], + "parent": key[0], + "is_left": key[1], + "impurity": tree.impurity[key[0]], + "n_constant_features": 0, + "lower_bound": -INFINITY, + "upper_bound": INFINITY, + }) + start += value[0] + if rc == -1: + # got return code -1 - out-of-memory + raise MemoryError() + else: # push root node onto stack builder_stack.push({ "start": 0, @@ -247,6 +347,135 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): "lower_bound": -INFINITY, "upper_bound": INFINITY, }) + if rc == -1: + # got return code -1 - out-of-memory + raise MemoryError() + + with nogil: + while not update_stack.empty(): + stack_record = update_stack.top() + update_stack.pop() + + start = stack_record.start + end = stack_record.end + depth = stack_record.depth + parent = stack_record.parent + is_left = stack_record.is_left + impurity = stack_record.impurity + n_constant_features = stack_record.n_constant_features + lower_bound = stack_record.lower_bound + upper_bound = stack_record.upper_bound + + n_node_samples = end - start + splitter.node_reset(start, end, &weighted_n_node_samples) + + is_leaf = (depth >= max_depth or + n_node_samples < min_samples_split or + n_node_samples < 2 * min_samples_leaf or + weighted_n_node_samples < 2 * min_weight_leaf) + + # impurity == 0 with tolerance due to rounding errors + is_leaf = is_leaf or impurity <= EPSILON + + if not is_leaf: + splitter.node_split( + impurity, + split_ptr, + &n_constant_features, + lower_bound, + upper_bound + ) + + # assign local copy of SplitRecord to assign + # pos, improvement, and impurity scores + split = deref(split_ptr) + + # If EPSILON=0 in the below comparison, float precision + # issues stop splitting, producing trees that are + # dissimilar to v0.18 + is_leaf = (is_leaf or split.pos >= end or + (split.improvement + EPSILON < + min_impurity_decrease)) + + node_id = tree._update_node(parent, is_left, is_leaf, + split_ptr, impurity, n_node_samples, + weighted_n_node_samples, + split.missing_go_to_left) + + if node_id == INTPTR_MAX: + rc = -1 + break + + # Store value for all nodes, to facilitate tree/model + # inspection and interpretation + splitter.node_value(tree.value + node_id * tree.value_stride) + if splitter.with_monotonic_cst: + splitter.clip_node_value(tree.value + node_id * tree.value_stride, lower_bound, upper_bound) + + if not is_leaf: + if ( + not splitter.with_monotonic_cst or + splitter.monotonic_cst[split.feature] == 0 + ): + # Split on a feature with no monotonicity constraint + + # Current bounds must always be propagated to both children. + # If a monotonic constraint is active, bounds are used in + # node value clipping. + left_child_min = right_child_min = lower_bound + left_child_max = right_child_max = upper_bound + elif splitter.monotonic_cst[split.feature] == 1: + # Split on a feature with monotonic increase constraint + left_child_min = lower_bound + right_child_max = upper_bound + + # Lower bound for right child and upper bound for left child + # are set to the same value. + middle_value = splitter.criterion.middle_value() + right_child_min = middle_value + left_child_max = middle_value + else: # i.e. splitter.monotonic_cst[split.feature] == -1 + # Split on a feature with monotonic decrease constraint + right_child_min = lower_bound + left_child_max = upper_bound + + # Lower bound for left child and upper bound for right child + # are set to the same value. + middle_value = splitter.criterion.middle_value() + left_child_min = middle_value + right_child_max = middle_value + + # Push right child on stack + builder_stack.push({ + "start": split.pos, + "end": end, + "depth": depth + 1, + "parent": node_id, + "is_left": 0, + "impurity": split.impurity_right, + "n_constant_features": n_constant_features, + "lower_bound": right_child_min, + "upper_bound": right_child_max, + }) + + # Push left child on stack + builder_stack.push({ + "start": start, + "end": split.pos, + "depth": depth + 1, + "parent": node_id, + "is_left": 1, + "impurity": split.impurity_left, + "n_constant_features": n_constant_features, + "lower_bound": left_child_min, + "upper_bound": left_child_max, + }) + elif self.store_leaf_values and is_leaf: + # copy leaf values to leaf_values array + splitter.node_samples(tree.value_samples[node_id]) + + if depth > max_depth_seen: + max_depth_seen = depth while not builder_stack.empty(): stack_record = builder_stack.top() @@ -272,7 +501,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if first: impurity = splitter.node_impurity() - first = 0 + first=0 # impurity == 0 with tolerance due to rounding errors is_leaf = is_leaf or impurity <= EPSILON @@ -388,6 +617,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if rc == -1: raise MemoryError() + self.initial_roots = None # Best first builder ---------------------------------------------------------- cdef struct FrontierRecord: @@ -441,6 +671,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): SIZE_t max_leaf_nodes, double min_impurity_decrease, unsigned char store_leaf_values=False, + cnp.ndarray initial_roots=None, ): self.splitter = splitter self.min_samples_split = min_samples_split @@ -450,6 +681,17 @@ cdef class BestFirstTreeBuilder(TreeBuilder): self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.store_leaf_values = store_leaf_values + self.initial_roots = initial_roots + + def __reduce__(self): + """Reduce re-implementation, for pickling.""" + return(BestFirstTreeBuilder, (self.splitter, self.min_samples_split, + self.min_samples_leaf, + self.min_weight_leaf, self.max_depth, + self.max_leaf_nodes, + self.min_impurity_decrease, + self.store_leaf_values, + self.initial_roots)) cpdef build( self, @@ -776,10 +1018,11 @@ cdef class BaseTree: self.capacity = capacity return 0 - cdef int _set_split_node( + cdef inline int _set_split_node( self, SplitRecord* split_node, - Node* node + Node* node, + SIZE_t node_id, ) except -1 nogil: """Set split node data. @@ -789,16 +1032,19 @@ cdef class BaseTree: The pointer to the record of the split node data. node : Node* The pointer to the node that will hold the split node. + node_id : SIZE_t + The index of the node. """ # left_child and right_child will be set later for a split node node.feature = split_node.feature node.threshold = split_node.threshold return 1 - cdef int _set_leaf_node( + cdef inline int _set_leaf_node( self, SplitRecord* split_node, - Node* node + Node* node, + SIZE_t node_id, ) except -1 nogil: """Set leaf node data. @@ -808,6 +1054,8 @@ cdef class BaseTree: The pointer to the record of the leaf node data. node : Node* The pointer to the node that will hold the leaf node. + node_id : SIZE_t + The index of the node. """ node.left_child = _TREE_LEAF node.right_child = _TREE_LEAF @@ -883,11 +1131,11 @@ cdef class BaseTree: self.nodes[parent].right_child = node_id if is_leaf: - if self._set_leaf_node(split_node, node) != 1: + if self._set_leaf_node(split_node, node, node_id) != 1: with gil: raise RuntimeError else: - if self._set_split_node(split_node, node) != 1: + if self._set_split_node(split_node, node, node_id) != 1: with gil: raise RuntimeError node.missing_go_to_left = missing_go_to_left @@ -896,6 +1144,49 @@ cdef class BaseTree: return node_id + cdef SIZE_t _update_node( + self, + SIZE_t parent, + bint is_left, + bint is_leaf, + SplitRecord* split_node, + double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples, + unsigned char missing_go_to_left + ) except -1 nogil: + """Update a node on the tree. + + The updated node remains on the same position. + Returns (size_t)(-1) on error. + """ + cdef SIZE_t node_id + if is_left: + node_id = self.nodes[parent].left_child + else: + node_id = self.nodes[parent].right_child + + if node_id >= self.capacity: + if self._resize_c() != 0: + return INTPTR_MAX + + cdef Node* node = &self.nodes[node_id] + node.impurity = impurity + node.n_node_samples = n_node_samples + node.weighted_n_node_samples = weighted_n_node_samples + + if is_leaf: + if self._set_leaf_node(split_node, node, node_id) != 1: + with gil: + raise RuntimeError + else: + if self._set_split_node(split_node, node, node_id) != 1: + with gil: + raise RuntimeError + node.missing_go_to_left = missing_go_to_left + + return node_id + cpdef cnp.ndarray apply(self, object X): """Finds the terminal region (=leaf node) for each sample in X.""" if issparse(X):