-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ContrastiveOutput #1191
Adding ContrastiveOutput #1191
Conversation
Documentation preview |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me. Good job!
assert isinstance(dot.to_call, DotProduct) | ||
|
||
target = ContrastiveOutput(schema=Schema([item_id_col_schema])) | ||
assert isinstance(target.to_call, CategoricalTarget) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is neat! Same interface for contrastive for two-tower like architectures and for sampled softmax.
from merlin.models.torch.block import registry | ||
|
||
|
||
class LogUniformSampler(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@marcromeyn Have you checked the LogUniformSampler class I created in T4Rec based on this one?
I adds some additional options like returning unique samples and also provides the probs for the items for logQ correction (considering whether returning unique samples or popularity biased).
It matches the implementation of sampling and probs from tf.random.log_uniform_candidate_sampler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not, yours looks much better! I can port that
self.false_negative_score = false_negative_score | ||
|
||
@classmethod | ||
def with_weight_tying( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great syntax sugar!
Goals ⚽
This PR builds on the
CategoricalOutput
and addsContrastiveOutput
. This can be used with dot-product, categorical-prediction or weight-tying.Implementation Details 🚧
The contrastive part of
CategoricalOutput
doesn't work with torch-script, this is fine since it's during training only.