Skip to content

Commit

Permalink
Making input- and output-block lazy to allow for nesting (#1184)
Browse files Browse the repository at this point in the history
* Making schema optional in TabularInputBlock to allow for nesting

* Adding tests to ensure correct behaviour

* Fixing failing test

---------

Co-authored-by: edknv <[email protected]>
  • Loading branch information
marcromeyn and edknv authored Jul 6, 2023
1 parent 9922f25 commit 58ce8c1
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 24 deletions.
25 changes: 15 additions & 10 deletions merlin/models/torch/inputs/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,26 @@ class TabularInputBlock(RouterBlock):

def __init__(
self,
schema: Schema,
schema: Optional[Schema] = None,
init: Optional[Union[str, Initializer]] = None,
agg: Optional[Union[str, nn.Module]] = None,
):
self.init = init
self.agg = agg
super().__init__(schema)

def setup_schema(self, schema: Schema):
super().setup_schema(schema)
self.schema: Schema = self.selectable.schema
if init:
if isinstance(init, str):
init = self.initializers.get(init)
if not init:
raise ValueError(f"Initializer {init} not found.")

init(self)
if agg:
self.append(Block.parse(agg))
if self.init:
if isinstance(self.init, str):
self.init = self.initializers.get(self.init)
if not self.init:
raise ValueError(f"Initializer {self.init} not found.")

self.init(self)
if self.agg:
self.append(Block.parse(self.agg))

@classmethod
def register_init(cls, name: str):
Expand Down
24 changes: 14 additions & 10 deletions merlin/models/torch/outputs/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,26 @@ class TabularOutputBlock(RouterBlock):

def __init__(
self,
schema: Schema,
schema: Optional[Schema] = None,
init: Optional[Union[str, Initializer]] = None,
selection: Optional[Selection] = Tags.TARGET,
):
if selection:
schema = select(schema, selection)

self.selection = selection
self.init = init
super().__init__(schema, prepend_routing_module=False)

def setup_schema(self, schema: Schema):
if self.selection:
schema = select(schema, self.selection)
super().setup_schema(schema)
self.schema: Schema = self.selectable.schema
if init:
if isinstance(init, str):
init = self.initializers.get(init)
if not init:
raise ValueError(f"Initializer {init} not found.")
if self.init:
if isinstance(self.init, str):
self.init = self.initializers.get(self.init)
if not self.init:
raise ValueError(f"Initializer {self.init} not found.")

init(self)
self.init(self)

@classmethod
def register_init(cls, name: str):
Expand Down
14 changes: 10 additions & 4 deletions merlin/models/torch/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@ class RouterBlock(ParallelBlock):

def __init__(self, selectable: schema.Selectable, prepend_routing_module: bool = True):
super().__init__()
self.prepend_routing_module = prepend_routing_module
if isinstance(selectable, Schema):
from merlin.models.torch.inputs.select import SelectKeys
self.setup_schema(selectable)
else:
self.selectable: schema.Selectable = selectable

selectable = SelectKeys(selectable)
def setup_schema(self, schema: Schema):
from merlin.models.torch.inputs.select import SelectKeys

self.selectable: schema.Selectable = selectable
self.prepend_routing_module = prepend_routing_module
self.selectable = SelectKeys(schema)

def add_route(
self,
Expand Down Expand Up @@ -88,6 +91,9 @@ def add_route(
The router block with the new route added.
"""

if self.selectable is None:
raise ValueError(f"{self} has nothing to select from, so cannot add route.")

routing_module = schema.select(self.selectable, selection)
if not routing_module:
return self
Expand Down
18 changes: 18 additions & 0 deletions tests/unit/torch/inputs/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,21 @@ def test_extract_double_nesting(self):
no_user_id, user_id_route = mm.schema.extract(input_block, Tags.USER_ID)

assert no_user_id

def test_nesting(self):
input_block = mm.TabularInputBlock(self.schema)
input_block.add_route(
lambda schema: schema,
mm.TabularInputBlock(init="defaults"),
)
outputs = module_utils.module_test(input_block, self.batch)

for name in mm.schema.select(self.schema, Tags.CONTINUOUS).column_names:
assert name in outputs

for name in mm.schema.select(self.schema, Tags.CATEGORICAL).column_names:
assert name in outputs
assert outputs[name].shape == (
10,
infer_embedding_dim(self.schema.select_by_name(name)),
)
10 changes: 10 additions & 0 deletions tests/unit/torch/outputs/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,13 @@ def test_no_route_for_non_existent_tag(self):
outputs.add_route(Tags.CATEGORICAL)

assert not outputs

def test_nesting(self):
output_block = mm.TabularOutputBlock(self.schema)
output_block.add_route(Tags.TARGET, mm.TabularOutputBlock(init="defaults"))

outputs = module_utils.module_test(output_block, torch.rand(10, 10))

assert "play_percentage" in outputs
assert "click" in outputs
assert "like" in outputs
9 changes: 9 additions & 0 deletions tests/unit/torch/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,12 @@ def test_nested(self):
outputs = module_utils.module_test(nested, self.batch.features)
assert list(outputs.keys()) == ["user_age"]
assert "user_age" in mm.output_schema(nested).column_names

def test_exceptions(self):
router = mm.RouterBlock(None)
with pytest.raises(ValueError):
router.add_route(Tags.CONTINUOUS)

router = mm.RouterBlock(self.schema, prepend_routing_module=False)
with pytest.raises(ValueError):
router.add_route(Tags.CONTINUOUS)

0 comments on commit 58ce8c1

Please sign in to comment.