Skip to content

Commit

Permalink
Add DLRM Model
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv committed Jul 2, 2023
1 parent 86d0a34 commit 28040dd
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 15 deletions.
2 changes: 2 additions & 0 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.models.base import Model
from merlin.models.torch.models.ranking import DLRMModel
from merlin.models.torch.outputs.base import ModelOutput
from merlin.models.torch.outputs.classification import BinaryOutput
from merlin.models.torch.outputs.regression import RegressionOutput
Expand Down Expand Up @@ -53,4 +54,5 @@
"Stack",
"schema",
"DLRMBlock",
"DLRMModel",
]
5 changes: 5 additions & 0 deletions merlin/models/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ def compute_loss(
else:
raise ValueError(f"Unknown 'predictions' type: {type(predictions)}")

if _targets.size() != _predictions.size():
_targets = _targets.view(_predictions.size())
if _targets.type() != _predictions.type():
_targets = _targets.type_as(_predictions)

results["loss"] = results["loss"] + model_out.loss(_predictions, _targets) / len(
model_outputs
)
Expand Down
75 changes: 75 additions & 0 deletions merlin/models/torch/models/ranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Optional

from torch import nn

from merlin.models.torch.block import Block
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.models.base import Model
from merlin.models.torch.outputs.tabular import TabularOutputBlock
from merlin.schema import Schema


def DLRMModel(
schema: Schema,
dim: int,
bottom_block: Block,
top_block: Optional[Block] = None,
interaction: Optional[nn.Module] = None,
output_block: Optional[Block] = None,
) -> Model:
"""
The Deep Learning Recommendation Model (DLRM) as proposed in Naumov, et al. [1]
Parameters
----------
schema : Schema
The schema to use for selection.
dim : int
The dimensionality of the output vectors.
bottom_block : Block
Block to pass the continuous features to.
Note that, the output dimensionality of this block must be equal to ``dim``.
top_block : Block, optional
An optional upper-level block of the model.
interaction : nn.Module, optional
Interaction module for DLRM.
If not provided, DLRMInteraction will be used by default.
output_block : Block, optional
The output block of the model, by default None.
If None, a TabularOutputBlock with schema and default initializations is used.
Returns
-------
Model
An instance of Model class representing the fully formed DLRM.
Example usage
-------------
>>> model = mm.DLRMModel(
... schema,
... dim=64,
... bottom_block=mm.MLPBlock([256, 64]),
... output_block=BinaryOutput(ColumnSchema("target")))
>>> trainer = pl.Trainer()
>>> model.initialize(dataloader)
>>> trainer.fit(model, dataloader)
References
----------
[1] Naumov, Maxim, et al. "Deep learning recommendation model for
personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019).
"""
if output_block is None:
output_block = TabularOutputBlock(schema, init="defaults")

dlrm_body = DLRMBlock(
schema,
dim,
bottom_block,
top_block=top_block,
interaction=interaction,
)

model = Model(dlrm_body, output_block)

return model
2 changes: 1 addition & 1 deletion merlin/models/torch/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def initialize(module, data: Union[Dataset, Loader, Batch], dtype=torch.float32)
if hasattr(module, "model_outputs"):
for model_out in module.model_outputs():
for metric in model_out.metrics:
metric.to(batch.device())
metric.to(device=batch.device())

from merlin.models.torch import schema

Expand Down
36 changes: 22 additions & 14 deletions tests/unit/torch/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
#
import pandas as pd
import pytest
import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics import AUROC, Accuracy, Precision, Recall

import merlin.models.torch as mm
from merlin.dataloader.torch import Loader
from merlin.io import Dataset
from merlin.models.torch.batch import Batch
from merlin.models.torch.batch import Batch, sample_batch
from merlin.models.torch.models.base import compute_loss
from merlin.models.torch.utils import module_utils
from merlin.schema import ColumnSchema
Expand Down Expand Up @@ -200,22 +201,29 @@ def test_no_output_schema(self):
with pytest.raises(ValueError, match="Could not get output schema of PlusOne()"):
mm.schema.output(model)

# def test_train_classification(self, music_streaming_data):
# schema = music_streaming_data.schema.without(["user_genres", "like", "item_genres"])
# music_streaming_data.schema = schema
def test_train_classification_with_lightning_trainer(self, music_streaming_data, batch_size=16):
schema = music_streaming_data.schema.select_by_name(
["item_id", "user_id", "user_age", "item_genres", "click"]
)
music_streaming_data.schema = schema

# model = mm.Model(
# mm.TabularInputBlock(schema),
# mm.MLPBlock([4, 2]),
# mm.BinaryOutput(schema.select_by_name("click").first),
# schema=schema,
# )
model = mm.Model(
mm.TabularInputBlock(schema, init="defaults"),
mm.MLPBlock([4, 2]),
mm.BinaryOutput(schema.select_by_name("click").first),
)

trainer = pl.Trainer(max_epochs=1, devices=1)

with Loader(music_streaming_data, batch_size=batch_size) as loader:
model.initialize(loader)
trainer.fit(model, loader)

# trainer = pl.Trainer(max_epochs=1)
assert trainer.logged_metrics["train_loss"] > 0.0
assert trainer.num_training_batches == 7 # 100 rows // 16 per batch + 1 for last batch

# with Loader(music_streaming_data, batch_size=16) as loader:
# model.initialize(loader)
# trainer.fit(model, loader)
batch = sample_batch(music_streaming_data, batch_size)
_ = module_utils.module_test(model, batch)


class TestComputeLoss:
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/torch/models/test_ranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import pytorch_lightning as pl

import merlin.models.torch as mm
from merlin.dataloader.torch import Loader
from merlin.models.torch.batch import sample_batch
from merlin.models.torch.utils import module_utils
from merlin.schema import ColumnSchema


@pytest.mark.parametrize("output_block", [None, mm.BinaryOutput(ColumnSchema("click"))])
class TestDLRMModel:
def test_train_dlrm_with_lightning_loader(
self, music_streaming_data, output_block, dim=2, batch_size=16
):
schema = music_streaming_data.schema.select_by_name(
["item_id", "user_id", "user_age", "item_genres", "click"]
)
music_streaming_data.schema = schema

model = mm.DLRMModel(
schema,
dim=dim,
bottom_block=mm.MLPBlock([4, 2]),
top_block=mm.MLPBlock([4, 2]),
output_block=output_block,
)

trainer = pl.Trainer(max_epochs=1, devices=1)

with Loader(music_streaming_data, batch_size=batch_size) as train_loader:
model.initialize(train_loader)
trainer.fit(model, train_loader)

assert trainer.logged_metrics["train_loss"] > 0.0

batch = sample_batch(music_streaming_data, batch_size)
_ = module_utils.module_test(model, batch)

0 comments on commit 28040dd

Please sign in to comment.