Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rewrite for matmul when only one of the inputs has batched dimensions #558

Merged

Conversation

ricardoV94
Copy link
Member

Description

This rewrites replaces a batched matmul by a core matmul by raveling the batched dimensions of the batched input, doing a 2D matmul and then unravelling the batched dimensions of the output.

This sidesteps the Blockwise Dot operation and allows specialization into BLAS routines.

The idea was taken from these two discussions:
numpy/numpy#7569
numpy/numpy#8957

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@codecov-commenter
Copy link

Codecov Report

Merging #558 (d28b35f) into main (60a9566) will increase coverage by 0.00%.
The diff coverage is 95.45%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #558   +/-   ##
=======================================
  Coverage   80.91%   80.92%           
=======================================
  Files         162      162           
  Lines       46436    46458   +22     
  Branches    11357    11364    +7     
=======================================
+ Hits        37576    37597   +21     
  Misses       6638     6638           
- Partials     2222     2223    +1     
Files Coverage Δ
pytensor/tensor/rewriting/math.py 87.18% <95.45%> (+0.11%) ⬆️

…ions

This rewrites replaces a batched matmul by a core matmul by raveling the batched dimensions of the batched input, doing a 2D matmul and then unravelling the batched dimensions of the output.

This sidesteps the Blockwise Dot operation and allows specialization into BLAS routines.

The idea was taken from these two discussions:
numpy/numpy#7569
numpy/numpy#8957
@ricardoV94 ricardoV94 force-pushed the batched_matmul_to_reshaped_dot_rewrite branch from d28b35f to b19c6ab Compare December 18, 2023 11:29
@ricardoV94 ricardoV94 merged commit c52154d into pymc-devs:main Dec 20, 2023
53 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants