Skip to content

Commit

Permalink
Merge branch 'main' into add_pytorch_DLRM_example
Browse files Browse the repository at this point in the history
  • Loading branch information
rnyak authored Jul 11, 2023
2 parents 33379da + 8a9e5ea commit 458f3ce
Show file tree
Hide file tree
Showing 43 changed files with 2,480 additions and 143 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cpu-horovod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ jobs:
if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then
extra_pytest_markers="and changed"
fi
EXTRA_PYTEST_MARKERS="$extra_pytest_markers" MERLIN_BRANCH="$merlin_branch" COMPARE_BRANCH=${{ github.base_ref }} tox -e horovod-cpu
PYTEST_MARKERS="$extra_pytest_markers" MERLIN_BRANCH="$merlin_branch" COMPARE_BRANCH=${{ github.base_ref }} tox -e horovod-cpu
45 changes: 44 additions & 1 deletion .github/workflows/gpu-multi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,47 @@ jobs:
if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then
extra_pytest_markers="and changed"
fi
cd ${{ github.workspace }}; EXTRA_PYTEST_MARKERS=$extra_pytest_markers MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e multi-gpu
cd ${{ github.workspace }}; PYTEST_MARKERS="multigpu $extra_pytest_markers" MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e gpu,horovod-gpu
check-changes-torch:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install GitPython
pip install . --no-deps
- name: Get changed backends
id: backend_check
run: |
echo "changed=$(python ci/get_changed_backends.py --backend torch --branch ${{github.base_ref}})" >> "$GITHUB_OUTPUT"
outputs:
needs_testing: ${{ steps.backend_check.outputs.changed }}

torch:
needs: check-changes-torch
if: ${{needs.check-changes-torch.outputs.needs_testing == 'true' || github.ref == 'refs/heads/main'}}
runs-on: 2GPU

steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Run tests
run: |
ref_type=${{ github.ref_type }}
branch=main
if [[ $ref_type == "tag"* ]]
then
git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release*
branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///')
fi
if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then
extra_pytest_markers="and changed"
fi
cd ${{ github.workspace }}; PYTEST_MARKERS="multigpu $extra_pytest_markers" MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e gpu
2 changes: 1 addition & 1 deletion .github/workflows/gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then
extra_pytest_markers="and changed"
fi
cd ${{ github.workspace }}; PYTEST_MARKERS="unit and not (examples or integration or notebook) $extra_pytest_markers" MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e gpu
cd ${{ github.workspace }}; PYTEST_MARKERS="unit and not (examples or integration or notebook) and (singlegpu or not multigpu) $extra_pytest_markers" MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e gpu
tests-examples:
runs-on: 1GPU
Expand Down
23 changes: 20 additions & 3 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,23 @@

from merlin.models.torch import schema
from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block, ParallelBlock, ResidualBlock, ShortcutBlock
from merlin.models.torch.block import (
Block,
ParallelBlock,
ResidualBlock,
ShortcutBlock,
repeat,
repeat_parallel,
repeat_parallel_like,
)
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.blocks.experts import CGCBlock, MMOEBlock, PLEBlock
from merlin.models.torch.blocks.mlp import MLPBlock
from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.models.base import Model
from merlin.models.torch.models.ranking import DLRMModel
from merlin.models.torch.models.base import Model, MultiLoader
from merlin.models.torch.models.ranking import DCNModel, DLRMModel
from merlin.models.torch.outputs.base import ModelOutput
from merlin.models.torch.outputs.classification import (
BinaryOutput,
Expand All @@ -48,6 +57,7 @@
"DLRMBlock",
"MLPBlock",
"Model",
"MultiLoader",
"EmbeddingTable",
"EmbeddingTables",
"ParallelBlock",
Expand All @@ -66,6 +76,9 @@
"Concat",
"Stack",
"schema",
"repeat",
"repeat_parallel",
"repeat_parallel_like",
"CategoricalOutput",
"CategoricalTarget",
"EmbeddingTablePrediction",
Expand All @@ -75,4 +88,8 @@
"target_schema",
"DLRMBlock",
"DLRMModel",
"DCNModel",
"MMOEBlock",
"PLEBlock",
"CGCBlock",
]
11 changes: 0 additions & 11 deletions merlin/models/torch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

from merlin.dataloader.torch import Loader
from merlin.io import Dataset
from merlin.models.torch import schema
from merlin.schema import Schema


