Skip to content

Commit

Permalink
Make check_sample_weights_invariance cv-aware (scikit-learn#29796)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinebaker authored Sep 6, 2024
1 parent 7baa11e commit 0b0b90b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 75 deletions.
11 changes: 0 additions & 11 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,17 +540,6 @@ def get_metadata_routing(self):
)
return router

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks = {
"check_sample_weights_invariance": (
"Due to the cross-validation and sample ordering, removing a sample"
" is not strictly equal to putting is weight to zero. Specific unit"
" tests are added for CalibratedClassifierCV specifically."
),
}
return tags


def _fit_classifier_calibrator_pair(
estimator,
Expand Down
9 changes: 0 additions & 9 deletions sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,15 +2269,6 @@ def get_metadata_routing(self):
)
return router

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks = {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
),
}
return tags

def _get_scorer(self):
"""Get the scorer based on the scoring method specified.
The default scoring method is `accuracy`.
Expand Down
19 changes: 9 additions & 10 deletions sklearn/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2682,6 +2682,15 @@ def fit(self, X, y, sample_weight=None, **params):
super().fit(X, y, sample_weight=sample_weight, **params)
return self

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks = {
"check_sample_weights_invariance": (
"GridSearchCV does not forward the weights to the scorer by default."
),
}
return tags


class RidgeClassifierCV(_RidgeClassifierMixin, _BaseRidgeCV):
"""Ridge classifier with built-in cross-validation.
Expand Down Expand Up @@ -2891,13 +2900,3 @@ def fit(self, X, y, sample_weight=None, **params):
target = Y if self.cv is None else y
super().fit(X, target, sample_weight=sample_weight, **params)
return self

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.classifier_tags.multi_label = True
tags._xfail_checks = {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
),
}
return tags
44 changes: 0 additions & 44 deletions sklearn/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,50 +941,6 @@ def fit(self, X, y, **fit_params):
pc_clf.fit(X, y, sample_weight=sample_weight)


@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
@pytest.mark.parametrize("ensemble", [True, False])
def test_calibrated_classifier_cv_zeros_sample_weights_equivalence(method, ensemble):
"""Check that passing removing some sample from the dataset `X` is
equivalent to passing a `sample_weight` with a factor 0."""
X, y = load_iris(return_X_y=True)
# Scale the data to avoid any convergence issue
X = StandardScaler().fit_transform(X)
# Only use 2 classes and select samples such that 2-fold cross-validation
# split will lead to an equivalence with a `sample_weight` of 0
X = np.vstack((X[:40], X[50:90]))
y = np.hstack((y[:40], y[50:90]))
sample_weight = np.zeros_like(y)
sample_weight[::2] = 1

estimator = LogisticRegression()
calibrated_clf_without_weights = CalibratedClassifierCV(
estimator,
method=method,
ensemble=ensemble,
cv=2,
)
calibrated_clf_with_weights = clone(calibrated_clf_without_weights)

calibrated_clf_with_weights.fit(X, y, sample_weight=sample_weight)
calibrated_clf_without_weights.fit(X[::2], y[::2])

# Check that the underlying fitted estimators have the same coefficients
for est_with_weights, est_without_weights in zip(
calibrated_clf_with_weights.calibrated_classifiers_,
calibrated_clf_without_weights.calibrated_classifiers_,
):
assert_allclose(
est_with_weights.estimator.coef_,
est_without_weights.estimator.coef_,
)

# Check that the predictions are the same
y_pred_with_weights = calibrated_clf_with_weights.predict_proba(X)
y_pred_without_weights = calibrated_clf_without_weights.predict_proba(X)

assert_allclose(y_pred_with_weights, y_pred_without_weights)


def test_calibration_with_non_sample_aligned_fit_param(data):
"""Check that CalibratedClassifierCV does not enforce sample alignment
for fit parameters."""
Expand Down
22 changes: 21 additions & 1 deletion sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ..linear_model._base import LinearClassifierMixin
from ..metrics import accuracy_score, adjusted_rand_score, f1_score
from ..metrics.pairwise import linear_kernel, pairwise_distances, rbf_kernel
from ..model_selection import ShuffleSplit, train_test_split
from ..model_selection import LeaveOneGroupOut, ShuffleSplit, train_test_split
from ..model_selection._validation import _safe_split
from ..pipeline import make_pipeline
from ..preprocessing import StandardScaler, scale
Expand Down Expand Up @@ -1108,6 +1108,26 @@ def check_sample_weights_invariance(name, estimator_orig, kind="ones"):
else: # pragma: no cover
raise ValueError

# when the estimator has an internal CV scheme
# we only use weights / repetitions in a specific CV group (here group=0)
if "cv" in estimator_orig.get_params():
groups2 = np.hstack(
[np.full_like(y2, 0), np.full_like(y1, 1), np.full_like(y1, 2)]
)
sw2 = np.hstack([sw2, np.ones_like(y1), np.ones_like(y1)])
X2 = np.vstack([X2, X1, X1])
y2 = np.hstack([y2, y1, y1])
splits2 = list(LeaveOneGroupOut().split(X2, groups=groups2))
estimator2.set_params(cv=splits2)

groups1 = np.hstack(
[np.full_like(y1, 0), np.full_like(y1, 1), np.full_like(y1, 2)]
)
X1 = np.vstack([X1, X1, X1])
y1 = np.hstack([y1, y1, y1])
splits1 = list(LeaveOneGroupOut().split(X1, groups=groups1))
estimator1.set_params(cv=splits1)

y1 = _enforce_estimator_tags_y(estimator1, y1)
y2 = _enforce_estimator_tags_y(estimator2, y2)

Expand Down

0 comments on commit 0b0b90b

Please sign in to comment.