Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Softmax kernel in Triton. Use softmax kernel and argmax in Llama generation.py. + Small changes #11

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

catherinelee274
Copy link

@catherinelee274 catherinelee274 commented Sep 24, 2024

  • Update .gitignore
  • Add README.md
  • Add README.md for pytests
  • Add fused softmax kernel and use in generation.py
  • Added pytest for softmax
  • Use triton.argmax in generation.py
  • Add line dist.destroy_process_group() to remove warning during benchmarking

Results from calling python3 main.py llama_chat_completion --benchmark --ckpt_dir <model_checkpoint_path> --tokenizer_path <model_tokenizer_path>

With No Changes:

|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| kernel                  | kernel_path                                                                                                              | triton                 | non_triton             | triton-non_triton       |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion                                                                                                          |     23.363035631999992 |      23.20719621399985 |     0.15583941800014145 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion.chat_completion                                                                                          |     15.086765727000056 |     15.037501877000068 |     0.04926384999998845 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| generate                | chat_completion.chat_completion.generate                                                                                 |     15.085606602000098 |     15.036371463999785 |     0.04923513800031287 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.softmax                                                                         | 3.286413995026158e-05  | 3.2326578084264846e-05 | 5.375618659967339e-07   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transformer_forward     | chat_completion.chat_completion.generate.transformer_forward                                                             |    0.02763666868558919 |     0.0275613940689651 | 7.527461662408877e-05   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.RMSNorm                                                     | 5.827654970598593e-05  | 5.889426978790332e-05  | -6.177200819173836e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward                                     |  0.0008483634264688104 |  0.0008461694928391175 | 2.1939336296929223e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm                             | 5.88654680211981e-05   | 5.911337566830836e-05  | -2.4790764711026176e-07 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention_forward       | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward                   |   0.000491074871831931 |  0.0004896370724516646 | 1.437799380266374e-06   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| apply_rotary_emb        | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb  | 8.806513197629571e-05  | 8.784457036059364e-05  | 2.2056161570207199e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention               | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention         | 0.00017238609349734332 | 0.00017186252022342258 | 5.235732739207336e-07   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| matmul                  | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul  | 4.327956193072774e-05  | 4.3119774561757945e-05 | 1.597873689697973e-07   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.88178547136412e-05   | 2.8733377409596218e-05 | 8.4477304044983e-08     |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| feed_forward_forward    | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward                | 0.00012714412176862338 | 0.00012642639452434758 | 7.177272442757999e-07   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| precompute_freqs_cis    | chat_completion.precompute_freqs_cis                                                                                     |  0.0003592040002331487 | 0.00034214400011478574 | 1.706000011836295e-05   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|

With just softmax

|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| kernel                  | kernel_path                                                                                                              | triton                 | non_triton             | triton-non_triton       |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion                                                                                                          |     23.529098738999892 |                        |                         |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion.chat_completion                                                                                          |     15.229127281999808 |     23.885352947999763 |      -8.656225665999955 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| generate                | chat_completion.chat_completion.generate                                                                                 |     15.228001847000087 |     15.719125695999992 |      -0.491123848999905 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.softmax                                                                         | 3.3104419885907815e-05 |     15.717982729999676 |      -15.71794962557979 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transformer_forward     | chat_completion.chat_completion.generate.transformer_forward                                                             |   0.028188293081122473 |   0.028127436847877902 | 6.0856233244570984e-05  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.RMSNorm                                                     | 5.8095263698646935e-05 | 5.832987626748923e-05  | -2.3461256884229476e-07 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward                                     |  0.0008650784662146882 |  0.0008631522671130063 | 1.9261991016819406e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm                             | 5.8953483520847516e-05 | 5.920526296552413e-05  | -2.517794446766161e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention_forward       | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward                   |   0.000502045900099974 |  0.0005004565054525663 | 1.5893946474076093e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| apply_rotary_emb        | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb  | 8.903067526494468e-05  | 8.906119377373757e-05  | -3.05185087928929e-08   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention               | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention         | 0.00017652477433859322 | 0.00017632209013441347 | 2.0268420417975867e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| matmul                  | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul  | 4.470880419631436e-05  | 4.464734606365697e-05  | 6.145813265738515e-08   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.94319926492972e-05   | 2.9444132737926863e-05 | -1.2140088629663967e-08 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| feed_forward_forward    | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward                | 0.00013109435198944578 | 0.00013054154848626072 | 5.528035031850662e-07   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| precompute_freqs_cis    | chat_completion.precompute_freqs_cis                                                                                     |  0.0003498339997349831 |  0.0003586540001379035 | -8.820000402920414e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|

With softmax and argmax

|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| kernel                  | kernel_path                                                                                                              | triton                 | non_triton             | triton-non_triton       |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion                                                                                                          |     23.316155643000002 |                        |                         |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| chat_completion         | chat_completion.chat_completion                                                                                          |     15.026287104999938 |       23.7322098059999 |      -8.705922700999963 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| generate                | chat_completion.chat_completion.generate                                                                                 |     15.025166173999878 |     15.588987290000205 |     -0.5638211160003266 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.softmax                                                                         | 3.194667139574176e-05  |     15.587871648000146 |      -15.58783970132875 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transformer_forward     | chat_completion.chat_completion.generate.transformer_forward                                                             |    0.02752115360446059 |   0.027770376318450807 |   -0.000249222713990218 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.RMSNorm                                                     | 5.7826257591456e-05    | 5.821012373749101e-05  | -3.838661460350075e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward                                     |  0.0008447297065781728 |   0.000852438977687303 | -7.709271109130212e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| RMSNorm                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm                             | 5.789487043606993e-05  | 5.838420017865719e-05  | -4.893297425872609e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention_forward       | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward                   |  0.0004889343329106125 |  0.0004935158465387692 | -4.581513628156746e-06  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| apply_rotary_emb        | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb  | 8.793432188105586e-05  | 8.822047401022832e-05  | -2.861521291724629e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| attention               | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention         | 0.00017178861625323266 | 0.00017370242444190314 | -1.9138081886704806e-06 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| matmul                  | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul  | 4.3013189750603796e-05 | 4.373048187156608e-05  | -7.172921209622871e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| softmax                 | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.8741690918662413e-05 | 2.8887915059976048e-05 | -1.462241413136349e-07  |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| feed_forward_forward    | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward                | 0.00012659588773971398 | 0.00012815826540009848 | -1.5623776603845016e-06 |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|
| precompute_freqs_cis    | chat_completion.precompute_freqs_cis                                                                                     |  0.0003504739997879369 |  0.0003422939998927177 | 8.179999895219225e-06   |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------|

@catherinelee274 catherinelee274 changed the title Add Softmax kernel in triton. Use softmax and argmax in llama generation. Add Softmax kernel in Triton. Use softmax kernel and argmax in Llama generation.py. + Small changes Sep 24, 2024
@catherinelee274 catherinelee274 marked this pull request as ready for review September 24, 2024 08:08
import pandas as pd


def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason you are deleting this?

Comment on lines +196 to +199
if self.use_triton:
probs = triton_softmax(logits[:,-1])
else:
probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a mathOps file in this directory to abstract this away from users. Lets use that instead and allow for proper benchmarking (see decorator on the functions there)

Comment on lines +203 to +206
if self.use_triton:
next_token = self.triton.language.argmax(logits[:, -1], axis=-1)
else:
next_token = self.Math.argmax(logits[:, -1], dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

import pytest
from kernels.fused_softmax import triton_softmax

@pytest.mark.parametrize("input_size", [(1024, 1024), (512, 512), (2048, 512)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding tests!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants