From f0f52913f6292bf6bd753b03ab5346aba6a5dbd8 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 11 Jul 2023 16:38:35 +0200 Subject: [PATCH] Apply rename to initialize_from_schema to ContrastiveOutput --- merlin/models/torch/outputs/contrastive.py | 4 ++-- tests/unit/torch/outputs/test_constrastive.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/merlin/models/torch/outputs/contrastive.py b/merlin/models/torch/outputs/contrastive.py index 25a386e41e..733a3d69b4 100644 --- a/merlin/models/torch/outputs/contrastive.py +++ b/merlin/models/torch/outputs/contrastive.py @@ -74,7 +74,7 @@ def __init__( ) if schema: - self.setup_schema(schema) + self.initialize_from_schema(schema) self.init_hook_handle = self.register_forward_pre_hook(self.initialize) if not torch.jit.is_scripting(): @@ -121,7 +121,7 @@ def tie_weights( return self - def setup_schema(self, target: Union[ColumnSchema, Schema]): + def initialize_from_schema(self, target: Union[ColumnSchema, Schema]): """Set up the schema for the output. Parameters diff --git a/tests/unit/torch/outputs/test_constrastive.py b/tests/unit/torch/outputs/test_constrastive.py index aa0c1ba8bf..c616978f68 100644 --- a/tests/unit/torch/outputs/test_constrastive.py +++ b/tests/unit/torch/outputs/test_constrastive.py @@ -13,7 +13,7 @@ class TestContrastiveOutput: - def test_setup_schema(self, item_id_col_schema, user_id_col_schema): + def test_initialize_from_schema(self, item_id_col_schema, user_id_col_schema): contrastive = ContrastiveOutput() dot = ContrastiveOutput(schema=Schema([item_id_col_schema, user_id_col_schema])) @@ -23,10 +23,10 @@ def test_setup_schema(self, item_id_col_schema, user_id_col_schema): assert isinstance(target.to_call, CategoricalTarget) with pytest.raises(ValueError): - contrastive.setup_schema(1) + contrastive.initialize_from_schema(1) with pytest.raises(ValueError): - contrastive.setup_schema(Schema(["a", "b", "c"])) + contrastive.initialize_from_schema(Schema(["a", "b", "c"])) def test_outputs_without_downscore(self, item_id_col_schema): contrastive = ContrastiveOutput(item_id_col_schema, downscore_false_negatives=False)