From 575b8d1b2cfcd89c401383dec3487a0586d8c95f Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Tue, 13 Aug 2024 11:11:41 +0300 Subject: [PATCH 01/12] additional model configs added --- rectools/models/base.py | 15 ++++++++++++++ rectools/models/ease.py | 25 ++++++++++++++++++++++- rectools/models/implicit_als.py | 15 +------------- rectools/models/pure_svd.py | 34 +++++++++++++++++++++++++++++++- rectools/models/random.py | 23 +++++++++++++++++++++- tests/models/test_ease.py | 25 +++++++++++++++++++++++ tests/models/test_pure_svd.py | 35 +++++++++++++++++++++++++++++++++ tests/models/test_random.py | 23 ++++++++++++++++++++++ 8 files changed, 178 insertions(+), 17 deletions(-) diff --git a/rectools/models/base.py b/rectools/models/base.py index 98c9e46c..2339177b 100644 --- a/rectools/models/base.py +++ b/rectools/models/base.py @@ -19,6 +19,7 @@ import numpy as np import pandas as pd import typing_extensions as tpe +from pydantic import PlainSerializer from pydantic_core import PydanticSerializationError from rectools import AnyIds, Columns, InternalIds @@ -40,6 +41,20 @@ RecoTriplet_T = tp.TypeVar("RecoTriplet_T", InternalRecoTriplet, SemiInternalRecoTriplet, RecoTriplet) +def _serialize_random_state(rs: tp.Optional[tp.Union[None, int, np.random.RandomState]]) -> tp.Union[None, int]: + if rs is None or isinstance(rs, int): + return rs + + # NOBUG: We can add serialization using get/set_state, but it's not human readable + raise TypeError("`random_state` must be ``None`` or have ``int`` type to convert it to simple type") + + +RandomState = tpe.Annotated[ + tp.Union[None, int, np.random.RandomState], + PlainSerializer(func=_serialize_random_state, when_used="json"), +] + + class ModelConfig(BaseConfig): """Base model config.""" diff --git a/rectools/models/ease.py b/rectools/models/ease.py index 36e90135..212db172 100644 --- a/rectools/models/ease.py +++ b/rectools/models/ease.py @@ -17,17 +17,26 @@ import typing as tp import numpy as np +import typing_extensions as tpe from scipy import sparse from rectools import InternalIds from rectools.dataset import Dataset +from rectools.models.base import ModelConfig from rectools.types import InternalIdsArray from .base import ModelBase, Scores from .rank import Distance, ImplicitRanker -class EASEModel(ModelBase): +class EASEModelConfig(ModelConfig): + """Config for `EASE` model.""" + + regularization: float = 500.0 + num_threads: int = 1 + + +class EASEModel(ModelBase[EASEModelConfig]): """ Embarrassingly Shallow Autoencoders for Sparse Data model. @@ -51,17 +60,31 @@ class EASEModel(ModelBase): recommends_for_warm = False recommends_for_cold = False + config_class = EASEModelConfig + def __init__( self, regularization: float = 500.0, num_threads: int = 1, verbose: int = 0, ): + self._config = self._make_config(regularization, num_threads, verbose) + super().__init__(verbose=verbose) self.weight: np.ndarray self.regularization = regularization self.num_threads = num_threads + def _make_config(self, regularization: float, num_threads: int, verbose: int) -> EASEModelConfig: + return EASEModelConfig(regularization=regularization, num_threads=num_threads, verbose=verbose) + + def _get_config(self) -> EASEModelConfig: + return self._config + + @classmethod + def _from_config(cls, config: EASEModelConfig) -> tpe.Self: + return cls(regularization=config.regularization, num_threads=config.num_threads, verbose=config.verbose) + def _fit(self, dataset: Dataset) -> None: # type: ignore ui_csr = dataset.get_user_item_matrix(include_weights=True) diff --git a/rectools/models/implicit_als.py b/rectools/models/implicit_als.py index 6fcc1548..6cc55768 100644 --- a/rectools/models/implicit_als.py +++ b/rectools/models/implicit_als.py @@ -32,6 +32,7 @@ from rectools.utils.config import BaseConfig from rectools.utils.misc import get_class_or_function_full_path, import_object +from .base import RandomState from .rank import Distance from .vector import Factors, VectorModel @@ -68,20 +69,6 @@ def _serialize_alternating_least_squares_class( ] -def _serialize_random_state(rs: tp.Optional[tp.Union[None, int, np.random.RandomState]]) -> tp.Union[None, int]: - if rs is None or isinstance(rs, int): - return rs - - # NOBUG: We can add serialization using get/set_state, but it's not human readable - raise TypeError("`random_state` must be ``None`` or have ``int`` type to convert it to simple type") - - -RandomState = tpe.Annotated[ - tp.Union[None, int, np.random.RandomState], - PlainSerializer(func=_serialize_random_state, when_used="json"), -] - - class AlternatingLeastSquaresParams(tpe.TypedDict): """Params for implicit `AlternatingLeastSquares` model.""" diff --git a/rectools/models/pure_svd.py b/rectools/models/pure_svd.py index 9ef9f874..c9ca79e4 100644 --- a/rectools/models/pure_svd.py +++ b/rectools/models/pure_svd.py @@ -17,15 +17,26 @@ import typing as tp import numpy as np +import typing_extensions as tpe from scipy.sparse.linalg import svds from rectools.dataset import Dataset from rectools.exceptions import NotFittedError +from rectools.models.base import ModelConfig from rectools.models.rank import Distance from rectools.models.vector import Factors, VectorModel -class PureSVDModel(VectorModel): +class PureSVDModelConfig(ModelConfig): + """Config for `PureSVD` model.""" + + factors: int = 10 + tol: float = 0 + maxiter: tp.Optional[int] = None + random_state: tp.Optional[int] = None + + +class PureSVDModel(VectorModel[PureSVDModelConfig]): """ PureSVD matrix factorization model. @@ -51,6 +62,8 @@ class PureSVDModel(VectorModel): u2i_dist = Distance.DOT i2i_dist = Distance.COSINE + config_class = PureSVDModelConfig + def __init__( self, factors: int = 10, @@ -59,6 +72,7 @@ def __init__( random_state: tp.Optional[int] = None, verbose: int = 0, ): + self._config = self._make_config(factors, tol, maxiter, random_state, verbose) super().__init__(verbose=verbose) self.factors = factors @@ -69,6 +83,24 @@ def __init__( self.user_factors: np.ndarray self.item_factors: np.ndarray + def _make_config( + self, factors: int, tol: float, maxiter: tp.Optional[int], random_state: tp.Optional[int], verbose: int + ) -> PureSVDModelConfig: + return PureSVDModelConfig(factors=factors, tol=tol, maxiter=maxiter, random_state=random_state, verbose=verbose) + + def _get_config(self) -> PureSVDModelConfig: + return self._config + + @classmethod + def _from_config(cls, config: PureSVDModelConfig) -> tpe.Self: + return cls( + factors=config.factors, + tol=config.tol, + maxiter=config.maxiter, + random_state=config.random_state, + verbose=config.verbose, + ) + def _fit(self, dataset: Dataset) -> None: # type: ignore ui_csr = dataset.get_user_item_matrix(include_weights=True) diff --git a/rectools/models/random.py b/rectools/models/random.py index df84f2b6..8a9820af 100644 --- a/rectools/models/random.py +++ b/rectools/models/random.py @@ -18,10 +18,12 @@ import typing as tp import numpy as np +import typing_extensions as tpe from tqdm.auto import tqdm from rectools import InternalIds from rectools.dataset import Dataset +from rectools.models.base import ModelConfig from rectools.types import AnyIdsArray, InternalId, InternalIdsArray from rectools.utils import fast_isin_for_sorted_test_elements @@ -50,7 +52,13 @@ def sample(self, n: int) -> np.ndarray: return sampled -class RandomModel(ModelBase): +class RandomModelConfig(ModelConfig): + """Config for `Random` model.""" + + random_state: tp.Optional[int] = None + + +class RandomModel(ModelBase[RandomModelConfig]): """ Model generating random recommendations. @@ -70,13 +78,26 @@ class RandomModel(ModelBase): recommends_for_warm = False recommends_for_cold = True + config_class = RandomModelConfig + def __init__(self, random_state: tp.Optional[int] = None, verbose: int = 0): + self._config = self._make_config(random_state, verbose) super().__init__(verbose=verbose) self.random_state = random_state self.random_gen = _RandomGen(random_state) self.all_item_ids: np.ndarray + def _make_config(self, random_state: tp.Optional[int], verbose: int) -> RandomModelConfig: + return RandomModelConfig(random_state=random_state, verbose=verbose) + + def _get_config(self) -> RandomModelConfig: + return self._config + + @classmethod + def _from_config(cls, config: RandomModelConfig) -> tpe.Self: + return cls(random_state=config.random_state, verbose=config.verbose) + def _fit(self, dataset: Dataset) -> None: # type: ignore self.all_item_ids = dataset.item_id_map.internal_ids diff --git a/tests/models/test_ease.py b/tests/models/test_ease.py index 2eb75d05..84eef131 100644 --- a/tests/models/test_ease.py +++ b/tests/models/test_ease.py @@ -27,6 +27,31 @@ class TestEASEModel: + def test_from_config(self) -> None: + config = { + "regularization": 500, + "num_threads": 1, + "verbose": 1, + } + model = EASEModel.from_config(config) + assert model.num_threads == 1 + assert model.verbose == 1 + assert model.regularization == 500 + + def test_get_config(self) -> None: + model = EASEModel( + regularization=500, + num_threads=1, + verbose=1, + ) + config = model.get_config() + expected = { + "regularization": 500, + "num_threads": 1, + "verbose": 1, + } + assert config == expected + @pytest.fixture def dataset(self) -> Dataset: return DATASET diff --git a/tests/models/test_pure_svd.py b/tests/models/test_pure_svd.py index 43c145a3..b33af705 100644 --- a/tests/models/test_pure_svd.py +++ b/tests/models/test_pure_svd.py @@ -29,6 +29,41 @@ class TestPureSVDModel: + + def test_from_config(self) -> None: + config = { + "factors": 100, + "tol": 0, + "maxiter": 100, + "random_state": 32, + "verbose": 0, + } + model = PureSVDModel.from_config(config) + assert model.factors == 100 + assert model.tol == 0 + assert model.maxiter == 100 + assert model.random_state == 32 + assert model.verbose == 0 + + @pytest.mark.parametrize("random_state", (None, 42)) + def test_get_config(self, random_state: tp.Optional[int]) -> None: + model = PureSVDModel( + factors=100, + tol=1, + maxiter=100, + random_state=random_state, + verbose=1, + ) + config = model.get_config() + expected = { + "factors": 100, + "tol": 1, + "maxiter": 100, + "random_state": random_state, + "verbose": 1, + } + assert config == expected + @pytest.fixture def dataset(self) -> Dataset: return DATASET diff --git a/tests/models/test_random.py b/tests/models/test_random.py index 618b3741..90975c35 100644 --- a/tests/models/test_random.py +++ b/tests/models/test_random.py @@ -28,6 +28,29 @@ class TestRandomSampler: + + def test_from_config(self) -> None: + config = { + "random_state": 32, + "verbose": 0, + } + model = RandomModel.from_config(config) + assert model.random_state == 32 + assert model.verbose == 0 + + @pytest.mark.parametrize("random_state", (None, 42)) + def test_get_config(self, random_state: tp.Optional[int]) -> None: + model = RandomModel( + random_state=random_state, + verbose=1, + ) + config = model.get_config() + expected = { + "random_state": random_state, + "verbose": 1, + } + assert config == expected + def test_sample_small_n(self) -> None: gen = _RandomGen(42) sampler = _RandomSampler(np.arange(10), gen) From 7cc9e363624dadf1ca3de55b9a3ef4d222fb6841 Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Tue, 13 Aug 2024 11:38:39 +0300 Subject: [PATCH 02/12] chengelog updated --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6be04eba..e2968e84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added +- Configs for EASE, Random, PureSVD models ([#178](https://github.com/MobileTeleSystems/RecTools/pull/178)) - Configs for implicit models ([#167](https://github.com/MobileTeleSystems/RecTools/pull/167)) From 5384b8ba698f3a21a9e64f774e112bdc10618526 Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Tue, 13 Aug 2024 18:53:54 +0300 Subject: [PATCH 03/12] make_config removed, tests added --- rectools/models/ease.py | 6 +-- rectools/models/pure_svd.py | 14 +++--- rectools/models/random.py | 6 +-- tests/models/test_ease.py | 72 +++++++++++++++++---------- tests/models/test_pure_svd.py | 93 ++++++++++++++++++++++------------- tests/models/test_random.py | 65 +++++++++++++++--------- tests/models/utils.py | 27 ++++++++++ 7 files changed, 182 insertions(+), 101 deletions(-) diff --git a/rectools/models/ease.py b/rectools/models/ease.py index 212db172..3139a72e 100644 --- a/rectools/models/ease.py +++ b/rectools/models/ease.py @@ -68,18 +68,14 @@ def __init__( num_threads: int = 1, verbose: int = 0, ): - self._config = self._make_config(regularization, num_threads, verbose) super().__init__(verbose=verbose) self.weight: np.ndarray self.regularization = regularization self.num_threads = num_threads - def _make_config(self, regularization: float, num_threads: int, verbose: int) -> EASEModelConfig: - return EASEModelConfig(regularization=regularization, num_threads=num_threads, verbose=verbose) - def _get_config(self) -> EASEModelConfig: - return self._config + return EASEModelConfig(regularization=self.regularization, num_threads=self.num_threads, verbose=self.verbose) @classmethod def _from_config(cls, config: EASEModelConfig) -> tpe.Self: diff --git a/rectools/models/pure_svd.py b/rectools/models/pure_svd.py index c9ca79e4..9984bcff 100644 --- a/rectools/models/pure_svd.py +++ b/rectools/models/pure_svd.py @@ -72,7 +72,6 @@ def __init__( random_state: tp.Optional[int] = None, verbose: int = 0, ): - self._config = self._make_config(factors, tol, maxiter, random_state, verbose) super().__init__(verbose=verbose) self.factors = factors @@ -83,13 +82,14 @@ def __init__( self.user_factors: np.ndarray self.item_factors: np.ndarray - def _make_config( - self, factors: int, tol: float, maxiter: tp.Optional[int], random_state: tp.Optional[int], verbose: int - ) -> PureSVDModelConfig: - return PureSVDModelConfig(factors=factors, tol=tol, maxiter=maxiter, random_state=random_state, verbose=verbose) - def _get_config(self) -> PureSVDModelConfig: - return self._config + return PureSVDModelConfig( + factors=self.factors, + tol=self.tol, + maxiter=self.maxiter, + random_state=self.random_state, + verbose=self.verbose, + ) @classmethod def _from_config(cls, config: PureSVDModelConfig) -> tpe.Self: diff --git a/rectools/models/random.py b/rectools/models/random.py index 8a9820af..3b3ed4e9 100644 --- a/rectools/models/random.py +++ b/rectools/models/random.py @@ -81,18 +81,14 @@ class RandomModel(ModelBase[RandomModelConfig]): config_class = RandomModelConfig def __init__(self, random_state: tp.Optional[int] = None, verbose: int = 0): - self._config = self._make_config(random_state, verbose) super().__init__(verbose=verbose) self.random_state = random_state self.random_gen = _RandomGen(random_state) self.all_item_ids: np.ndarray - def _make_config(self, random_state: tp.Optional[int], verbose: int) -> RandomModelConfig: - return RandomModelConfig(random_state=random_state, verbose=verbose) - def _get_config(self) -> RandomModelConfig: - return self._config + return RandomModelConfig(random_state=self.random_state, verbose=self.verbose) @classmethod def _from_config(cls, config: RandomModelConfig) -> tpe.Self: diff --git a/tests/models/test_ease.py b/tests/models/test_ease.py index 84eef131..0f87e473 100644 --- a/tests/models/test_ease.py +++ b/tests/models/test_ease.py @@ -23,35 +23,14 @@ from rectools.models import EASEModel from .data import DATASET, INTERACTIONS -from .utils import assert_second_fit_refits_model +from .utils import ( + assert_default_config_and_default_model_params_are_the_same, + assert_get_config_and_from_config_compatibility, + assert_second_fit_refits_model, +) class TestEASEModel: - def test_from_config(self) -> None: - config = { - "regularization": 500, - "num_threads": 1, - "verbose": 1, - } - model = EASEModel.from_config(config) - assert model.num_threads == 1 - assert model.verbose == 1 - assert model.regularization == 500 - - def test_get_config(self) -> None: - model = EASEModel( - regularization=500, - num_threads=1, - verbose=1, - ) - config = model.get_config() - expected = { - "regularization": 500, - "num_threads": 1, - "verbose": 1, - } - assert config == expected - @pytest.fixture def dataset(self) -> Dataset: return DATASET @@ -245,3 +224,44 @@ def test_i2i_with_warm_and_cold_items(self, item_features: tp.Optional[pd.DataFr dataset=dataset, k=2, ) + + +class TestEASEModelConfiguration: + def test_from_config(self) -> None: + config = { + "regularization": 500, + "num_threads": 1, + "verbose": 1, + } + model = EASEModel.from_config(config) + assert model.num_threads == 1 + assert model.verbose == 1 + assert model.regularization == 500 + + def test_get_config(self) -> None: + model = EASEModel( + regularization=500, + num_threads=1, + verbose=1, + ) + config = model.get_config() + expected = { + "regularization": 500, + "num_threads": 1, + "verbose": 1, + } + assert config == expected + + def test_get_config_and_from_config_compatibility(self) -> None: + initial_config = { + "regularization": 500, + "num_threads": 1, + "verbose": 1, + } + model = EASEModel() + assert_get_config_and_from_config_compatibility(model, DATASET, initial_config) + + def test_default_config_and_default_model_params_are_the_same(self) -> None: + default_config: tp.Dict[str, int] = {} + model = EASEModel() + assert_default_config_and_default_model_params_are_the_same(model, default_config) diff --git a/tests/models/test_pure_svd.py b/tests/models/test_pure_svd.py index b33af705..47d70f02 100644 --- a/tests/models/test_pure_svd.py +++ b/tests/models/test_pure_svd.py @@ -25,45 +25,15 @@ from rectools.models.utils import recommend_from_scores from .data import DATASET, INTERACTIONS -from .utils import assert_second_fit_refits_model +from .utils import ( + assert_default_config_and_default_model_params_are_the_same, + assert_get_config_and_from_config_compatibility, + assert_second_fit_refits_model, +) class TestPureSVDModel: - def test_from_config(self) -> None: - config = { - "factors": 100, - "tol": 0, - "maxiter": 100, - "random_state": 32, - "verbose": 0, - } - model = PureSVDModel.from_config(config) - assert model.factors == 100 - assert model.tol == 0 - assert model.maxiter == 100 - assert model.random_state == 32 - assert model.verbose == 0 - - @pytest.mark.parametrize("random_state", (None, 42)) - def test_get_config(self, random_state: tp.Optional[int]) -> None: - model = PureSVDModel( - factors=100, - tol=1, - maxiter=100, - random_state=random_state, - verbose=1, - ) - config = model.get_config() - expected = { - "factors": 100, - "tol": 1, - "maxiter": 100, - "random_state": random_state, - "verbose": 1, - } - assert config == expected - @pytest.fixture def dataset(self) -> Dataset: return DATASET @@ -287,3 +257,56 @@ def test_i2i_with_warm_and_cold_items(self, item_features: tp.Optional[pd.DataFr dataset=dataset, k=2, ) + + +class TestPureSVDModelConfiguration: + + def test_from_config(self) -> None: + config = { + "factors": 100, + "tol": 0, + "maxiter": 100, + "random_state": 32, + "verbose": 0, + } + model = PureSVDModel.from_config(config) + assert model.factors == 100 + assert model.tol == 0 + assert model.maxiter == 100 + assert model.random_state == 32 + assert model.verbose == 0 + + @pytest.mark.parametrize("random_state", (None, 42)) + def test_get_config(self, random_state: tp.Optional[int]) -> None: + model = PureSVDModel( + factors=100, + tol=1, + maxiter=100, + random_state=random_state, + verbose=1, + ) + config = model.get_config() + expected = { + "factors": 100, + "tol": 1, + "maxiter": 100, + "random_state": random_state, + "verbose": 1, + } + assert config == expected + + def test_get_config_and_from_config_compatibility(self) -> None: + initial_config = { + "factors": 2, + "tol": 0, + "maxiter": 100, + "random_state": 32, + "verbose": 0, + } + model = PureSVDModel() + assert_get_config_and_from_config_compatibility(model, DATASET, initial_config) + + def test_default_config_and_default_model_params_are_the_same(self) -> None: + default_config: tp.Dict[str, int] = {} + model = PureSVDModel() + assert_default_config_and_default_model_params_are_the_same(model, default_config) diff --git a/tests/models/test_random.py b/tests/models/test_random.py index 90975c35..cc75f493 100644 --- a/tests/models/test_random.py +++ b/tests/models/test_random.py @@ -24,33 +24,15 @@ from rectools.models.random import _RandomGen, _RandomSampler from .data import DATASET, INTERACTIONS -from .utils import assert_second_fit_refits_model +from .utils import ( + assert_default_config_and_default_model_params_are_the_same, + assert_get_config_and_from_config_compatibility, + assert_second_fit_refits_model, +) class TestRandomSampler: - def test_from_config(self) -> None: - config = { - "random_state": 32, - "verbose": 0, - } - model = RandomModel.from_config(config) - assert model.random_state == 32 - assert model.verbose == 0 - - @pytest.mark.parametrize("random_state", (None, 42)) - def test_get_config(self, random_state: tp.Optional[int]) -> None: - model = RandomModel( - random_state=random_state, - verbose=1, - ) - config = model.get_config() - expected = { - "random_state": random_state, - "verbose": 1, - } - assert config == expected - def test_sample_small_n(self) -> None: gen = _RandomGen(42) sampler = _RandomSampler(np.arange(10), gen) @@ -201,3 +183,40 @@ def test_i2i(self, filter_itself: bool, whitelist: tp.Optional[tp.List[tp.Any]]) def test_second_fit_refits_model(self, dataset: Dataset) -> None: model = RandomModel(random_state=1) assert_second_fit_refits_model(model, dataset) + + +class TestRandomModelConfiguration: + def test_from_config(self) -> None: + config = { + "random_state": 32, + "verbose": 0, + } + model = RandomModel.from_config(config) + assert model.random_state == 32 + assert model.verbose == 0 + + @pytest.mark.parametrize("random_state", (None, 42)) + def test_get_config(self, random_state: tp.Optional[int]) -> None: + model = RandomModel( + random_state=random_state, + verbose=1, + ) + config = model.get_config() + expected = { + "random_state": random_state, + "verbose": 1, + } + assert config == expected + + def test_get_config_and_from_config_compatibility(self) -> None: + initial_config = { + "random_state": 32, + "verbose": 0, + } + model = RandomModel() + assert_get_config_and_from_config_compatibility(model, DATASET, initial_config) + + def test_default_config_and_default_model_params_are_the_same(self) -> None: + default_config: tp.Dict[str, int] = {} + model = RandomModel() + assert_default_config_and_default_model_params_are_the_same(model, default_config) diff --git a/tests/models/utils.py b/tests/models/utils.py index 4d321975..dcc715fd 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -15,6 +15,7 @@ import typing as tp from copy import deepcopy +import numpy as np import pandas as pd from rectools.dataset import Dataset @@ -47,3 +48,29 @@ def assert_second_fit_refits_model( reco_i2i_1 = model_1.recommend_to_items(dataset.item_id_map.external_ids, dataset, k, False) reco_i2i_2 = model_2.recommend_to_items(dataset.item_id_map.external_ids, dataset, k, False) pd.testing.assert_frame_equal(reco_i2i_1, reco_i2i_2, atol=0.001) + + +def assert_default_config_and_default_model_params_are_the_same( + model: ModelBase, default_config: tp.Dict[str, tp.Any] +) -> None: + model_from_config = model.from_config(default_config) + model_from_params = model + assert model_from_config.get_config() == model_from_params.get_config() + + +def assert_get_config_and_from_config_compatibility( + model: ModelBase, dataset: Dataset, initial_config: tp.Dict[str, tp.Any] +) -> None: + def get_reco(model: ModelBase) -> pd.DataFrame: + return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) + + model_1 = model.from_config(initial_config) + reco_1 = get_reco(model_1) + config_1 = model_1.get_config() + + model_2 = model.from_config(config_1) + reco_2 = get_reco(model_2) + config_2 = model_1.get_config() + + assert config_1 == config_2 + pd.testing.assert_frame_equal(reco_1, reco_2) From aecd3654c672e74e53eb442f5bcad0e4145ef18e Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Wed, 14 Aug 2024 11:20:20 +0300 Subject: [PATCH 04/12] implicit models tests changed --- tests/models/test_implicit_als.py | 28 ++++++++++------------------ tests/models/test_implicit_knn.py | 28 ++++++++++------------------ tests/models/utils.py | 9 ++++----- 3 files changed, 24 insertions(+), 41 deletions(-) diff --git a/tests/models/test_implicit_als.py b/tests/models/test_implicit_als.py index 7694da58..da5dd34b 100644 --- a/tests/models/test_implicit_als.py +++ b/tests/models/test_implicit_als.py @@ -34,7 +34,11 @@ from rectools.models.utils import recommend_from_scores from .data import DATASET -from .utils import assert_second_fit_refits_model +from .utils import ( + assert_default_config_and_default_model_params_are_the_same, + assert_get_config_and_from_config_compatibility, + assert_second_fit_refits_model, +) @pytest.mark.filterwarnings("ignore:Converting sparse features to dense") @@ -451,28 +455,16 @@ def test_custom_model_class(self) -> None: @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: - def get_reco(model: ImplicitALSWrapperModel) -> pd.DataFrame: - return model.fit(DATASET).recommend(users=[10, 20], dataset=DATASET, k=2, filter_viewed=False) - initial_config = { "model": { "params": {"factors": 16, "num_threads": 2, "iterations": 3, "random_state": 42}, }, "verbose": 1, } - - model_1 = ImplicitALSWrapperModel.from_config(initial_config) - reco_1 = get_reco(model_1) - config_1 = model_1.get_config(simple_types=simple_types) - - model_2 = ImplicitALSWrapperModel.from_config(config_1) - reco_2 = get_reco(model_2) - config_2 = model_2.get_config(simple_types=simple_types) - - assert config_1 == config_2 - pd.testing.assert_frame_equal(reco_1, reco_2) + model = ImplicitALSWrapperModel(model=AlternatingLeastSquares()) + assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: - model_from_config = ImplicitALSWrapperModel.from_config({"model": {}}) - model_from_params = ImplicitALSWrapperModel(model=AlternatingLeastSquares()) - assert model_from_config.get_config() == model_from_params.get_config() + default_config: tp.Dict[str, tp.Any] = {"model": {}} + model = ImplicitALSWrapperModel(model=AlternatingLeastSquares()) + assert_default_config_and_default_model_params_are_the_same(model, default_config) diff --git a/tests/models/test_implicit_knn.py b/tests/models/test_implicit_knn.py index d8eac2b9..732e7808 100644 --- a/tests/models/test_implicit_knn.py +++ b/tests/models/test_implicit_knn.py @@ -24,7 +24,11 @@ from rectools.models import ImplicitItemKNNWrapperModel from .data import DATASET, INTERACTIONS -from .utils import assert_second_fit_refits_model +from .utils import ( + assert_default_config_and_default_model_params_are_the_same, + assert_get_config_and_from_config_compatibility, + assert_second_fit_refits_model, +) class TestImplicitItemKNNWrapperModel: @@ -309,9 +313,6 @@ def test_to_config( @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: - def get_reco(model: ImplicitItemKNNWrapperModel) -> pd.DataFrame: - return model.fit(DATASET).recommend(users=np.array([10, 20]), dataset=DATASET, k=2, filter_viewed=False) - initial_config = { "model": { "cls": TFIDFRecommender, @@ -319,19 +320,10 @@ def get_reco(model: ImplicitItemKNNWrapperModel) -> pd.DataFrame: }, "verbose": 1, } - - model_1 = ImplicitItemKNNWrapperModel.from_config(initial_config) - reco_1 = get_reco(model_1) - config_1 = model_1.get_config(simple_types=simple_types) - - model_2 = ImplicitItemKNNWrapperModel.from_config(config_1) - reco_2 = get_reco(model_2) - config_2 = model_2.get_config(simple_types=simple_types) - - assert config_1 == config_2 - pd.testing.assert_frame_equal(reco_1, reco_2) + model = ImplicitItemKNNWrapperModel(model=ItemItemRecommender()) + assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: - model_from_config = ImplicitItemKNNWrapperModel.from_config({"model": {}}) - model_from_params = ImplicitItemKNNWrapperModel(model=ItemItemRecommender()) - assert model_from_config.get_config() == model_from_params.get_config() + default_config: tp.Dict[str, tp.Any] = {"model": {}} + model = ImplicitItemKNNWrapperModel(model=ItemItemRecommender()) + assert_default_config_and_default_model_params_are_the_same(model, default_config) diff --git a/tests/models/utils.py b/tests/models/utils.py index dcc715fd..f270c3cb 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -54,23 +54,22 @@ def assert_default_config_and_default_model_params_are_the_same( model: ModelBase, default_config: tp.Dict[str, tp.Any] ) -> None: model_from_config = model.from_config(default_config) - model_from_params = model - assert model_from_config.get_config() == model_from_params.get_config() + assert model_from_config.get_config() == model.get_config() def assert_get_config_and_from_config_compatibility( - model: ModelBase, dataset: Dataset, initial_config: tp.Dict[str, tp.Any] + model: ModelBase, dataset: Dataset, initial_config: tp.Dict[str, tp.Any], simple_types: tp.Optional[bool] = None ) -> None: def get_reco(model: ModelBase) -> pd.DataFrame: return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) model_1 = model.from_config(initial_config) reco_1 = get_reco(model_1) - config_1 = model_1.get_config() + config_1 = model_1.get_config(simple_types=simple_types) if simple_types is not None else model_1.get_config() model_2 = model.from_config(config_1) reco_2 = get_reco(model_2) - config_2 = model_1.get_config() + config_2 = model_2.get_config(simple_types=simple_types) if simple_types is not None else model_2.get_config() assert config_1 == config_2 pd.testing.assert_frame_equal(reco_1, reco_2) From f6840d899e66476d505c5714b6652e92dfe4e954 Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Wed, 14 Aug 2024 11:55:13 +0300 Subject: [PATCH 05/12] simple_types added to all model tests --- tests/models/test_ease.py | 5 +++-- tests/models/test_pure_svd.py | 5 +++-- tests/models/test_random.py | 5 +++-- tests/models/utils.py | 6 +++--- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/models/test_ease.py b/tests/models/test_ease.py index 0f87e473..0f90de38 100644 --- a/tests/models/test_ease.py +++ b/tests/models/test_ease.py @@ -252,14 +252,15 @@ def test_get_config(self) -> None: } assert config == expected - def test_get_config_and_from_config_compatibility(self) -> None: + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: initial_config = { "regularization": 500, "num_threads": 1, "verbose": 1, } model = EASEModel() - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config) + assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/test_pure_svd.py b/tests/models/test_pure_svd.py index 47d70f02..14598ad5 100644 --- a/tests/models/test_pure_svd.py +++ b/tests/models/test_pure_svd.py @@ -295,7 +295,8 @@ def test_get_config(self, random_state: tp.Optional[int]) -> None: } assert config == expected - def test_get_config_and_from_config_compatibility(self) -> None: + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: initial_config = { "factors": 2, "tol": 0, @@ -304,7 +305,7 @@ def test_get_config_and_from_config_compatibility(self) -> None: "verbose": 0, } model = PureSVDModel() - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config) + assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/test_random.py b/tests/models/test_random.py index cc75f493..f55fa6b5 100644 --- a/tests/models/test_random.py +++ b/tests/models/test_random.py @@ -208,13 +208,14 @@ def test_get_config(self, random_state: tp.Optional[int]) -> None: } assert config == expected - def test_get_config_and_from_config_compatibility(self) -> None: + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: initial_config = { "random_state": 32, "verbose": 0, } model = RandomModel() - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config) + assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/utils.py b/tests/models/utils.py index f270c3cb..ec531b55 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -58,18 +58,18 @@ def assert_default_config_and_default_model_params_are_the_same( def assert_get_config_and_from_config_compatibility( - model: ModelBase, dataset: Dataset, initial_config: tp.Dict[str, tp.Any], simple_types: tp.Optional[bool] = None + model: ModelBase, dataset: Dataset, initial_config: tp.Dict[str, tp.Any], simple_types: bool ) -> None: def get_reco(model: ModelBase) -> pd.DataFrame: return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) model_1 = model.from_config(initial_config) reco_1 = get_reco(model_1) - config_1 = model_1.get_config(simple_types=simple_types) if simple_types is not None else model_1.get_config() + config_1 = model_1.get_config(simple_types=simple_types) model_2 = model.from_config(config_1) reco_2 = get_reco(model_2) - config_2 = model_2.get_config(simple_types=simple_types) if simple_types is not None else model_2.get_config() + config_2 = model_2.get_config(simple_types=simple_types) assert config_1 == config_2 pd.testing.assert_frame_equal(reco_1, reco_2) From bad21cfbe00400b4e19d14009363980c423778ab Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Tue, 3 Sep 2024 14:44:06 +0300 Subject: [PATCH 06/12] pop and pop cat configs added --- rectools/models/popular.py | 116 +++++++++++------ rectools/models/popular_in_category.py | 45 ++++++- tests/models/test_popular.py | 73 ++++++++++- tests/models/test_popular_in_category.py | 155 ++++++++++++++++++++++- 4 files changed, 349 insertions(+), 40 deletions(-) diff --git a/rectools/models/popular.py b/rectools/models/popular.py index 746fc9e4..3c7dc5e3 100644 --- a/rectools/models/popular.py +++ b/rectools/models/popular.py @@ -20,6 +20,7 @@ import numpy as np import pandas as pd +import typing_extensions as tpe from tqdm.auto import tqdm from rectools import Columns, InternalIds @@ -28,7 +29,7 @@ from rectools.types import InternalIdsArray from rectools.utils import fast_isin_for_sorted_test_elements -from .base import FixedColdRecoModelMixin, ModelBase, Scores, ScoresArray +from .base import FixedColdRecoModelMixin, ModelBase, ModelConfig_T, Scores, ScoresArray from .utils import get_viewed_item_ids @@ -51,7 +52,54 @@ class PopularModelConfig(ModelConfig): inverse: bool = False -class PopularModel(FixedColdRecoModelMixin, ModelBase): +class PopularModelBaseMixin(ModelBase[ModelConfig_T]): + """Mixin for models based on popularity.""" + + def __init__( + self, + popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", + period: tp.Optional[timedelta] = None, + begin_from: tp.Optional[datetime] = None, + add_cold: bool = False, + inverse: bool = False, + verbose: int = 0, + ): + super().__init__(verbose=verbose) + try: + self.popularity = Popularity(popularity) + except ValueError: + possible_values = {item.value for item in Popularity.__members__.values()} + raise ValueError(f"`popularity` must be one of the {possible_values}. Got {popularity}.") + + if period is not None and begin_from is not None: + raise ValueError("Only one of `period` and `begin_from` can be set") + self.period = period + self.begin_from = begin_from + + self.add_cold = add_cold + self.inverse = inverse + + def _filter_interactions(self, interactions: pd.DataFrame) -> pd.DataFrame: + if self.begin_from is not None: + interactions = interactions.loc[interactions[Columns.Datetime] >= self.begin_from] + elif self.period is not None: + begin_from = interactions[Columns.Datetime].max() - self.period + interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] + return interactions + + def _get_groupby_col_and_agg_func(self, popularity: Popularity) -> tp.Tuple[str, str]: + if popularity == Popularity.N_USERS: + return Columns.User, "nunique" + if popularity == Popularity.N_INTERACTIONS: + return Columns.User, "count" + if popularity == Popularity.MEAN_WEIGHT: + return Columns.Weight, "mean" + if popularity == Popularity.SUM_WEIGHT: + return Columns.Weight, "sum" + raise ValueError(f"Unexpected popularity {popularity}") + + +class PopularModel(FixedColdRecoModelMixin, PopularModelBaseMixin[PopularModelConfig]): """ Model generating recommendations based on popularity of items. @@ -87,6 +135,8 @@ class PopularModel(FixedColdRecoModelMixin, ModelBase): recommends_for_warm = False recommends_for_cold = True + config_class = PopularModelConfig + def __init__( self, popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", @@ -96,31 +146,37 @@ def __init__( inverse: bool = False, verbose: int = 0, ): - super().__init__(verbose=verbose) - - try: - self.popularity = Popularity(popularity) - except ValueError: - possible_values = {item.value for item in Popularity.__members__.values()} - raise ValueError(f"`popularity` must be one of the {possible_values}. Got {popularity}.") - - if period is not None and begin_from is not None: - raise ValueError("Only one of `period` and `begin_from` can be set") - self.period = period - self.begin_from = begin_from - - self.add_cold = add_cold - self.inverse = inverse + super().__init__( + popularity=popularity, + period=period, + begin_from=begin_from, + add_cold=add_cold, + inverse=inverse, + verbose=verbose, + ) self.popularity_list: tp.Tuple[InternalIdsArray, ScoresArray] - def _filter_interactions(self, interactions: pd.DataFrame) -> pd.DataFrame: - if self.begin_from is not None: - interactions = interactions.loc[interactions[Columns.Datetime] >= self.begin_from] - elif self.period is not None: - begin_from = interactions[Columns.Datetime].max() - self.period - interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] - return interactions + def _get_config(self) -> PopularModelConfig: + return PopularModelConfig( + popularity=self.popularity, + period=self.period, + begin_from=self.begin_from, + add_cold=self.add_cold, + inverse=self.inverse, + verbose=self.verbose, + ) + + @classmethod + def _from_config(cls, config: PopularModelConfig) -> tpe.Self: + return cls( + popularity=config.popularity.value, + period=config.period, + begin_from=config.begin_from, + add_cold=config.add_cold, + inverse=config.inverse, + verbose=config.verbose, + ) def _fit(self, dataset: Dataset) -> None: # type: ignore interactions = self._filter_interactions(dataset.interactions.df) @@ -141,18 +197,6 @@ def _fit(self, dataset: Dataset) -> None: # type: ignore self.popularity_list = (items, scores) - @classmethod - def _get_groupby_col_and_agg_func(cls, popularity: Popularity) -> tp.Tuple[str, str]: - if popularity == Popularity.N_USERS: - return Columns.User, "nunique" - if popularity == Popularity.N_INTERACTIONS: - return Columns.User, "count" - if popularity == Popularity.MEAN_WEIGHT: - return Columns.Weight, "mean" - if popularity == Popularity.SUM_WEIGHT: - return Columns.Weight, "sum" - raise ValueError(f"Unexpected popularity {popularity}") - def _recommend_u2i( self, user_ids: InternalIdsArray, diff --git a/rectools/models/popular_in_category.py b/rectools/models/popular_in_category.py index e860295f..23e6ffa9 100644 --- a/rectools/models/popular_in_category.py +++ b/rectools/models/popular_in_category.py @@ -21,13 +21,14 @@ import numpy as np import pandas as pd +import typing_extensions as tpe from rectools import Columns, InternalIds from rectools.dataset import Dataset, Interactions, features from rectools.types import InternalIdsArray from .base import Scores -from .popular import PopularModel +from .popular import FixedColdRecoModelMixin, PopularModel, PopularModelBaseMixin, PopularModelConfig class MixingStrategy(Enum): @@ -44,7 +45,16 @@ class RatioStrategy(Enum): PROPORTIONAL = "proportional" -class PopularInCategoryModel(PopularModel): +class PopularInCategoryModelConfig(PopularModelConfig): + """Config for `PopularInCategoryModel`.""" + + category_feature: str + n_categories: tp.Optional[int] = None + mixing_strategy: MixingStrategy = MixingStrategy.ROTATE + ratio_strategy: RatioStrategy = RatioStrategy.PROPORTIONAL + + +class PopularInCategoryModel(FixedColdRecoModelMixin, PopularModelBaseMixin[PopularInCategoryModelConfig]): """ Model generating recommendations based on popularity of items. @@ -98,6 +108,8 @@ class PopularInCategoryModel(PopularModel): recommends_for_warm = False recommends_for_cold = True + config_class = PopularInCategoryModelConfig + def __init__( self, category_feature: str, @@ -144,6 +156,35 @@ def __init__( possible_values = {item.value for item in RatioStrategy.__members__.values()} raise ValueError(f"`ratio_strategy` must be one of the {possible_values}. Got {ratio_strategy}.") + def _get_config(self) -> PopularInCategoryModelConfig: + return PopularInCategoryModelConfig( + category_feature=self.category_feature, + n_categories=self.n_categories, + mixing_strategy=self.mixing_strategy, + ratio_strategy=self.ratio_strategy, + popularity=self.popularity, + period=self.period, + begin_from=self.begin_from, + add_cold=self.add_cold, + inverse=self.inverse, + verbose=self.verbose, + ) + + @classmethod + def _from_config(cls, config: PopularInCategoryModelConfig) -> tpe.Self: + return cls( + category_feature=config.category_feature, + n_categories=config.n_categories, + mixing_strategy=config.mixing_strategy.value, + ratio_strategy=config.ratio_strategy.value, + popularity=config.popularity.value, + period=config.period, + begin_from=config.begin_from, + add_cold=config.add_cold, + inverse=config.inverse, + verbose=config.verbose, + ) + def _check_category_feature(self, dataset: Dataset) -> None: if not dataset.item_features: raise ValueError( diff --git a/tests/models/test_popular.py b/tests/models/test_popular.py index e1ab4e8b..5568541d 100644 --- a/tests/models/test_popular.py +++ b/tests/models/test_popular.py @@ -22,7 +22,14 @@ from rectools import Columns from rectools.dataset import Dataset, IdMap, Interactions from rectools.models import PopularModel -from tests.models.utils import assert_second_fit_refits_model +from rectools.models.popular import Popularity +from tests.models.utils import ( + assert_default_config_and_default_model_params_are_the_same, + assert_get_config_and_from_config_compatibility, + assert_second_fit_refits_model, +) + +from .data import DATASET class TestPopularModel: @@ -212,3 +219,67 @@ def test_i2i( def test_second_fit_refits_model(self, dataset: Dataset) -> None: model = PopularModel() assert_second_fit_refits_model(model, dataset) + + +class TestPopularModelConfiguration: + + def test_from_config(self) -> None: + config = { + "popularity": "n_interactions", + "period": timedelta(days=7), + "begin_from": None, + "add_cold": True, + "inverse": True, + "verbose": 0, + } + model = PopularModel.from_config(config) + assert model.popularity.value == "n_interactions" + assert model.period == timedelta(days=7) + assert model.begin_from is None + assert model.add_cold is True + assert model.inverse is True + assert model.verbose == 0 + + @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23))) + @pytest.mark.parametrize("popularity", ("mean_weight", "sum_weight")) + def test_get_config( + self, + popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], + begin_from: tp.Optional[datetime], + ) -> None: + model = PopularModel( + popularity=popularity, + period=None, + begin_from=begin_from, + add_cold=False, + inverse=False, + verbose=1, + ) + config = model.get_config() + expected = { + "popularity": Popularity(popularity), + "period": None, + "begin_from": begin_from, + "add_cold": False, + "inverse": False, + "verbose": 1, + } + assert config == expected + + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: + initial_config = { + "popularity": "n_users", + "period": None, + "begin_from": None, + "add_cold": True, + "inverse": False, + "verbose": 0, + } + model = PopularModel() + assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + + def test_default_config_and_default_model_params_are_the_same(self) -> None: + default_config: tp.Dict[str, int] = {} + model = PopularModel() + assert_default_config_and_default_model_params_are_the_same(model, default_config) diff --git a/tests/models/test_popular_in_category.py b/tests/models/test_popular_in_category.py index 59f3b02d..2505af61 100644 --- a/tests/models/test_popular_in_category.py +++ b/tests/models/test_popular_in_category.py @@ -22,7 +22,13 @@ from rectools import Columns from rectools.dataset import Dataset from rectools.models import PopularInCategoryModel -from tests.models.utils import assert_second_fit_refits_model +from rectools.models.popular import Popularity +from rectools.models.popular_in_category import MixingStrategy, RatioStrategy +from tests.models.utils import ( + assert_default_config_and_default_model_params_are_the_same, + assert_get_config_and_from_config_compatibility, + assert_second_fit_refits_model, +) @pytest.mark.filterwarnings("ignore") @@ -444,3 +450,150 @@ def test_second_fit_refits_model( n_categories=n_categories, ) assert_second_fit_refits_model(model, dataset) + + +class TestPopularInCategoryModelConfiguration: + @pytest.fixture + def interactions_df(self) -> pd.DataFrame: + interactions_df = pd.DataFrame( + [ + [70, 11, 1, "2021-11-30"], + [70, 12, 1, "2021-11-30"], + [10, 11, 1, "2021-11-30"], + [10, 12, 1, "2021-11-29"], + [10, 13, 9, "2021-11-28"], + [20, 11, 1, "2021-11-27"], + [20, 14, 2, "2021-11-26"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [30, 11, 1, "2021-11-24"], + [30, 12, 1, "2021-11-23"], + [30, 14, 1, "2021-11-23"], + [30, 15, 5, "2021-11-21"], + [30, 15, 5, "2021-11-21"], + [40, 11, 1, "2021-11-20"], + [40, 12, 1, "2021-11-19"], + [50, 12, 1, "2021-11-19"], + [60, 12, 1, "2021-11-19"], + ], + columns=Columns.Interactions, + ) + return interactions_df + + @pytest.fixture + def item_features_df(self) -> pd.DataFrame: + item_features_df = pd.DataFrame( + { + "id": [11, 11, 12, 12, 13, 13, 14, 14, 14], + "feature": ["f1", "f2", "f1", "f2", "f1", "f2", "f1", "f2", "f3"], + "value": [100, "a", 100, "b", 100, "b", 200, "c", 1], + } + ) + return item_features_df + + @pytest.fixture + def dataset(self, interactions_df: pd.DataFrame, item_features_df: pd.DataFrame) -> Dataset: + user_features_df = pd.DataFrame( + { + "id": [10, 50], + "feature": ["f1", "f1"], + "value": [1, 1], + } + ) + dataset = Dataset.construct( + interactions_df=interactions_df, + user_features_df=user_features_df, + item_features_df=item_features_df, + cat_item_features=["f2", "f1"], + ) + return dataset + + def test_from_config(self) -> None: + config = { + "category_feature": "f1", + "n_categories": 2, + "mixing_strategy": "group", + "ratio_strategy": "equal", + "popularity": "n_interactions", + "period": timedelta(days=7), + "begin_from": None, + "add_cold": True, + "inverse": True, + "verbose": 0, + } + model = PopularInCategoryModel.from_config(config) + assert model.category_feature == "f1" + assert model.n_categories == 2 + assert model.mixing_strategy == MixingStrategy("group") + assert model.ratio_strategy == RatioStrategy("equal") + assert model.popularity == Popularity("n_interactions") + assert model.period == timedelta(days=7) + assert model.begin_from is None + assert model.add_cold is True + assert model.inverse is True + assert model.verbose == 0 + + @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23))) + @pytest.mark.parametrize("popularity", ("mean_weight", "sum_weight")) + def test_get_config( + self, + popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], + begin_from: tp.Optional[datetime], + ) -> None: + model = PopularInCategoryModel( + category_feature="f2", + n_categories=3, + mixing_strategy="rotate", + ratio_strategy="proportional", + popularity=popularity, + period=None, + begin_from=begin_from, + add_cold=False, + inverse=False, + verbose=1, + ) + config = model.get_config() + expected = { + "category_feature": "f2", + "n_categories": 3, + "mixing_strategy": MixingStrategy("rotate"), + "ratio_strategy": RatioStrategy("proportional"), + "popularity": Popularity(popularity), + "period": None, + "begin_from": begin_from, + "add_cold": False, + "inverse": False, + "verbose": 1, + } + assert config == expected + + @pytest.mark.parametrize("category_feature", ("f1", "f2")) + @pytest.mark.parametrize("simple_types", (False, True)) + def test_get_config_and_from_config_compatibility( + self, + dataset: Dataset, + category_feature: str, + simple_types: bool, + ) -> None: + initial_config = { + "category_feature": category_feature, + "n_categories": 2, + "mixing_strategy": "group", + "ratio_strategy": "equal", + "popularity": "n_users", + "period": None, + "begin_from": None, + "add_cold": True, + "inverse": False, + "verbose": 0, + } + model = PopularInCategoryModel(category_feature) + assert_get_config_and_from_config_compatibility(model, dataset, initial_config, simple_types) + + def test_default_config_and_default_model_params_are_the_same(self) -> None: + default_config: tp.Dict[str, str] = {"category_feature": "f2"} + model = PopularInCategoryModel(category_feature="f2") + assert_default_config_and_default_model_params_are_the_same(model, default_config) From cad4b59634bcba34f74ba82d34f54e7d9e76aa35 Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Wed, 4 Sep 2024 15:14:15 +0300 Subject: [PATCH 07/12] mypy error ignore added --- rectools/models/popular.py | 21 +++++++++----------- rectools/models/popular_in_category.py | 16 +++------------ tests/model_selection/test_cross_validate.py | 4 ++-- 3 files changed, 14 insertions(+), 27 deletions(-) diff --git a/rectools/models/popular.py b/rectools/models/popular.py index 3c7dc5e3..164f39ce 100644 --- a/rectools/models/popular.py +++ b/rectools/models/popular.py @@ -65,25 +65,22 @@ def __init__( verbose: int = 0, ): super().__init__(verbose=verbose) - try: - self.popularity = Popularity(popularity) - except ValueError: - possible_values = {item.value for item in Popularity.__members__.values()} - raise ValueError(f"`popularity` must be one of the {possible_values}. Got {popularity}.") - if period is not None and begin_from is not None: raise ValueError("Only one of `period` and `begin_from` can be set") + self.popularity = Popularity(popularity) self.period = period self.begin_from = begin_from self.add_cold = add_cold self.inverse = inverse - def _filter_interactions(self, interactions: pd.DataFrame) -> pd.DataFrame: - if self.begin_from is not None: - interactions = interactions.loc[interactions[Columns.Datetime] >= self.begin_from] - elif self.period is not None: - begin_from = interactions[Columns.Datetime].max() - self.period + def _filter_interactions( + self, interactions: pd.DataFrame, period: tp.Optional[timedelta], begin_from: tp.Optional[datetime] + ) -> pd.DataFrame: + if begin_from is not None: + interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] + elif period is not None: + begin_from = interactions[Columns.Datetime].max() - period interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] return interactions @@ -179,7 +176,7 @@ def _from_config(cls, config: PopularModelConfig) -> tpe.Self: ) def _fit(self, dataset: Dataset) -> None: # type: ignore - interactions = self._filter_interactions(dataset.interactions.df) + interactions = self._filter_interactions(dataset.interactions.df, self.period, self.begin_from) col, func = self._get_groupby_col_and_agg_func(self.popularity) items_scores = interactions.groupby(Columns.Item)[col].agg(func).sort_values(ascending=False) diff --git a/rectools/models/popular_in_category.py b/rectools/models/popular_in_category.py index 23e6ffa9..f81ff5db 100644 --- a/rectools/models/popular_in_category.py +++ b/rectools/models/popular_in_category.py @@ -133,6 +133,8 @@ def __init__( ) self.category_feature = category_feature + self.mixing_strategy = MixingStrategy(mixing_strategy) + self.ratio_strategy = RatioStrategy(ratio_strategy) self.category_columns: tp.List[int] = [] self.category_interactions: tp.Dict[int, pd.DataFrame] = {} self.category_scores: pd.Series @@ -144,18 +146,6 @@ def __init__( else: raise ValueError(f"`n_categories` must be a positive number. Got {n_categories}") - try: - self.mixing_strategy = MixingStrategy(mixing_strategy) - except ValueError: - possible_values = {item.value for item in MixingStrategy.__members__.values()} - raise ValueError(f"`mixing_strategy` must be one of the {possible_values}. Got {mixing_strategy}.") - - try: - self.ratio_strategy = RatioStrategy(ratio_strategy) - except ValueError: - possible_values = {item.value for item in RatioStrategy.__members__.values()} - raise ValueError(f"`ratio_strategy` must be one of the {possible_values}. Got {ratio_strategy}.") - def _get_config(self) -> PopularInCategoryModelConfig: return PopularInCategoryModelConfig( category_feature=self.category_feature, @@ -241,7 +231,7 @@ def _fit(self, dataset: Dataset) -> None: # type: ignore self.n_effective_categories = 0 self._check_category_feature(dataset) - interactions = self._filter_interactions(dataset.interactions.df) + interactions = self._filter_interactions(dataset.interactions.df, self.period, self.begin_from) self._calc_category_scores(dataset, interactions) self._define_categories_for_analysis() diff --git a/tests/model_selection/test_cross_validate.py b/tests/model_selection/test_cross_validate.py index d5d9dd87..acaa6cac 100644 --- a/tests/model_selection/test_cross_validate.py +++ b/tests/model_selection/test_cross_validate.py @@ -209,7 +209,7 @@ def test_happy_path( dataset=self.dataset, splitter=splitter, metrics=self.metrics, - models=self.models, + models=self.models, # type: ignore k=2, filter_viewed=False, items_to_recommend=items_to_recommend, @@ -426,7 +426,7 @@ def test_happy_path_with_intersection( dataset=self.dataset, splitter=splitter, metrics=self.metrics_intersection, - models=self.models, + models=self.models, # type: ignore k=2, filter_viewed=False, ref_models=ref_models, From 1142093d1d2c79592861e5bee5651f8b689210ba Mon Sep 17 00:00:00 2001 From: Mike <78963317+mikesokolovv@users.noreply.github.com> Date: Wed, 4 Sep 2024 15:21:31 +0300 Subject: [PATCH 08/12] Configs for simple models (#178) (#11) - Added model configs for `EASE`, `Random` and `PureSVD` models - Moved `RandomState` serialization to models/base.py From 9559385b679dfe030131490d8596efbb59704fc6 Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Wed, 4 Sep 2024 16:34:35 +0300 Subject: [PATCH 09/12] changelog updated --- CHANGELOG.md | 1 + rectools/models/popular.py | 16 +++++++--------- rectools/models/popular_in_category.py | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e2968e84..e42b5885 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added +- Configs for Popular, PopularInCategory models ([#188](https://github.com/MobileTeleSystems/RecTools/pull/188)) - Configs for EASE, Random, PureSVD models ([#178](https://github.com/MobileTeleSystems/RecTools/pull/178)) - Configs for implicit models ([#167](https://github.com/MobileTeleSystems/RecTools/pull/167)) diff --git a/rectools/models/popular.py b/rectools/models/popular.py index 164f39ce..6f018403 100644 --- a/rectools/models/popular.py +++ b/rectools/models/popular.py @@ -74,14 +74,12 @@ def __init__( self.add_cold = add_cold self.inverse = inverse - def _filter_interactions( - self, interactions: pd.DataFrame, period: tp.Optional[timedelta], begin_from: tp.Optional[datetime] - ) -> pd.DataFrame: - if begin_from is not None: - interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] - elif period is not None: - begin_from = interactions[Columns.Datetime].max() - period - interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] + def _filter_interactions(self, interactions: pd.DataFrame) -> pd.DataFrame: + if self.begin_from is not None: + interactions = interactions.loc[interactions[Columns.Datetime] >= self.begin_from] + elif self.period is not None: + self.begin_from = interactions[Columns.Datetime].max() - self.period + interactions = interactions.loc[interactions[Columns.Datetime] >= self.begin_from] return interactions def _get_groupby_col_and_agg_func(self, popularity: Popularity) -> tp.Tuple[str, str]: @@ -176,7 +174,7 @@ def _from_config(cls, config: PopularModelConfig) -> tpe.Self: ) def _fit(self, dataset: Dataset) -> None: # type: ignore - interactions = self._filter_interactions(dataset.interactions.df, self.period, self.begin_from) + interactions = self._filter_interactions(dataset.interactions.df) col, func = self._get_groupby_col_and_agg_func(self.popularity) items_scores = interactions.groupby(Columns.Item)[col].agg(func).sort_values(ascending=False) diff --git a/rectools/models/popular_in_category.py b/rectools/models/popular_in_category.py index f81ff5db..7856003c 100644 --- a/rectools/models/popular_in_category.py +++ b/rectools/models/popular_in_category.py @@ -231,7 +231,7 @@ def _fit(self, dataset: Dataset) -> None: # type: ignore self.n_effective_categories = 0 self._check_category_feature(dataset) - interactions = self._filter_interactions(dataset.interactions.df, self.period, self.begin_from) + interactions = self._filter_interactions(dataset.interactions.df) self._calc_category_scores(dataset, interactions) self._define_categories_for_analysis() From 9e41f9503030624d6d632f07b3909f9ad4450c8a Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Wed, 11 Sep 2024 12:37:53 +0300 Subject: [PATCH 10/12] added timedelta serialization, PR fixes --- rectools/models/popular.py | 94 +++-- rectools/models/popular_in_category.py | 28 +- tests/model_selection/test_cross_validate.py | 6 +- tests/models/test_ease.py | 3 +- tests/models/test_implicit_als.py | 3 +- tests/models/test_implicit_knn.py | 5 +- tests/models/test_popular.py | 177 +++++++-- tests/models/test_popular_in_category.py | 379 +++++++++++-------- tests/models/test_pure_svd.py | 3 +- tests/models/test_random.py | 3 +- tests/models/utils.py | 2 +- 11 files changed, 444 insertions(+), 259 deletions(-) diff --git a/rectools/models/popular.py b/rectools/models/popular.py index 6f018403..c281c402 100644 --- a/rectools/models/popular.py +++ b/rectools/models/popular.py @@ -21,6 +21,7 @@ import numpy as np import pandas as pd import typing_extensions as tpe +from pydantic import PlainSerializer, PlainValidator from tqdm.auto import tqdm from rectools import Columns, InternalIds @@ -29,7 +30,7 @@ from rectools.types import InternalIdsArray from rectools.utils import fast_isin_for_sorted_test_elements -from .base import FixedColdRecoModelMixin, ModelBase, ModelConfig_T, Scores, ScoresArray +from .base import FixedColdRecoModelMixin, ModelBase, Scores, ScoresArray from .utils import get_viewed_item_ids @@ -42,44 +43,68 @@ class Popularity(Enum): SUM_WEIGHT = "sum_weight" +def _serialize_timedelta(td: tp.Optional[tp.Union[None, dict, timedelta]]) -> tp.Optional[timedelta]: + if isinstance(td, dict): + return timedelta( + days=td.get("days", 0), + seconds=td.get("seconds", 0), + microseconds=td.get("microseconds", 0), + milliseconds=td.get("milliseconds", 0), + minutes=td.get("minutes", 0), + hours=td.get("hours", 0), + weeks=td.get("weeks", 0), + ) + return td + + +def _deserialize_timedelta(td: tp.Optional[timedelta]) -> tp.Optional[dict]: + if td is None: + return td + return {"days": td.days, "seconds": td.seconds, "microseconds": td.microseconds} + + +TimeDelta = tpe.Annotated[ + tp.Union[None, timedelta, dict], + PlainValidator(func=_serialize_timedelta), + PlainSerializer(func=_deserialize_timedelta), +] + + class PopularModelConfig(ModelConfig): """Config for `PopularModel`.""" popularity: Popularity = Popularity.N_USERS - period: tp.Optional[timedelta] = None - begin_from: tp.Optional[datetime] = None + period: TimeDelta = None + begin_from: tp.Optional[tp.Union[datetime, str]] = None add_cold: bool = False inverse: bool = False -class PopularModelBaseMixin(ModelBase[ModelConfig_T]): +class PopularModelMixin: """Mixin for models based on popularity.""" - def __init__( + def _validate_popular_model_attributes( self, - popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", - period: tp.Optional[timedelta] = None, - begin_from: tp.Optional[datetime] = None, - add_cold: bool = False, - inverse: bool = False, - verbose: int = 0, - ): - super().__init__(verbose=verbose) + popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], + period: TimeDelta, + begin_from: tp.Optional[tp.Union[datetime, str]], + ) -> None: + try: + self.popularity = Popularity(popularity) + except ValueError: + possible_values = {item.value for item in Popularity.__members__.values()} + raise ValueError(f"`popularity` must be one of the {possible_values}. Got {popularity}.") if period is not None and begin_from is not None: raise ValueError("Only one of `period` and `begin_from` can be set") - self.popularity = Popularity(popularity) - self.period = period - self.begin_from = begin_from - - self.add_cold = add_cold - self.inverse = inverse - def _filter_interactions(self, interactions: pd.DataFrame) -> pd.DataFrame: - if self.begin_from is not None: - interactions = interactions.loc[interactions[Columns.Datetime] >= self.begin_from] - elif self.period is not None: - self.begin_from = interactions[Columns.Datetime].max() - self.period - interactions = interactions.loc[interactions[Columns.Datetime] >= self.begin_from] + def _filter_interactions( + self, interactions: pd.DataFrame, period: TimeDelta, begin_from: tp.Optional[tp.Union[datetime, str]] + ) -> pd.DataFrame: + if begin_from is not None: + interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] + elif period is not None: + begin_from = interactions[Columns.Datetime].max() - period + interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] return interactions def _get_groupby_col_and_agg_func(self, popularity: Popularity) -> tp.Tuple[str, str]: @@ -94,7 +119,7 @@ def _get_groupby_col_and_agg_func(self, popularity: Popularity) -> tp.Tuple[str, raise ValueError(f"Unexpected popularity {popularity}") -class PopularModel(FixedColdRecoModelMixin, PopularModelBaseMixin[PopularModelConfig]): +class PopularModel(FixedColdRecoModelMixin, PopularModelMixin, ModelBase[PopularModelConfig]): """ Model generating recommendations based on popularity of items. @@ -135,20 +160,21 @@ class PopularModel(FixedColdRecoModelMixin, PopularModelBaseMixin[PopularModelCo def __init__( self, popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", - period: tp.Optional[timedelta] = None, - begin_from: tp.Optional[datetime] = None, + period: TimeDelta = None, + begin_from: tp.Optional[tp.Union[datetime, str]] = None, add_cold: bool = False, inverse: bool = False, verbose: int = 0, ): super().__init__( - popularity=popularity, - period=period, - begin_from=begin_from, - add_cold=add_cold, - inverse=inverse, verbose=verbose, ) + self._validate_popular_model_attributes(popularity, period, begin_from) + self.period = period + self.begin_from = begin_from + + self.add_cold = add_cold + self.inverse = inverse self.popularity_list: tp.Tuple[InternalIdsArray, ScoresArray] @@ -174,7 +200,7 @@ def _from_config(cls, config: PopularModelConfig) -> tpe.Self: ) def _fit(self, dataset: Dataset) -> None: # type: ignore - interactions = self._filter_interactions(dataset.interactions.df) + interactions = self._filter_interactions(dataset.interactions.df, self.period, self.begin_from) col, func = self._get_groupby_col_and_agg_func(self.popularity) items_scores = interactions.groupby(Columns.Item)[col].agg(func).sort_values(ascending=False) diff --git a/rectools/models/popular_in_category.py b/rectools/models/popular_in_category.py index 7856003c..136ed912 100644 --- a/rectools/models/popular_in_category.py +++ b/rectools/models/popular_in_category.py @@ -16,7 +16,7 @@ import typing as tp import warnings -from datetime import datetime, timedelta +from datetime import datetime from enum import Enum import numpy as np @@ -27,8 +27,8 @@ from rectools.dataset import Dataset, Interactions, features from rectools.types import InternalIdsArray -from .base import Scores -from .popular import FixedColdRecoModelMixin, PopularModel, PopularModelBaseMixin, PopularModelConfig +from .base import ModelBase, Scores +from .popular import FixedColdRecoModelMixin, PopularModel, PopularModelConfig, PopularModelMixin, TimeDelta class MixingStrategy(Enum): @@ -54,7 +54,9 @@ class PopularInCategoryModelConfig(PopularModelConfig): ratio_strategy: RatioStrategy = RatioStrategy.PROPORTIONAL -class PopularInCategoryModel(FixedColdRecoModelMixin, PopularModelBaseMixin[PopularInCategoryModelConfig]): +class PopularInCategoryModel( + FixedColdRecoModelMixin, PopularModelMixin, ModelBase[PopularInCategoryModelConfig] +): # pylint: disable=too-many-instance-attributes """ Model generating recommendations based on popularity of items. @@ -117,21 +119,23 @@ def __init__( mixing_strategy: tp.Literal["rotate", "group"] = "rotate", ratio_strategy: tp.Literal["proportional", "equal"] = "proportional", popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", - period: tp.Optional[timedelta] = None, - begin_from: tp.Optional[datetime] = None, + period: TimeDelta = None, + begin_from: tp.Optional[tp.Union[datetime, str]] = None, add_cold: bool = False, inverse: bool = False, verbose: int = 0, ): super().__init__( - popularity=popularity, - period=period, - begin_from=begin_from, - add_cold=add_cold, - inverse=inverse, verbose=verbose, ) + self._validate_popular_model_attributes(popularity, period, begin_from) + self.period = period + self.begin_from = begin_from + + self.add_cold = add_cold + self.inverse = inverse + self.category_feature = category_feature self.mixing_strategy = MixingStrategy(mixing_strategy) self.ratio_strategy = RatioStrategy(ratio_strategy) @@ -231,7 +235,7 @@ def _fit(self, dataset: Dataset) -> None: # type: ignore self.n_effective_categories = 0 self._check_category_feature(dataset) - interactions = self._filter_interactions(dataset.interactions.df) + interactions = self._filter_interactions(dataset.interactions.df, self.period, self.begin_from) self._calc_category_scores(dataset, interactions) self._define_categories_for_analysis() diff --git a/tests/model_selection/test_cross_validate.py b/tests/model_selection/test_cross_validate.py index acaa6cac..b7bc374d 100644 --- a/tests/model_selection/test_cross_validate.py +++ b/tests/model_selection/test_cross_validate.py @@ -168,7 +168,7 @@ def setup_method(self) -> None: "intersection": Intersection(1), } - self.models = { + self.models: tp.Dict[str, ModelBase] = { "popular": PopularModel(), "random": RandomModel(random_state=42), } @@ -209,7 +209,7 @@ def test_happy_path( dataset=self.dataset, splitter=splitter, metrics=self.metrics, - models=self.models, # type: ignore + models=self.models, k=2, filter_viewed=False, items_to_recommend=items_to_recommend, @@ -426,7 +426,7 @@ def test_happy_path_with_intersection( dataset=self.dataset, splitter=splitter, metrics=self.metrics_intersection, - models=self.models, # type: ignore + models=self.models, k=2, filter_viewed=False, ref_models=ref_models, diff --git a/tests/models/test_ease.py b/tests/models/test_ease.py index 0f90de38..9ea04f94 100644 --- a/tests/models/test_ease.py +++ b/tests/models/test_ease.py @@ -259,8 +259,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "num_threads": 1, "verbose": 1, } - model = EASEModel() - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility(EASEModel, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/test_implicit_als.py b/tests/models/test_implicit_als.py index da5dd34b..3ee309a0 100644 --- a/tests/models/test_implicit_als.py +++ b/tests/models/test_implicit_als.py @@ -461,8 +461,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N }, "verbose": 1, } - model = ImplicitALSWrapperModel(model=AlternatingLeastSquares()) - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility(ImplicitALSWrapperModel, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, tp.Any] = {"model": {}} diff --git a/tests/models/test_implicit_knn.py b/tests/models/test_implicit_knn.py index 732e7808..db0efd53 100644 --- a/tests/models/test_implicit_knn.py +++ b/tests/models/test_implicit_knn.py @@ -320,8 +320,9 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N }, "verbose": 1, } - model = ImplicitItemKNNWrapperModel(model=ItemItemRecommender()) - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility( + ImplicitItemKNNWrapperModel, DATASET, initial_config, simple_types + ) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, tp.Any] = {"model": {}} diff --git a/tests/models/test_popular.py b/tests/models/test_popular.py index 5568541d..ebaf8381 100644 --- a/tests/models/test_popular.py +++ b/tests/models/test_popular.py @@ -22,7 +22,7 @@ from rectools import Columns from rectools.dataset import Dataset, IdMap, Interactions from rectools.models import PopularModel -from rectools.models.popular import Popularity +from rectools.models.popular import Popularity, TimeDelta from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, assert_get_config_and_from_config_compatibility, @@ -222,62 +222,165 @@ def test_second_fit_refits_model(self, dataset: Dataset) -> None: class TestPopularModelConfiguration: - - def test_from_config(self) -> None: + @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23T10:20:30.400+02:30")) + @pytest.mark.parametrize( + "period", + ( + None, + timedelta(days=7), + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "weeks": 7, + }, + ), + ) + def test_from_config(self, period: TimeDelta, begin_from: tp.Optional[tp.Union[datetime, str]]) -> None: config = { "popularity": "n_interactions", - "period": timedelta(days=7), - "begin_from": None, + "period": period, + "begin_from": begin_from, "add_cold": True, "inverse": True, "verbose": 0, } - model = PopularModel.from_config(config) - assert model.popularity.value == "n_interactions" - assert model.period == timedelta(days=7) - assert model.begin_from is None - assert model.add_cold is True - assert model.inverse is True - assert model.verbose == 0 + if period is not None and begin_from is not None: + with pytest.raises(ValueError): + model = PopularModel.from_config(config) + else: + model = PopularModel.from_config(config) + assert model.popularity.value == "n_interactions" + serialized_period = ( + timedelta( + days=period.get("days", 0), + seconds=period.get("seconds", 0), + microseconds=period.get("microseconds", 0), + milliseconds=period.get("milliseconds", 0), + minutes=period.get("minutes", 0), + hours=period.get("hours", 0), + weeks=period.get("weeks", 0), + ) + if isinstance(period, dict) + else period + ) + assert model.period == serialized_period + assert model.begin_from == begin_from + assert model.add_cold is True + assert model.inverse is True + assert model.verbose == 0 - @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23))) + @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23T10:20:30.400")) + @pytest.mark.parametrize( + "period", + ( + None, + timedelta(days=7), + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "hours": 10, + "weeks": 7, + }, + ), + ) @pytest.mark.parametrize("popularity", ("mean_weight", "sum_weight")) def test_get_config( self, popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], - begin_from: tp.Optional[datetime], + period: TimeDelta, + begin_from: tp.Optional[tp.Union[datetime, str]], ) -> None: - model = PopularModel( - popularity=popularity, - period=None, - begin_from=begin_from, - add_cold=False, - inverse=False, - verbose=1, - ) - config = model.get_config() - expected = { - "popularity": Popularity(popularity), - "period": None, - "begin_from": begin_from, - "add_cold": False, - "inverse": False, - "verbose": 1, - } - assert config == expected + if period is not None and begin_from is not None: + with pytest.raises(ValueError): + model = PopularModel( + popularity=popularity, + period=period, + begin_from=begin_from, + add_cold=False, + inverse=False, + verbose=1, + ) + else: + model = PopularModel( + popularity=popularity, + period=period, + begin_from=begin_from, + add_cold=False, + inverse=False, + verbose=1, + ) + config = model.get_config() + pre_serialized_period = ( + timedelta( + days=period.get("days", 0), + seconds=period.get("seconds", 0), + microseconds=period.get("microseconds", 0), + milliseconds=period.get("milliseconds", 0), + minutes=period.get("minutes", 0), + hours=period.get("hours", 0), + weeks=period.get("weeks", 0), + ) + if isinstance(period, dict) + else period + ) + serialized_period = ( + { + "days": pre_serialized_period.days, + "seconds": pre_serialized_period.seconds, + "microseconds": pre_serialized_period.microseconds, + } + if pre_serialized_period is not None + else pre_serialized_period + ) + expected = { + "popularity": Popularity(popularity), + "period": serialized_period, + "begin_from": begin_from, + "add_cold": False, + "inverse": False, + "verbose": 1, + } + assert config == expected + @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23T10:20:30.400")) + @pytest.mark.parametrize( + "period", + ( + None, + timedelta(days=7), + { + "days": 7, + "seconds": 123, + "milliseconds": 32, + "minutes": 2, + "hours": 10, + "weeks": 7, + }, + ), + ) @pytest.mark.parametrize("simple_types", (False, True)) - def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> None: + def test_get_config_and_from_config_compatibility( + self, period: TimeDelta, begin_from: tp.Optional[tp.Union[datetime, str]], simple_types: bool + ) -> None: initial_config = { "popularity": "n_users", - "period": None, - "begin_from": None, + "period": period, + "begin_from": begin_from, "add_cold": True, "inverse": False, "verbose": 0, } - model = PopularModel() - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + if period is not None and begin_from is not None: + with pytest.raises(ValueError): + PopularModel(period=period, begin_from=begin_from) + else: + assert_get_config_and_from_config_compatibility(PopularModel, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/test_popular_in_category.py b/tests/models/test_popular_in_category.py index 2505af61..686801f1 100644 --- a/tests/models/test_popular_in_category.py +++ b/tests/models/test_popular_in_category.py @@ -22,7 +22,7 @@ from rectools import Columns from rectools.dataset import Dataset from rectools.models import PopularInCategoryModel -from rectools.models.popular import Popularity +from rectools.models.popular import Popularity, TimeDelta from rectools.models.popular_in_category import MixingStrategy, RatioStrategy from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -31,66 +31,69 @@ ) -@pytest.mark.filterwarnings("ignore") -class TestPopularInCategoryModel: - @pytest.fixture - def interactions_df(self) -> pd.DataFrame: - interactions_df = pd.DataFrame( - [ - [70, 11, 1, "2021-11-30"], - [70, 12, 1, "2021-11-30"], - [10, 11, 1, "2021-11-30"], - [10, 12, 1, "2021-11-29"], - [10, 13, 9, "2021-11-28"], - [20, 11, 1, "2021-11-27"], - [20, 14, 2, "2021-11-26"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [30, 11, 1, "2021-11-24"], - [30, 12, 1, "2021-11-23"], - [30, 14, 1, "2021-11-23"], - [30, 15, 5, "2021-11-21"], - [30, 15, 5, "2021-11-21"], - [40, 11, 1, "2021-11-20"], - [40, 12, 1, "2021-11-19"], - [50, 12, 1, "2021-11-19"], - [60, 12, 1, "2021-11-19"], - ], - columns=Columns.Interactions, - ) - return interactions_df +@pytest.fixture(name="interactions_df") +def custom_interactions_df() -> pd.DataFrame: + interactions_df = pd.DataFrame( + [ + [70, 11, 1, "2021-11-30"], + [70, 12, 1, "2021-11-30"], + [10, 11, 1, "2021-11-30"], + [10, 12, 1, "2021-11-29"], + [10, 13, 9, "2021-11-28"], + [20, 11, 1, "2021-11-27"], + [20, 14, 2, "2021-11-26"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [20, 14, 1, "2021-11-25"], + [30, 11, 1, "2021-11-24"], + [30, 12, 1, "2021-11-23"], + [30, 14, 1, "2021-11-23"], + [30, 15, 5, "2021-11-21"], + [30, 15, 5, "2021-11-21"], + [40, 11, 1, "2021-11-20"], + [40, 12, 1, "2021-11-19"], + [50, 12, 1, "2021-11-19"], + [60, 12, 1, "2021-11-19"], + ], + columns=Columns.Interactions, + ) + return interactions_df - @pytest.fixture - def item_features_df(self) -> pd.DataFrame: - item_features_df = pd.DataFrame( - { - "id": [11, 11, 12, 12, 13, 13, 14, 14, 14], - "feature": ["f1", "f2", "f1", "f2", "f1", "f2", "f1", "f2", "f3"], - "value": [100, "a", 100, "b", 100, "b", 200, "c", 1], - } - ) - return item_features_df - @pytest.fixture - def dataset(self, interactions_df: pd.DataFrame, item_features_df: pd.DataFrame) -> Dataset: - user_features_df = pd.DataFrame( - { - "id": [10, 50], - "feature": ["f1", "f1"], - "value": [1, 1], - } - ) - dataset = Dataset.construct( - interactions_df=interactions_df, - user_features_df=user_features_df, - item_features_df=item_features_df, - cat_item_features=["f2", "f1"], - ) - return dataset +@pytest.fixture(name="item_features_df") +def custom_item_features_df() -> pd.DataFrame: + item_features_df = pd.DataFrame( + { + "id": [11, 11, 12, 12, 13, 13, 14, 14, 14], + "feature": ["f1", "f2", "f1", "f2", "f1", "f2", "f1", "f2", "f3"], + "value": [100, "a", 100, "b", 100, "b", 200, "c", 1], + } + ) + return item_features_df + + +@pytest.fixture(name="dataset") +def custom_dataset(interactions_df: pd.DataFrame, item_features_df: pd.DataFrame) -> Dataset: + user_features_df = pd.DataFrame( + { + "id": [10, 50], + "feature": ["f1", "f1"], + "value": [1, 1], + } + ) + dataset = Dataset.construct( + interactions_df=interactions_df, + user_features_df=user_features_df, + item_features_df=item_features_df, + cat_item_features=["f2", "f1"], + ) + return dataset + +@pytest.mark.filterwarnings("ignore") +class TestPopularInCategoryModel: @classmethod def assert_reco( cls, @@ -453,129 +456,176 @@ def test_second_fit_refits_model( class TestPopularInCategoryModelConfiguration: - @pytest.fixture - def interactions_df(self) -> pd.DataFrame: - interactions_df = pd.DataFrame( - [ - [70, 11, 1, "2021-11-30"], - [70, 12, 1, "2021-11-30"], - [10, 11, 1, "2021-11-30"], - [10, 12, 1, "2021-11-29"], - [10, 13, 9, "2021-11-28"], - [20, 11, 1, "2021-11-27"], - [20, 14, 2, "2021-11-26"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [20, 14, 1, "2021-11-25"], - [30, 11, 1, "2021-11-24"], - [30, 12, 1, "2021-11-23"], - [30, 14, 1, "2021-11-23"], - [30, 15, 5, "2021-11-21"], - [30, 15, 5, "2021-11-21"], - [40, 11, 1, "2021-11-20"], - [40, 12, 1, "2021-11-19"], - [50, 12, 1, "2021-11-19"], - [60, 12, 1, "2021-11-19"], - ], - columns=Columns.Interactions, - ) - return interactions_df - - @pytest.fixture - def item_features_df(self) -> pd.DataFrame: - item_features_df = pd.DataFrame( - { - "id": [11, 11, 12, 12, 13, 13, 14, 14, 14], - "feature": ["f1", "f2", "f1", "f2", "f1", "f2", "f1", "f2", "f3"], - "value": [100, "a", 100, "b", 100, "b", 200, "c", 1], - } - ) - return item_features_df - - @pytest.fixture - def dataset(self, interactions_df: pd.DataFrame, item_features_df: pd.DataFrame) -> Dataset: - user_features_df = pd.DataFrame( + @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23T10:20:30.400+02:30")) + @pytest.mark.parametrize( + "period", + ( + None, + timedelta(days=7), { - "id": [10, 50], - "feature": ["f1", "f1"], - "value": [1, 1], - } - ) - dataset = Dataset.construct( - interactions_df=interactions_df, - user_features_df=user_features_df, - item_features_df=item_features_df, - cat_item_features=["f2", "f1"], - ) - return dataset - - def test_from_config(self) -> None: + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "weeks": 7, + }, + ), + ) + def test_from_config(self, period: TimeDelta, begin_from: tp.Optional[tp.Union[datetime, str]]) -> None: config = { "category_feature": "f1", "n_categories": 2, "mixing_strategy": "group", "ratio_strategy": "equal", "popularity": "n_interactions", - "period": timedelta(days=7), - "begin_from": None, + "period": period, + "begin_from": begin_from, "add_cold": True, "inverse": True, "verbose": 0, } - model = PopularInCategoryModel.from_config(config) - assert model.category_feature == "f1" - assert model.n_categories == 2 - assert model.mixing_strategy == MixingStrategy("group") - assert model.ratio_strategy == RatioStrategy("equal") - assert model.popularity == Popularity("n_interactions") - assert model.period == timedelta(days=7) - assert model.begin_from is None - assert model.add_cold is True - assert model.inverse is True - assert model.verbose == 0 - - @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23))) + if period is not None and begin_from is not None: + with pytest.raises(ValueError): + model = PopularInCategoryModel.from_config(config) + else: + model = PopularInCategoryModel.from_config(config) + assert model.category_feature == "f1" + assert model.n_categories == 2 + assert model.mixing_strategy == MixingStrategy("group") + assert model.ratio_strategy == RatioStrategy("equal") + assert model.popularity == Popularity("n_interactions") + serialized_period = ( + timedelta( + days=period.get("days", 0), + seconds=period.get("seconds", 0), + microseconds=period.get("microseconds", 0), + milliseconds=period.get("milliseconds", 0), + minutes=period.get("minutes", 0), + hours=period.get("hours", 0), + weeks=period.get("weeks", 0), + ) + if isinstance(period, dict) + else period + ) + assert model.period == serialized_period + assert model.begin_from == begin_from + assert model.add_cold is True + assert model.inverse is True + assert model.verbose == 0 + + @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23 10:20:30.400")) + @pytest.mark.parametrize( + "period", + ( + None, + timedelta(days=7), + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "hours": 10, + "weeks": 7, + }, + ), + ) @pytest.mark.parametrize("popularity", ("mean_weight", "sum_weight")) def test_get_config( self, popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], - begin_from: tp.Optional[datetime], + period: TimeDelta, + begin_from: tp.Optional[tp.Union[datetime, str]], ) -> None: - model = PopularInCategoryModel( - category_feature="f2", - n_categories=3, - mixing_strategy="rotate", - ratio_strategy="proportional", - popularity=popularity, - period=None, - begin_from=begin_from, - add_cold=False, - inverse=False, - verbose=1, - ) - config = model.get_config() - expected = { - "category_feature": "f2", - "n_categories": 3, - "mixing_strategy": MixingStrategy("rotate"), - "ratio_strategy": RatioStrategy("proportional"), - "popularity": Popularity(popularity), - "period": None, - "begin_from": begin_from, - "add_cold": False, - "inverse": False, - "verbose": 1, - } - assert config == expected + if period is not None and begin_from is not None: + with pytest.raises(ValueError): + model = PopularInCategoryModel( + category_feature="f2", + n_categories=3, + mixing_strategy="rotate", + ratio_strategy="proportional", + popularity=popularity, + period=period, + begin_from=begin_from, + add_cold=False, + inverse=False, + verbose=1, + ) + else: + model = PopularInCategoryModel( + category_feature="f2", + n_categories=3, + mixing_strategy="rotate", + ratio_strategy="proportional", + popularity=popularity, + period=period, + begin_from=begin_from, + add_cold=False, + inverse=False, + verbose=1, + ) + config = model.get_config() + pre_serialized_period = ( + timedelta( + days=period.get("days", 0), + seconds=period.get("seconds", 0), + microseconds=period.get("microseconds", 0), + milliseconds=period.get("milliseconds", 0), + minutes=period.get("minutes", 0), + hours=period.get("hours", 0), + weeks=period.get("weeks", 0), + ) + if isinstance(period, dict) + else period + ) + serialized_period = ( + { + "days": pre_serialized_period.days, + "seconds": pre_serialized_period.seconds, + "microseconds": pre_serialized_period.microseconds, + } + if pre_serialized_period is not None + else pre_serialized_period + ) + expected = { + "category_feature": "f2", + "n_categories": 3, + "mixing_strategy": MixingStrategy("rotate"), + "ratio_strategy": RatioStrategy("proportional"), + "popularity": Popularity(popularity), + "period": serialized_period, + "begin_from": begin_from, + "add_cold": False, + "inverse": False, + "verbose": 1, + } + assert config == expected + @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23T10:20:30.400")) + @pytest.mark.parametrize( + "period", + ( + None, + timedelta(days=7), + { + "days": 7, + "seconds": 123, + "milliseconds": 32, + "minutes": 2, + "hours": 10, + "weeks": 7, + }, + ), + ) @pytest.mark.parametrize("category_feature", ("f1", "f2")) @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config_and_from_config_compatibility( self, dataset: Dataset, category_feature: str, + period: TimeDelta, + begin_from: tp.Optional[tp.Union[datetime, str]], simple_types: bool, ) -> None: initial_config = { @@ -584,14 +634,19 @@ def test_get_config_and_from_config_compatibility( "mixing_strategy": "group", "ratio_strategy": "equal", "popularity": "n_users", - "period": None, - "begin_from": None, + "period": period, + "begin_from": begin_from, "add_cold": True, "inverse": False, "verbose": 0, } - model = PopularInCategoryModel(category_feature) - assert_get_config_and_from_config_compatibility(model, dataset, initial_config, simple_types) + if period is not None and begin_from is not None: + with pytest.raises(ValueError): + PopularInCategoryModel(category_feature=category_feature, period=period, begin_from=begin_from) + else: + assert_get_config_and_from_config_compatibility( + PopularInCategoryModel, dataset, initial_config, simple_types + ) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, str] = {"category_feature": "f2"} diff --git a/tests/models/test_pure_svd.py b/tests/models/test_pure_svd.py index 14598ad5..7842c131 100644 --- a/tests/models/test_pure_svd.py +++ b/tests/models/test_pure_svd.py @@ -304,8 +304,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "random_state": 32, "verbose": 0, } - model = PureSVDModel() - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility(PureSVDModel, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/test_random.py b/tests/models/test_random.py index f55fa6b5..373ee9fe 100644 --- a/tests/models/test_random.py +++ b/tests/models/test_random.py @@ -214,8 +214,7 @@ def test_get_config_and_from_config_compatibility(self, simple_types: bool) -> N "random_state": 32, "verbose": 0, } - model = RandomModel() - assert_get_config_and_from_config_compatibility(model, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility(RandomModel, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/utils.py b/tests/models/utils.py index ec531b55..92f2757d 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -58,7 +58,7 @@ def assert_default_config_and_default_model_params_are_the_same( def assert_get_config_and_from_config_compatibility( - model: ModelBase, dataset: Dataset, initial_config: tp.Dict[str, tp.Any], simple_types: bool + model: tp.Type[ModelBase], dataset: Dataset, initial_config: tp.Dict[str, tp.Any], simple_types: bool ) -> None: def get_reco(model: ModelBase) -> pd.DataFrame: return model.fit(dataset).recommend(users=np.array([10, 20]), dataset=dataset, k=2, filter_viewed=False) From 61d87a04bd5858bbbf6141cd4362ccdb981baa8d Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Thu, 12 Sep 2024 14:21:56 +0300 Subject: [PATCH 11/12] fixed typing and tests --- rectools/models/popular.py | 67 ++++++++++++----------- rectools/models/popular_in_category.py | 23 ++++++-- tests/models/test_popular.py | 70 ++++++++---------------- tests/models/test_popular_in_category.py | 69 +++++++---------------- 4 files changed, 96 insertions(+), 133 deletions(-) diff --git a/rectools/models/popular.py b/rectools/models/popular.py index c281c402..b4c3b576 100644 --- a/rectools/models/popular.py +++ b/rectools/models/popular.py @@ -43,30 +43,25 @@ class Popularity(Enum): SUM_WEIGHT = "sum_weight" -def _serialize_timedelta(td: tp.Optional[tp.Union[None, dict, timedelta]]) -> tp.Optional[timedelta]: +def _deserialize_timedelta(td: tp.Union[dict, timedelta]) -> timedelta: if isinstance(td, dict): - return timedelta( - days=td.get("days", 0), - seconds=td.get("seconds", 0), - microseconds=td.get("microseconds", 0), - milliseconds=td.get("milliseconds", 0), - minutes=td.get("minutes", 0), - hours=td.get("hours", 0), - weeks=td.get("weeks", 0), - ) + return timedelta(**td) return td -def _deserialize_timedelta(td: tp.Optional[timedelta]) -> tp.Optional[dict]: - if td is None: - return td - return {"days": td.days, "seconds": td.seconds, "microseconds": td.microseconds} +def _serialize_timedelta(td: timedelta) -> dict: + serialized_td = { + key: value + for key, value in {"days": td.days, "seconds": td.seconds, "microseconds": td.microseconds}.items() + if value != 0 + } + return serialized_td TimeDelta = tpe.Annotated[ - tp.Union[None, timedelta, dict], - PlainValidator(func=_serialize_timedelta), - PlainSerializer(func=_deserialize_timedelta), + timedelta, + PlainValidator(func=_deserialize_timedelta), + PlainSerializer(func=_serialize_timedelta), ] @@ -74,8 +69,8 @@ class PopularModelConfig(ModelConfig): """Config for `PopularModel`.""" popularity: Popularity = Popularity.N_USERS - period: TimeDelta = None - begin_from: tp.Optional[tp.Union[datetime, str]] = None + period: tp.Optional[TimeDelta] = None + begin_from: tp.Optional[datetime] = None add_cold: bool = False inverse: bool = False @@ -83,22 +78,27 @@ class PopularModelConfig(ModelConfig): class PopularModelMixin: """Mixin for models based on popularity.""" - def _validate_popular_model_attributes( - self, + @classmethod + def _validate_popularity( + cls, popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], - period: TimeDelta, - begin_from: tp.Optional[tp.Union[datetime, str]], ) -> None: - try: - self.popularity = Popularity(popularity) - except ValueError: - possible_values = {item.value for item in Popularity.__members__.values()} + possible_values = {item.value for item in Popularity.__members__.values()} + if popularity not in possible_values: raise ValueError(f"`popularity` must be one of the {possible_values}. Got {popularity}.") + + @classmethod + def _validate_time_attributes( + cls, + period: tp.Optional[TimeDelta], + begin_from: tp.Optional[datetime], + ) -> None: if period is not None and begin_from is not None: raise ValueError("Only one of `period` and `begin_from` can be set") + @classmethod def _filter_interactions( - self, interactions: pd.DataFrame, period: TimeDelta, begin_from: tp.Optional[tp.Union[datetime, str]] + cls, interactions: pd.DataFrame, period: tp.Optional[TimeDelta], begin_from: tp.Optional[datetime] ) -> pd.DataFrame: if begin_from is not None: interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] @@ -107,7 +107,8 @@ def _filter_interactions( interactions = interactions.loc[interactions[Columns.Datetime] >= begin_from] return interactions - def _get_groupby_col_and_agg_func(self, popularity: Popularity) -> tp.Tuple[str, str]: + @classmethod + def _get_groupby_col_and_agg_func(cls, popularity: Popularity) -> tp.Tuple[str, str]: if popularity == Popularity.N_USERS: return Columns.User, "nunique" if popularity == Popularity.N_INTERACTIONS: @@ -160,8 +161,8 @@ class PopularModel(FixedColdRecoModelMixin, PopularModelMixin, ModelBase[Popular def __init__( self, popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", - period: TimeDelta = None, - begin_from: tp.Optional[tp.Union[datetime, str]] = None, + period: tp.Optional[TimeDelta] = None, + begin_from: tp.Optional[datetime] = None, add_cold: bool = False, inverse: bool = False, verbose: int = 0, @@ -169,7 +170,9 @@ def __init__( super().__init__( verbose=verbose, ) - self._validate_popular_model_attributes(popularity, period, begin_from) + self._validate_popularity(popularity) + self.popularity = Popularity(popularity) + self._validate_time_attributes(period, begin_from) self.period = period self.begin_from = begin_from diff --git a/rectools/models/popular_in_category.py b/rectools/models/popular_in_category.py index 136ed912..acf30e8d 100644 --- a/rectools/models/popular_in_category.py +++ b/rectools/models/popular_in_category.py @@ -28,7 +28,7 @@ from rectools.types import InternalIdsArray from .base import ModelBase, Scores -from .popular import FixedColdRecoModelMixin, PopularModel, PopularModelConfig, PopularModelMixin, TimeDelta +from .popular import FixedColdRecoModelMixin, PopularModel, PopularModelConfig, PopularModelMixin, Popularity, TimeDelta class MixingStrategy(Enum): @@ -119,8 +119,8 @@ def __init__( mixing_strategy: tp.Literal["rotate", "group"] = "rotate", ratio_strategy: tp.Literal["proportional", "equal"] = "proportional", popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", - period: TimeDelta = None, - begin_from: tp.Optional[tp.Union[datetime, str]] = None, + period: tp.Optional[TimeDelta] = None, + begin_from: tp.Optional[datetime] = None, add_cold: bool = False, inverse: bool = False, verbose: int = 0, @@ -129,7 +129,9 @@ def __init__( verbose=verbose, ) - self._validate_popular_model_attributes(popularity, period, begin_from) + self._validate_popularity(popularity) + self.popularity = Popularity(popularity) + self._validate_time_attributes(period, begin_from) self.period = period self.begin_from = begin_from @@ -137,8 +139,17 @@ def __init__( self.inverse = inverse self.category_feature = category_feature - self.mixing_strategy = MixingStrategy(mixing_strategy) - self.ratio_strategy = RatioStrategy(ratio_strategy) + try: + self.mixing_strategy = MixingStrategy(mixing_strategy) + except ValueError: + possible_values = {item.value for item in MixingStrategy.__members__.values()} + raise ValueError(f"`mixing_strategy` must be one of the {possible_values}. Got {mixing_strategy}.") + + try: + self.ratio_strategy = RatioStrategy(ratio_strategy) + except ValueError: + possible_values = {item.value for item in RatioStrategy.__members__.values()} + raise ValueError(f"`ratio_strategy` must be one of the {possible_values}. Got {ratio_strategy}.") self.category_columns: tp.List[int] = [] self.category_interactions: tp.Dict[int, pd.DataFrame] = {} self.category_scores: pd.Series diff --git a/tests/models/test_popular.py b/tests/models/test_popular.py index ebaf8381..308b088f 100644 --- a/tests/models/test_popular.py +++ b/tests/models/test_popular.py @@ -22,7 +22,7 @@ from rectools import Columns from rectools.dataset import Dataset, IdMap, Interactions from rectools.models import PopularModel -from rectools.models.popular import Popularity, TimeDelta +from rectools.models.popular import Popularity from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, assert_get_config_and_from_config_compatibility, @@ -238,7 +238,9 @@ class TestPopularModelConfiguration: }, ), ) - def test_from_config(self, period: TimeDelta, begin_from: tp.Optional[tp.Union[datetime, str]]) -> None: + def test_from_config( + self, period: tp.Optional[tp.Union[timedelta, dict]], begin_from: tp.Optional[tp.Union[datetime, str]] + ) -> None: config = { "popularity": "n_interactions", "period": period, @@ -253,48 +255,28 @@ def test_from_config(self, period: TimeDelta, begin_from: tp.Optional[tp.Union[d else: model = PopularModel.from_config(config) assert model.popularity.value == "n_interactions" - serialized_period = ( - timedelta( - days=period.get("days", 0), - seconds=period.get("seconds", 0), - microseconds=period.get("microseconds", 0), - milliseconds=period.get("milliseconds", 0), - minutes=period.get("minutes", 0), - hours=period.get("hours", 0), - weeks=period.get("weeks", 0), - ) - if isinstance(period, dict) - else period - ) + serialized_period = timedelta(**period) if isinstance(period, dict) else period assert model.period == serialized_period - assert model.begin_from == begin_from + serialiazed_begin_from = datetime.fromisoformat(begin_from) if isinstance(begin_from, str) else begin_from + assert model.begin_from == serialiazed_begin_from assert model.add_cold is True assert model.inverse is True assert model.verbose == 0 - @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23T10:20:30.400")) + @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23))) @pytest.mark.parametrize( "period", ( None, timedelta(days=7), - { - "days": 7, - "seconds": 123, - "microseconds": 12345, - "milliseconds": 32, - "minutes": 2, - "hours": 10, - "weeks": 7, - }, ), ) @pytest.mark.parametrize("popularity", ("mean_weight", "sum_weight")) def test_get_config( self, popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], - period: TimeDelta, - begin_from: tp.Optional[tp.Union[datetime, str]], + period: tp.Optional[timedelta], + begin_from: tp.Optional[datetime], ) -> None: if period is not None and begin_from is not None: with pytest.raises(ValueError): @@ -316,27 +298,18 @@ def test_get_config( verbose=1, ) config = model.get_config() - pre_serialized_period = ( - timedelta( - days=period.get("days", 0), - seconds=period.get("seconds", 0), - microseconds=period.get("microseconds", 0), - milliseconds=period.get("milliseconds", 0), - minutes=period.get("minutes", 0), - hours=period.get("hours", 0), - weeks=period.get("weeks", 0), - ) - if isinstance(period, dict) - else period - ) serialized_period = ( { - "days": pre_serialized_period.days, - "seconds": pre_serialized_period.seconds, - "microseconds": pre_serialized_period.microseconds, + key: value + for key, value in { + "days": period.days, + "seconds": period.seconds, + "microseconds": period.microseconds, + }.items() + if value != 0 } - if pre_serialized_period is not None - else pre_serialized_period + if isinstance(period, timedelta) + else period ) expected = { "popularity": Popularity(popularity), @@ -366,7 +339,10 @@ def test_get_config( ) @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config_and_from_config_compatibility( - self, period: TimeDelta, begin_from: tp.Optional[tp.Union[datetime, str]], simple_types: bool + self, + period: tp.Optional[timedelta], + begin_from: tp.Optional[datetime], + simple_types: bool, ) -> None: initial_config = { "popularity": "n_users", diff --git a/tests/models/test_popular_in_category.py b/tests/models/test_popular_in_category.py index 686801f1..4f64c7b5 100644 --- a/tests/models/test_popular_in_category.py +++ b/tests/models/test_popular_in_category.py @@ -22,7 +22,7 @@ from rectools import Columns from rectools.dataset import Dataset from rectools.models import PopularInCategoryModel -from rectools.models.popular import Popularity, TimeDelta +from rectools.models.popular import Popularity from rectools.models.popular_in_category import MixingStrategy, RatioStrategy from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -472,7 +472,9 @@ class TestPopularInCategoryModelConfiguration: }, ), ) - def test_from_config(self, period: TimeDelta, begin_from: tp.Optional[tp.Union[datetime, str]]) -> None: + def test_from_config( + self, period: tp.Optional[tp.Union[timedelta, dict]], begin_from: tp.Optional[tp.Union[datetime, str]] + ) -> None: config = { "category_feature": "f1", "n_categories": 2, @@ -495,48 +497,28 @@ def test_from_config(self, period: TimeDelta, begin_from: tp.Optional[tp.Union[d assert model.mixing_strategy == MixingStrategy("group") assert model.ratio_strategy == RatioStrategy("equal") assert model.popularity == Popularity("n_interactions") - serialized_period = ( - timedelta( - days=period.get("days", 0), - seconds=period.get("seconds", 0), - microseconds=period.get("microseconds", 0), - milliseconds=period.get("milliseconds", 0), - minutes=period.get("minutes", 0), - hours=period.get("hours", 0), - weeks=period.get("weeks", 0), - ) - if isinstance(period, dict) - else period - ) + serialized_period = timedelta(**period) if isinstance(period, dict) else period assert model.period == serialized_period - assert model.begin_from == begin_from + serialiazed_begin_from = datetime.fromisoformat(begin_from) if isinstance(begin_from, str) else begin_from + assert model.begin_from == serialiazed_begin_from assert model.add_cold is True assert model.inverse is True assert model.verbose == 0 - @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23 10:20:30.400")) + @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23))) @pytest.mark.parametrize( "period", ( None, timedelta(days=7), - { - "days": 7, - "seconds": 123, - "microseconds": 12345, - "milliseconds": 32, - "minutes": 2, - "hours": 10, - "weeks": 7, - }, ), ) @pytest.mark.parametrize("popularity", ("mean_weight", "sum_weight")) def test_get_config( self, popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], - period: TimeDelta, - begin_from: tp.Optional[tp.Union[datetime, str]], + period: tp.Optional[timedelta], + begin_from: tp.Optional[datetime], ) -> None: if period is not None and begin_from is not None: with pytest.raises(ValueError): @@ -566,27 +548,18 @@ def test_get_config( verbose=1, ) config = model.get_config() - pre_serialized_period = ( - timedelta( - days=period.get("days", 0), - seconds=period.get("seconds", 0), - microseconds=period.get("microseconds", 0), - milliseconds=period.get("milliseconds", 0), - minutes=period.get("minutes", 0), - hours=period.get("hours", 0), - weeks=period.get("weeks", 0), - ) - if isinstance(period, dict) - else period - ) serialized_period = ( { - "days": pre_serialized_period.days, - "seconds": pre_serialized_period.seconds, - "microseconds": pre_serialized_period.microseconds, + key: value + for key, value in { + "days": period.days, + "seconds": period.seconds, + "microseconds": period.microseconds, + }.items() + if value != 0 } - if pre_serialized_period is not None - else pre_serialized_period + if isinstance(period, timedelta) + else period ) expected = { "category_feature": "f2", @@ -624,8 +597,8 @@ def test_get_config_and_from_config_compatibility( self, dataset: Dataset, category_feature: str, - period: TimeDelta, - begin_from: tp.Optional[tp.Union[datetime, str]], + period: tp.Optional[timedelta], + begin_from: tp.Optional[datetime], simple_types: bool, ) -> None: initial_config = { From 2de6563d48b6db71272b1903d4ed614b074c3426 Mon Sep 17 00:00:00 2001 From: mikesokolovv Date: Fri, 13 Sep 2024 12:39:24 +0300 Subject: [PATCH 12/12] Fixed tests, typing, naming. --- rectools/models/popular.py | 20 ++- rectools/models/popular_in_category.py | 11 +- tests/models/test_popular.py | 168 ++++++++---------- tests/models/test_popular_in_category.py | 212 ++++++++++------------- 4 files changed, 187 insertions(+), 224 deletions(-) diff --git a/rectools/models/popular.py b/rectools/models/popular.py index b4c3b576..29708b10 100644 --- a/rectools/models/popular.py +++ b/rectools/models/popular.py @@ -75,16 +75,21 @@ class PopularModelConfig(ModelConfig): inverse: bool = False +PopularityOptions = tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] + + class PopularModelMixin: """Mixin for models based on popularity.""" @classmethod def _validate_popularity( cls, - popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], - ) -> None: - possible_values = {item.value for item in Popularity.__members__.values()} - if popularity not in possible_values: + popularity: PopularityOptions, + ) -> Popularity: + try: + return Popularity(popularity) + except ValueError: + possible_values = {item.value for item in Popularity.__members__.values()} raise ValueError(f"`popularity` must be one of the {possible_values}. Got {popularity}.") @classmethod @@ -160,8 +165,8 @@ class PopularModel(FixedColdRecoModelMixin, PopularModelMixin, ModelBase[Popular def __init__( self, - popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", - period: tp.Optional[TimeDelta] = None, + popularity: PopularityOptions = "n_users", + period: tp.Optional[timedelta] = None, begin_from: tp.Optional[datetime] = None, add_cold: bool = False, inverse: bool = False, @@ -170,8 +175,7 @@ def __init__( super().__init__( verbose=verbose, ) - self._validate_popularity(popularity) - self.popularity = Popularity(popularity) + self.popularity = self._validate_popularity(popularity) self._validate_time_attributes(period, begin_from) self.period = period self.begin_from = begin_from diff --git a/rectools/models/popular_in_category.py b/rectools/models/popular_in_category.py index acf30e8d..4f6416c4 100644 --- a/rectools/models/popular_in_category.py +++ b/rectools/models/popular_in_category.py @@ -16,7 +16,7 @@ import typing as tp import warnings -from datetime import datetime +from datetime import datetime, timedelta from enum import Enum import numpy as np @@ -28,7 +28,7 @@ from rectools.types import InternalIdsArray from .base import ModelBase, Scores -from .popular import FixedColdRecoModelMixin, PopularModel, PopularModelConfig, PopularModelMixin, Popularity, TimeDelta +from .popular import FixedColdRecoModelMixin, PopularModel, PopularModelConfig, PopularModelMixin, PopularityOptions class MixingStrategy(Enum): @@ -118,8 +118,8 @@ def __init__( n_categories: tp.Optional[int] = None, mixing_strategy: tp.Literal["rotate", "group"] = "rotate", ratio_strategy: tp.Literal["proportional", "equal"] = "proportional", - popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"] = "n_users", - period: tp.Optional[TimeDelta] = None, + popularity: PopularityOptions = "n_users", + period: tp.Optional[timedelta] = None, begin_from: tp.Optional[datetime] = None, add_cold: bool = False, inverse: bool = False, @@ -129,8 +129,7 @@ def __init__( verbose=verbose, ) - self._validate_popularity(popularity) - self.popularity = Popularity(popularity) + self.popularity = self._validate_popularity(popularity) self._validate_time_attributes(period, begin_from) self.period = period self.begin_from = begin_from diff --git a/tests/models/test_popular.py b/tests/models/test_popular.py index 308b088f..fd419c1a 100644 --- a/tests/models/test_popular.py +++ b/tests/models/test_popular.py @@ -222,24 +222,33 @@ def test_second_fit_refits_model(self, dataset: Dataset) -> None: class TestPopularModelConfiguration: - @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23T10:20:30.400+02:30")) @pytest.mark.parametrize( - "period", + "begin_from,period,expected_begin_from,expected_period", ( - None, - timedelta(days=7), - { - "days": 7, - "seconds": 123, - "microseconds": 12345, - "milliseconds": 32, - "minutes": 2, - "weeks": 7, - }, + (None, timedelta(days=7), None, timedelta(days=7)), + (datetime(2021, 11, 23), None, datetime(2021, 11, 23), None), + ("2021-11-23T10:20:30.400", None, datetime(2021, 11, 23, 10, 20, 30, 400000), None), + ( + None, + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "weeks": 7, + }, + None, + timedelta(days=56, seconds=243, microseconds=44345), + ), ), ) def test_from_config( - self, period: tp.Optional[tp.Union[timedelta, dict]], begin_from: tp.Optional[tp.Union[datetime, str]] + self, + period: tp.Optional[tp.Union[timedelta, dict]], + begin_from: tp.Optional[tp.Union[datetime, str]], + expected_begin_from: tp.Optional[datetime], + expected_period: tp.Optional[dict], ) -> None: config = { "popularity": "n_interactions", @@ -249,95 +258,74 @@ def test_from_config( "inverse": True, "verbose": 0, } - if period is not None and begin_from is not None: - with pytest.raises(ValueError): - model = PopularModel.from_config(config) - else: - model = PopularModel.from_config(config) - assert model.popularity.value == "n_interactions" - serialized_period = timedelta(**period) if isinstance(period, dict) else period - assert model.period == serialized_period - serialiazed_begin_from = datetime.fromisoformat(begin_from) if isinstance(begin_from, str) else begin_from - assert model.begin_from == serialiazed_begin_from - assert model.add_cold is True - assert model.inverse is True - assert model.verbose == 0 + model = PopularModel.from_config(config) + assert model.popularity.value == "n_interactions" + assert model.period == expected_period + assert model.begin_from == expected_begin_from + assert model.add_cold is True + assert model.inverse is True + assert model.verbose == 0 - @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23))) @pytest.mark.parametrize( - "period", + "begin_from,period,expected_period", ( - None, - timedelta(days=7), + ( + None, + timedelta(weeks=2, days=7, hours=23, milliseconds=12345), + {"days": 21, "microseconds": 345000, "seconds": 82812}, + ), + (datetime(2021, 11, 23, 10, 20, 30, 400000), None, None), ), ) - @pytest.mark.parametrize("popularity", ("mean_weight", "sum_weight")) def test_get_config( self, - popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], period: tp.Optional[timedelta], begin_from: tp.Optional[datetime], + expected_period: tp.Optional[timedelta], ) -> None: - if period is not None and begin_from is not None: - with pytest.raises(ValueError): - model = PopularModel( - popularity=popularity, - period=period, - begin_from=begin_from, - add_cold=False, - inverse=False, - verbose=1, - ) - else: - model = PopularModel( - popularity=popularity, - period=period, - begin_from=begin_from, - add_cold=False, - inverse=False, - verbose=1, - ) - config = model.get_config() - serialized_period = ( - { - key: value - for key, value in { - "days": period.days, - "seconds": period.seconds, - "microseconds": period.microseconds, - }.items() - if value != 0 - } - if isinstance(period, timedelta) - else period - ) - expected = { - "popularity": Popularity(popularity), - "period": serialized_period, - "begin_from": begin_from, - "add_cold": False, - "inverse": False, - "verbose": 1, - } - assert config == expected + model = PopularModel( + popularity="n_users", + period=period, + begin_from=begin_from, + add_cold=False, + inverse=False, + verbose=1, + ) + config = model.get_config() + expected = { + "popularity": Popularity("n_users"), + "period": expected_period, + "begin_from": begin_from, + "add_cold": False, + "inverse": False, + "verbose": 1, + } + assert config == expected - @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23T10:20:30.400")) @pytest.mark.parametrize( - "period", + "begin_from,period,simple_types", ( - None, - timedelta(days=7), - { - "days": 7, - "seconds": 123, - "milliseconds": 32, - "minutes": 2, - "hours": 10, - "weeks": 7, - }, + ( + None, + timedelta(weeks=1, days=2, hours=3, minutes=4, seconds=5, milliseconds=6000, microseconds=70000), + True, + ), + (datetime(2021, 11, 23), None, False), + ("2021-11-23T10:20:30.400", None, True), + ( + None, + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "weeks": 7, + }, + False, + ), ), ) - @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config_and_from_config_compatibility( self, period: tp.Optional[timedelta], @@ -352,11 +340,7 @@ def test_get_config_and_from_config_compatibility( "inverse": False, "verbose": 0, } - if period is not None and begin_from is not None: - with pytest.raises(ValueError): - PopularModel(period=period, begin_from=begin_from) - else: - assert_get_config_and_from_config_compatibility(PopularModel, DATASET, initial_config, simple_types) + assert_get_config_and_from_config_compatibility(PopularModel, DATASET, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, int] = {} diff --git a/tests/models/test_popular_in_category.py b/tests/models/test_popular_in_category.py index 4f64c7b5..3d0a6ffa 100644 --- a/tests/models/test_popular_in_category.py +++ b/tests/models/test_popular_in_category.py @@ -31,8 +31,8 @@ ) -@pytest.fixture(name="interactions_df") -def custom_interactions_df() -> pd.DataFrame: +@pytest.fixture(name="interactions_df") # https://github.com/pylint-dev/pylint/issues/6531 +def _interactions_df() -> pd.DataFrame: interactions_df = pd.DataFrame( [ [70, 11, 1, "2021-11-30"], @@ -63,7 +63,7 @@ def custom_interactions_df() -> pd.DataFrame: @pytest.fixture(name="item_features_df") -def custom_item_features_df() -> pd.DataFrame: +def _item_features_df() -> pd.DataFrame: item_features_df = pd.DataFrame( { "id": [11, 11, 12, 12, 13, 13, 14, 14, 14], @@ -75,7 +75,7 @@ def custom_item_features_df() -> pd.DataFrame: @pytest.fixture(name="dataset") -def custom_dataset(interactions_df: pd.DataFrame, item_features_df: pd.DataFrame) -> Dataset: +def _dataset(interactions_df: pd.DataFrame, item_features_df: pd.DataFrame) -> Dataset: user_features_df = pd.DataFrame( { "id": [10, 50], @@ -456,24 +456,33 @@ def test_second_fit_refits_model( class TestPopularInCategoryModelConfiguration: - @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23T10:20:30.400+02:30")) @pytest.mark.parametrize( - "period", + "begin_from,period,expected_begin_from,expected_period", ( - None, - timedelta(days=7), - { - "days": 7, - "seconds": 123, - "microseconds": 12345, - "milliseconds": 32, - "minutes": 2, - "weeks": 7, - }, + (None, timedelta(days=7), None, timedelta(days=7)), + (datetime(2021, 11, 23), None, datetime(2021, 11, 23), None), + ("2021-11-23T10:20:30.400", None, datetime(2021, 11, 23, 10, 20, 30, 400000), None), + ( + None, + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "weeks": 7, + }, + None, + timedelta(days=56, seconds=243, microseconds=44345), + ), ), ) def test_from_config( - self, period: tp.Optional[tp.Union[timedelta, dict]], begin_from: tp.Optional[tp.Union[datetime, str]] + self, + period: tp.Optional[tp.Union[timedelta, dict]], + begin_from: tp.Optional[tp.Union[datetime, str]], + expected_begin_from: tp.Optional[datetime], + expected_period: tp.Optional[dict], ) -> None: config = { "category_feature": "f1", @@ -487,122 +496,95 @@ def test_from_config( "inverse": True, "verbose": 0, } - if period is not None and begin_from is not None: - with pytest.raises(ValueError): - model = PopularInCategoryModel.from_config(config) - else: - model = PopularInCategoryModel.from_config(config) - assert model.category_feature == "f1" - assert model.n_categories == 2 - assert model.mixing_strategy == MixingStrategy("group") - assert model.ratio_strategy == RatioStrategy("equal") - assert model.popularity == Popularity("n_interactions") - serialized_period = timedelta(**period) if isinstance(period, dict) else period - assert model.period == serialized_period - serialiazed_begin_from = datetime.fromisoformat(begin_from) if isinstance(begin_from, str) else begin_from - assert model.begin_from == serialiazed_begin_from - assert model.add_cold is True - assert model.inverse is True - assert model.verbose == 0 - - @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23))) + model = PopularInCategoryModel.from_config(config) + assert model.category_feature == "f1" + assert model.n_categories == 2 + assert model.mixing_strategy == MixingStrategy("group") + assert model.ratio_strategy == RatioStrategy("equal") + assert model.popularity == Popularity("n_interactions") + assert model.period == expected_period + assert model.begin_from == expected_begin_from + assert model.add_cold is True + assert model.inverse is True + assert model.verbose == 0 + @pytest.mark.parametrize( - "period", + "begin_from,period,expected_period", ( - None, - timedelta(days=7), + ( + None, + timedelta(weeks=2, days=7, hours=23, milliseconds=12345), + {"days": 21, "microseconds": 345000, "seconds": 82812}, + ), + (datetime(2021, 11, 23, 10, 20, 30, 400000), None, None), ), ) - @pytest.mark.parametrize("popularity", ("mean_weight", "sum_weight")) def test_get_config( self, - popularity: tp.Literal["n_users", "n_interactions", "mean_weight", "sum_weight"], period: tp.Optional[timedelta], begin_from: tp.Optional[datetime], + expected_period: tp.Optional[timedelta], ) -> None: - if period is not None and begin_from is not None: - with pytest.raises(ValueError): - model = PopularInCategoryModel( - category_feature="f2", - n_categories=3, - mixing_strategy="rotate", - ratio_strategy="proportional", - popularity=popularity, - period=period, - begin_from=begin_from, - add_cold=False, - inverse=False, - verbose=1, - ) - else: - model = PopularInCategoryModel( - category_feature="f2", - n_categories=3, - mixing_strategy="rotate", - ratio_strategy="proportional", - popularity=popularity, - period=period, - begin_from=begin_from, - add_cold=False, - inverse=False, - verbose=1, - ) - config = model.get_config() - serialized_period = ( - { - key: value - for key, value in { - "days": period.days, - "seconds": period.seconds, - "microseconds": period.microseconds, - }.items() - if value != 0 - } - if isinstance(period, timedelta) - else period - ) - expected = { - "category_feature": "f2", - "n_categories": 3, - "mixing_strategy": MixingStrategy("rotate"), - "ratio_strategy": RatioStrategy("proportional"), - "popularity": Popularity(popularity), - "period": serialized_period, - "begin_from": begin_from, - "add_cold": False, - "inverse": False, - "verbose": 1, - } - assert config == expected + model = PopularInCategoryModel( + category_feature="f2", + n_categories=3, + mixing_strategy="rotate", + ratio_strategy="proportional", + popularity="n_users", + period=period, + begin_from=begin_from, + add_cold=False, + inverse=False, + verbose=1, + ) + config = model.get_config() + expected = { + "category_feature": "f2", + "n_categories": 3, + "mixing_strategy": MixingStrategy("rotate"), + "ratio_strategy": RatioStrategy("proportional"), + "popularity": Popularity("n_users"), + "period": expected_period, + "begin_from": begin_from, + "add_cold": False, + "inverse": False, + "verbose": 1, + } + assert config == expected - @pytest.mark.parametrize("begin_from", (None, datetime(2021, 11, 23), "2021-11-23T10:20:30.400")) @pytest.mark.parametrize( - "period", + "begin_from,period,simple_types", ( - None, - timedelta(days=7), - { - "days": 7, - "seconds": 123, - "milliseconds": 32, - "minutes": 2, - "hours": 10, - "weeks": 7, - }, + ( + None, + timedelta(weeks=1, days=2, hours=3, minutes=4, seconds=5, milliseconds=6000, microseconds=70000), + True, + ), + (datetime(2021, 11, 23), None, False), + ("2021-11-23T10:20:30.400", None, True), + ( + None, + { + "days": 7, + "seconds": 123, + "microseconds": 12345, + "milliseconds": 32, + "minutes": 2, + "weeks": 7, + }, + False, + ), ), ) - @pytest.mark.parametrize("category_feature", ("f1", "f2")) - @pytest.mark.parametrize("simple_types", (False, True)) def test_get_config_and_from_config_compatibility( self, dataset: Dataset, - category_feature: str, period: tp.Optional[timedelta], begin_from: tp.Optional[datetime], simple_types: bool, ) -> None: initial_config = { - "category_feature": category_feature, + "category_feature": "f1", "n_categories": 2, "mixing_strategy": "group", "ratio_strategy": "equal", @@ -613,13 +595,7 @@ def test_get_config_and_from_config_compatibility( "inverse": False, "verbose": 0, } - if period is not None and begin_from is not None: - with pytest.raises(ValueError): - PopularInCategoryModel(category_feature=category_feature, period=period, begin_from=begin_from) - else: - assert_get_config_and_from_config_compatibility( - PopularInCategoryModel, dataset, initial_config, simple_types - ) + assert_get_config_and_from_config_compatibility(PopularInCategoryModel, dataset, initial_config, simple_types) def test_default_config_and_default_model_params_are_the_same(self) -> None: default_config: tp.Dict[str, str] = {"category_feature": "f2"}