Skip to content

Commit

Permalink
TST remove _required_parameters and improve instance generation (scik…
Browse files Browse the repository at this point in the history
  • Loading branch information
adrinjalali authored Sep 6, 2024
1 parent eb29207 commit 95e9459
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 141 deletions.
9 changes: 0 additions & 9 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -562,15 +562,6 @@ for your estimator's tags. For example::
You can create a new subclass of :class:`~sklearn.utils.Tags` if you wish
to add new tags to the existing set.

In addition to the tags, estimators also need to declare any non-optional
parameters to ``__init__`` in the ``_required_parameters`` class attribute,
which is a list or tuple. If ``_required_parameters`` is only
``["estimator"]`` or ``["base_estimator"]``, then the estimator will be
instantiated with an instance of ``LogisticRegression`` (or
``RidgeRegression`` if the estimator is a regressor) in the tests. The choice
of these two models is somewhat idiosyncratic but both should provide robust
closed-form solutions.

.. _developer_api_set_output:

Developer API for `set_output`
Expand Down
9 changes: 5 additions & 4 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,9 +1037,12 @@ def fit_predict(self, X, y=None, **kwargs):
class MetaEstimatorMixin:
"""Mixin class for all meta estimators in scikit-learn.
This mixin defines the following functionality:
This mixin is empty, and only exists to indicate that the estimator is a
meta-estimator.
- define `_required_parameters` that specify the mandatory `estimator` parameter.
.. versionchanged:: 1.6
The `_required_parameters` is now removed and is unnecessary since tests are
refactored and don't use this anymore.
Examples
--------
Expand All @@ -1061,8 +1064,6 @@ class MetaEstimatorMixin:
LogisticRegression()
"""

_required_parameters = ["estimator"]


class MultiOutputMixin:
"""Mixin to mark estimators that support multioutput."""
Expand Down
17 changes: 15 additions & 2 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,6 @@ class ColumnTransformer(TransformerMixin, _BaseComposition):
:ref:`sphx_glr_auto_examples_compose_plot_column_transformer_mixed_types.py`.
"""

_required_parameters = ["transformers"]

_parameter_constraints: dict = {
"transformers": [list, Hidden(tuple)],
"remainder": [
Expand Down Expand Up @@ -1322,6 +1320,21 @@ def get_metadata_routing(self):

return router

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks = {
"check_estimators_empty_data_messages": "FIXME",
"check_estimators_nan_inf": "FIXME",
"check_estimator_sparse_array": "FIXME",
"check_estimator_sparse_matrix": "FIXME",
"check_transformer_data_not_an_array": "FIXME",
"check_fit1d": "FIXME",
"check_fit2d_predict1d": "FIXME",
"check_complex_data": "FIXME",
"check_fit2d_1feature": "FIXME",
}
return tags


def _check_X(X):
"""Use check_array only when necessary, e.g. on lists and other non-array-likes."""
Expand Down
2 changes: 0 additions & 2 deletions sklearn/decomposition/_dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,8 +1279,6 @@ class SparseCoder(_BaseSparseCoding, BaseEstimator):
[ 0., 1., 1., 0., 0.]])
"""

_required_parameters = ["dictionary"]

def __init__(
self,
dictionary,
Expand Down
6 changes: 0 additions & 6 deletions sklearn/ensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABCMeta, abstractmethod
from typing import List

import numpy as np
from joblib import effective_n_jobs
Expand Down Expand Up @@ -106,9 +105,6 @@ class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
The collection of fitted base estimators.
"""

# overwrite _required_parameters from MetaEstimatorMixin
_required_parameters: List[str] = []

@abstractmethod
def __init__(
self,
Expand Down Expand Up @@ -200,8 +196,6 @@ class _BaseHeterogeneousEnsemble(
appear in `estimators_`.
"""

_required_parameters = ["estimators"]

@property
def named_estimators(self):
"""Dictionary to access any fitted sub-estimators by name.
Expand Down
1 change: 0 additions & 1 deletion sklearn/model_selection/_classification_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ class BaseThresholdClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator
error.
"""

_required_parameters = ["estimator"]
_parameter_constraints: dict = {
"estimator": [
HasMethods(["fit", "predict_proba"]),
Expand Down
4 changes: 0 additions & 4 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,8 +1532,6 @@ class GridSearchCV(BaseSearchCV):
'std_fit_time', 'std_score_time', 'std_test_score']
"""

_required_parameters = ["estimator", "param_grid"]

_parameter_constraints: dict = {
**BaseSearchCV._parameter_constraints,
"param_grid": [dict, list],
Expand Down Expand Up @@ -1913,8 +1911,6 @@ class RandomizedSearchCV(BaseSearchCV):
{'C': np.float64(2...), 'penalty': 'l1'}
"""

_required_parameters = ["estimator", "param_distributions"]

_parameter_constraints: dict = {
**BaseSearchCV._parameter_constraints,
"param_distributions": [dict, list],
Expand Down
7 changes: 3 additions & 4 deletions sklearn/model_selection/_search_successive_halving.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ def __sklearn_tags__(self):
"Fail during parameter check since min/max resources requires"
" more samples"
),
"check_estimators_nan_inf": "FIXME",
"check_classifiers_one_label_sample_weights": "FIXME",
"check_fit2d_1feature": "FIXME",
}
)
return tags
Expand Down Expand Up @@ -668,8 +671,6 @@ class HalvingGridSearchCV(BaseSuccessiveHalving):
{'max_depth': None, 'min_samples_split': 10, 'n_estimators': 9}
"""

_required_parameters = ["estimator", "param_grid"]

_parameter_constraints: dict = {
**BaseSuccessiveHalving._parameter_constraints,
"param_grid": [dict, list],
Expand Down Expand Up @@ -1018,8 +1019,6 @@ class HalvingRandomSearchCV(BaseSuccessiveHalving):
{'max_depth': None, 'min_samples_split': 10, 'n_estimators': 9}
"""

_required_parameters = ["estimator", "param_distributions"]

_parameter_constraints: dict = {
**BaseSuccessiveHalving._parameter_constraints,
"param_distributions": [dict, list],
Expand Down
13 changes: 9 additions & 4 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ class Pipeline(_BaseComposition):
"""

# BaseEstimator interface
_required_parameters = ["steps"]

_parameter_constraints: dict = {
"steps": [list, Hidden(tuple)],
"memory": [None, str, HasMethods(["cache"])],
Expand Down Expand Up @@ -1427,8 +1425,6 @@ class FeatureUnion(TransformerMixin, _BaseComposition):
:ref:`sphx_glr_auto_examples_compose_plot_feature_union.py`.
"""

_required_parameters = ["transformer_list"]

def __init__(
self,
transformer_list,
Expand Down Expand Up @@ -1882,6 +1878,15 @@ def get_metadata_routing(self):

return router

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks = {
"check_estimators_overwrite_params": "FIXME",
"check_estimators_nan_inf": "FIXME",
"check_dont_overwrite_parameters": "FIXME",
}
return tags


def make_union(*transformers, n_jobs=None, verbose=False):
"""Construct a :class:`FeatureUnion` from the given transformers.
Expand Down
24 changes: 8 additions & 16 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@
MeanShift,
SpectralClustering,
)
from sklearn.compose import ColumnTransformer
from sklearn.datasets import make_blobs
from sklearn.exceptions import ConvergenceWarning, FitFailedWarning

# make it possible to discover experimental estimators when calling `all_estimators`
from sklearn.experimental import (
enable_halving_search_cv, # noqa
enable_iterative_imputer, # noqa
)

# make it possible to discover experimental estimators when calling `all_estimators`
from sklearn.linear_model import LogisticRegression
from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding
from sklearn.neighbors import (
Expand All @@ -43,7 +44,7 @@
RadiusNeighborsClassifier,
RadiusNeighborsRegressor,
)
from sklearn.pipeline import make_pipeline
from sklearn.pipeline import FeatureUnion, make_pipeline
from sklearn.preprocessing import (
FunctionTransformer,
MinMaxScaler,
Expand All @@ -54,11 +55,9 @@
from sklearn.utils import all_estimators
from sklearn.utils._tags import get_tags
from sklearn.utils._test_common.instance_generator import (
_generate_column_transformer_instances,
_generate_pipeline,
_generate_search_cv_instances,
_get_check_estimator_ids,
_set_checking_parameters,
_tested_estimators,
)
from sklearn.utils._testing import (
Expand Down Expand Up @@ -139,7 +138,6 @@ def test_estimators(estimator, check, request):
with ignore_warnings(
category=(FutureWarning, ConvergenceWarning, UserWarning, LinAlgWarning)
):
_set_checking_parameters(estimator)
check(estimator)


Expand Down Expand Up @@ -285,7 +283,6 @@ def check_field_types(tags, defaults):
"estimator", _tested_estimators(), ids=_get_check_estimator_ids
)
def test_check_n_features_in_after_fitting(estimator):
_set_checking_parameters(estimator)
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)


Expand Down Expand Up @@ -324,7 +321,8 @@ def _estimators_that_predict_in_fit():
"estimator", column_name_estimators, ids=_get_check_estimator_ids
)
def test_pandas_column_name_consistency(estimator):
_set_checking_parameters(estimator)
if isinstance(estimator, ColumnTransformer):
pytest.skip("ColumnTransformer is not tested here")
with ignore_warnings(category=(FutureWarning)):
with warnings.catch_warnings(record=True) as record:
check_dataframe_column_names_consistency(
Expand Down Expand Up @@ -360,7 +358,6 @@ def _include_in_get_feature_names_out_check(transformer):
"transformer", GET_FEATURES_OUT_ESTIMATORS, ids=_get_check_estimator_ids
)
def test_transformers_get_feature_names_out(transformer):
_set_checking_parameters(transformer)

with ignore_warnings(category=(FutureWarning)):
check_transformer_get_feature_names_out(
Expand All @@ -381,7 +378,6 @@ def test_transformers_get_feature_names_out(transformer):
)
def test_estimators_get_feature_names_out_error(estimator):
estimator_name = estimator.__class__.__name__
_set_checking_parameters(estimator)
check_get_feature_names_out_error(estimator_name, estimator)


Expand Down Expand Up @@ -409,14 +405,14 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
chain(
_tested_estimators(),
_generate_pipeline(),
_generate_column_transformer_instances(),
_generate_search_cv_instances(),
),
ids=_get_check_estimator_ids,
)
def test_check_param_validation(estimator):
if isinstance(estimator, FeatureUnion):
pytest.skip("FeatureUnion is not tested here")
name = estimator.__class__.__name__
_set_checking_parameters(estimator)
check_param_validation(name, estimator)


Expand Down Expand Up @@ -481,7 +477,6 @@ def test_set_output_transform(estimator):
f"Skipping check_set_output_transform for {name}: Does not support"
" set_output API"
)
_set_checking_parameters(estimator)
with ignore_warnings(category=(FutureWarning)):
check_set_output_transform(estimator.__class__.__name__, estimator)

Expand All @@ -505,7 +500,6 @@ def test_set_output_transform_configured(estimator, check_func):
f"Skipping {check_func.__name__} for {name}: Does not support"
" set_output API yet"
)
_set_checking_parameters(estimator)
with ignore_warnings(category=(FutureWarning)):
check_func(estimator.__class__.__name__, estimator)

Expand All @@ -523,8 +517,6 @@ def test_check_inplace_ensure_writeable(estimator):
else:
raise SkipTest(f"{name} doesn't require writeable input.")

_set_checking_parameters(estimator)

# The following estimators can work inplace only with certain settings
if name == "HDBSCAN":
estimator.set_params(metric="precomputed", algorithm="brute")
Expand Down
Loading

0 comments on commit 95e9459

Please sign in to comment.