forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH v2] Add partial fit to the correct branch for decisiontreeclassi…
…fier (#54) Supersedes: #50 Implements partial_fit API for all classification decision trees. --------- Signed-off-by: Adam Li <[email protected]> Co-authored-by: Haoyin Xu <[email protected]>
- Loading branch information
Showing
3 changed files
with
506 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,12 +11,12 @@ | |
# Joly Arnaud <[email protected]> | ||
# Fares Hedayati <[email protected]> | ||
# Nelson Liu <[email protected]> | ||
# Haoyin Xu <[email protected]> | ||
# | ||
# License: BSD 3 clause | ||
|
||
import copy | ||
import numbers | ||
import warnings | ||
from abc import ABCMeta, abstractmethod | ||
from math import ceil | ||
from numbers import Integral, Real | ||
|
@@ -35,7 +35,10 @@ | |
) | ||
from sklearn.utils import Bunch, check_random_state, compute_sample_weight | ||
from sklearn.utils._param_validation import Hidden, Interval, RealNotInt, StrOptions | ||
from sklearn.utils.multiclass import check_classification_targets | ||
from sklearn.utils.multiclass import ( | ||
_check_partial_fit_first_call, | ||
check_classification_targets, | ||
) | ||
from sklearn.utils.validation import ( | ||
_assert_all_finite_element_wise, | ||
_check_sample_weight, | ||
|
@@ -237,6 +240,7 @@ def _fit( | |
self, | ||
X, | ||
y, | ||
classes=None, | ||
sample_weight=None, | ||
check_input=True, | ||
missing_values_in_feature_mask=None, | ||
|
@@ -291,7 +295,6 @@ def _fit( | |
is_classification = False | ||
if y is not None: | ||
is_classification = is_classifier(self) | ||
|
||
y = np.atleast_1d(y) | ||
expanded_class_weight = None | ||
|
||
|
@@ -313,10 +316,28 @@ def _fit( | |
y_original = np.copy(y) | ||
|
||
y_encoded = np.zeros(y.shape, dtype=int) | ||
for k in range(self.n_outputs_): | ||
classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) | ||
self.classes_.append(classes_k) | ||
self.n_classes_.append(classes_k.shape[0]) | ||
if classes is not None: | ||
classes = np.atleast_1d(classes) | ||
if classes.ndim == 1: | ||
classes = np.array([classes]) | ||
|
||
for k in classes: | ||
self.classes_.append(np.array(k)) | ||
self.n_classes_.append(np.array(k).shape[0]) | ||
|
||
for i in range(n_samples): | ||
for j in range(self.n_outputs_): | ||
y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][ | ||
0 | ||
] | ||
else: | ||
for k in range(self.n_outputs_): | ||
classes_k, y_encoded[:, k] = np.unique( | ||
y[:, k], return_inverse=True | ||
) | ||
self.classes_.append(classes_k) | ||
self.n_classes_.append(classes_k.shape[0]) | ||
|
||
y = y_encoded | ||
|
||
if self.class_weight is not None: | ||
|
@@ -355,24 +376,8 @@ def _fit( | |
if self.max_features == "auto": | ||
if is_classification: | ||
max_features = max(1, int(np.sqrt(self.n_features_in_))) | ||
warnings.warn( | ||
( | ||
"`max_features='auto'` has been deprecated in 1.1 " | ||
"and will be removed in 1.3. To keep the past behaviour, " | ||
"explicitly set `max_features='sqrt'`." | ||
), | ||
FutureWarning, | ||
) | ||
else: | ||
max_features = self.n_features_in_ | ||
warnings.warn( | ||
( | ||
"`max_features='auto'` has been deprecated in 1.1 " | ||
"and will be removed in 1.3. To keep the past behaviour, " | ||
"explicitly set `max_features=1.0'`." | ||
), | ||
FutureWarning, | ||
) | ||
elif self.max_features == "sqrt": | ||
max_features = max(1, int(np.sqrt(self.n_features_in_))) | ||
elif self.max_features == "log2": | ||
|
@@ -538,7 +543,7 @@ def _build_tree( | |
|
||
# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise | ||
if max_leaf_nodes < 0: | ||
builder = DepthFirstTreeBuilder( | ||
self.builder_ = DepthFirstTreeBuilder( | ||
splitter, | ||
min_samples_split, | ||
min_samples_leaf, | ||
|
@@ -548,7 +553,7 @@ def _build_tree( | |
self.store_leaf_values, | ||
) | ||
else: | ||
builder = BestFirstTreeBuilder( | ||
self.builder_ = BestFirstTreeBuilder( | ||
splitter, | ||
min_samples_split, | ||
min_samples_leaf, | ||
|
@@ -558,7 +563,9 @@ def _build_tree( | |
self.min_impurity_decrease, | ||
self.store_leaf_values, | ||
) | ||
builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) | ||
self.builder_.build( | ||
self.tree_, X, y, sample_weight, missing_values_in_feature_mask | ||
) | ||
|
||
if self.n_outputs_ == 1 and is_classifier(self): | ||
self.n_classes_ = self.n_classes_[0] | ||
|
@@ -1119,6 +1126,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): | |
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` | ||
for basic usage of these attributes. | ||
builder_ : TreeBuilder instance | ||
The underlying TreeBuilder object. | ||
See Also | ||
-------- | ||
DecisionTreeRegressor : A decision tree regressor. | ||
|
@@ -1209,7 +1219,14 @@ def __init__( | |
) | ||
|
||
@_fit_context(prefer_skip_nested_validation=True) | ||
def fit(self, X, y, sample_weight=None, check_input=True): | ||
def fit( | ||
self, | ||
X, | ||
y, | ||
sample_weight=None, | ||
check_input=True, | ||
classes=None, | ||
): | ||
"""Build a decision tree classifier from the training set (X, y). | ||
Parameters | ||
|
@@ -1233,6 +1250,11 @@ def fit(self, X, y, sample_weight=None, check_input=True): | |
Allow to bypass several input checking. | ||
Don't use this parameter unless you know what you're doing. | ||
classes : array-like of shape (n_classes,), default=None | ||
List of all the classes that can possibly appear in the y vector. | ||
Must be provided at the first call to partial_fit, can be omitted | ||
in subsequent calls. | ||
Returns | ||
------- | ||
self : DecisionTreeClassifier | ||
|
@@ -1243,9 +1265,112 @@ def fit(self, X, y, sample_weight=None, check_input=True): | |
y, | ||
sample_weight=sample_weight, | ||
check_input=check_input, | ||
classes=classes, | ||
) | ||
return self | ||
|
||
def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): | ||
"""Update a decision tree classifier from the training set (X, y). | ||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
The training input samples. Internally, it will be converted to | ||
``dtype=np.float32`` and if a sparse matrix is provided | ||
to a sparse ``csc_matrix``. | ||
y : array-like of shape (n_samples,) or (n_samples, n_outputs) | ||
The target values (class labels) as integers or strings. | ||
classes : array-like of shape (n_classes,), default=None | ||
List of all the classes that can possibly appear in the y vector. | ||
Must be provided at the first call to partial_fit, can be omitted | ||
in subsequent calls. | ||
sample_weight : array-like of shape (n_samples,), default=None | ||
Sample weights. If None, then samples are equally weighted. Splits | ||
that would create child nodes with net zero or negative weight are | ||
ignored while searching for a split in each node. Splits are also | ||
ignored if they would result in any single class carrying a | ||
negative weight in either child node. | ||
check_input : bool, default=True | ||
Allow to bypass several input checking. | ||
Don't use this parameter unless you know what you do. | ||
Returns | ||
------- | ||
self : DecisionTreeClassifier | ||
Fitted estimator. | ||
""" | ||
self._validate_params() | ||
|
||
# validate input parameters | ||
first_call = _check_partial_fit_first_call(self, classes=classes) | ||
|
||
# Fit if no tree exists yet | ||
if first_call: | ||
self.fit( | ||
X, | ||
y, | ||
sample_weight=sample_weight, | ||
check_input=check_input, | ||
classes=classes, | ||
) | ||
return self | ||
|
||
if check_input: | ||
# Need to validate separately here. | ||
# We can't pass multi_ouput=True because that would allow y to be | ||
# csr. | ||
check_X_params = dict(dtype=DTYPE, accept_sparse="csc") | ||
check_y_params = dict(ensure_2d=False, dtype=None) | ||
X, y = self._validate_data( | ||
X, y, reset=False, validate_separately=(check_X_params, check_y_params) | ||
) | ||
if issparse(X): | ||
X.sort_indices() | ||
|
||
if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: | ||
raise ValueError( | ||
"No support for np.int64 index based sparse matrices" | ||
) | ||
|
||
if X.shape[1] != self.n_features_in_: | ||
msg = "Number of features %d does not match previous data %d." | ||
raise ValueError(msg % (X.shape[1], self.n_features_in_)) | ||
|
||
y = np.atleast_1d(y) | ||
|
||
if y.ndim == 1: | ||
# reshape is necessary to preserve the data contiguity against vs | ||
# [:, np.newaxis] that does not. | ||
y = np.reshape(y, (-1, 1)) | ||
|
||
check_classification_targets(y) | ||
y = np.copy(y) | ||
|
||
classes = self.classes_ | ||
if self.n_outputs_ == 1: | ||
classes = [classes] | ||
|
||
y_encoded = np.zeros(y.shape, dtype=int) | ||
for i in range(X.shape[0]): | ||
for j in range(self.n_outputs_): | ||
y_encoded[i, j] = np.where(classes[j] == y[i, j])[0][0] | ||
y = y_encoded | ||
|
||
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: | ||
y = np.ascontiguousarray(y, dtype=DOUBLE) | ||
|
||
# Update tree | ||
self.builder_.initialize_node_queue(self.tree_, X, y, sample_weight) | ||
self.builder_.build(self.tree_, X, y, sample_weight) | ||
|
||
self._prune_tree() | ||
|
||
return self | ||
|
||
def predict_proba(self, X, check_input=True): | ||
"""Predict class probabilities of the input samples X. | ||
|
@@ -1518,6 +1643,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): | |
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` | ||
for basic usage of these attributes. | ||
builder_ : TreeBuilder instance | ||
The underlying TreeBuilder object. | ||
See Also | ||
-------- | ||
DecisionTreeClassifier : A decision tree classifier. | ||
|
@@ -1600,7 +1728,14 @@ def __init__( | |
) | ||
|
||
@_fit_context(prefer_skip_nested_validation=True) | ||
def fit(self, X, y, sample_weight=None, check_input=True): | ||
def fit( | ||
self, | ||
X, | ||
y, | ||
sample_weight=None, | ||
check_input=True, | ||
classes=None, | ||
): | ||
"""Build a decision tree regressor from the training set (X, y). | ||
Parameters | ||
|
@@ -1623,6 +1758,9 @@ def fit(self, X, y, sample_weight=None, check_input=True): | |
Allow to bypass several input checking. | ||
Don't use this parameter unless you know what you're doing. | ||
classes : array-like of shape (n_classes,), default=None | ||
List of all the classes that can possibly appear in the y vector. | ||
Returns | ||
------- | ||
self : DecisionTreeRegressor | ||
|
@@ -1634,6 +1772,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): | |
y, | ||
sample_weight=sample_weight, | ||
check_input=check_input, | ||
classes=classes, | ||
) | ||
return self | ||
|
||
|
@@ -1885,6 +2024,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier): | |
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` | ||
for basic usage of these attributes. | ||
builder_ : TreeBuilder instance | ||
The underlying TreeBuilder object. | ||
See Also | ||
-------- | ||
ExtraTreeRegressor : An extremely randomized tree regressor. | ||
|
@@ -2147,6 +2289,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor): | |
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` | ||
for basic usage of these attributes. | ||
builder_ : TreeBuilder instance | ||
The underlying TreeBuilder object. | ||
See Also | ||
-------- | ||
ExtraTreeClassifier : An extremely randomized tree classifier. | ||
|
Oops, something went wrong.