Skip to content

Commit

Permalink
Merge branch 'main' into torch/experts
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn authored Jul 11, 2023
2 parents ee8c419 + 190cd48 commit a3a39a9
Show file tree
Hide file tree
Showing 9 changed files with 1,185 additions and 0 deletions.
15 changes: 15 additions & 0 deletions merlin/models/torch/outputs/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import merlin.dtypes as md
from merlin.models.torch import schema
from merlin.models.torch.batch import Batch
from merlin.models.torch.inputs.embedding import EmbeddingTable
from merlin.models.torch.outputs.base import ModelOutput
from merlin.schema import ColumnSchema, Schema, Tags
Expand Down Expand Up @@ -288,6 +289,12 @@ def embeddings(self) -> nn.Parameter:
"""
return self.linear.weight.t()

def should_apply_contrastive(self, batch: Optional[Batch]) -> bool:
if batch is not None and batch.targets and self.training:
return True

return False


class EmbeddingTablePrediction(nn.Module):
"""Prediction of a categorical feature using weight-sharing [1] with an embedding table.
Expand Down Expand Up @@ -318,6 +325,7 @@ def __init__(self, table: EmbeddingTable, selection: Optional[schema.Selection]
self.num_classes = table.num_embeddings
self.col_schema = table.input_schema.first
self.col_name = self.col_schema.name
self.target_name = self.col_name
self.bias = nn.Parameter(
torch.zeros(self.num_classes, dtype=torch.float32, device=self.embeddings().device)
)
Expand Down Expand Up @@ -368,6 +376,7 @@ def add_selection(self, selection: schema.Selection):
self.col_name = self.col_schema.name
self.num_classes = self.col_schema.int_domain.max + 1
self.output_schema = categorical_output_schema(self.col_schema, self.num_classes)
self.target_name = self.col_name

return self

Expand Down Expand Up @@ -399,6 +408,12 @@ def embedding_lookup(self, inputs: torch.Tensor) -> torch.Tensor:
"""
return self.table({self.col_name: inputs})[self.col_name]

def should_apply_contrastive(self, batch: Optional[Batch]) -> bool:
if batch is not None and batch.targets and self.training:
return True

return False


def categorical_output_schema(target: ColumnSchema, num_classes: int) -> Schema:
"""Return the output schema given the target column schema."""
Expand Down
Loading

0 comments on commit a3a39a9

Please sign in to comment.