-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding MMOE & PLE #1173
Adding MMOE & PLE #1173
Conversation
Documentation preview |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The API is very clean. I debugged the tests to better understand and implementation seems ok. Just added some optional suggestions
def test_init_with_outputs(self): | ||
outputs = mm.ParallelBlock({"a": mm.BinaryOutput(), "b": mm.BinaryOutput()}) | ||
outputs.prepend_for_each(mm.MLPBlock([2])) | ||
outputs.prepend(MMOEBlock(mm.MLPBlock([2, 2]), 2, outputs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The API is very clean!. But building MTL models from the output to the input kind of works, but might be counter-intuitive. Will that be the pattern for MTL models?
But I think I understand the challenges of building such MMOE models from the inputs, as gates depend on # experts and number of gates and towers depend on numbers of outputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it might make sense to add a MMOEOutputs
or something that does this under-the-hood.
outputs : ParallelBlock | ||
The output block. | ||
shared_gate : bool, optional | ||
If true, use a shared gate for all tasks. Defaults to False. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this argument name and docstrings is misleading. When num_shared_experts>0
all task gates will use the shared experts, right?
This shared_gate
argument seems to be responsible to make it available the output of the shared expert (shortcut
and experts
keys) from one CGCBlock to the next one in PLE architecture? If that is the case, we could try and clarify that.
"outputs": outputs, | ||
} | ||
super().__init__(*CGCBlock(shared_gate=True, **cgc_kwargs).repeat(depth - 1)) | ||
self.append(CGCBlock(**cgc_kwargs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good trick to avoid outputting the shared experts in the last layer.
Goals ⚽
This PR introduces mixture-of-experts + PLE/CGC. With this we should be able to write a pytorch version of the multi-task blogpost.