Skip to content

Commit

Permalink
lightning module (#187)
Browse files Browse the repository at this point in the history
Added lightning module
  • Loading branch information
spirinamayya authored and Семенов Андрей Максимович committed Sep 30, 2024
1 parent 1ab86a9 commit 245aada
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions rectools/models/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ def get_all_embeddings(self) -> torch.Tensor:
"""TODO"""
raise NotImplementedError()

@classmethod
def construct_nets_from_dataset(cls, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> tpe.Self:
"""TODO"""
raise NotImplementedError()

@property
def device(self) -> torch.device:
"""TODO"""
Expand Down Expand Up @@ -216,7 +211,7 @@ def get_all_embeddings(self) -> torch.Tensor:
return self.forward(self.catalogue)

@classmethod
def construct_nets_from_dataset(
def from_dataset(
cls,
dataset: Dataset,
n_factors: int,
Expand Down Expand Up @@ -398,7 +393,7 @@ def __init__(

def construct_item_net(self, dataset: Dataset) -> None:
"""TODO"""
self.item_model = ItemNetConstructor.construct_nets_from_dataset(
self.item_model = ItemNetConstructor.from_dataset(
dataset, self.n_factors, self.dropout_rate, self.item_net_block_types
)

Expand Down Expand Up @@ -550,6 +545,7 @@ def process_dataset_train(self, dataset: Dataset) -> Dataset:
item_features = None
if dataset.item_features is not None:
item_features = dataset.item_features
# TODO: remove assumption on SparseFeatures and add Dense Features support
if not isinstance(item_features, SparseFeatures):
raise ValueError("`item_features` in `dataset` must be `SparseFeatures` instance.")

Expand All @@ -560,13 +556,13 @@ def process_dataset_train(self, dataset: Dataset) -> Dataset:

dtype = sorted_item_features.values.dtype
n_features = sorted_item_features.values.shape[1]
pad_item_features = sparse.csr_matrix((self.n_item_extra_tokens, n_features), dtype=dtype)
extra_token_feature_values = sparse.csr_matrix((self.n_item_extra_tokens, n_features), dtype=dtype)

extra_tokens_feature_values: sparse.scr_matrix = sparse.vstack(
[pad_item_features.toarray(), sorted_item_features.values], format="csr"
full_feature_values: sparse.scr_matrix = sparse.vstack(
[extra_token_feature_values, sorted_item_features.values], format="csr"
)

item_features = SparseFeatures.from_iterables(values=extra_tokens_feature_values, names=item_features.names)
item_features = SparseFeatures.from_iterables(values=full_feature_values, names=item_features.names)

interactions = Interactions.from_raw(interactions, user_id_map, item_id_map)

Expand Down

0 comments on commit 245aada

Please sign in to comment.