Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename Schema Inspection names in the torch schema module #1179

Merged
9 changes: 9 additions & 0 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from merlin.models.torch.router import RouterBlock
from merlin.models.torch.transforms.agg import Concat, Stack

input_schema = schema.input_schema
output_schema = schema.output_schema
target_schema = schema.target_schema
feature_schema = schema.feature_schema

__all__ = [
"Batch",
"BinaryOutput",
Expand All @@ -55,6 +60,10 @@
"Concat",
"Stack",
"schema",
"input_schema",
"output_schema",
"feature_schema",
"target_schema",
"DLRMBlock",
"DLRMModel",
]
6 changes: 3 additions & 3 deletions merlin/models/torch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,10 @@ def sample_features(
return sample_batch(data, batch_size, shuffle).features


@schema.output.register_tensor(Batch)
@schema.output_schema.register_tensor(Batch)
def _(input):
output_schema = Schema()
output_schema += schema.output.tensors(input.features)
output_schema += schema.output.tensors(input.targets)
output_schema += schema.output_schema.tensors(input.features)
output_schema += schema.output_schema.tensors(input.targets)

return output_schema
26 changes: 13 additions & 13 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,31 +588,31 @@ def set_pre(module: nn.Module, pre: BlockContainer):
return set_pre(module[0], pre)


@schema.input.register(BlockContainer)
@schema.input_schema.register(BlockContainer)
def _(module: BlockContainer, input: Schema):
return schema.input(module[0], input) if module else input
return schema.input_schema(module[0], input) if module else input


@schema.input.register(ParallelBlock)
@schema.input_schema.register(ParallelBlock)
def _(module: ParallelBlock, input: Schema):
if module.pre:
return schema.input(module.pre)
return schema.input_schema(module.pre)

out_schema = Schema()
for branch in module.branches.values():
out_schema += schema.input(branch, input)
out_schema += schema.input_schema(branch, input)

return out_schema


@schema.output.register(ParallelBlock)
@schema.output_schema.register(ParallelBlock)
def _(module: ParallelBlock, input: Schema):
if module.post:
return schema.output(module.post, input)
return schema.output_schema(module.post, input)

output = Schema()
for name, branch in module.branches.items():
branch_schema = schema.output(branch, input)
branch_schema = schema.output_schema(branch, input)

if len(branch_schema) == 1 and branch_schema.first.name == "output":
branch_schema = Schema([branch_schema.first.with_name(name)])
Expand All @@ -622,9 +622,9 @@ def _(module: ParallelBlock, input: Schema):
return output


@schema.output.register(BlockContainer)
@schema.output_schema.register(BlockContainer)
def _(module: BlockContainer, input: Schema):
return schema.output(module[-1], input) if module else input
return schema.output_schema(module[-1], input) if module else input


BlockT = TypeVar("BlockT", bound=BlockContainer)
Expand Down Expand Up @@ -720,13 +720,13 @@ def _extract_block(main, selection, route, name=None):
if isinstance(main, ParallelBlock):
return _extract_parallel(main, selection, route=route, name=name)

main_schema = schema.input(main)
route_schema = schema.input(route)
main_schema = schema.input_schema(main)
route_schema = schema.input_schema(route)

if main_schema == route_schema:
from merlin.models.torch.inputs.select import SelectFeatures

out_schema = schema.output(main, main_schema)
out_schema = schema.output_schema(main, main_schema)
if len(out_schema) == 1 and out_schema.first.name == "output":
out_schema = Schema([out_schema.first.with_name(name)])

Expand Down
12 changes: 6 additions & 6 deletions merlin/models/torch/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn

from merlin.models.torch.block import Block
from merlin.models.torch.schema import Schema, output
from merlin.models.torch.schema import Schema, output_schema
from merlin.models.torch.transforms.agg import Concat, MaybeAgg


Expand Down Expand Up @@ -84,8 +84,8 @@ def __init__(
super().__init__(*modules)


@output.register(nn.LazyLinear)
@output.register(nn.Linear)
@output.register(MLPBlock)
def _output_schema_block(module: nn.LazyLinear, input: Schema):
return output.tensors(torch.ones((1, module.out_features), dtype=float))
@output_schema.register(nn.LazyLinear)
@output_schema.register(nn.Linear)
@output_schema.register(MLPBlock)
def _output_schema_block(module: nn.LazyLinear, inputs: Schema):
return output_schema.tensors(torch.ones((1, module.out_features), dtype=float))
4 changes: 2 additions & 2 deletions merlin/models/torch/inputs/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def forward(self, inputs, batch: Batch) -> Dict[str, torch.Tensor]:

@schema.extract.register(SelectKeys)
def _(main, selection, route, name=None):
main_schema = schema.input(main)
route_schema = schema.input(route)
main_schema = schema.input_schema(main)
route_schema = schema.input_schema(route)

diff = main_schema.excluding_by_name(route_schema.column_names)

Expand Down
46 changes: 23 additions & 23 deletions merlin/models/torch/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema
return super().__call__(module, inputs)
except NotImplementedError:
raise ValueError(
f"Could not get output schema of {module} " "please call mm.trace_schema first."
f"Could not get output schema of {module} " "please call `mm.schema.trace` first."
)

def trace(
Expand Down Expand Up @@ -127,7 +127,7 @@ def _func(module: nn.Module, input: Schema) -> Schema:

def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema:
try:
_inputs = input(module)
_inputs = input_schema(module)
inputs = _inputs
except ValueError:
pass
Expand Down Expand Up @@ -156,7 +156,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema
return super().__call__(module, inputs)
except NotImplementedError:
raise ValueError(
f"Could not get output schema of {module} " "please call mm.trace_schema first."
f"Could not get output schema of {module} " "please call `mm.schema.trace` first."
)

def trace(
Expand All @@ -165,7 +165,7 @@ def trace(
inputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema],
outputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema],
) -> Schema:
_input_schema = input.get_schema(inputs)
_input_schema = input_schema.get_schema(inputs)
_output_schema = self.get_schema(outputs)

try:
Expand Down Expand Up @@ -207,8 +207,8 @@ def extract(self, module: nn.Module, selection: Selection, route: nn.Module, nam
return fn(module, selection, route, name=name)


input = _InputSchemaDispatch("input_schema")
output = _OutputSchemaDispatch("output_schema")
input_schema = _InputSchemaDispatch("input_schema")
output_schema = _OutputSchemaDispatch("output_schema")
select = _SelectDispatch("selection")
extract = _ExtractDispatch("extract")

Expand Down Expand Up @@ -240,13 +240,13 @@ def _hook(mod: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor):
mod.__input_schemas = ()
mod.__output_schemas = ()

_input_schema = input.trace(mod, inputs[0])
_input_schema = input_schema.trace(mod, inputs[0])
if _input_schema not in mod.__input_schemas:
mod.__input_schemas += (_input_schema,)
mod.__output_schemas += (output.trace(mod, _input_schema, outputs),)
mod.__output_schemas += (output_schema.trace(mod, _input_schema, outputs),)

def add_hook(m):
custom_modules = list(output.dispatcher.registry.keys())
custom_modules = list(output_schema.dispatcher.registry.keys())
if m and isinstance(m, tuple(custom_modules[1:])):
return

Expand All @@ -261,7 +261,7 @@ def add_hook(m):
return module_out


def features(module: nn.Module) -> Schema:
def feature_schema(module: nn.Module) -> Schema:
"""Extract the feature schema from a PyTorch Module.

This function operates by applying the `get_feature_schema` method
Expand Down Expand Up @@ -293,7 +293,7 @@ def get_feature_schema(module):
return feature_schema


def targets(module: nn.Module) -> Schema:
def target_schema(module: nn.Module) -> Schema:
"""
Extract the target schema from a PyTorch Module.

Expand Down Expand Up @@ -484,7 +484,7 @@ def select(self, selection: Selection) -> "Selectable":
raise NotImplementedError()


@output.register_tensor(torch.Tensor)
@output_schema.register_tensor(torch.Tensor)
def _tensor_to_schema(input, name="output"):
kwargs = dict(dims=input.shape[1:], dtype=input.dtype)

Expand All @@ -494,13 +494,13 @@ def _tensor_to_schema(input, name="output"):
return Schema([ColumnSchema(name, **kwargs)])


@input.register_tensor(torch.Tensor)
@input_schema.register_tensor(torch.Tensor)
def _(input):
return _tensor_to_schema(input, "input")


@input.register_tensor(Dict[str, torch.Tensor])
@output.register_tensor(Dict[str, torch.Tensor])
@input_schema.register_tensor(Dict[str, torch.Tensor])
@output_schema.register_tensor(Dict[str, torch.Tensor])
def _(input):
output = Schema()
for k, v in sorted(input.items()):
Expand All @@ -509,14 +509,14 @@ def _(input):
return output


@input.register_tensor(Tuple[torch.Tensor])
@output.register_tensor(Tuple[torch.Tensor])
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor])
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor])
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
@input_schema.register_tensor(Tuple[torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor])
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor])
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
def _(input):
output = Schema()

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/torch/inputs/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def test_forward(self):

outputs = mm.schema.trace(block, self.batch.features["session_id"], batch=self.batch)
assert len(outputs) == 5
assert mm.schema.input(block).column_names == ["input"]
assert mm.schema.features(block).column_names == [
assert mm.input_schema(block).column_names == ["input"]
assert mm.feature_schema(block).column_names == [
"user_id",
"country",
"user_age",
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/torch/inputs/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,27 @@ def test_extract_route_two_tower(self):
"item_recency",
"item_genres",
}
assert set(mm.schema.input(towers).column_names) == input_cols
assert mm.schema.output(towers).column_names == ["user", "item"]
assert set(mm.input_schema(towers).column_names) == input_cols
assert mm.output_schema(towers).column_names == ["user", "item"]

categorical = towers.select(Tags.CATEGORICAL)
outputs = module_utils.module_test(towers, self.batch)

assert mm.schema.extract(towers, Tags.CATEGORICAL)[1] == categorical
assert set(mm.schema.input(towers).column_names) == input_cols
assert mm.schema.output(towers).column_names == ["user", "item"]
assert set(mm.input_schema(towers).column_names) == input_cols
assert mm.output_schema(towers).column_names == ["user", "item"]

outputs = towers(self.batch.features)
assert outputs["user"].shape == (10, 10)
assert outputs["item"].shape == (10, 10)

new_inputs, route = mm.schema.extract(towers, Tags.USER)
assert mm.schema.output(new_inputs).column_names == ["user", "item"]
assert mm.output_schema(new_inputs).column_names == ["user", "item"]

assert "user" in new_inputs.branches
assert new_inputs.branches["user"][0].select_keys.column_names == ["user"]
assert "user" in route.branches
assert mm.schema.output(route).select_by_tag(Tags.EMBEDDING).column_names == ["user"]
assert mm.output_schema(route).select_by_tag(Tags.EMBEDDING).column_names == ["user"]

def test_extract_route_embeddings(self):
input_block = mm.TabularInputBlock(self.schema, init="defaults", agg="concat")
Expand All @@ -97,7 +97,7 @@ def test_extract_route_embeddings(self):
assert outputs.shape == (10, 107)

no_embs, emb_route = mm.schema.extract(input_block, Tags.CATEGORICAL)
output_schema = mm.schema.output(emb_route)
output_schema = mm.output_schema(emb_route)

assert len(output_schema.select_by_tag(Tags.USER)) == 3
assert len(output_schema.select_by_tag(Tags.ITEM)) == 3
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/torch/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ def test_output_schema(self):
"b": torch.tensor([[5.0, 6.0], [7.0, 8.0]]),
}
outputs = mm.schema.trace(model, inputs)
schema = mm.schema.output(model)
schema = mm.output_schema(model)
for name in outputs:
assert name in schema.column_names
assert schema[name].dtype.name == str(outputs[name].dtype).split(".")[-1]

def test_no_output_schema(self):
model = mm.Model(PlusOne())
with pytest.raises(ValueError, match="Could not get output schema of PlusOne()"):
mm.schema.output(model)
mm.output_schema(model)

def test_train_classification_with_lightning_trainer(self, music_streaming_data, batch_size=16):
schema = music_streaming_data.schema.select_by_name(
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/torch/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_identity(self):
outputs = module_utils.module_test(block, inputs, batch=Batch(inputs))

assert torch.equal(inputs, outputs)
assert mm.schema.output(block) == mm.schema.output.tensors(inputs)
assert mm.output_schema(block) == mm.output_schema.tensors(inputs)

def test_insertion(self):
block = Block()
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_schema_tracking(self):

inputs = torch.randn(1, 3)
outputs = mm.schema.trace(pb, inputs)
schema = mm.schema.output(pb)
schema = mm.output_schema(pb)

for name in outputs:
assert name in schema.column_names
Expand Down Expand Up @@ -258,9 +258,9 @@ def test_set_pre(self):
def test_input_schema_pre(self):
pb = ParallelBlock({"a": PlusOne(), "b": PlusOne()})
outputs = mm.schema.trace(pb, torch.randn(1, 3))
input_schema = mm.schema.input(pb)
input_schema = mm.input_schema(pb)
assert len(input_schema) == 1
assert len(mm.schema.output(pb)) == 2
assert len(mm.output_schema(pb)) == 2
assert len(outputs) == 2

pb2 = ParallelBlock({"a": PlusOne(), "b": PlusOne()})
Expand All @@ -270,8 +270,8 @@ def test_input_schema_pre(self):
assert get_pre(pb2)[0] == pb
pb2.append(pb)

assert input_schema == mm.schema.input(pb2)
assert mm.schema.output(pb2) == mm.schema.output(pb)
assert input_schema == mm.input_schema(pb2)
assert mm.output_schema(pb2) == mm.output_schema(pb)

def test_leaf(self):
block = ParallelBlock({"a": PlusOne()})
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/torch/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,4 @@ def test_nested(self):

outputs = module_utils.module_test(nested, self.batch.features)
assert list(outputs.keys()) == ["user_age"]
assert "user_age" in mm.schema.output(nested).column_names
assert "user_age" in mm.output_schema(nested).column_names
Loading