From 7baa11e938a9f27bf867be395345ce004ef30497 Mon Sep 17 00:00:00 2001 From: Shruti Nath <51656807+snath-xoc@users.noreply.github.com> Date: Fri, 6 Sep 2024 13:31:06 +0200 Subject: [PATCH] add sample weight to default scoring _log_reg_scoring_path (#29419) Co-authored-by: Shruti Nath Co-authored-by: Olivier Grisel Co-authored-by: Omar Salman --- doc/whats_new/v1.6.rst | 4 + sklearn/linear_model/_logistic.py | 9 +- sklearn/linear_model/tests/test_logistic.py | 170 +++++++++++++++----- 3 files changed, 135 insertions(+), 48 deletions(-) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 2021e9bb8ccc0..3a0d43d9e47fb 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -253,6 +253,10 @@ Changelog :mod:`sklearn.linear_model` ........................... +- |Fix| :class:`linear_model.LogisticRegressionCV` corrects sample weight handling + for the calculation of test scores. + :pr:`29419` by :user:`Shruti Nath `. + - |API| Deprecates `copy_X` in :class:`linear_model.TheilSenRegressor` as the parameter has no effect. `copy_X` will be removed in 1.8. :pr:`29105` by :user:`Adam Li `. diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 9dcf34e2a4aeb..2645ab9b81d18 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -737,9 +737,11 @@ def _log_reg_scoring_path( y_train = y[train] y_test = y[test] + sw_train, sw_test = None, None if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X) - sample_weight = sample_weight[train] + sw_train = sample_weight[train] + sw_test = sample_weight[test] coefs, Cs, n_iter = _logistic_regression_path( X_train, @@ -760,7 +762,7 @@ def _log_reg_scoring_path( random_state=random_state, check_input=False, max_squared_sum=max_squared_sum, - sample_weight=sample_weight, + sample_weight=sw_train, ) log_reg = LogisticRegression(solver=solver, multi_class=multi_class) @@ -794,12 +796,11 @@ def _log_reg_scoring_path( log_reg.intercept_ = 0.0 if scoring is None: - scores.append(log_reg.score(X_test, y_test)) + scores.append(log_reg.score(X_test, y_test, sample_weight=sw_test)) else: score_params = score_params or {} score_params = _check_method_params(X=X, params=score_params, indices=test) scores.append(scoring(log_reg, X_test, y_test, **score_params)) - return coefs, Cs, np.array(scores), n_iter diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index f47b06a0cc4b1..c8c98c80f67c3 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -31,6 +31,7 @@ from sklearn.metrics import get_scorer, log_loss from sklearn.model_selection import ( GridSearchCV, + LeaveOneGroupOut, StratifiedKFold, cross_val_score, train_test_split, @@ -775,86 +776,167 @@ def test_logistic_regressioncv_class_weights(weight, class_weight, global_random ) -def test_logistic_regression_sample_weights(): +@pytest.mark.parametrize("problem", ("single", "cv")) +@pytest.mark.parametrize( + "solver", ("lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga") +) +def test_logistic_regression_sample_weights(problem, solver, global_random_seed): + n_samples_per_cv_group = 200 + n_cv_groups = 3 + X, y = make_classification( - n_samples=20, n_features=5, n_informative=3, n_classes=2, random_state=0 + n_samples=n_samples_per_cv_group * n_cv_groups, + n_features=5, + n_informative=3, + n_classes=2, + n_redundant=0, + random_state=global_random_seed, ) + rng = np.random.RandomState(global_random_seed) + sw = np.ones(y.shape[0]) + + kw_weighted = { + "random_state": global_random_seed, + "fit_intercept": False, + "max_iter": 100_000 if solver.startswith("sag") else 1_000, + "tol": 1e-8, + } + kw_repeated = kw_weighted.copy() + sw[:n_samples_per_cv_group] = rng.randint(0, 5, size=n_samples_per_cv_group) + X_repeated = np.repeat(X, sw.astype(int), axis=0) + y_repeated = np.repeat(y, sw.astype(int), axis=0) + + if problem == "single": + LR = LogisticRegression + elif problem == "cv": + LR = LogisticRegressionCV + # We weight the first fold 2 times more. + groups_weighted = np.concatenate( + [ + np.full(n_samples_per_cv_group, 0), + np.full(n_samples_per_cv_group, 1), + np.full(n_samples_per_cv_group, 2), + ] + ) + splits_weighted = list(LeaveOneGroupOut().split(X, groups=groups_weighted)) + kw_weighted.update({"Cs": 100, "cv": splits_weighted}) + + groups_repeated = np.repeat(groups_weighted, sw.astype(int), axis=0) + splits_repeated = list( + LeaveOneGroupOut().split(X_repeated, groups=groups_repeated) + ) + kw_repeated.update({"Cs": 100, "cv": splits_repeated}) + + clf_sw_weighted = LR(solver=solver, **kw_weighted) + clf_sw_repeated = LR(solver=solver, **kw_repeated) + + if solver == "lbfgs": + # lbfgs has convergence issues on the data but this should not impact + # the quality of the results. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ConvergenceWarning) + clf_sw_weighted.fit(X, y, sample_weight=sw) + clf_sw_repeated.fit(X_repeated, y_repeated) + + else: + clf_sw_weighted.fit(X, y, sample_weight=sw) + clf_sw_repeated.fit(X_repeated, y_repeated) + + if problem == "cv": + assert_allclose(clf_sw_weighted.scores_[1], clf_sw_repeated.scores_[1]) + assert_allclose(clf_sw_weighted.coef_, clf_sw_repeated.coef_, atol=1e-5) + + +@pytest.mark.parametrize( + "solver", ("lbfgs", "newton-cg", "newton-cholesky", "sag", "saga") +) +def test_logistic_regression_solver_class_weights(solver, global_random_seed): + # Test that passing class_weight as [1, 2] is the same as + # passing class weight = [1,1] but adjusting sample weights + # to be 2 for all instances of class 1. + + X, y = make_classification( + n_samples=300, + n_features=5, + n_informative=3, + n_classes=2, + random_state=global_random_seed, + ) + sample_weight = y + 1 - for LR in [LogisticRegression, LogisticRegressionCV]: - kw = {"random_state": 42, "fit_intercept": False} - if LR is LogisticRegressionCV: - kw.update({"Cs": 3, "cv": 3}) - - # Test that passing sample_weight as ones is the same as - # not passing them at all (default None) - for solver in ["lbfgs", "liblinear"]: - clf_sw_none = LR(solver=solver, **kw) - clf_sw_ones = LR(solver=solver, **kw) - clf_sw_none.fit(X, y) - clf_sw_ones.fit(X, y, sample_weight=np.ones(y.shape[0])) - assert_allclose(clf_sw_none.coef_, clf_sw_ones.coef_, rtol=1e-4) - - # Test that sample weights work the same with the lbfgs, - # newton-cg, newton-cholesky and 'sag' solvers - clf_sw_lbfgs = LR(**kw, tol=1e-5) - clf_sw_lbfgs.fit(X, y, sample_weight=sample_weight) - for solver in set(SOLVERS) - set(["lbfgs"]): - clf_sw = LR(solver=solver, tol=1e-10 if solver == "sag" else 1e-5, **kw) - # ignore convergence warning due to small dataset with sag - with ignore_warnings(): - clf_sw.fit(X, y, sample_weight=sample_weight) - assert_allclose(clf_sw_lbfgs.coef_, clf_sw.coef_, rtol=1e-4) - - # Test that passing class_weight as [1,2] is the same as - # passing class weight = [1,1] but adjusting sample weights - # to be 2 for all instances of class 2 - for solver in ["lbfgs", "liblinear"]: - clf_cw_12 = LR(solver=solver, class_weight={0: 1, 1: 2}, **kw) - clf_cw_12.fit(X, y) - clf_sw_12 = LR(solver=solver, **kw) - clf_sw_12.fit(X, y, sample_weight=sample_weight) - assert_allclose(clf_cw_12.coef_, clf_sw_12.coef_, rtol=1e-4) + kw_weighted = { + "random_state": global_random_seed, + "fit_intercept": False, + "max_iter": 100_000, + "tol": 1e-8, + } + clf_cw_12 = LogisticRegression( + solver=solver, class_weight={0: 1, 1: 2}, **kw_weighted + ) + clf_cw_12.fit(X, y) + clf_sw_12 = LogisticRegression(solver=solver, **kw_weighted) + clf_sw_12.fit(X, y, sample_weight=sample_weight) + assert_allclose(clf_cw_12.coef_, clf_sw_12.coef_, atol=1e-6) + +def test_sample_and_class_weight_equivalence_liblinear(global_random_seed): # Test the above for l1 penalty and l2 penalty with dual=True. # since the patched liblinear code is different. + + X, y = make_classification( + n_samples=300, + n_features=5, + n_informative=3, + n_classes=2, + random_state=global_random_seed, + ) + + sample_weight = y + 1 + clf_cw = LogisticRegression( solver="liblinear", fit_intercept=False, class_weight={0: 1, 1: 2}, penalty="l1", - tol=1e-5, - random_state=42, + max_iter=10_000, + tol=1e-12, + random_state=global_random_seed, ) clf_cw.fit(X, y) clf_sw = LogisticRegression( solver="liblinear", fit_intercept=False, penalty="l1", - tol=1e-5, - random_state=42, + max_iter=10_000, + tol=1e-12, + random_state=global_random_seed, ) clf_sw.fit(X, y, sample_weight) - assert_array_almost_equal(clf_cw.coef_, clf_sw.coef_, decimal=4) + assert_allclose(clf_cw.coef_, clf_sw.coef_, atol=1e-10) clf_cw = LogisticRegression( solver="liblinear", fit_intercept=False, class_weight={0: 1, 1: 2}, penalty="l2", + max_iter=10_000, + tol=1e-12, dual=True, - random_state=42, + random_state=global_random_seed, ) clf_cw.fit(X, y) clf_sw = LogisticRegression( solver="liblinear", fit_intercept=False, penalty="l2", + max_iter=10_000, + tol=1e-12, dual=True, - random_state=42, + random_state=global_random_seed, ) clf_sw.fit(X, y, sample_weight) - assert_array_almost_equal(clf_cw.coef_, clf_sw.coef_, decimal=4) + assert_allclose(clf_cw.coef_, clf_sw.coef_, atol=1e-10) def _compute_class_weight_dictionary(y):