Skip to content

Commit

Permalink
Some bug fixes + 100% test-coverage for block.py
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed Jul 3, 2023
1 parent ab0f635 commit 7c71ac8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 27 deletions.
31 changes: 21 additions & 10 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import textwrap
from copy import deepcopy
from typing import Dict, Optional, Tuple, TypeVar, Union
from typing import Dict, Final, Optional, Tuple, TypeVar, Union

import torch
from torch import nn
Expand All @@ -26,7 +26,7 @@
from merlin.models.torch.batch import Batch
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.registry import registry
from merlin.models.torch.utils.traversal_utils import TraversableMixin, leaf
from merlin.models.torch.utils.traversal_utils import TraversableMixin
from merlin.models.utils.registry import RegistryMixin
from merlin.schema import Schema

Expand All @@ -43,6 +43,7 @@ class Block(BlockContainer, RegistryMixin, TraversableMixin):
"""

registry = registry
is_block: Final[bool] = True

def __init__(self, *module: nn.Module, name: Optional[str] = None):
super().__init__(*module, name=name)
Expand Down Expand Up @@ -350,10 +351,7 @@ def leaf(self) -> nn.Module:
raise ValueError("Cannot call leaf() on a ParallelBlock with multiple branches")

first = list(self.branches.values())[0]
if hasattr(first, "leaf"):
return first.leaf()

return leaf(first)
return first.leaf()

def __getitem__(self, idx: Union[slice, int, str]):
if isinstance(idx, str) and idx in self.branches:
Expand Down Expand Up @@ -451,26 +449,30 @@ def __init__(
self,
*module: nn.Module,
name: Optional[str] = None,
propagate_shortcut: bool = False,
shortcut_name: str = "shortcut",
output_name: str = "output",
):
super().__init__(*module, name=name)
self.shortcut_name = shortcut_name
self.output_name = output_name
self.propagate_shortcut = propagate_shortcut

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
) -> Dict[str, torch.Tensor]:
if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]):
if self.shortcut_name not in inputs:
raise ValueError(f"Shortcut name {self.shortcut_name} not found in inputs {inputs}")
raise RuntimeError(
f"Shortcut name {self.shortcut_name} not found in inputs {inputs}"
)
shortcut = inputs[self.shortcut_name]
else:
shortcut = inputs

output = inputs
for module in self.values:
if getattr(module, "accepts_dict", False) or hasattr(module, "values"):
if self.propagate_shortcut:
if torch.jit.isinstance(output, Dict[str, torch.Tensor]):
module_output = module(output, batch=batch)
else:
Expand All @@ -486,7 +488,7 @@ def forward(
elif torch.jit.isinstance(module_output, Dict[str, torch.Tensor]):
output = module_output[self.output_name]
else:
raise ValueError(
raise RuntimeError(
f"Module {module} must return a tensor or a dict ",
f"with key {self.output_name}",
)
Expand All @@ -495,7 +497,16 @@ def forward(
output, Dict[str, torch.Tensor]
):
output = output[self.output_name]
output = module(output, batch=batch)
_output = module(output, batch=batch)
if torch.jit.isinstance(_output, torch.Tensor) or torch.jit.isinstance(
_output, Dict[str, torch.Tensor]
):
output = _output
else:
raise RuntimeError(
f"Module {module} must return a tensor or a dict ",
f"with key {self.output_name}",
)

to_return = {self.shortcut_name: shortcut}
if torch.jit.isinstance(output, Dict[str, torch.Tensor]):
Expand Down
78 changes: 61 additions & 17 deletions tests/unit/torch/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ def test_insertion(self):

assert torch.equal(outputs, inputs + 2)

# block.append(PlusOne(), link="residual")
# assert isinstance(block[-1], link.Residual)

def test_copy(self):
block = Block(PlusOne())

Expand All @@ -95,19 +92,6 @@ def test_repeat(self):
with pytest.raises(ValueError, match="n must be greater than 0"):
block.repeat(0)

# def test_repeat_with_link(self):
# block = Block(PlusOne())

# repeated = block.repeat(2, link="residual")
# assert isinstance(repeated, Block)
# assert len(repeated) == 2
# assert isinstance(repeated[-1], link.Residual)

# inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
# outputs = module_utils.module_test(repeated, inputs)

# assert torch.equal(outputs, (inputs + 1) + (inputs + 1) + 1)

def test_from_registry(self):
@Block.registry.register("my_block")
class MyBlock(Block):
Expand Down Expand Up @@ -289,6 +273,22 @@ def test_input_schema_pre(self):
assert input_schema == mm.schema.input(pb2)
assert mm.schema.output(pb2) == mm.schema.output(pb)

def test_leaf(self):
block = ParallelBlock({"a": PlusOne()})

assert isinstance(block.leaf(), PlusOne)

block.branches["b"] = PlusOne()
with pytest.raises(ValueError):
block.leaf()

block.prepend(PlusOne())
with pytest.raises(ValueError):
block.leaf()

block = ParallelBlock({"a": nn.Sequential(PlusOne())})
assert isinstance(block.leaf(), PlusOne)


class TestResidualBlock:
def test_forward(self):
Expand Down Expand Up @@ -326,10 +326,54 @@ def test_nesting(self):
def test_convert(self):
block = Block(PlusOne())
shortcut = ShortcutBlock(*block)
nested = ShortcutBlock(ShortcutBlock(shortcut), propagate_shortcut=True)

assert isinstance(shortcut[0], PlusOne)
inputs = torch.rand(5, 5)
assert torch.equal(
module_utils.module_test(shortcut, inputs)["output"],
module_utils.module_test(ShortcutBlock(PlusOne()), inputs)["output"],
module_utils.module_test(nested, inputs)["output"],
)

def test_with_parallel(self):
parallel = ParallelBlock({"a": PlusOne(), "b": PlusOne()})
shortcut = ShortcutBlock(parallel)

inputs = torch.rand(5, 5)

outputs = shortcut(inputs)

outputs = module_utils.module_test(shortcut, inputs)
assert torch.equal(outputs["shortcut"], inputs)
assert torch.equal(outputs["a"], inputs + 1)
assert torch.equal(outputs["b"], inputs + 1)

def test_propagate_shortcut(self):
class PlusOneShortcut(nn.Module):
def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
return inputs["shortcut"] + 1

shortcut = ShortcutBlock(PlusOneShortcut(), propagate_shortcut=True)
shortcut = ShortcutBlock(shortcut, propagate_shortcut=True)
inputs = torch.rand(5, 5)
outputs = module_utils.module_test(shortcut, inputs)

assert torch.equal(outputs["output"], inputs + 1)

with pytest.raises(RuntimeError):
shortcut({"a": inputs})

def test_exception(self):
with_tuple = Block(PlusOneTuple())
shortcut = ShortcutBlock(with_tuple)

with pytest.raises(RuntimeError):
module_utils.module_test(shortcut, torch.rand(5, 5))

class PlusOneShortcutTuple(nn.Module):
def forward(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return inputs["shortcut"] + 1, inputs["shortcut"]

shortcut_propagate = ShortcutBlock(PlusOneShortcutTuple(), propagate_shortcut=True)
with pytest.raises(RuntimeError):
module_utils.module_test(shortcut_propagate, torch.rand(5, 5))

0 comments on commit 7c71ac8

Please sign in to comment.