Skip to content

Commit

Permalink
Adding MMOE & PLE (#1173)
Browse files Browse the repository at this point in the history
* First pass over MMOEBlock & PLEBlock

* Adding some simple tests for MMOEBlock

* Adding some doc-strings

* Fixing failing tests

* Increase test-coverage

* Improving doc-strings

* Fixing failing tests
  • Loading branch information
marcromeyn authored Jul 11, 2023
1 parent 190cd48 commit d2113e8
Show file tree
Hide file tree
Showing 5 changed files with 525 additions and 10 deletions.
17 changes: 16 additions & 1 deletion merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,17 @@

from merlin.models.torch import schema
from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block, ParallelBlock, ResidualBlock, ShortcutBlock
from merlin.models.torch.block import (
Block,
ParallelBlock,
ResidualBlock,
ShortcutBlock,
repeat,
repeat_parallel,
repeat_parallel_like,
)
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.blocks.experts import CGCBlock, MMOEBlock, PLEBlock
from merlin.models.torch.blocks.mlp import MLPBlock
from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
Expand Down Expand Up @@ -67,6 +76,9 @@
"Concat",
"Stack",
"schema",
"repeat",
"repeat_parallel",
"repeat_parallel_like",
"CategoricalOutput",
"CategoricalTarget",
"EmbeddingTablePrediction",
Expand All @@ -77,4 +89,7 @@
"DLRMBlock",
"DLRMModel",
"DCNModel",
"MMOEBlock",
"PLEBlock",
"CGCBlock",
]
87 changes: 79 additions & 8 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import textwrap
from copy import deepcopy
from typing import Dict, Optional, Tuple, TypeVar, Union
from typing import Dict, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable

import torch
from torch import nn
Expand All @@ -31,6 +31,12 @@
from merlin.schema import Schema


@runtime_checkable
class HasKeys(Protocol):
def keys(self):
...


class Block(BlockContainer, RegistryMixin, TraversableMixin):
"""A base-class that calls it's modules sequentially.
Expand Down Expand Up @@ -87,15 +93,13 @@ def repeat(self, n: int = 1, name=None) -> "Block":
Block
The new block created by repeating the current block `n` times.
"""
if not isinstance(n, int):
raise TypeError("n must be an integer")
return repeat(self, n, name=name)

if n < 1:
raise ValueError("n must be greater than 0")
def repeat_parallel(self, n: int = 1, name=None) -> "ParallelBlock":
return repeat_parallel(self, n, name=name)

repeats = [self.copy() for _ in range(n - 1)]

return Block(self, *repeats, name=name)
def repeat_parallel_like(self, like: HasKeys, name=None) -> "ParallelBlock":
return repeat_parallel_like(self, like, name=name)

def copy(self) -> "Block":
"""
Expand Down Expand Up @@ -342,6 +346,9 @@ def replace(self, pre=None, branches=None, post=None) -> "ParallelBlock":

return output

def keys(self):
return self.branches.keys()

def leaf(self) -> nn.Module:
if self.pre:
raise ValueError("Cannot call leaf() on a ParallelBlock with a pre-processing stage")
Expand Down Expand Up @@ -567,6 +574,70 @@ def forward(
return to_return


def _validate_n(n: int) -> None:
if not isinstance(n, int):
raise TypeError("n must be an integer")

if n < 1:
raise ValueError("n must be greater than 0")


def repeat(module: nn.Module, n: int = 1, name=None) -> Block:
"""
Creates a new block by repeating the current block `n` times.
Each repetition is a deep copy of the current block.
Parameters
----------
module: nn.Module
The module to be repeated.
n : int
The number of times to repeat the current block.
name : Optional[str], default = None
The name for the new block. If None, no name is assigned.
Returns
-------
Block
The new block created by repeating the current block `n` times.
"""
_validate_n(n)

repeats = [module.copy() if hasattr(module, "copy") else deepcopy(module) for _ in range(n - 1)]

return Block(module, *repeats, name=name)


def repeat_parallel(module: nn.Module, n: int = 1, agg=None) -> ParallelBlock:
_validate_n(n)

branches = {"0": module}
branches.update(
{str(n): module.copy() if hasattr(module, "copy") else deepcopy(module) for n in range(n)}
)

output = ParallelBlock(branches)
if agg:
output.append(Block.parse(agg))

return output


def repeat_parallel_like(module: nn.Module, like: HasKeys, agg=None) -> ParallelBlock:
branches = {}
for i, key in enumerate(like.keys()):
if i == 0:
branches[str(key)] = module
else:
branches[str(key)] = module.copy() if hasattr(module, "copy") else deepcopy(module)

output = ParallelBlock(branches)
if agg:
output.append(Block.parse(agg))

return output


def get_pre(module: nn.Module) -> BlockContainer:
if hasattr(module, "pre"):
return module.pre
Expand Down
Loading

0 comments on commit d2113e8

Please sign in to comment.