From eff55db5498c4be0ba7112188f74d1565b89d9c5 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 3 Jul 2023 10:30:32 +0200 Subject: [PATCH] Remove link, and introduce ShortcutBlock + ResidualBlock (#1170) * Removing Link in favour of some new Blocks like: ResidualBlock & ShortcutBlock * Removing Link in favour of some new Blocks like: ResidualBlock & ShortcutBlock * Add conversion test * Some bug fixes + 100% test-coverage for block.py * Improve doc-strings * Remove un-used isblock again --- merlin/models/torch/__init__.py | 4 +- merlin/models/torch/block.py | 196 +++++++++++++++++++++++++---- merlin/models/torch/blocks/dlrm.py | 40 +++--- merlin/models/torch/container.py | 97 +++----------- merlin/models/torch/link.py | 83 ------------ tests/unit/torch/test_block.py | 131 ++++++++++++++++--- tests/unit/torch/test_container.py | 50 +------- tests/unit/torch/test_link.py | 76 ----------- 8 files changed, 332 insertions(+), 345 deletions(-) delete mode 100644 merlin/models/torch/link.py delete mode 100644 tests/unit/torch/test_link.py diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index d2326af5e9..988897ef44 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -16,7 +16,7 @@ from merlin.models.torch import schema from merlin.models.torch.batch import Batch, Sequence -from merlin.models.torch.block import Block, ParallelBlock +from merlin.models.torch.block import Block, ParallelBlock, ResidualBlock, ShortcutBlock from merlin.models.torch.blocks.dlrm import DLRMBlock from merlin.models.torch.blocks.mlp import MLPBlock from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables @@ -45,9 +45,11 @@ "ParallelBlock", "Sequence", "RegressionOutput", + "ResidualBlock", "RouterBlock", "SelectKeys", "SelectFeatures", + "ShortcutBlock", "TabularInputBlock", "Concat", "Stack", diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index 44d6212909..42dede5b9b 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -25,9 +25,8 @@ from merlin.models.torch import schema from merlin.models.torch.batch import Batch from merlin.models.torch.container import BlockContainer, BlockContainerDict -from merlin.models.torch.link import Link, LinkType from merlin.models.torch.registry import registry -from merlin.models.torch.utils.traversal_utils import TraversableMixin, leaf +from merlin.models.torch.utils.traversal_utils import TraversableMixin from merlin.models.utils.registry import RegistryMixin from merlin.schema import Schema @@ -41,8 +40,6 @@ class Block(BlockContainer, RegistryMixin, TraversableMixin): Variable length argument list of PyTorch modules to be contained in the block. name : Optional[str], default = None The name of the block. If None, no name is assigned. - track_schema : bool, default = True - If True, the schema of the output tensors are tracked. """ registry = registry @@ -73,7 +70,7 @@ def forward( return inputs - def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Block": + def repeat(self, n: int = 1, name=None) -> "Block": """ Creates a new block by repeating the current block `n` times. Each repetition is a deep copy of the current block. @@ -97,9 +94,6 @@ def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Blo raise ValueError("n must be greater than 0") repeats = [self.copy() for _ in range(n - 1)] - if link: - parsed_link = Link.parse(link) - repeats = [parsed_link.copy().setup_link(repeat) for repeat in repeats] return Block(self, *repeats, name=name) @@ -221,7 +215,7 @@ def forward( return outputs - def append(self, module: nn.Module, link: Optional[LinkType] = None): + def append(self, module: nn.Module): """Appends a module to the post-processing stage. Parameters @@ -235,7 +229,7 @@ def append(self, module: nn.Module, link: Optional[LinkType] = None): The current object itself. """ - self.post.append(module, link=link) + self.post.append(module) return self @@ -244,7 +238,7 @@ def prepend(self, module: nn.Module): return self - def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None): + def append_to(self, name: str, module: nn.Module): """Appends a module to a specified branch. Parameters @@ -260,11 +254,11 @@ def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = Non The current object itself. """ - self.branches[name].append(module, link=link) + self.branches[name].append(module) return self - def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None): + def prepend_to(self, name: str, module: nn.Module): """Prepends a module to a specified branch. Parameters @@ -279,11 +273,11 @@ def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = No ParallelBlock The current object itself. """ - self.branches[name].prepend(module, link=link) + self.branches[name].prepend(module) return self - def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None): + def append_for_each(self, module: nn.Module, shared=False): """Appends a module to each branch. Parameters @@ -300,11 +294,11 @@ def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkTy The current object itself. """ - self.branches.append_for_each(module, shared=shared, link=link) + self.branches.append_for_each(module, shared=shared) return self - def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None): + def prepend_for_each(self, module: nn.Module, shared=False): """Prepends a module to each branch. Parameters @@ -321,7 +315,7 @@ def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkT The current object itself. """ - self.branches.prepend_for_each(module, shared=shared, link=link) + self.branches.prepend_for_each(module, shared=shared) return self @@ -356,10 +350,7 @@ def leaf(self) -> nn.Module: raise ValueError("Cannot call leaf() on a ParallelBlock with multiple branches") first = list(self.branches.values())[0] - if hasattr(first, "leaf"): - return first.leaf() - - return leaf(first) + return first.leaf() def __getitem__(self, idx: Union[slice, int, str]): if isinstance(idx, str) and idx in self.branches: @@ -415,6 +406,167 @@ def __repr__(self) -> str: return self._get_name() + branches +class ResidualBlock(Block): + """ + A block that applies each contained module sequentially on the input + and performs a residual connection after each module. + + Parameters + ---------- + *module : nn.Module + Variable length argument list of PyTorch modules to be contained in the block. + name : Optional[str], default = None + The name of the block. If None, no name is assigned. + + """ + + def forward(self, inputs: torch.Tensor, batch: Optional[Batch] = None): + """ + Forward pass through the block. Applies each contained module sequentially on the input. + + Parameters + ---------- + inputs : Union[torch.Tensor, Dict[str, torch.Tensor]] + The input data as a tensor or a dictionary of tensors. + batch : Optional[Batch], default = None + Optional batch of data. If provided, it is used by the `module`s. + + Returns + ------- + torch.Tensor or Dict[str, torch.Tensor] + The output of the block after processing the input. + """ + shortcut, outputs = inputs, inputs + for module in self.values: + outputs = shortcut + module(outputs, batch=batch) + + return outputs + + +class ShortcutBlock(Block): + """ + A block with a 'shortcut' or a 'skip connection'. + + The shortcut tensor can be propagated through the layers of the module or not, + depending on the value of `propagate_shortcut` argument: + If `propagate_shortcut` is True, the shortcut tensor is passed through + each layer of the module. + If `propagate_shortcut` is False, the shortcut tensor is only used as part of + the final output dictionary. + + Example usage:: + >>> shortcut = mm.ShortcutBlock(nn.Identity()) + >>> shortcut(torch.ones(1, 1)) + {'shortcut': tensor([[1.]]), 'output': tensor([[1.]])} + + Parameters + ---------- + *module : nn.Module + Variable length argument list of PyTorch modules to be contained in the block. + name : str, optional + The name of the module, by default None. + propagate_shortcut : bool, optional + If True, propagates the shortcut tensor through the layers of this block, by default False. + shortcut_name : str, optional + The name to use for the shortcut tensor, by default "shortcut". + output_name : str, optional + The name to use for the output tensor, by default "output". + """ + + def __init__( + self, + *module: nn.Module, + name: Optional[str] = None, + propagate_shortcut: bool = False, + shortcut_name: str = "shortcut", + output_name: str = "output", + ): + super().__init__(*module, name=name) + self.shortcut_name = shortcut_name + self.output_name = output_name + self.propagate_shortcut = propagate_shortcut + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ) -> Dict[str, torch.Tensor]: + """ + Defines the forward propagation of the module. + + Parameters + ---------- + inputs : Union[torch.Tensor, Dict[str, torch.Tensor]] + The input tensor or a dictionary of tensors. + batch : Batch, optional + A batch of inputs, by default None. + + Returns + ------- + Dict[str, torch.Tensor] + The output tensor as a dictionary. + + Raises + ------ + RuntimeError + If the shortcut name is not found in the input dictionary, or + if the module does not return a tensor or a dictionary with a key 'output_name'. + """ + + if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]): + if self.shortcut_name not in inputs: + raise RuntimeError( + f"Shortcut name {self.shortcut_name} not found in inputs {inputs}" + ) + shortcut = inputs[self.shortcut_name] + else: + shortcut = inputs + + output = inputs + for module in self.values: + if self.propagate_shortcut: + if torch.jit.isinstance(output, Dict[str, torch.Tensor]): + module_output = module(output, batch=batch) + else: + to_pass: Dict[str, torch.Tensor] = { + self.shortcut_name: shortcut, + self.output_name: torch.jit.annotate(torch.Tensor, output), + } + + module_output = module(to_pass, batch=batch) + + if torch.jit.isinstance(module_output, torch.Tensor): + output = module_output + elif torch.jit.isinstance(module_output, Dict[str, torch.Tensor]): + output = module_output[self.output_name] + else: + raise RuntimeError( + f"Module {module} must return a tensor or a dict ", + f"with key {self.output_name}", + ) + else: + if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]) and torch.jit.isinstance( + output, Dict[str, torch.Tensor] + ): + output = output[self.output_name] + _output = module(output, batch=batch) + if torch.jit.isinstance(_output, torch.Tensor) or torch.jit.isinstance( + _output, Dict[str, torch.Tensor] + ): + output = _output + else: + raise RuntimeError( + f"Module {module} must return a tensor or a dict ", + f"with key {self.output_name}", + ) + + to_return = {self.shortcut_name: shortcut} + if torch.jit.isinstance(output, Dict[str, torch.Tensor]): + to_return.update(output) + else: + to_return[self.output_name] = output + + return to_return + + def get_pre(module: nn.Module) -> BlockContainer: if hasattr(module, "pre"): return module.pre diff --git a/merlin/models/torch/blocks/dlrm.py b/merlin/models/torch/blocks/dlrm.py index a24e4d1f71..3b638ada08 100644 --- a/merlin/models/torch/blocks/dlrm.py +++ b/merlin/models/torch/blocks/dlrm.py @@ -1,13 +1,13 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Union import torch from torch import nn +from merlin.models.torch.batch import Batch from merlin.models.torch.block import Block from merlin.models.torch.inputs.embedding import EmbeddingTables from merlin.models.torch.inputs.tabular import TabularInputBlock -from merlin.models.torch.link import Link -from merlin.models.torch.transforms.agg import MaybeAgg, Stack +from merlin.models.torch.transforms.agg import Stack from merlin.models.utils.doc_utils import docstring_parameter from merlin.schema import Schema, Tags @@ -77,7 +77,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return interactions_flat -class ShortcutConcatContinuous(Link): +class InteractionBlock(Block): """ A shortcut connection that concatenates continuous input features and intermediate outputs. @@ -85,13 +85,28 @@ class ShortcutConcatContinuous(Link): When there's no continuous input, the intermediate output is returned. """ - def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: - intermediate_output = self.output(inputs) + def __init__( + self, + *module: nn.Module, + name: Optional[str] = None, + prepend_agg: bool = True, + ): + if prepend_agg: + module = (Stack(dim=1),) + module + super().__init__(*module, name=name) - if "continuous" in inputs: - return torch.cat((inputs["continuous"], intermediate_output), dim=1) + def forward( + self, inputs: Union[Dict[str, torch.Tensor], torch.Tensor], batch: Optional[Batch] = None + ) -> torch.Tensor: + outputs = inputs + for module in self.values: + outputs = module(outputs, batch) - return intermediate_output + if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]): + if "continuous" in inputs: + return torch.cat((inputs["continuous"], outputs), dim=1) + + return outputs @docstring_parameter(dlrm_reference=_DLRM_REF) @@ -131,11 +146,6 @@ def __init__( interaction: Optional[nn.Module] = None, ): super().__init__(DLRMInputBlock(schema, dim, bottom_block)) - - self.append( - Block(MaybeAgg(Stack(dim=1)), interaction or DLRMInteraction()), - link=ShortcutConcatContinuous(), - ) - + self.append(InteractionBlock(interaction or DLRMInteraction())) if top_block: self.append(top_block) diff --git a/merlin/models/torch/container.py b/merlin/models/torch/container.py index ed185092ac..e289c694fa 100644 --- a/merlin/models/torch/container.py +++ b/merlin/models/torch/container.py @@ -21,7 +21,6 @@ from torch import nn from torch._jit_internal import _copy_to_script_wrapper -from merlin.models.torch.link import Link, LinkType from merlin.models.torch.utils import torchscript_utils @@ -47,52 +46,37 @@ def __init__(self, *inputs: nn.Module, name: Optional[str] = None): self._name: str = name - def append(self, module: nn.Module, link: Optional[Link] = None): + def append(self, module: nn.Module): """Appends a given module to the end of the list. Parameters ---------- module : nn.Module The PyTorch module to be appended. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- self """ - _module = self._check_link(module, link=link) - self.values.append(self.wrap_module(_module)) + self.values.append(self.wrap_module(module)) return self - def prepend(self, module: nn.Module, link: Optional[Link] = None): + def prepend(self, module: nn.Module): """Prepends a given module to the beginning of the list. Parameters ---------- module : nn.Module The PyTorch module to be prepended. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- self """ - return self.insert(0, module, link=link) + return self.insert(0, module) - def insert(self, index: int, module: nn.Module, link: Optional[Link] = None): + def insert(self, index: int, module: nn.Module): """Inserts a given module at the specified index. Parameters @@ -101,20 +85,12 @@ def insert(self, index: int, module: nn.Module, link: Optional[Link] = None): The index at which the module is to be inserted. module : nn.Module The PyTorch module to be inserted. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- self """ - _module = self._check_link(module, link=link) - self.values.insert(index, self.wrap_module(_module)) + self.values.insert(index, self.wrap_module(module)) return self @@ -193,15 +169,6 @@ def __repr__(self) -> str: def _get_name(self) -> str: return super()._get_name() if self._name is None else self._name - def _check_link(self, module: nn.Module, link: Optional[LinkType] = None) -> nn.Module: - if link: - linked_module: Link = Link.parse(link) - linked_module.setup_link(module) - - return linked_module - - return module - class BlockContainerDict(nn.ModuleDict): """A container class for PyTorch `nn.Module` that allows for manipulation and traversal @@ -232,9 +199,7 @@ def __init__( super().__init__(modules) self._name: str = name - def append_to( - self, name: str, module: nn.Module, link: Optional[LinkType] = None - ) -> "BlockContainerDict": + def append_to(self, name: str, module: nn.Module) -> "BlockContainerDict": """Appends a module to a specified name. Parameters @@ -243,13 +208,6 @@ def append_to( The name of the branch. module : nn.Module The module to append. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- @@ -257,13 +215,11 @@ def append_to( The current object itself. """ - self._modules[name].append(module, link=link) + self._modules[name].append(module) return self - def prepend_to( - self, name: str, module: nn.Module, link: Optional[LinkType] = None - ) -> "BlockContainerDict": + def prepend_to(self, name: str, module: nn.Module) -> "BlockContainerDict": """Prepends a module to a specified name. Parameters @@ -272,13 +228,6 @@ def prepend_to( The name of the branch. module : nn.Module The module to prepend. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- @@ -286,11 +235,9 @@ def prepend_to( The current object itself. """ - self._modules[name].prepend(module, link=link) + self._modules[name].prepend(module) - def append_for_each( - self, module: nn.Module, shared=False, link: Optional[LinkType] = None - ) -> "BlockContainerDict": + def append_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict": """Appends a module to each branch. Parameters @@ -300,13 +247,6 @@ def append_for_each( shared : bool, default=False If True, the same module is shared across all elements. Otherwise a deep copy of the module is used in each element. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- @@ -316,13 +256,11 @@ def append_for_each( for branch in self.values(): _module = module if shared else deepcopy(module) - branch.append(_module, link=link) + branch.append(_module) return self - def prepend_for_each( - self, module: nn.Module, shared=False, link: Optional[LinkType] = None - ) -> "BlockContainerDict": + def prepend_for_each(self, module: nn.Module, shared=False) -> "BlockContainerDict": """Prepends a module to each branch. Parameters @@ -332,13 +270,6 @@ def prepend_for_each( shared : bool, default=False If True, the same module is shared across all elements. Otherwise a deep copy of the module is used in each element. - link : Optional[LinkType] - The link to use for the module. If None, no link is used. - This can either be a Module or a string, options are: - - "residual": Adds a residual connection to the module. - - "shortcut": Adds a shortcut connection to the module. - - "shortcut-concat": Adds a shortcut connection by concatenating - the input and output. Returns ------- @@ -347,7 +278,7 @@ def prepend_for_each( """ for branch in self.values(): _module = module if shared else deepcopy(module) - branch.prepend(_module, link=link) + branch.prepend(_module) return self diff --git a/merlin/models/torch/link.py b/merlin/models/torch/link.py deleted file mode 100644 index f490aeec7f..0000000000 --- a/merlin/models/torch/link.py +++ /dev/null @@ -1,83 +0,0 @@ -import copy -from typing import Dict, Optional, Union - -import torch -from torch import nn - -from merlin.models.torch.registry import TorchRegistryMixin - -LinkType = Union[str, "Link"] - - -class Link(nn.Module, TorchRegistryMixin): - """Base class for different types of network links. - - This is typically used as part of a `Block` to connect different modules. - - Some examples of links are: - - `residual`: Adds the input to the output of the module. - - `shortcut`: Outputs a dictionary with the output of the module and the input. - - `shortcut-concat`: Concatenates the input and the output of the module. - - """ - - def __init__(self, output: Optional[nn.Module] = None): - super().__init__() - - if output is not None: - self.setup_link(output) - - def setup_link(self, output: nn.Module) -> "Link": - """ - Setup function for the link. - - Parameters - ---------- - output : nn.Module - The output module for the link. - - Returns - ------- - Link - The updated Link instance. - """ - - self.output = output - - return self - - def copy(self) -> "Link": - """ - Returns a copy of the link. - - Returns - ------- - Link - The copied link. - """ - return copy.deepcopy(self) - - -@Link.registry.register("residual") -class Residual(Link): - """Adds the input to the output of the module.""" - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x + self.output(x) - - -@Link.registry.register("shortcut") -class Shortcut(Link): - """Outputs a dictionary with the output of the module and the input.""" - - def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: - return {"output": self.output(x), "shortcut": x} - - -@Link.registry.register("shortcut-concat") -class ShortcutConcat(Link): - """Concatenates the input and the output of the module.""" - - def forward(self, x: torch.Tensor) -> torch.Tensor: - intermediate_output = self.output(x) - return torch.cat((x, intermediate_output), dim=1) diff --git a/tests/unit/torch/test_block.py b/tests/unit/torch/test_block.py index a2d2d9b627..ea36aaa412 100644 --- a/tests/unit/torch/test_block.py +++ b/tests/unit/torch/test_block.py @@ -20,9 +20,15 @@ from torch import nn import merlin.models.torch as mm -from merlin.models.torch import link from merlin.models.torch.batch import Batch -from merlin.models.torch.block import Block, ParallelBlock, get_pre, set_pre +from merlin.models.torch.block import ( + Block, + ParallelBlock, + ResidualBlock, + ShortcutBlock, + get_pre, + set_pre, +) from merlin.models.torch.container import BlockContainer, BlockContainerDict from merlin.models.torch.utils import module_utils from merlin.schema import Tags @@ -63,9 +69,6 @@ def test_insertion(self): assert torch.equal(outputs, inputs + 2) - block.append(PlusOne(), link="residual") - assert isinstance(block[-1], link.Residual) - def test_copy(self): block = Block(PlusOne()) @@ -89,19 +92,6 @@ def test_repeat(self): with pytest.raises(ValueError, match="n must be greater than 0"): block.repeat(0) - def test_repeat_with_link(self): - block = Block(PlusOne()) - - repeated = block.repeat(2, link="residual") - assert isinstance(repeated, Block) - assert len(repeated) == 2 - assert isinstance(repeated[-1], link.Residual) - - inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) - outputs = module_utils.module_test(repeated, inputs) - - assert torch.equal(outputs, (inputs + 1) + (inputs + 1) + 1) - def test_from_registry(self): @Block.registry.register("my_block") class MyBlock(Block): @@ -282,3 +272,108 @@ def test_input_schema_pre(self): assert input_schema == mm.schema.input(pb2) assert mm.schema.output(pb2) == mm.schema.output(pb) + + def test_leaf(self): + block = ParallelBlock({"a": PlusOne()}) + + assert isinstance(block.leaf(), PlusOne) + + block.branches["b"] = PlusOne() + with pytest.raises(ValueError): + block.leaf() + + block.prepend(PlusOne()) + with pytest.raises(ValueError): + block.leaf() + + block = ParallelBlock({"a": nn.Sequential(PlusOne())}) + assert isinstance(block.leaf(), PlusOne) + + +class TestResidualBlock: + def test_forward(self): + input_tensor = torch.randn(1, 3, 64, 64) + conv = nn.Conv2d(3, 3, kernel_size=3, padding=1) + residual = ResidualBlock(conv) + + output_tensor = module_utils.module_test(residual, input_tensor) + expected_tensor = input_tensor + conv(input_tensor) + + assert torch.allclose(output_tensor, expected_tensor) + + +class TestShortcutBlock: + def test_forward(self): + input_tensor = torch.randn(1, 3, 64, 64) + conv = nn.Conv2d(3, 3, kernel_size=3, padding=1) + shortcut = ShortcutBlock(conv) + + output_dict = module_utils.module_test(shortcut, input_tensor) + + assert "output" in output_dict + assert "shortcut" in output_dict + assert torch.allclose(output_dict["output"], conv(input_tensor)) + assert torch.allclose(output_dict["shortcut"], input_tensor) + + def test_nesting(self): + inputs = torch.rand(5, 5) + shortcut = ShortcutBlock(ShortcutBlock(PlusOne())) + output = module_utils.module_test(shortcut, inputs) + + assert torch.equal(output["shortcut"], inputs) + assert torch.equal(output["output"], inputs + 1) + + def test_convert(self): + block = Block(PlusOne()) + shortcut = ShortcutBlock(*block) + nested = ShortcutBlock(ShortcutBlock(shortcut), propagate_shortcut=True) + + assert isinstance(shortcut[0], PlusOne) + inputs = torch.rand(5, 5) + assert torch.equal( + module_utils.module_test(shortcut, inputs)["output"], + module_utils.module_test(nested, inputs)["output"], + ) + + def test_with_parallel(self): + parallel = ParallelBlock({"a": PlusOne(), "b": PlusOne()}) + shortcut = ShortcutBlock(parallel) + + inputs = torch.rand(5, 5) + + outputs = shortcut(inputs) + + outputs = module_utils.module_test(shortcut, inputs) + assert torch.equal(outputs["shortcut"], inputs) + assert torch.equal(outputs["a"], inputs + 1) + assert torch.equal(outputs["b"], inputs + 1) + + def test_propagate_shortcut(self): + class PlusOneShortcut(nn.Module): + def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: + return inputs["shortcut"] + 1 + + shortcut = ShortcutBlock(PlusOneShortcut(), propagate_shortcut=True) + shortcut = ShortcutBlock(shortcut, propagate_shortcut=True) + inputs = torch.rand(5, 5) + outputs = module_utils.module_test(shortcut, inputs) + + assert torch.equal(outputs["output"], inputs + 1) + + with pytest.raises(RuntimeError): + shortcut({"a": inputs}) + + def test_exception(self): + with_tuple = Block(PlusOneTuple()) + shortcut = ShortcutBlock(with_tuple) + + with pytest.raises(RuntimeError): + module_utils.module_test(shortcut, torch.rand(5, 5)) + + class PlusOneShortcutTuple(nn.Module): + def forward(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + return inputs["shortcut"] + 1, inputs["shortcut"] + + shortcut_propagate = ShortcutBlock(PlusOneShortcutTuple(), propagate_shortcut=True) + with pytest.raises(RuntimeError): + module_utils.module_test(shortcut_propagate, torch.rand(5, 5)) diff --git a/tests/unit/torch/test_container.py b/tests/unit/torch/test_container.py index 4479cc70de..4c8b14be9f 100644 --- a/tests/unit/torch/test_container.py +++ b/tests/unit/torch/test_container.py @@ -15,13 +15,11 @@ # import pytest -import torch import torch.nn as nn import merlin.models.torch as mm -from merlin.models.torch import link from merlin.models.torch.container import BlockContainer, BlockContainerDict -from merlin.models.torch.utils import module_utils, torchscript_utils +from merlin.models.torch.utils import torchscript_utils from merlin.schema import Tags @@ -32,6 +30,7 @@ def setup_method(self): def test_init(self): assert isinstance(self.block_container, BlockContainer) assert self.block_container._name == "test_container" + assert self.block_container != "" def test_append(self): module = nn.Linear(20, 30) @@ -39,16 +38,6 @@ def test_append(self): assert len(self.block_container) == 1 assert self.block_container != BlockContainer(name="test_container") - def test_append_link(self): - module = nn.Linear(20, 20) - self.block_container.append(module, link="residual") - assert len(self.block_container) == 1 - - inputs = torch.randn(1, 20) - outputs = module_utils.module_test(self.block_container[0], inputs) - - assert torch.equal(inputs + module(inputs), outputs) - def test_prepend(self): module1 = nn.Linear(20, 30) module2 = nn.Linear(30, 40) @@ -57,16 +46,6 @@ def test_prepend(self): assert len(self.block_container) == 2 assert isinstance(self.block_container[0], nn.Linear) - def test_prepend_link(self): - module = nn.Linear(20, 20) - self.block_container.prepend(module, link="residual") - assert len(self.block_container) == 1 - - inputs = torch.randn(1, 20) - outputs = module_utils.module_test(self.block_container[0], inputs) - - assert torch.equal(inputs + module(inputs), outputs) - def test_insert(self): module1 = nn.Linear(20, 30) module2 = nn.Linear(30, 40) @@ -75,16 +54,6 @@ def test_insert(self): assert len(self.block_container) == 2 assert isinstance(self.block_container[0], nn.Linear) - def test_insert_link(self): - module = nn.Linear(20, 20) - self.block_container.insert(0, module, link="residual") - assert len(self.block_container) == 1 - - inputs = torch.randn(1, 20) - outputs = module_utils.module_test(self.block_container[0], inputs) - - assert torch.equal(inputs + module(inputs), outputs) - def test_len(self): module = nn.Linear(20, 30) self.block_container.append(module) @@ -179,6 +148,7 @@ def test_init(self): assert isinstance(self.container, BlockContainerDict) assert self.container._get_name() == "test" assert isinstance(self.container.unwrap()["test"], BlockContainer) + assert self.container != "" def test_empty(self): container = BlockContainerDict() @@ -192,16 +162,10 @@ def test_append_to(self): self.container.append_to("test", self.module) assert "test" in self.container._modules - self.container.append_to("test", self.module, link="residual") - assert isinstance(self.container["test"][-1], link.Residual) - def test_prepend_to(self): self.container.prepend_to("test", self.module) assert "test" in self.container._modules - self.container.prepend_to("test", self.module, link="residual") - assert isinstance(self.container["test"][0], link.Residual) - def test_append_for_each(self): container = BlockContainerDict({"a": nn.Module(), "b": nn.Module()}) @@ -216,10 +180,6 @@ def test_append_for_each(self): assert len(container["b"]) == 3 assert container["a"][-1] == container["b"][-1] - container.append_for_each(to_add, link="residual") - assert isinstance(container["a"][-1], link.Residual) - assert isinstance(container["b"][-1], link.Residual) - def test_prepend_for_each(self): container = BlockContainerDict({"a": nn.Module(), "b": nn.Module()}) @@ -233,7 +193,3 @@ def test_prepend_for_each(self): assert len(container["a"]) == 3 assert len(container["b"]) == 3 assert container["a"][0] == container["b"][0] - - container.prepend_for_each(to_add, link="residual") - assert isinstance(container["a"][0], link.Residual) - assert isinstance(container["b"][0], link.Residual) diff --git a/tests/unit/torch/test_link.py b/tests/unit/torch/test_link.py deleted file mode 100644 index 5514154184..0000000000 --- a/tests/unit/torch/test_link.py +++ /dev/null @@ -1,76 +0,0 @@ -# -# Copyright (c) 2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import torch -from torch import nn - -from merlin.models.torch.link import Link, Residual, Shortcut, ShortcutConcat -from merlin.models.torch.utils import module_utils - - -class TestResidual: - def test_forward(self): - input_tensor = torch.randn(1, 3, 64, 64) - conv = nn.Conv2d(3, 3, kernel_size=3, padding=1) - residual = Residual(conv) - - output_tensor = module_utils.module_test(residual, input_tensor) - expected_tensor = input_tensor + conv(input_tensor) - - assert torch.allclose(output_tensor, expected_tensor) - - def test_from_registry(self): - residual = Link.parse("residual") - - assert isinstance(residual, Residual) - - -class TestShortcut: - def test_forward(self): - input_tensor = torch.randn(1, 3, 64, 64) - conv = nn.Conv2d(3, 3, kernel_size=3, padding=1) - shortcut = Shortcut(conv) - - output_dict = module_utils.module_test(shortcut, input_tensor) - - assert "output" in output_dict - assert "shortcut" in output_dict - assert torch.allclose(output_dict["output"], conv(input_tensor)) - assert torch.allclose(output_dict["shortcut"], input_tensor) - - def test_from_registry(self): - shortcut = Link.parse("shortcut") - - assert isinstance(shortcut, Shortcut) - - -class TestShortcutConcat: - def test_forward(self): - input_tensor = torch.randn(1, 3, 64, 64) - conv = nn.Conv2d( - 3, 10, kernel_size=3, padding=1 - ) # Output channels are different for concatenation - shortcut_concat = ShortcutConcat(conv) - - output_tensor = module_utils.module_test(shortcut_concat, input_tensor) - expected_tensor = torch.cat((input_tensor, conv(input_tensor)), dim=1) - - assert torch.allclose(output_tensor, expected_tensor) - - def test_from_registry(self): - shortcut_concat = Link.parse("shortcut-concat") - - assert isinstance(shortcut_concat, ShortcutConcat)