@torch.jit.script
Expand Down Expand Up @@ -373,12 +371,3 @@ def sample_features(
"""

return sample_batch(data, batch_size, shuffle).features


@schema.output_schema.register_tensor(Batch)
def _(input):
output_schema = Schema()
output_schema += schema.output_schema.tensors(input.features)
output_schema += schema.output_schema.tensors(input.targets)

return output_schema
87 changes: 79 additions & 8 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import textwrap
from copy import deepcopy
from typing import Dict, Optional, Tuple, TypeVar, Union
from typing import Dict, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable

import torch
from torch import nn
Expand All @@ -31,6 +31,12 @@
from merlin.schema import Schema


@runtime_checkable
class HasKeys(Protocol):
def keys(self):
...


class Block(BlockContainer, RegistryMixin, TraversableMixin):
"""A base-class that calls it's modules sequentially.
Expand Down Expand Up @@ -87,15 +93,13 @@ def repeat(self, n: int = 1, name=None) -> "Block":
Block
The new block created by repeating the current block `n` times.
"""
if not isinstance(n, int):
raise TypeError("n must be an integer")
return repeat(self, n, name=name)

if n < 1:
raise ValueError("n must be greater than 0")
def repeat_parallel(self, n: int = 1, name=None) -> "ParallelBlock":
return repeat_parallel(self, n, name=name)

repeats = [self.copy() for _ in range(n - 1)]

return Block(self, *repeats, name=name)
def repeat_parallel_like(self, like: HasKeys, name=None) -> "ParallelBlock":
return repeat_parallel_like(self, like, name=name)

def copy(self) -> "Block":
"""
Expand Down Expand Up @@ -342,6 +346,9 @@ def replace(self, pre=None, branches=None, post=None) -> "ParallelBlock":

return output

def keys(self):
return self.branches.keys()

def leaf(self) -> nn.Module:
if self.pre:
raise ValueError("Cannot call leaf() on a ParallelBlock with a pre-processing stage")
Expand Down Expand Up @@ -567,6 +574,70 @@ def forward(
return to_return


def _validate_n(n: int) -> None:
if not isinstance(n, int):
raise TypeError("n must be an integer")

if n < 1:
raise ValueError("n must be greater than 0")


def repeat(module: nn.Module, n: int = 1, name=None) -> Block:
"""
Creates a new block by repeating the current block `n` times.
Each repetition is a deep copy of the current block.
Parameters
----------
module: nn.Module
The module to be repeated.
n : int
The number of times to repeat the current block.
name : Optional[str], default = None
The name for the new block. If None, no name is assigned.
Returns
-------
Block
The new block created by repeating the current block `n` times.
"""
_validate_n(n)

repeats = [module.copy() if hasattr(module, "copy") else deepcopy(module) for _ in range(n - 1)]

return Block(module, *repeats, name=name)


def repeat_parallel(module: nn.Module, n: int = 1, agg=None) -> ParallelBlock:
_validate_n(n)

branches = {"0": module}
branches.update(
{str(n): module.copy() if hasattr(module, "copy") else deepcopy(module) for n in range(n)}
)

output = ParallelBlock(branches)
if agg:
output.append(Block.parse(agg))

return output


def repeat_parallel_like(module: nn.Module, like: HasKeys, agg=None) -> ParallelBlock:
branches = {}
for i, key in enumerate(like.keys()):
if i == 0:
branches[str(key)] = module
else:
branches[str(key)] = module.copy() if hasattr(module, "copy") else deepcopy(module)

output = ParallelBlock(branches)
if agg:
output.append(Block.parse(agg))

return output


def get_pre(module: nn.Module) -> BlockContainer:
if hasattr(module, "pre"):
return module.pre
Expand Down
5 changes: 4 additions & 1 deletion merlin/models/torch/blocks/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import nn
from torch.nn.modules.lazy import LazyModuleMixin

from merlin.models.torch.batch import Batch
from merlin.models.torch.block import Block
from merlin.models.torch.transforms.agg import Concat
from merlin.models.utils.doc_utils import docstring_parameter
Expand Down Expand Up @@ -127,7 +128,9 @@ def with_low_rank(cls, depth: int, low_rank: nn.Module) -> "CrossBlock":

return cls(*(Block(deepcopy(low_rank), *block) for block in cls.with_depth(depth)))

def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
) -> torch.Tensor:
"""Forward-pass of the cross-block.
Parameters
Expand Down
Loading

0 comments on commit 458f3ce

Please sign in to comment.