Skip to content

Commit

Permalink
Added description for new classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Семенов Андрей Максимович committed Sep 27, 2024
1 parent 6392532 commit 88a1721
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions rectools/models/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def forward(self, sessions: torch.Tensor, timeline_mask: torch.Tensor) -> torch.


class CatFeaturesItemNet(ItemNetBase):
"""Сlass for all category features embeddings. TODO"""
"""
Base class for all category item features embeddings. To use more complicated logic then just id embeddings inherit
from this class and pass your custom ItemNet to your model params.
"""

def __init__(
self,
Expand Down Expand Up @@ -135,7 +138,7 @@ def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) ->
class IdEmbeddingsItemNet(ItemNetBase):
"""
Base class for item embeddings. To use more complicated logic then just id embeddings inherit
from this class and pass your custom ItemNet to your model params
from this class and pass your custom ItemNet to your model params.
"""

def __init__(self, n_factors: int, n_items: int, dropout_rate: float):
Expand Down Expand Up @@ -168,7 +171,11 @@ def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) ->


class ItemNetConstructor(ItemNetBase):
"""TODO"""
"""
Base class constructor for ItemNet, taking as input a sequence of ItemNetBase nets,
including custom ItemNetBase nets.
Constructs item's embedding based on aggregation of its embeddings from the passed networks.
"""

def __init__(
self,
Expand Down

0 comments on commit 88a1721

Please sign in to comment.