Skip to content

Commit

Permalink
Removing Link in favour of some new Blocks like: ResidualBlock & Shor…
Browse files Browse the repository at this point in the history
…tcutBlock
  • Loading branch information
marcromeyn committed Jul 1, 2023
1 parent 91774db commit ddcf3bc
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 18 deletions.
49 changes: 31 additions & 18 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,38 +459,51 @@ def __init__(
self.output_name = output_name

def forward(
self, inputs: torch.Tensor, batch: Optional[Batch] = None
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
) -> Dict[str, torch.Tensor]:
shortcut, output = inputs, inputs
if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]):
if self.shortcut_name not in inputs:
raise ValueError(f"Shortcut name {self.shortcut_name} not found in inputs {inputs}")
shortcut = inputs[self.shortcut_name]
else:
shortcut = inputs

output = inputs
for module in self.values:
if getattr(module, "accepts_dict", False):
module_output = module(self._create_dict(shortcut, output), batch=batch)
if getattr(module, "accepts_dict", False) or hasattr(module, "values"):
if torch.jit.isinstance(output, Dict[str, torch.Tensor]):
module_output = module(output, batch=batch)
else:
to_pass: Dict[str, torch.Tensor] = {
self.shortcut_name: shortcut,
self.output_name: torch.jit.annotate(torch.Tensor, output),
}

module_output = module(to_pass, batch=batch)

if torch.jit.isinstance(module_output, torch.Tensor):
output = module_output
elif isinstance(module_output, Dict[str, torch.Tensor]):
elif torch.jit.isinstance(module_output, Dict[str, torch.Tensor]):
output = module_output[self.output_name]
else:
raise ValueError(
f"Module {module} must return a tensor or a dict ",
f"with key {self.output_name}",
)
else:
if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]) and torch.jit.isinstance(
output, Dict[str, torch.Tensor]
):
output = output[self.output_name]
output = module(output, batch=batch)

return self._create_dict(shortcut, output)

def _create_dict(self, shortcut: torch.Tensor, output: torch.Tensor) -> Dict[str, torch.Tensor]:
return {self.shortcut_name: shortcut, self.output_name: output}


class CrossBlock(Block):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
x0 = inputs
current = inputs
for module in self.values:
current = x0 * module(current) + current
to_return = {self.shortcut_name: shortcut}
if torch.jit.isinstance(output, Dict[str, torch.Tensor]):
to_return.update(output)
else:
to_return[self.output_name] = output

return current
return to_return


def get_pre(module: nn.Module) -> BlockContainer:
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/torch/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,11 @@ def test_forward(self):
assert "shortcut" in output_dict
assert torch.allclose(output_dict["output"], conv(input_tensor))
assert torch.allclose(output_dict["shortcut"], input_tensor)

def test_nesting(self):
inputs = torch.rand(5, 5)
shortcut = ShortcutBlock(ShortcutBlock(PlusOne()))
output = module_utils.module_test(shortcut, inputs)

assert torch.equal(output["shortcut"], inputs)
assert torch.equal(output["output"], inputs + 1)
2 changes: 2 additions & 0 deletions tests/unit/torch/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def setup_method(self):
def test_init(self):
assert isinstance(self.block_container, BlockContainer)
assert self.block_container._name == "test_container"
assert self.block_container != ""

def test_append(self):
module = nn.Linear(20, 30)
Expand Down Expand Up @@ -147,6 +148,7 @@ def test_init(self):
assert isinstance(self.container, BlockContainerDict)
assert self.container._get_name() == "test"
assert isinstance(self.container.unwrap()["test"], BlockContainer)
assert self.container != ""

def test_empty(self):
container = BlockContainerDict()
Expand Down

0 comments on commit ddcf3bc

Please sign in to comment.