Skip to content

Commit

Permalink
[ENH v2] Add partial fit to the correct branch for decisiontreeclassi…
Browse files Browse the repository at this point in the history
…fier (#54)

Supersedes: #50 

Implements partial_fit API for all classification decision trees.

---------

Signed-off-by: Adam Li <[email protected]>
Co-authored-by: Haoyin Xu <[email protected]>
  • Loading branch information
adam2392 and PSSF23 committed Sep 8, 2023
1 parent f042577 commit 351f14d
Show file tree
Hide file tree
Showing 3 changed files with 506 additions and 45 deletions.
201 changes: 173 additions & 28 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# Joly Arnaud <[email protected]>
# Fares Hedayati <[email protected]>
# Nelson Liu <[email protected]>
# Haoyin Xu <[email protected]>
#
# License: BSD 3 clause

import copy
import numbers
import warnings
from abc import ABCMeta, abstractmethod
from math import ceil
from numbers import Integral, Real
Expand All @@ -35,7 +35,10 @@
)
from sklearn.utils import Bunch, check_random_state, compute_sample_weight
from sklearn.utils._param_validation import Hidden, Interval, RealNotInt, StrOptions
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.multiclass import (
_check_partial_fit_first_call,
check_classification_targets,
)
from sklearn.utils.validation import (
_assert_all_finite_element_wise,
_check_sample_weight,
Expand Down Expand Up @@ -237,6 +240,7 @@ def _fit(
self,
X,
y,
classes=None,
sample_weight=None,
check_input=True,
missing_values_in_feature_mask=None,
Expand Down Expand Up @@ -291,7 +295,6 @@ def _fit(
is_classification = False
if y is not None:
is_classification = is_classifier(self)

y = np.atleast_1d(y)
expanded_class_weight = None

Expand All @@ -313,10 +316,28 @@ def _fit(
y_original = np.copy(y)

y_encoded = np.zeros(y.shape, dtype=int)
for k in range(self.n_outputs_):
classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True)
self.classes_.append(classes_k)
self.n_classes_.append(classes_k.shape[0])
if classes is not None:
classes = np.atleast_1d(classes)
if classes.ndim == 1:
classes = np.array([classes])

for k in classes:
self.classes_.append(np.array(k))
self.n_classes_.append(np.array(k).shape[0])

for i in range(n_samples):
for j in range(self.n_outputs_):
y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][
0
]
else:
for k in range(self.n_outputs_):
classes_k, y_encoded[:, k] = np.unique(
y[:, k], return_inverse=True
)
self.classes_.append(classes_k)
self.n_classes_.append(classes_k.shape[0])

y = y_encoded

if self.class_weight is not None:
Expand Down Expand Up @@ -355,24 +376,8 @@ def _fit(
if self.max_features == "auto":
if is_classification:
max_features = max(1, int(np.sqrt(self.n_features_in_)))
warnings.warn(
(
"`max_features='auto'` has been deprecated in 1.1 "
"and will be removed in 1.3. To keep the past behaviour, "
"explicitly set `max_features='sqrt'`."
),
FutureWarning,
)
else:
max_features = self.n_features_in_
warnings.warn(
(
"`max_features='auto'` has been deprecated in 1.1 "
"and will be removed in 1.3. To keep the past behaviour, "
"explicitly set `max_features=1.0'`."
),
FutureWarning,
)
elif self.max_features == "sqrt":
max_features = max(1, int(np.sqrt(self.n_features_in_)))
elif self.max_features == "log2":
Expand Down Expand Up @@ -538,7 +543,7 @@ def _build_tree(

# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
if max_leaf_nodes < 0:
builder = DepthFirstTreeBuilder(
self.builder_ = DepthFirstTreeBuilder(
splitter,
min_samples_split,
min_samples_leaf,
Expand All @@ -548,7 +553,7 @@ def _build_tree(
self.store_leaf_values,
)
else:
builder = BestFirstTreeBuilder(
self.builder_ = BestFirstTreeBuilder(
splitter,
min_samples_split,
min_samples_leaf,
Expand All @@ -558,7 +563,9 @@ def _build_tree(
self.min_impurity_decrease,
self.store_leaf_values,
)
builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask)
self.builder_.build(
self.tree_, X, y, sample_weight, missing_values_in_feature_mask
)

if self.n_outputs_ == 1 and is_classifier(self):
self.n_classes_ = self.n_classes_[0]
Expand Down Expand Up @@ -1119,6 +1126,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
for basic usage of these attributes.
builder_ : TreeBuilder instance
The underlying TreeBuilder object.
See Also
--------
DecisionTreeRegressor : A decision tree regressor.
Expand Down Expand Up @@ -1209,7 +1219,14 @@ def __init__(
)

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y, sample_weight=None, check_input=True):
def fit(
self,
X,
y,
sample_weight=None,
check_input=True,
classes=None,
):
"""Build a decision tree classifier from the training set (X, y).
Parameters
Expand All @@ -1233,6 +1250,11 @@ def fit(self, X, y, sample_weight=None, check_input=True):
Allow to bypass several input checking.
Don't use this parameter unless you know what you're doing.
classes : array-like of shape (n_classes,), default=None
List of all the classes that can possibly appear in the y vector.
Must be provided at the first call to partial_fit, can be omitted
in subsequent calls.
Returns
-------
self : DecisionTreeClassifier
Expand All @@ -1243,9 +1265,112 @@ def fit(self, X, y, sample_weight=None, check_input=True):
y,
sample_weight=sample_weight,
check_input=check_input,
classes=classes,
)
return self

def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True):
"""Update a decision tree classifier from the training set (X, y).
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The training input samples. Internally, it will be converted to
``dtype=np.float32`` and if a sparse matrix is provided
to a sparse ``csc_matrix``.
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
The target values (class labels) as integers or strings.
classes : array-like of shape (n_classes,), default=None
List of all the classes that can possibly appear in the y vector.
Must be provided at the first call to partial_fit, can be omitted
in subsequent calls.
sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If None, then samples are equally weighted. Splits
that would create child nodes with net zero or negative weight are
ignored while searching for a split in each node. Splits are also
ignored if they would result in any single class carrying a
negative weight in either child node.
check_input : bool, default=True
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.
Returns
-------
self : DecisionTreeClassifier
Fitted estimator.
"""
self._validate_params()

# validate input parameters
first_call = _check_partial_fit_first_call(self, classes=classes)

# Fit if no tree exists yet
if first_call:
self.fit(
X,
y,
sample_weight=sample_weight,
check_input=check_input,
classes=classes,
)
return self

if check_input:
# Need to validate separately here.
# We can't pass multi_ouput=True because that would allow y to be
# csr.
check_X_params = dict(dtype=DTYPE, accept_sparse="csc")
check_y_params = dict(ensure_2d=False, dtype=None)
X, y = self._validate_data(
X, y, reset=False, validate_separately=(check_X_params, check_y_params)
)
if issparse(X):
X.sort_indices()

if X.indices.dtype != np.intc or X.indptr.dtype != np.intc:
raise ValueError(
"No support for np.int64 index based sparse matrices"
)

if X.shape[1] != self.n_features_in_:
msg = "Number of features %d does not match previous data %d."
raise ValueError(msg % (X.shape[1], self.n_features_in_))

y = np.atleast_1d(y)

if y.ndim == 1:
# reshape is necessary to preserve the data contiguity against vs
# [:, np.newaxis] that does not.
y = np.reshape(y, (-1, 1))

check_classification_targets(y)
y = np.copy(y)

classes = self.classes_
if self.n_outputs_ == 1:
classes = [classes]

y_encoded = np.zeros(y.shape, dtype=int)
for i in range(X.shape[0]):
for j in range(self.n_outputs_):
y_encoded[i, j] = np.where(classes[j] == y[i, j])[0][0]
y = y_encoded

if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)

# Update tree
self.builder_.initialize_node_queue(self.tree_, X, y, sample_weight)
self.builder_.build(self.tree_, X, y, sample_weight)

self._prune_tree()

return self

def predict_proba(self, X, check_input=True):
"""Predict class probabilities of the input samples X.
Expand Down Expand Up @@ -1518,6 +1643,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
for basic usage of these attributes.
builder_ : TreeBuilder instance
The underlying TreeBuilder object.
See Also
--------
DecisionTreeClassifier : A decision tree classifier.
Expand Down Expand Up @@ -1600,7 +1728,14 @@ def __init__(
)

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y, sample_weight=None, check_input=True):
def fit(
self,
X,
y,
sample_weight=None,
check_input=True,
classes=None,
):
"""Build a decision tree regressor from the training set (X, y).
Parameters
Expand All @@ -1623,6 +1758,9 @@ def fit(self, X, y, sample_weight=None, check_input=True):
Allow to bypass several input checking.
Don't use this parameter unless you know what you're doing.
classes : array-like of shape (n_classes,), default=None
List of all the classes that can possibly appear in the y vector.
Returns
-------
self : DecisionTreeRegressor
Expand All @@ -1634,6 +1772,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
y,
sample_weight=sample_weight,
check_input=check_input,
classes=classes,
)
return self

Expand Down Expand Up @@ -1885,6 +2024,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier):
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
for basic usage of these attributes.
builder_ : TreeBuilder instance
The underlying TreeBuilder object.
See Also
--------
ExtraTreeRegressor : An extremely randomized tree regressor.
Expand Down Expand Up @@ -2147,6 +2289,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor):
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
for basic usage of these attributes.
builder_ : TreeBuilder instance
The underlying TreeBuilder object.
See Also
--------
ExtraTreeClassifier : An extremely randomized tree classifier.
Expand Down
Loading

0 comments on commit 351f14d

Please sign in to comment.