From 7ab4546a313d77fb119878af84bab2f1eb051fc7 Mon Sep 17 00:00:00 2001 From: Daria <93913290+blondered@users.noreply.github.com> Date: Fri, 30 Aug 2024 15:32:58 +0300 Subject: [PATCH] Feature/sasrec (#185) Added SasRecModel --- examples/sasrec_metrics_comp.ipynb | 1175 ++++++++++++++++++++++++++++ rectools/models/sasrec.py | 716 +++++++++++++++++ 2 files changed, 1891 insertions(+) create mode 100644 examples/sasrec_metrics_comp.ipynb create mode 100644 rectools/models/sasrec.py diff --git a/examples/sasrec_metrics_comp.ipynb b/examples/sasrec_metrics_comp.ipynb new file mode 100644 index 00000000..74d4a66f --- /dev/null +++ b/examples/sasrec_metrics_comp.ipynb @@ -0,0 +1,1175 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/data/home/dmtikhono1/git_project/sasrec/RecTools/examples\n" + ] + } + ], + "source": [ + "!pwd" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"/data/home/dmtikhono1/git_project/sasrec/RecTools/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2024-06-27T14:29:22.494979Z", + "iopub.status.busy": "2024-06-27T14:29:22.494423Z", + "iopub.status.idle": "2024-06-27T14:29:22.812073Z", + "shell.execute_reply": "2024-06-27T14:29:22.811404Z", + "shell.execute_reply.started": "2024-06-27T14:29:22.494879Z" + } + }, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import pandas as pd\n", + "from rectools import Columns\n", + "import numpy as np\n", + "import logging\n", + "import os\n", + "import torch\n", + "from lightning_fabric import seed_everything\n", + "\n", + "from rectools.models import ImplicitALSWrapperModel\n", + "from implicit.als import AlternatingLeastSquares\n", + "from rectools.models.sasrec import SasRecModel\n", + "\n", + "from rectools.metrics import MAP, calc_metrics, MeanInvUserFreq, Serendipity\n", + "from rectools.dataset import Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", + "os.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"\n", + "\n", + "logging.basicConfig()\n", + "logging.getLogger().setLevel(logging.INFO)\n", + "\n", + "logger = logging.getLogger()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# %%time\n", + "# !wget -q https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip -O data_original.zip\n", + "# !unzip -o data_original.zip\n", + "# !rm data_original.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_PATH = Path(\"data_original\")\n", + "\n", + "interactions = (\n", + " pd.read_csv(DATA_PATH / 'interactions.csv', parse_dates=[\"last_watch_dt\"])\n", + " .rename(columns={\"last_watch_dt\": \"datetime\"})\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Split dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)\n", + "\n", + "# Split to train / test\n", + "max_date = interactions[Columns.Datetime].max()\n", + "train = interactions[interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)].copy()\n", + "test = interactions[interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)].copy()\n", + "train.drop(train.query(\"total_dur < 300\").index, inplace=True)\n", + "\n", + "# drop items with less than 20 interactions in train\n", + "items = train[\"item_id\"].value_counts()\n", + "items = items[items >= 20]\n", + "items = items.index.to_list()\n", + "train = train[train[\"item_id\"].isin(items)]\n", + " \n", + "# drop users with less than 2 interactions in train\n", + "users = train[\"user_id\"].value_counts()\n", + "users = users[users >= 2]\n", + "users = users.index.to_list()\n", + "train = train[(train[\"user_id\"].isin(users))]\n", + "\n", + "# leave item features for items only from train\n", + "# items = train[\"item_id\"].drop_duplicates().to_list()\n", + "users = train[\"user_id\"].drop_duplicates().to_list()\n", + "\n", + "# drop cold users from test\n", + "test_users = test[Columns.User].unique()\n", + "cold_users = set(test[Columns.User]) - set(train[Columns.User])\n", + "test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)\n", + "\n", + "catalog=train[Columns.Item].unique()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = Dataset.construct(\n", + " interactions_df=train,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# sasrec" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 32\n" + ] + }, + { + "data": { + "text/plain": [ + "32" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "RANDOM_SEED = 32\n", + "torch.use_deterministic_algorithms(True)\n", + "seed_everything(RANDOM_SEED, workers=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "factors=128\n", + "session_maxlen=32\n", + "model = SasRecModel(\n", + " factors=factors, # 50\n", + " n_blocks=2,\n", + " n_heads=1,\n", + " dropout_rate=0.2,\n", + " use_pos_emb=True,\n", + " session_maxlen=session_maxlen,\n", + " lr=1e-3,\n", + " batch_size=128,\n", + " epochs=5,\n", + " device=\"cuda:1\",\n", + " loss=\"softmax\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:rectools.models.sasrec:training epoch 1\n", + "INFO:rectools.models.sasrec:training epoch 2\n", + "INFO:rectools.models.sasrec:training epoch 3\n", + "INFO:rectools.models.sasrec:training epoch 4\n", + "INFO:rectools.models.sasrec:training epoch 5\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 4min 50s, sys: 8.14 s, total: 4min 58s\n", + "Wall time: 4min 53s\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "%%time\n", + "model.fit(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/home/dmtikhono1/git_project/sasrec/RecTools/rectools/models/sasrec.py:522: UserWarning: 91202 target users were considered cold\n", + " because of missing known items\n", + " interactions[Columns.User] = dataset.user_id_map.convert_to_external(interactions[Columns.User])\n", + "/data/home/dmtikhono1/git_project/sasrec/RecTools/rectools/models/base.py:403: UserWarning: \n", + " Model `` doesn't support recommendations for cold users,\n", + " but some of given users are cold: they are not in the `dataset.user_id_map`\n", + " \n", + " warnings.warn(explanation)\n", + "100%|██████████| 740/740 [00:02<00:00, 267.59it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2min 15s, sys: 11min 24s, total: 13min 40s\n", + "Wall time: 22 s\n" + ] + } + ], + "source": [ + "%%time\n", + "recs = model.recommend(\n", + " users = test_users, \n", + " dataset = dataset,\n", + " k = 10,\n", + " filter_viewed = True,\n", + " on_unsupported_targets=\"warn\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "metrics_name = {\n", + " 'MAP': MAP,\n", + " 'MIUF': MeanInvUserFreq,\n", + " 'Serendipity': Serendipity\n", + " \n", + "\n", + "}\n", + "metrics = {}\n", + "for metric_name, metric in metrics_name.items():\n", + " for k in (1, 5, 10):\n", + " metrics[f'{metric_name}@{k}'] = metric(k=k)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "recs[\"item_id\"] = recs[\"item_id\"].apply(str)\n", + "test[\"item_id\"] = test[\"item_id\"].astype(str)\n", + "features_results = []\n", + "metric_values = calc_metrics(metrics, recs[[\"user_id\", \"item_id\", \"rank\"]], test, train, catalog)\n", + "metric_values[\"model\"] = \"sasrec\"\n", + "features_results.append(metric_values)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscorerank
575550377932.7551871
575551378292.6235832
5755523152972.6182093
575553337842.3957074
5755543148991.9945785
...............
224955109754437342.1089716
2249561097544138652.0898627
2249571097544144312.0583028
224958109754441511.9439509
2249591097544152971.94186410
\n", + "

947050 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " user_id item_id score rank\n", + "575550 3 7793 2.755187 1\n", + "575551 3 7829 2.623583 2\n", + "575552 3 15297 2.618209 3\n", + "575553 3 3784 2.395707 4\n", + "575554 3 14899 1.994578 5\n", + "... ... ... ... ...\n", + "224955 1097544 3734 2.108971 6\n", + "224956 1097544 13865 2.089862 7\n", + "224957 1097544 14431 2.058302 8\n", + "224958 1097544 4151 1.943950 9\n", + "224959 1097544 15297 1.941864 10\n", + "\n", + "[947050 rows x 4 columns]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# major recommend\n", + "recs.sort_values([\"user_id\", \"rank\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'MAP@1': 0.04896729054820606,\n", + " 'MAP@5': 0.08284725776567772,\n", + " 'MAP@10': 0.09202214080523476,\n", + " 'MIUF@1': 18.824620072061013,\n", + " 'MIUF@5': 18.824620072061013,\n", + " 'MIUF@10': 18.824620072061013,\n", + " 'Serendipity@1': 0.10074441687344914,\n", + " 'Serendipity@5': 0.06064590171647837,\n", + " 'Serendipity@10': 0.04443191713787037,\n", + " 'model': 'sasrec'}]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "features_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Item to item" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "target_items = [13865, 4457, 15297]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.76 s, sys: 2.4 s, total: 4.16 s\n", + "Wall time: 1.14 s\n" + ] + } + ], + "source": [ + "%%time\n", + "recs = model.recommend_to_items(\n", + " target_items = target_items, \n", + " dataset = dataset,\n", + " k = 10,\n", + " filter_itself = True,\n", + " items_to_recommend=None, #white_list,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
target_item_iditem_idscorerank
01386597280.7533471
11386541510.7402392
21386537340.7162843
31386568090.6731164
4138651420.6504365
51386518440.6465566
61386575710.6458287
713865152970.6247718
81386586360.6231939
913865104400.58220610
10445797280.6961661
11445737340.6718792
1244571420.6664783
13445786360.6639684
14445768090.6427035
15445741510.6301686
16445718440.6252827
17445775710.6186418
18445744360.6098939
19445726570.58072910
201529737340.7100781
211529797280.6907392
2215297104400.6703693
231529768090.6404654
24152971420.6385145
251529726570.6268806
2615297138650.6247717
271529786360.6097698
281529741510.6017069
291529718440.58179910
\n", + "
" + ], + "text/plain": [ + " target_item_id item_id score rank\n", + "0 13865 9728 0.753347 1\n", + "1 13865 4151 0.740239 2\n", + "2 13865 3734 0.716284 3\n", + "3 13865 6809 0.673116 4\n", + "4 13865 142 0.650436 5\n", + "5 13865 1844 0.646556 6\n", + "6 13865 7571 0.645828 7\n", + "7 13865 15297 0.624771 8\n", + "8 13865 8636 0.623193 9\n", + "9 13865 10440 0.582206 10\n", + "10 4457 9728 0.696166 1\n", + "11 4457 3734 0.671879 2\n", + "12 4457 142 0.666478 3\n", + "13 4457 8636 0.663968 4\n", + "14 4457 6809 0.642703 5\n", + "15 4457 4151 0.630168 6\n", + "16 4457 1844 0.625282 7\n", + "17 4457 7571 0.618641 8\n", + "18 4457 4436 0.609893 9\n", + "19 4457 2657 0.580729 10\n", + "20 15297 3734 0.710078 1\n", + "21 15297 9728 0.690739 2\n", + "22 15297 10440 0.670369 3\n", + "23 15297 6809 0.640465 4\n", + "24 15297 142 0.638514 5\n", + "25 15297 2657 0.626880 6\n", + "26 15297 13865 0.624771 7\n", + "27 15297 8636 0.609769 8\n", + "28 15297 4151 0.601706 9\n", + "29 15297 1844 0.581799 10" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "recs" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[25], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m() \u001b[38;5;66;03m# skip updating cells below\u001b[39;00m\n", + "\u001b[0;31mValueError\u001b[0m: " + ] + } + ], + "source": [ + "raise ValueError() # skip updating cells below" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ALS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "users = pd.read_csv(DATA_PATH / 'users.csv')\n", + "items = pd.read_csv(DATA_PATH / 'items.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Process user features to the form of a flatten dataframe\n", + "users.fillna('Unknown', inplace=True)\n", + "users = users.loc[users[Columns.User].isin(train[Columns.User])].copy()\n", + "user_features_frames = []\n", + "for feature in [\"sex\", \"age\", \"income\"]:\n", + " feature_frame = users.reindex(columns=[Columns.User, feature])\n", + " feature_frame.columns = [\"id\", \"value\"]\n", + " feature_frame[\"feature\"] = feature\n", + " user_features_frames.append(feature_frame)\n", + "user_features = pd.concat(user_features_frames)\n", + "\n", + "# Process item features to the form of a flatten dataframe\n", + "items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()\n", + "items[\"genre\"] = items[\"genres\"].str.lower().str.replace(\", \", \",\", regex=False).str.split(\",\")\n", + "genre_feature = items[[\"item_id\", \"genre\"]].explode(\"genre\")\n", + "genre_feature.columns = [\"id\", \"value\"]\n", + "genre_feature[\"feature\"] = \"genre\"\n", + "content_feature = items.reindex(columns=[Columns.Item, \"content_type\"])\n", + "content_feature.columns = [\"id\", \"value\"]\n", + "content_feature[\"feature\"] = \"content_type\"\n", + "item_features = pd.concat((genre_feature, content_feature))\n", + "\n", + "candidate_items = interactions['item_id'].drop_duplicates().astype(int)\n", + "test[\"user_id\"] = test[\"user_id\"].astype(int)\n", + "test[\"item_id\"] = test[\"item_id\"].astype(int)\n", + "catalog=train[Columns.Item].unique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_no_features = Dataset.construct(\n", + " interactions_df=train,\n", + ")\n", + "\n", + "dataset_full_features = Dataset.construct(\n", + " interactions_df=train,\n", + " user_features_df=user_features,\n", + " cat_user_features=[\"sex\", \"age\", \"income\"],\n", + " item_features_df=item_features,\n", + " cat_item_features=[\"genre\", \"content_type\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "K_RECOS = 10\n", + "NUM_THREADS = 32\n", + "RANDOM_STATE = 32\n", + "ITERATIONS = 10\n", + "\n", + "def make_base_model(factors: int, regularization: float, alpha: float, fit_features_together: bool=False):\n", + " return ImplicitALSWrapperModel(\n", + " AlternatingLeastSquares(\n", + " factors=factors,\n", + " regularization=regularization,\n", + " alpha=alpha,\n", + " random_state=RANDOM_STATE,\n", + " use_gpu=False,\n", + " num_threads = NUM_THREADS,\n", + " iterations=ITERATIONS),\n", + " fit_features_together = fit_features_together,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/home/maspirina1/Tasks/RecTools/.venv/lib/python3.8/site-packages/implicit/cpu/als.py:95: RuntimeWarning: OpenBLAS is configured to use 64 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n", + " check_blas_config()\n" + ] + } + ], + "source": [ + "n_factors = 128\n", + "regularization = 0.5\n", + "alpha = 10\n", + "\n", + "model = make_base_model(factors=n_factors, regularization=regularization, alpha=alpha)\n", + "model.fit(dataset_no_features)\n", + "recos = model.recommend(\n", + " users=test_users.astype(int),\n", + " dataset=dataset_no_features,\n", + " k=K_RECOS,\n", + " filter_viewed=True,\n", + ")\n", + "metric_values = calc_metrics(metrics, recos, test, train, catalog)\n", + "metric_values[\"model\"] = \"no_features_factors_128_alpha_10_reg_0.5\"\n", + "features_results.append(metric_values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/home/maspirina1/Tasks/RecTools/rectools/dataset/features.py:424: UserWarning: Converting sparse features to dense array may cause MemoryError\n", + " warnings.warn(\"Converting sparse features to dense array may cause MemoryError\")\n" + ] + } + ], + "source": [ + "model = make_base_model(factors = n_factors, regularization=regularization, alpha=alpha, fit_features_together=True)\n", + "model.fit(dataset_full_features)\n", + "recos = model.recommend(\n", + " users=test_users.astype(int),\n", + " dataset=dataset_full_features,\n", + " k=K_RECOS,\n", + " filter_viewed=True,\n", + ")\n", + "metric_values = calc_metrics(metrics, recos, test, train, catalog)\n", + "metric_values[\"model\"] = \"full_features_factors_128_fit_together_True\"\n", + "features_results.append(metric_values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MAP@1MAP@5MAP@10MIUF@1MIUF@5MIUF@10Serendipity@1Serendipity@5Serendipity@10
model
sasrec0.0475790.0810930.09032218.82462018.82462018.8246200.0981680.0599830.044268
full_features_factors_128_fit_together_True0.0338490.0565330.0624864.3395145.3380826.0441690.0004290.0004600.000459
no_features_factors_128_alpha_10_reg_0.50.0155300.0284660.0328206.6038476.9432177.1465070.0010470.0009040.000815
\n", + "
" + ], + "text/plain": [ + " MAP@1 MAP@5 MAP@10 \\\n", + "model \n", + "sasrec 0.047579 0.081093 0.090322 \n", + "full_features_factors_128_fit_together_True 0.033849 0.056533 0.062486 \n", + "no_features_factors_128_alpha_10_reg_0.5 0.015530 0.028466 0.032820 \n", + "\n", + " MIUF@1 MIUF@5 MIUF@10 \\\n", + "model \n", + "sasrec 18.824620 18.824620 18.824620 \n", + "full_features_factors_128_fit_together_True 4.339514 5.338082 6.044169 \n", + "no_features_factors_128_alpha_10_reg_0.5 6.603847 6.943217 7.146507 \n", + "\n", + " Serendipity@1 Serendipity@5 \\\n", + "model \n", + "sasrec 0.098168 0.059983 \n", + "full_features_factors_128_fit_together_True 0.000429 0.000460 \n", + "no_features_factors_128_alpha_10_reg_0.5 0.001047 0.000904 \n", + "\n", + " Serendipity@10 \n", + "model \n", + "sasrec 0.044268 \n", + "full_features_factors_128_fit_together_True 0.000459 \n", + "no_features_factors_128_alpha_10_reg_0.5 0.000815 " + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "features_df = (\n", + " pd.DataFrame(features_results)\n", + " .set_index(\"model\")\n", + " .sort_values(by=[\"MAP@10\", \"Serendipity@10\"], ascending=False)\n", + ")\n", + "features_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/rectools/models/sasrec.py b/rectools/models/sasrec.py new file mode 100644 index 00000000..192d1c29 --- /dev/null +++ b/rectools/models/sasrec.py @@ -0,0 +1,716 @@ +import logging +import typing as tp +import warnings +from copy import deepcopy +from typing import List, Tuple + +import numpy as np +import pandas as pd +import torch +import tqdm +import typing_extensions as tpe +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as TorchDataset + +from rectools import Columns, ExternalIds +from rectools.dataset import Dataset, Interactions +from rectools.dataset.identifiers import IdMap +from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase +from rectools.models.rank import Distance, ImplicitRanker +from rectools.types import InternalIdsArray + +PADDING_VALUE = "PAD" + +logger = logging.getLogger(__name__) # TODO: remove + +# #### -------------- Net blocks -------------- #### # + + +class ItemNetBase(nn.Module): + """Base class ItemNet. Used only for type hinting.""" + + def forward(self, items: torch.Tensor) -> torch.Tensor: + """TODO""" + raise NotImplementedError() + + @classmethod + def from_dataset(cls, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> tpe.Self: + """TODO""" + raise NotImplementedError() + + def get_all_embeddings(self) -> torch.Tensor: + """TODO""" + raise NotImplementedError() + + +class TransformerLayersBase(nn.Module): + """Base class for transformer layers. Used only for type hinting.""" + + def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + """Forward""" + raise NotImplementedError() + + +class IdEmbeddingsItemNet(ItemNetBase): + """ + Base class for item embeddings. To use more complicated logic then just id embeddings inherit + from this class and pass your custom ItemNet to your model params + """ + + def __init__(self, n_factors: int, n_items: int, dropout_rate: float): + super().__init__() + + self.n_items = n_items + self.item_emb = nn.Embedding( + num_embeddings=n_items, + embedding_dim=n_factors, + padding_idx=0, + ) + self.drop_layer = nn.Dropout(dropout_rate) + + def forward(self, items: torch.Tensor) -> torch.Tensor: + """TODO""" + item_embs = self.item_emb(items) + item_embs = self.drop_layer(item_embs) + return item_embs + + @property + def catalogue(self) -> torch.Tensor: + """TODO""" + return torch.arange(0, self.n_items, device=self.item_emb.weight.device) + + def get_all_embeddings(self) -> torch.Tensor: + """TODO""" + return self.forward(self.catalogue) + + @classmethod + def from_dataset(cls, dataset: Dataset, n_factors: int, dropout_rate: float) -> tpe.Self: + """TODO""" + n_items = dataset.item_id_map.size + return cls(n_factors, n_items, dropout_rate) + + +class PointWiseFeedForward(nn.Module): + """TODO""" + + def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float) -> None: + """TODO""" + super().__init__() + self.ff_linear1 = nn.Linear(n_factors, n_factors_ff) + self.ff_dropout1 = torch.nn.Dropout(dropout_rate) + self.ff_relu = torch.nn.ReLU() + self.ff_linear2 = nn.Linear(n_factors_ff, n_factors) + self.ff_dropout2 = torch.nn.Dropout(dropout_rate) + + def forward(self, seqs: torch.Tensor) -> torch.Tensor: + """TODO""" + output = self.ff_relu(self.ff_dropout1(self.ff_linear1(seqs))) + fin = self.ff_dropout2(self.ff_linear2(output)) + return fin + + +class SasRecTransformerLayers(TransformerLayersBase): + """Exactly SASRec authors architecture but with torch MHA realisation""" + + def __init__( + self, + n_blocks: int, + n_factors: int, + n_heads: int, + dropout_rate: float, + ): + super().__init__() + self.n_blocks = n_blocks + self.multi_head_attn = nn.ModuleList( + [torch.nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True) for _ in range(n_blocks)] + ) # TODO: original architecture had another version of MHA + self.q_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) + self.ff_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) + self.feed_forward = nn.ModuleList( + [PointWiseFeedForward(n_factors, n_factors, dropout_rate) for _ in range(n_blocks)] + ) + self.last_layernorm = torch.nn.LayerNorm(n_factors, eps=1e-8) + + def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + """TODO""" + for i in range(self.n_blocks): + q = self.q_layer_norm[i](seqs) + mha_output, _ = self.multi_head_attn[i](q, seqs, seqs, attn_mask=attn_mask, need_weights=False) + seqs = q + mha_output + ff_input = self.ff_layer_norm[i](seqs) + seqs = self.feed_forward[i](ff_input) + seqs += ff_input + seqs *= timeline_mask + + seqs = self.last_layernorm(seqs) + + return seqs + + +class PreLNTransformerLayers(TransformerLayersBase): + """ + Based on https://arxiv.org/pdf/2002.04745 + On Kion open dataset didn't change metrics, even got a bit worse + But let's keep it for now + """ + + def __init__( + self, + n_blocks: int, + n_factors: int, + n_heads: int, + dropout_rate: float, + ): + super().__init__() + self.n_blocks = n_blocks + self.multi_head_attn = nn.ModuleList( + [torch.nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True) for _ in range(n_blocks)] + ) + self.mha_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) + self.mha_dropout = nn.Dropout(dropout_rate) + self.ff_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)]) + self.feed_forward = nn.ModuleList( + [PointWiseFeedForward(n_factors, n_factors, dropout_rate) for _ in range(n_blocks)] + ) + + def forward(self, seqs: torch.Tensor, timeline_mask: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + """TODO""" + for i in range(self.n_blocks): + mha_input = self.mha_layer_norm[i](seqs) + mha_output, _ = self.multi_head_attn[i]( + mha_input, mha_input, mha_input, attn_mask=attn_mask, need_weights=False + ) + mha_output = self.mha_dropout(mha_output) + seqs = seqs + mha_output + ff_input = self.ff_layer_norm[i](seqs) + ff_output = self.feed_forward[i](ff_input) + seqs = seqs + ff_output + seqs *= timeline_mask + + return seqs + + +class LearnableInversePositionalEncoding(torch.nn.Module): + """TODO""" + + def __init__(self, use_pos_emb: bool, session_maxlen: int, n_factors: int): + super().__init__() + self.pos_emb = torch.nn.Embedding(session_maxlen, n_factors) if use_pos_emb else None + + def forward(self, sessions: torch.Tensor, timeline_mask: torch.Tensor) -> torch.Tensor: + """TODO""" + batch_size, session_maxlen, _ = sessions.shape + + if self.pos_emb is not None: + # Inverse positions are appropriate for variable length sequences across different batches + # They are equal to absolute positions for fixed sequence length across different batches + positions = torch.tile( + torch.arange(session_maxlen - 1, -1, -1), (batch_size, 1) + ) # [batch_size, session_maxlen] + sessions += self.pos_emb(positions.to(sessions.device)) + + # TODO: do we need to fill padding embeds in sessions to all zeros + # or should we use the learnt padding embedding? Should we make it an option for user to decide? + sessions *= timeline_mask # [batch_size, session_maxlen, n_factors] + + return sessions + + +# #### -------------- Session Encoder -------------- #### # + + +class TransformerBasedSessionEncoder(torch.nn.Module): + """TODO""" + + def __init__( + self, + n_blocks: int, + n_factors: int, + n_heads: int, + session_maxlen: int, + dropout_rate: float, + use_pos_emb: bool = True, # TODO: add pos_encoding_type option for user to pass + use_causal_attn: bool = True, + transformer_layers_type: tp.Type[TransformerLayersBase] = SasRecTransformerLayers, + item_net_type: tp.Type[ItemNetBase] = IdEmbeddingsItemNet, + ) -> None: + super().__init__() + + self.item_model: ItemNetBase + self.pos_encoding = LearnableInversePositionalEncoding(use_pos_emb, session_maxlen, n_factors) + self.emb_dropout = torch.nn.Dropout(dropout_rate) + self.transformer_layers = transformer_layers_type( + n_blocks=n_blocks, + n_factors=n_factors, + n_heads=n_heads, + dropout_rate=dropout_rate, + ) + self.use_causal_attn = use_causal_attn + self.item_net_type = item_net_type + self.n_factors = n_factors + self.dropout_rate = dropout_rate + + def costruct_item_net(self, dataset: Dataset) -> None: + """TODO""" + self.item_model = self.item_net_type.from_dataset(dataset, self.n_factors, self.dropout_rate) + + def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> torch.Tensor: + """ + Pass user history through item embeddings and transformer blocks. + + Returns + ------- + torch.Tensor. [batch_size, session_maxlen, factors] + + """ + session_maxlen = sessions.shape[1] + attn_mask = None + if self.use_causal_attn: + attn_mask = ~torch.tril( + torch.ones((session_maxlen, session_maxlen), dtype=torch.bool, device=sessions.device) + ) + timeline_mask = (sessions != 0).unsqueeze(-1) # [batch_size, session_maxlen, 1] + seqs = item_embs[sessions] # [batch_size, session_maxlen, n_factors] + seqs = self.pos_encoding(seqs, timeline_mask) + seqs = self.emb_dropout(seqs) + seqs = self.transformer_layers(seqs, timeline_mask, attn_mask) + return seqs + + def forward( + self, + sessions: torch.Tensor, # [batch_size, session_maxlen] + ) -> torch.Tensor: + """TODO""" + item_embs = self.item_model.get_all_embeddings() # [n_items + 1, n_factors] + session_embs = self.encode_sessions(sessions, item_embs) # [batch_size, session_maxlen, n_factors] + logits = session_embs @ item_embs.T # [batch_size, session_maxlen, n_items + 1] + return logits + + +# #### -------------- Trainer -------------- #### # + + +class Trainer: + """TODO""" + + def __init__( + self, + lr: float, + epochs: int, + device: torch.device, + loss: str = "softmax", + ): + """TODO""" + self.model: TransformerBasedSessionEncoder + self.optimizer: torch.optim.Adam + self.lr = lr + self.epochs = epochs + self.device = device + self.loss_func = self._init_loss_func(loss) # TODO: move loss func to `SasRec` class + + def fit( + self, + model: TransformerBasedSessionEncoder, + fit_dataloader: DataLoader, + ) -> None: + """TODO""" + self.model = model + self.optimizer = self._init_optimizers() + self.model.to(self.device) + + self.xavier_normal_init(self.model) + self.model.train() # enable model training + + # self.model.item_model.to_device(self.device) + + epoch_start_idx = 1 + + # ce_criterion = torch.nn.CrossEntropyLoss() + # https://github.com/NVIDIA/pix2pixHD/issues/9 how could an old bug appear again... + + for epoch in range(epoch_start_idx, self.epochs + 1): + logger.info("training epoch %s", epoch) + for x, y, w in fit_dataloader: + x = x.to(self.device) # [batch_size, session_maxlen] + y = y.to(self.device) # [batch_size, session_maxlen] + w = w.to(self.device) # [batch_size, session_maxlen] + + self.train_step(x, y, w) + + def train_step(self, x: torch.Tensor, y: torch.Tensor, w: torch.Tensor) -> None: + """TODO""" + self.optimizer.zero_grad() + logits = self.model(x) # [batch_size, session_maxlen, n_items + 1] + # We are using CrossEntropyLoss with a multi-dimensional case + + # Logits must be passed in form of [batch_size, n_items + 1, session_maxlen], + # where n_items + 1 is number of classes + + # Target label indexes must be passed in a form of [batch_size, session_maxlen] + # (`0` index for "PAD" ix excluded from loss) + + # Loss output will have a shape of [batch_size, session_maxlen] + # and will have zeros for every `0` target label + loss = self.loss_func(logits.transpose(1, 2), y) # [batch_size, session_maxlen] + loss = loss * w + n = (loss > 0).to(loss.dtype) + loss = torch.sum(loss) / torch.sum(n) + loss.backward() + self.optimizer.step() + + def _init_optimizers(self) -> torch.optim.Adam: + optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.98)) + return optimizer + + def _init_loss_func(self, loss: str) -> nn.CrossEntropyLoss: + + if loss == "softmax": + return nn.CrossEntropyLoss(ignore_index=0, reduction="none") + raise ValueError(f"loss {loss} is not supported") + + def xavier_normal_init(self, model: nn.Module) -> None: + """TODO""" + for _, param in model.named_parameters(): + try: + torch.nn.init.xavier_normal_(param.data) + except ValueError: + pass + + +# #### -------------- Data Processor -------------- #### # + + +class SequenceDataset(TorchDataset): + """TODO""" + + def __init__(self, sessions: List[List[int]], weights: List[List[float]]): + self.sessions = sessions + self.weights = weights + + def __len__(self) -> int: + return len(self.sessions) + + def __getitem__(self, index: int) -> Tuple[List[int], List[float]]: + session = self.sessions[index] # [session_len] + weights = self.weights[index] # [session_len] + return session, weights + + @classmethod + def from_interactions( + cls, + interactions: pd.DataFrame, + ) -> "SequenceDataset": + """TODO""" + sessions = ( + interactions.sort_values(Columns.Datetime) + .groupby(Columns.User, sort=True)[[Columns.Item, Columns.Weight]] + .agg(list) + ) + sessions, weights = ( + sessions[Columns.Item].to_list(), + sessions[Columns.Weight].to_list(), + ) + + return cls(sessions=sessions, weights=weights) + + +class SasRecDataPreparator: + """TODO""" + + def __init__( + self, + session_maxlen: int, + batch_size: int, + item_extra_tokens: tp.Sequence[tp.Hashable] = (PADDING_VALUE,), + shuffle_train: bool = True, # not shuffling train dataloader hurts performance + train_min_user_interactions: int = 2, + ) -> None: + self.session_maxlen = session_maxlen + self.batch_size = batch_size + self.item_extra_tokens = item_extra_tokens + self.shuffle_train = shuffle_train + self.train_min_user_interactions = train_min_user_interactions + self.item_id_map: IdMap + # TODO: add SequenceDatasetType for fit and recommend + + @property + def n_item_extra_tokens(self) -> int: + """TODO""" + return len(self.item_extra_tokens) + + def get_known_item_ids(self) -> np.ndarray: + """TODO""" + return self.item_id_map.get_external_sorted_by_internal()[self.n_item_extra_tokens :] + + def get_known_items_sorted_internal_ids(self) -> np.ndarray: + """TODO""" + return self.item_id_map.get_sorted_internal()[self.n_item_extra_tokens :] + + def process_dataset_train(self, dataset: Dataset) -> Dataset: + """TODO""" + interactions = dataset.get_raw_interactions() + + # Filter interactions + user_stats = interactions[Columns.User].value_counts() + users = user_stats[user_stats >= self.train_min_user_interactions].index + interactions = interactions[(interactions[Columns.User].isin(users))] + interactions = interactions.sort_values(Columns.Datetime).groupby(Columns.User).tail(self.session_maxlen + 1) + + # Construct dataset + # TODO: user features and item features are dropped for now + user_id_map = IdMap.from_values(interactions[Columns.User].values) + item_id_map = IdMap.from_values(self.item_extra_tokens) + item_id_map = item_id_map.add_ids(interactions[Columns.Item]) + interactions = Interactions.from_raw(interactions, user_id_map, item_id_map) + dataset = Dataset(user_id_map, item_id_map, interactions) + + self.item_id_map = dataset.item_id_map + return dataset + + def _collate_fn_train( + self, + batch: List[Tuple[List[int], List[float]]], + ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.FloatTensor]: + """ + Truncate each session from right to keep (session_maxlen+1) last items. + Do left padding until (session_maxlen+1) is reached. + Split to `x`, `y`, and `yw`. + """ + batch_size = len(batch) + x = np.zeros((batch_size, self.session_maxlen)) + y = np.zeros((batch_size, self.session_maxlen)) + yw = np.zeros((batch_size, self.session_maxlen)) + for i, (ses, ses_weights) in enumerate(batch): + x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_maxlen] + y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_maxlen] + yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_maxlen] + return torch.LongTensor(x), torch.LongTensor(y), torch.FloatTensor(yw) + + def get_dataloader_train(self, processed_dataset: Dataset) -> DataLoader: + """TODO""" + sequence_dataset = SequenceDataset.from_interactions(processed_dataset.interactions.df) + train_dataloader = DataLoader( + sequence_dataset, collate_fn=self._collate_fn_train, batch_size=self.batch_size, shuffle=self.shuffle_train + ) + return train_dataloader + + def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset: + """ + Filter out interactions and adapt id maps. + Final dataset will consist only of model known items during fit and only of required + (and supported) target users for recommendations. + All users beyond target users for recommendations are dropped. + All target users that do not have at least one known item in interactions are dropped. + Final user_id_map is an enumerated list of supported (filtered) target users + Final item_id_map is model item_id_map constructed during training + """ + # Filter interactions in dataset internal ids + interactions = dataset.interactions.df + users_internal = dataset.user_id_map.convert_to_internal(users, strict=False) + items_internal = dataset.item_id_map.convert_to_internal(self.get_known_item_ids(), strict=False) + interactions = interactions[interactions[Columns.User].isin(users_internal)] # todo: fast_isin + interactions = interactions[interactions[Columns.Item].isin(items_internal)] + + # Convert to external ids + interactions[Columns.Item] = dataset.item_id_map.convert_to_external(interactions[Columns.Item]) + interactions[Columns.User] = dataset.user_id_map.convert_to_external(interactions[Columns.User]) + + # Prepare new user id mapping + rec_user_id_map = IdMap.from_values(interactions[Columns.User]) + + # Construct dataset + # TODO: For now features are dropped because model doesn't support them + n_filtered = len(users) - rec_user_id_map.size + if n_filtered > 0: + explanation = f"""{n_filtered} target users were considered cold + because of missing known items""" + warnings.warn(explanation) + filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map) + filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions) + return filtered_dataset + + def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: + """ + Filter out interactions and adapt id maps. + Final dataset will consist only of model known items during fit. + Final user_id_map is the same as dataset original + Final item_id_map is model item_id_map constructed during training + """ + # TODO: optimize by filtering in internal ids + # TODO: For now features are dropped because model doesn't support them + interactions = dataset.get_raw_interactions() + interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())] + filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map) + filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions) + return filtered_dataset + + def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> torch.LongTensor: + """Right truncation, left padding to session_maxlen""" + x = np.zeros((len(batch), self.session_maxlen)) + for i, (ses, _) in enumerate(batch): + x[i, -len(ses) :] = ses[-self.session_maxlen :] + return torch.LongTensor(x) + + def get_dataloader_recommend(self, dataset: Dataset) -> DataLoader: + """TODO""" + sequence_dataset = SequenceDataset.from_interactions(dataset.interactions.df) + recommend_dataloader = DataLoader( + sequence_dataset, batch_size=self.batch_size, collate_fn=self._collate_fn_recommend, shuffle=False + ) + return recommend_dataloader + + +# #### -------------- SASRec Model -------------- #### # + + +class SasRecModel(ModelBase): # pylint: disable=too-many-instance-attributes + """TODO""" + + def __init__( + self, + session_maxlen: int, + lr: float, + batch_size: int, + epochs: int, + device: str, + n_blocks: int, + n_factors: int, + n_heads: int, + dropout_rate: float, + use_pos_emb: bool = True, + loss: str = "softmax", + verbose: int = 0, + cpu_n_threads: int = 0, + transformer_layers_type: tp.Type[TransformerLayersBase] = SasRecTransformerLayers, # SASRec authors net + item_net_type: tp.Type[ItemNetBase] = IdEmbeddingsItemNet, # item embeddings on ids + ): + super().__init__(verbose=verbose) + self.device = torch.device(device) + self.n_threads = cpu_n_threads + self.model: TransformerBasedSessionEncoder + self._model = TransformerBasedSessionEncoder( + n_blocks=n_blocks, + n_factors=n_factors, + n_heads=n_heads, + session_maxlen=session_maxlen, + dropout_rate=dropout_rate, + use_pos_emb=use_pos_emb, + use_causal_attn=True, + transformer_layers_type=transformer_layers_type, + item_net_type=item_net_type, + ) + self.trainer = Trainer( # TODO: move to lightning trainer and add option to pass initialized trainer + lr=lr, + epochs=epochs, + device=self.device, + loss=loss, + ) + self.data_preparator = SasRecDataPreparator(session_maxlen, batch_size) # TODO: add data_preparator_type + self.u2i_dist = Distance.DOT + self.i2i_dist = Distance.COSINE + + def _fit( + self, + dataset: Dataset, + ) -> None: + processed_dataset = self.data_preparator.process_dataset_train(dataset) + train_dataloader = self.data_preparator.get_dataloader_train(processed_dataset) + + self.model = deepcopy(self._model) # TODO: check that it works + self.model.costruct_item_net(processed_dataset) + + self.trainer.fit(self.model, train_dataloader) + self.model = self.trainer.model + + def _custom_transform_dataset_u2i( + self, dataset: Dataset, users: ExternalIds, on_unsupported_targets: ErrorBehaviour + ) -> Dataset: + return self.data_preparator.transform_dataset_u2i(dataset, users) + + def _custom_transform_dataset_i2i( + self, dataset: Dataset, target_items: ExternalIds, on_unsupported_targets: ErrorBehaviour + ) -> Dataset: + return self.data_preparator.transform_dataset_i2i(dataset) + + def _recommend_u2i( + self, + user_ids: InternalIdsArray, + dataset: Dataset, # [n_rec_users x n_items + 1] + k: int, + filter_viewed: bool, + sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], # model_internal + ) -> InternalRecoTriplet: + + if sorted_item_ids_to_recommend is None: # TODO: move to _get_sorted_item_ids_to_recommend + sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() # model internal + + self.model = self.model.eval() + self.model.to(self.device) + + # Dataset has already been filtered and adapted to known item_id_map + recommend_dataloader = self.data_preparator.get_dataloader_recommend(dataset) + + session_embs = [] + item_embs = self.model.item_model.get_all_embeddings() # [n_items + 1, n_factors] + with torch.no_grad(): + for x_batch in tqdm.tqdm(recommend_dataloader): # TODO: from tqdm.auto import tqdm. Also check `verbose`` + x_batch = x_batch.to(self.device) # [batch_size, session_maxlen] + encoded = self.model.encode_sessions(x_batch, item_embs)[:, -1, :] # [batch_size, n_factors] + encoded = encoded.detach().cpu().numpy() + session_embs.append(encoded) + + user_embs = np.concatenate(session_embs, axis=0) + user_embs = user_embs[user_ids] + item_embs_np = item_embs.detach().cpu().numpy() + + ranker = ImplicitRanker( + self.u2i_dist, + user_embs, # [n_rec_users, n_factors] + item_embs_np, # [n_items + 1, n_factors] + ) + if filter_viewed: + user_items = dataset.get_user_item_matrix(include_weights=False) + ui_csr_for_filter = user_items[user_ids] + else: + ui_csr_for_filter = None + + # TODO: When filter_viewed is not needed and user has GPU, torch DOT and topk should be faster + + user_ids_indices, all_reco_ids, all_scores = ranker.rank( + subject_ids=np.arange(user_embs.shape[0]), # n_rec_users + k=k, + filter_pairs_csr=ui_csr_for_filter, # [n_rec_users x n_items + 1] + sorted_object_whitelist=sorted_item_ids_to_recommend, # model_internal + num_threads=self.n_threads, + ) + all_target_ids = user_ids[user_ids_indices] + + return all_target_ids, all_reco_ids, all_scores # n_rec_users, model_internal, scores + + def _recommend_i2i( + self, + target_ids: InternalIdsArray, # model internal + dataset: Dataset, + k: int, + sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], + ) -> InternalRecoTriplet: + if sorted_item_ids_to_recommend is None: + sorted_item_ids_to_recommend = self.data_preparator.get_known_items_sorted_internal_ids() + item_embs = self.model.item_model.get_all_embeddings().detach().cpu().numpy() # [n_items + 1, n_factors] + + # TODO: i2i reco do not need filtering viewed. And user most of the times has GPU + # Should we use torch dot and topk? Should be faster + + ranker = ImplicitRanker( + self.i2i_dist, + item_embs, # [n_items + 1, n_factors] + item_embs, # [n_items + 1, n_factors] + ) + return ranker.rank( + subject_ids=target_ids, # model internal + k=k, + filter_pairs_csr=None, + sorted_object_whitelist=sorted_item_ids_to_recommend, # model internal + num_threads=0, + )