Skip to content

Commit

Permalink
Added comments TODO.
Browse files Browse the repository at this point in the history
  • Loading branch information
Семенов Андрей Максимович committed Sep 27, 2024
1 parent 88a1721 commit 1ab86a9
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions rectools/models/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def feature_catalogue(self) -> torch.Tensor:

def get_dense_item_features(self, items: torch.Tensor) -> torch.Tensor:
"""TODO"""
# TODO: optomize save in memory and to gpu for inference.
# TODO: Add the whole `feature_dense` to the right gpu device at once?
feature_dense = self.item_features.take(items.detach().cpu().numpy()).get_dense()
return torch.from_numpy(feature_dense).to(self.device)

Expand All @@ -125,7 +125,6 @@ def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) ->

if item_features is None:
explanation = """When `use_cat_features_embs` is True, the dataset must have item features."""
# warnings.warn(explanation) TODO
raise ValueError(explanation)

if not isinstance(item_features, SparseFeatures):
Expand Down Expand Up @@ -186,7 +185,7 @@ def __init__(
super().__init__()

if len(item_net_blocks) == 0:
raise ValueError("At least one type of net for processing items should be provided.")
raise ValueError("At least one type of net to calculate item embeddings should be provided.")

self.n_items = n_items
self.n_item_blocks = len(item_net_blocks)
Expand All @@ -195,7 +194,7 @@ def __init__(
def forward(self, items: torch.Tensor) -> torch.Tensor:
"""TODO"""
item_embs = []
# TODO: parallel
# TODO: Add functionality for parallel computing.
for idx_block in range(self.n_item_blocks):
item_emb = self.item_net_blocks[idx_block](items)
item_embs.append(item_emb)
Expand Down Expand Up @@ -229,7 +228,8 @@ def construct_nets_from_dataset(

item_net_blocks = []
for item_net in item_net_block_types:
item_net_blocks.append(item_net.from_dataset(dataset, n_factors, dropout_rate))
item_net_block = item_net.from_dataset(dataset, n_factors, dropout_rate)
item_net_blocks.append(item_net_block)

return cls(n_items, item_net_blocks)

Expand Down

0 comments on commit 1ab86a9

Please sign in to comment.