From 58ce8c15fef5de2e034bb6619201049285b1914d Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Thu, 6 Jul 2023 13:05:31 +0200 Subject: [PATCH] Making input- and output-block lazy to allow for nesting (#1184) * Making schema optional in TabularInputBlock to allow for nesting * Adding tests to ensure correct behaviour * Fixing failing test --------- Co-authored-by: edknv <109497216+edknv@users.noreply.github.com> --- merlin/models/torch/inputs/tabular.py | 25 ++++++++++++++---------- merlin/models/torch/outputs/tabular.py | 24 +++++++++++++---------- merlin/models/torch/router.py | 14 +++++++++---- tests/unit/torch/inputs/test_tabular.py | 18 +++++++++++++++++ tests/unit/torch/outputs/test_tabular.py | 10 ++++++++++ tests/unit/torch/test_router.py | 9 +++++++++ 6 files changed, 76 insertions(+), 24 deletions(-) diff --git a/merlin/models/torch/inputs/tabular.py b/merlin/models/torch/inputs/tabular.py index 0d73656e20..dc6ae4473e 100644 --- a/merlin/models/torch/inputs/tabular.py +++ b/merlin/models/torch/inputs/tabular.py @@ -51,21 +51,26 @@ class TabularInputBlock(RouterBlock): def __init__( self, - schema: Schema, + schema: Optional[Schema] = None, init: Optional[Union[str, Initializer]] = None, agg: Optional[Union[str, nn.Module]] = None, ): + self.init = init + self.agg = agg super().__init__(schema) + + def setup_schema(self, schema: Schema): + super().setup_schema(schema) self.schema: Schema = self.selectable.schema - if init: - if isinstance(init, str): - init = self.initializers.get(init) - if not init: - raise ValueError(f"Initializer {init} not found.") - - init(self) - if agg: - self.append(Block.parse(agg)) + if self.init: + if isinstance(self.init, str): + self.init = self.initializers.get(self.init) + if not self.init: + raise ValueError(f"Initializer {self.init} not found.") + + self.init(self) + if self.agg: + self.append(Block.parse(self.agg)) @classmethod def register_init(cls, name: str): diff --git a/merlin/models/torch/outputs/tabular.py b/merlin/models/torch/outputs/tabular.py index bab73d9be8..dfc74eb059 100644 --- a/merlin/models/torch/outputs/tabular.py +++ b/merlin/models/torch/outputs/tabular.py @@ -49,22 +49,26 @@ class TabularOutputBlock(RouterBlock): def __init__( self, - schema: Schema, + schema: Optional[Schema] = None, init: Optional[Union[str, Initializer]] = None, selection: Optional[Selection] = Tags.TARGET, ): - if selection: - schema = select(schema, selection) - + self.selection = selection + self.init = init super().__init__(schema, prepend_routing_module=False) + + def setup_schema(self, schema: Schema): + if self.selection: + schema = select(schema, self.selection) + super().setup_schema(schema) self.schema: Schema = self.selectable.schema - if init: - if isinstance(init, str): - init = self.initializers.get(init) - if not init: - raise ValueError(f"Initializer {init} not found.") + if self.init: + if isinstance(self.init, str): + self.init = self.initializers.get(self.init) + if not self.init: + raise ValueError(f"Initializer {self.init} not found.") - init(self) + self.init(self) @classmethod def register_init(cls, name: str): diff --git a/merlin/models/torch/router.py b/merlin/models/torch/router.py index 087068b6e0..bfa7cbc1ec 100644 --- a/merlin/models/torch/router.py +++ b/merlin/models/torch/router.py @@ -49,13 +49,16 @@ class RouterBlock(ParallelBlock): def __init__(self, selectable: schema.Selectable, prepend_routing_module: bool = True): super().__init__() + self.prepend_routing_module = prepend_routing_module if isinstance(selectable, Schema): - from merlin.models.torch.inputs.select import SelectKeys + self.setup_schema(selectable) + else: + self.selectable: schema.Selectable = selectable - selectable = SelectKeys(selectable) + def setup_schema(self, schema: Schema): + from merlin.models.torch.inputs.select import SelectKeys - self.selectable: schema.Selectable = selectable - self.prepend_routing_module = prepend_routing_module + self.selectable = SelectKeys(schema) def add_route( self, @@ -88,6 +91,9 @@ def add_route( The router block with the new route added. """ + if self.selectable is None: + raise ValueError(f"{self} has nothing to select from, so cannot add route.") + routing_module = schema.select(self.selectable, selection) if not routing_module: return self diff --git a/tests/unit/torch/inputs/test_tabular.py b/tests/unit/torch/inputs/test_tabular.py index e81fe44ce4..da8acc90b4 100644 --- a/tests/unit/torch/inputs/test_tabular.py +++ b/tests/unit/torch/inputs/test_tabular.py @@ -131,3 +131,21 @@ def test_extract_double_nesting(self): no_user_id, user_id_route = mm.schema.extract(input_block, Tags.USER_ID) assert no_user_id + + def test_nesting(self): + input_block = mm.TabularInputBlock(self.schema) + input_block.add_route( + lambda schema: schema, + mm.TabularInputBlock(init="defaults"), + ) + outputs = module_utils.module_test(input_block, self.batch) + + for name in mm.schema.select(self.schema, Tags.CONTINUOUS).column_names: + assert name in outputs + + for name in mm.schema.select(self.schema, Tags.CATEGORICAL).column_names: + assert name in outputs + assert outputs[name].shape == ( + 10, + infer_embedding_dim(self.schema.select_by_name(name)), + ) diff --git a/tests/unit/torch/outputs/test_tabular.py b/tests/unit/torch/outputs/test_tabular.py index 3ed04b4725..12b045879a 100644 --- a/tests/unit/torch/outputs/test_tabular.py +++ b/tests/unit/torch/outputs/test_tabular.py @@ -69,3 +69,13 @@ def test_no_route_for_non_existent_tag(self): outputs.add_route(Tags.CATEGORICAL) assert not outputs + + def test_nesting(self): + output_block = mm.TabularOutputBlock(self.schema) + output_block.add_route(Tags.TARGET, mm.TabularOutputBlock(init="defaults")) + + outputs = module_utils.module_test(output_block, torch.rand(10, 10)) + + assert "play_percentage" in outputs + assert "click" in outputs + assert "like" in outputs diff --git a/tests/unit/torch/test_router.py b/tests/unit/torch/test_router.py index 76459ea8db..10e57d7d74 100644 --- a/tests/unit/torch/test_router.py +++ b/tests/unit/torch/test_router.py @@ -163,3 +163,12 @@ def test_nested(self): outputs = module_utils.module_test(nested, self.batch.features) assert list(outputs.keys()) == ["user_age"] assert "user_age" in mm.output_schema(nested).column_names + + def test_exceptions(self): + router = mm.RouterBlock(None) + with pytest.raises(ValueError): + router.add_route(Tags.CONTINUOUS) + + router = mm.RouterBlock(self.schema, prepend_routing_module=False) + with pytest.raises(ValueError): + router.add_route(Tags.CONTINUOUS)