Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow finetuning of mistral models using the HuggingFace Flax LM classes #97

Open
TheodoreGalanos opened this issue Sep 13, 2021 · 2 comments
Labels
first issue Good first issue for familiarizing yourself with the codebase

Comments

@TheodoreGalanos
Copy link

It would be amazing if we could load and finetune the models on TPUs using the flax LM classes in HF. In my experience, this makes the training and generation very straightforward on TPUs, along ofc with taking advantage of their compute.

I have tried to load a mistral checkpoint with the following code:
model = FlaxAutoModelForCausalLM.from_pretrained("alias/arwen-x21-checkpoint-400000", from_pt=True, pad_token_id=50256, )
This seems to work. The model loads, I can access its properties, and can even generate text.

However, once I try to fine tune it, using (more or less) the code here: https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_clm_flax.py, it takes about 10mins to compile and then about 5mins for each step (for reference, in this should be 2mins and some seconds respectively got gpt2-medium).

Finally, it would be nice if the changes in mistral models were smh included when loading the model in HF (I am actually not 100% sure that does not happen). Specifically, I'm thinking of this line here:

scale_factor = 1 / ((float(v.size(-1)) ** 0.5) * self.layer_num)

Hope this makes sense. Thank you in advance!

Best,
Theodore.

@siddk
Copy link
Contributor

siddk commented Sep 13, 2021

Hey Theodore - so we're definitely working on pushing the Mistral-specific operation changes (like the one you mentioned) to Transformers proper, as a flag in the GPT-2 Model class. This should happen by the end of the week (or at least, we'll have a PR in transformers you can use!).

As for why the Flax code is running slower - that's super interesting, and I don't have a good answer! Could be some weird interaction between the way we handle the upcasting code and defaults in the run_clm_flax.py script. Would be great if you could do some digging (or create an issue/PR!) as we're not too familiar with Flax ourselves, otherwise, I'll take a look when I can!

@TheodoreGalanos
Copy link
Author

Hello,

Bumping this real quick. I haven't checked in a while, so excuse me if this was done, but is it done? :)

Would love to finetune some mistral models on TPUs.

@dlwh dlwh added the first issue Good first issue for familiarizing yourself with the codebase label Jul 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
first issue Good first issue for familiarizing yourself with the codebase
Projects
None yet
Development

No branches or pull requests

3 participants