Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Adding a Deep Nearest Class Means Classifier model to Flair #3531

Open
sheldon-roberts opened this issue Aug 19, 2024 · 0 comments
Labels
feature A new feature

Comments

@sheldon-roberts
Copy link
Contributor

Problem statement

Flair has a decoder (PrototypicalDecoder) inspired by the paper Prototypical Networks for Few-shot Learning, but there is a notable difference in how the prototypes are being calculated.

The original paper states that we should "take a class’s prototype to be the mean of its support set in the embedding space". The PrototypicalDecoder, however, treats class prototypes as learnable model parameters that simply get updated during back propagation. This approach has some drawbacks:

  1. It compromises important theoretical properties, such as the equivalence to mixture density estimation on the support set.
  2. Randomly initializing prototypes fails to leverage the knowledge captured in pre-trained embeddings, potentially slowing down convergence and reducing performance.
  3. Poor performance on few-shot classification and incremental learning tasks (compared to the class means approach)

TL;DR
It would be nice to have a model in Flair that uses prototypes such that each class is represented by the mean of its examples.

Solution

I want to add a model to Flair that uses a class-mean update rule like a Prototypical Network, but forgoes the episodic training in order to remain compatible with the existing model trainers. There happens to be a paper called DEEP NEAREST CLASS MEAN CLASSIFIERS which defines update rules that do exactly that. I have experimented with this approach, and produced models that significantly outperform the PrototypicalDecoder at certain few-shot training tasks.

This model has been very useful for me, and I'm sure others would benefit too if it was supported by Flair.

Additional Context

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A new feature
Projects
None yet
Development

No branches or pull requests

1 participant