From 0b0b90b75102226d27e3947b9766729325d7042c Mon Sep 17 00:00:00 2001 From: antoinebaker Date: Fri, 6 Sep 2024 16:55:44 +0200 Subject: [PATCH] Make check_sample_weights_invariance cv-aware (#29796) --- sklearn/calibration.py | 11 -------- sklearn/linear_model/_logistic.py | 9 ------- sklearn/linear_model/_ridge.py | 19 +++++++------ sklearn/tests/test_calibration.py | 44 ------------------------------- sklearn/utils/estimator_checks.py | 22 +++++++++++++++- 5 files changed, 30 insertions(+), 75 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 2d49d6b1c3521..8b053f5382782 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -540,17 +540,6 @@ def get_metadata_routing(self): ) return router - def __sklearn_tags__(self): - tags = super().__sklearn_tags__() - tags._xfail_checks = { - "check_sample_weights_invariance": ( - "Due to the cross-validation and sample ordering, removing a sample" - " is not strictly equal to putting is weight to zero. Specific unit" - " tests are added for CalibratedClassifierCV specifically." - ), - } - return tags - def _fit_classifier_calibrator_pair( estimator, diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 2645ab9b81d18..b3ef71539e996 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -2269,15 +2269,6 @@ def get_metadata_routing(self): ) return router - def __sklearn_tags__(self): - tags = super().__sklearn_tags__() - tags._xfail_checks = { - "check_sample_weights_invariance": ( - "zero sample_weight is not equivalent to removing samples" - ), - } - return tags - def _get_scorer(self): """Get the scorer based on the scoring method specified. The default scoring method is `accuracy`. diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index 8e7b39d9b4ba1..827366fab2a25 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -2682,6 +2682,15 @@ def fit(self, X, y, sample_weight=None, **params): super().fit(X, y, sample_weight=sample_weight, **params) return self + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags._xfail_checks = { + "check_sample_weights_invariance": ( + "GridSearchCV does not forward the weights to the scorer by default." + ), + } + return tags + class RidgeClassifierCV(_RidgeClassifierMixin, _BaseRidgeCV): """Ridge classifier with built-in cross-validation. @@ -2891,13 +2900,3 @@ def fit(self, X, y, sample_weight=None, **params): target = Y if self.cv is None else y super().fit(X, target, sample_weight=sample_weight, **params) return self - - def __sklearn_tags__(self): - tags = super().__sklearn_tags__() - tags.classifier_tags.multi_label = True - tags._xfail_checks = { - "check_sample_weights_invariance": ( - "zero sample_weight is not equivalent to removing samples" - ), - } - return tags diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index b80083f3eac0d..d92512e42dc68 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -941,50 +941,6 @@ def fit(self, X, y, **fit_params): pc_clf.fit(X, y, sample_weight=sample_weight) -@pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) -@pytest.mark.parametrize("ensemble", [True, False]) -def test_calibrated_classifier_cv_zeros_sample_weights_equivalence(method, ensemble): - """Check that passing removing some sample from the dataset `X` is - equivalent to passing a `sample_weight` with a factor 0.""" - X, y = load_iris(return_X_y=True) - # Scale the data to avoid any convergence issue - X = StandardScaler().fit_transform(X) - # Only use 2 classes and select samples such that 2-fold cross-validation - # split will lead to an equivalence with a `sample_weight` of 0 - X = np.vstack((X[:40], X[50:90])) - y = np.hstack((y[:40], y[50:90])) - sample_weight = np.zeros_like(y) - sample_weight[::2] = 1 - - estimator = LogisticRegression() - calibrated_clf_without_weights = CalibratedClassifierCV( - estimator, - method=method, - ensemble=ensemble, - cv=2, - ) - calibrated_clf_with_weights = clone(calibrated_clf_without_weights) - - calibrated_clf_with_weights.fit(X, y, sample_weight=sample_weight) - calibrated_clf_without_weights.fit(X[::2], y[::2]) - - # Check that the underlying fitted estimators have the same coefficients - for est_with_weights, est_without_weights in zip( - calibrated_clf_with_weights.calibrated_classifiers_, - calibrated_clf_without_weights.calibrated_classifiers_, - ): - assert_allclose( - est_with_weights.estimator.coef_, - est_without_weights.estimator.coef_, - ) - - # Check that the predictions are the same - y_pred_with_weights = calibrated_clf_with_weights.predict_proba(X) - y_pred_without_weights = calibrated_clf_without_weights.predict_proba(X) - - assert_allclose(y_pred_with_weights, y_pred_without_weights) - - def test_calibration_with_non_sample_aligned_fit_param(data): """Check that CalibratedClassifierCV does not enforce sample alignment for fit parameters.""" diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 4680c987b4b3a..6da7c8eb1c7ff 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -36,7 +36,7 @@ from ..linear_model._base import LinearClassifierMixin from ..metrics import accuracy_score, adjusted_rand_score, f1_score from ..metrics.pairwise import linear_kernel, pairwise_distances, rbf_kernel -from ..model_selection import ShuffleSplit, train_test_split +from ..model_selection import LeaveOneGroupOut, ShuffleSplit, train_test_split from ..model_selection._validation import _safe_split from ..pipeline import make_pipeline from ..preprocessing import StandardScaler, scale @@ -1108,6 +1108,26 @@ def check_sample_weights_invariance(name, estimator_orig, kind="ones"): else: # pragma: no cover raise ValueError + # when the estimator has an internal CV scheme + # we only use weights / repetitions in a specific CV group (here group=0) + if "cv" in estimator_orig.get_params(): + groups2 = np.hstack( + [np.full_like(y2, 0), np.full_like(y1, 1), np.full_like(y1, 2)] + ) + sw2 = np.hstack([sw2, np.ones_like(y1), np.ones_like(y1)]) + X2 = np.vstack([X2, X1, X1]) + y2 = np.hstack([y2, y1, y1]) + splits2 = list(LeaveOneGroupOut().split(X2, groups=groups2)) + estimator2.set_params(cv=splits2) + + groups1 = np.hstack( + [np.full_like(y1, 0), np.full_like(y1, 1), np.full_like(y1, 2)] + ) + X1 = np.vstack([X1, X1, X1]) + y1 = np.hstack([y1, y1, y1]) + splits1 = list(LeaveOneGroupOut().split(X1, groups=groups1)) + estimator1.set_params(cv=splits1) + y1 = _enforce_estimator_tags_y(estimator1, y1) y2 = _enforce_estimator_tags_y(estimator2, y2)