Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax is slower than NumPy #11

Open
certik opened this issue Feb 16, 2023 · 1 comment
Open

Jax is slower than NumPy #11

certik opened this issue Feb 16, 2023 · 1 comment

Comments

@certik
Copy link
Contributor

certik commented Feb 16, 2023

With #10, I get the following timings with NumPy on my Apple M1 Max:

$ time python gpt2.py "Alan Turing theorized that computers would one day become" -n 40
generating: 100%|███████████████████████████████| 40/40 [00:18<00:00,  2.13it/s]
 the most powerful machines on the planet.

The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.

python gpt2.py "Alan Turing theorized that computers would one day become" -n  115.74s user 1.71s system 559% cpu 20.993 total

And Jax:

$ time python gpt2.py "Alan Turing theorized that computers would one day become" -n 40
generating: 100%|███████████████████████████████| 40/40 [00:21<00:00,  1.85it/s]
 the most powerful machines on the planet.

The computer is a machine that can perform complex calculations, and it can perform these calculations in a way that is very similar to the human brain.

python gpt2.py "Alan Turing theorized that computers would one day become" -n  28.86s user 1.91s system 127% cpu 24.115 total

So Jax is slower. Using htop Jax is using roughly 1.3 CPU cores, while NumPy is using almost 6 CPU cores. Is NumPy automatically parallel on macOS?

Here is my Conda environment:

$ conda env export
name: pico
channels:
  - conda-forge
dependencies:
  - appdirs=1.4.4=pyh9f0ad1d_0
  - appnope=0.1.3=pyhd8ed1ab_0
  - asttokens=2.2.1=pyhd8ed1ab_0
  - backcall=0.2.0=pyh9f0ad1d_0
  - backports=1.0=pyhd8ed1ab_3
  - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
  - brotlipy=0.7.0=py39h02fc5c5_1005
  - bzip2=1.0.8=h3422bc3_4
  - c-ares=1.18.1=h3422bc3_0
  - ca-certificates=2022.12.7=h4653dfc_0
  - cffi=1.15.1=py39h7e6b969_3
  - cryptography=39.0.1=py39he2a39a8_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - executing=1.2.0=pyhd8ed1ab_0
  - idna=3.4=pyhd8ed1ab_0
  - ipython=8.10.0=pyhd1c38e8_0
  - jax=0.4.3=pyhd8ed1ab_0
  - jaxlib=0.4.3=cpu_py39h99d3290_1
  - jedi=0.18.2=pyhd8ed1ab_0
  - libabseil=20220623.0=cxx17_h28b99d4_6
  - libblas=3.9.0=16_osxarm64_openblas
  - libcblas=3.9.0=16_osxarm64_openblas
  - libcxx=14.0.6=h2692d47_0
  - libffi=3.4.2=h3422bc3_5
  - libgfortran=5.0.0=11_3_0_hd922786_27
  - libgfortran5=11.3.0=hdaf2cc0_27
  - libgrpc=1.51.1=hb15be72_1
  - liblapack=3.9.0=16_osxarm64_openblas
  - libopenblas=0.3.21=openmp_hc731615_3
  - libprotobuf=3.21.12=hb5ab8b9_0
  - libsqlite=3.40.0=h76d750c_0
  - libzlib=1.2.13=h03a7124_4
  - llvm-openmp=15.0.7=h7cfbb63_0
  - matplotlib-inline=0.1.6=pyhd8ed1ab_0
  - ncurses=6.3=h07bb92c_1
  - openssl=3.0.8=h03a7124_0
  - opt_einsum=3.3.0=pyhd8ed1ab_1
  - packaging=23.0=pyhd8ed1ab_0
  - parso=0.8.3=pyhd8ed1ab_0
  - pexpect=4.8.0=pyh1a96a4e_2
  - pickleshare=0.7.5=py_1003
  - pip=23.0=pyhd8ed1ab_0
  - pooch=1.6.0=pyhd8ed1ab_0
  - prompt-toolkit=3.0.36=pyha770c72_0
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pure_eval=0.2.2=pyhd8ed1ab_0
  - pycparser=2.21=pyhd8ed1ab_0
  - pygments=2.14.0=pyhd8ed1ab_0
  - pyopenssl=23.0.0=pyhd8ed1ab_0
  - pysocks=1.7.1=pyha2e5f31_6
  - python=3.9.16=hea58f1e_0_cpython
  - python_abi=3.9=3_cp39
  - re2=2023.02.01=hb7217d7_0
  - readline=8.1.2=h46ed386_0
  - scipy=1.10.0=py39h18313fe_2
  - setuptools=67.1.0=pyhd8ed1ab_0
  - six=1.16.0=pyh6c4a22f_0
  - stack_data=0.6.2=pyhd8ed1ab_0
  - tk=8.6.12=he1e0b03_0
  - traitlets=5.9.0=pyhd8ed1ab_0
  - tzdata=2022g=h191b570_0
  - urllib3=1.26.14=pyhd8ed1ab_0
  - wcwidth=0.2.6=pyhd8ed1ab_0
  - wheel=0.38.4=pyhd8ed1ab_0
  - xz=5.2.6=h57fd34a_0
  - zlib=1.2.13=h03a7124_4
  - pip:
    - absl-py==1.4.0
    - astunparse==1.6.3
    - cachetools==5.3.0
    - certifi==2022.12.7
    - charset-normalizer==2.0.12
    - fire==0.5.0
    - flatbuffers==23.1.21
    - gast==0.4.0
    - google-auth==2.16.0
    - google-auth-oauthlib==0.4.6
    - google-pasta==0.2.0
    - grpcio==1.51.1
    - h5py==3.8.0
    - importlib-metadata==6.0.0
    - keras==2.11.0
    - libclang==15.0.6.1
    - markdown==3.4.1
    - markupsafe==2.1.2
    - numpy==1.24.1
    - oauthlib==3.2.2
    - protobuf==3.19.6
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - regex==2017.4.5
    - requests==2.27.1
    - requests-oauthlib==1.3.1
    - rsa==4.9
    - tensorboard==2.11.2
    - tensorboard-data-server==0.6.1
    - tensorboard-plugin-wit==1.8.1
    - tensorflow-estimator==2.11.0
    - tensorflow-macos==2.11.0
    - termcolor==2.2.0
    - tqdm==4.64.0
    - typing-extensions==4.4.0
    - werkzeug==2.2.2
    - wrapt==1.14.1
    - zipp==3.13.0
prefix: /Users/ondrej/mambaforge/envs/pico
@jaymody
Copy link
Owner

jaymody commented Feb 16, 2023

Curious, I would've expected jax to be faster given that it executes asynchronously (which should effectively make this line out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] parallel, while numpy would execute sequentially since each call is eager and blocking).

Not sure how jax handles multiple CPUs, I know you can manually set multiple CPUs with the environment var export XLA_FLAGS="--xla_force_host_platform_device_count=8", but that didn't yield a speedup for me.

Relevant link: https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants