Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New kernel implementations, nominal model implementation, and training CLI #12

Merged
merged 117 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
dc2949b
notebook: functional notebook for refactorization
laserkelvin Aug 15, 2024
b5cfb61
feat: added utility function to calculate number of parallel blocks
laserkelvin Aug 19, 2024
37d1652
feat: added unravel index function
laserkelvin Aug 19, 2024
1827344
refactor: modifying num block calculation with matching arithmetic
laserkelvin Aug 21, 2024
5e56f15
feat: added direct second order implementation
laserkelvin Aug 21, 2024
fa442b9
feat: adding automatically formatted kernels
laserkelvin Aug 21, 2024
b0b506f
style: ran ruff on modules
laserkelvin Aug 22, 2024
2b3e5b9
notebook: updated implementations to substitute powers
laserkelvin Aug 22, 2024
00fe5d9
Merge branch 'individual-l-embeds' of github.com:laserkelvin/EquiTrit…
laserkelvin Aug 22, 2024
998b3cc
refactor: constraining export to wrapper class for second order
laserkelvin Aug 22, 2024
72d7630
feat: implemented fifth order
laserkelvin Aug 22, 2024
b7e62a4
feat: exposing direct wrapper functions
laserkelvin Aug 22, 2024
a6d128c
refactor & feat: added standardized utility functions for running sph…
laserkelvin Aug 22, 2024
5a4f388
test: refactored unit test to be generalized
laserkelvin Aug 22, 2024
581526d
fix: correcting programmatic library importing mechanism
laserkelvin Aug 22, 2024
ffa0ce5
feat: implemented tenth order terms
laserkelvin Aug 22, 2024
17b0234
feat: exposing tenth order to direct namespace
laserkelvin Aug 22, 2024
15ca83b
feat & test: functioning tenth order done
laserkelvin Aug 22, 2024
ddb0138
chore: exposing utility function to direct namespace
laserkelvin Aug 22, 2024
b3a157a
feat: implemented fused second order kernel
laserkelvin Aug 22, 2024
bbf5313
docs: added short write up for direct
laserkelvin Aug 22, 2024
269807d
feat: implemented third order terms
laserkelvin Aug 24, 2024
768b44d
fix: correcting output allocation shape for third order
laserkelvin Aug 24, 2024
410ba7b
tests: updating unit test for third order
laserkelvin Aug 24, 2024
b6ff222
feat: implemented fourth order terms
laserkelvin Aug 24, 2024
1d4f998
test: added parametrized test for fourth order
laserkelvin Aug 24, 2024
e7fdfac
feat: added sixth order terms
laserkelvin Aug 24, 2024
e1dd9b9
test: added parameters for sixth order test
laserkelvin Aug 24, 2024
7b7c1b7
feat: added seventh order terms
laserkelvin Aug 24, 2024
6c32ab4
test: parameterized seventh order
laserkelvin Aug 24, 2024
ef59bd3
feat: added eighth order terms
laserkelvin Aug 24, 2024
07c06ad
tests: parameterized eighth order
laserkelvin Aug 24, 2024
ce09f8b
feat: added ninth order terms
laserkelvin Aug 24, 2024
812e9c3
fix: correcting eighth order backward stride
laserkelvin Aug 24, 2024
0988acf
test: parameterized ninth order tests
laserkelvin Aug 24, 2024
4ed1f98
feat: added utility function for calculating irreps shapes
laserkelvin Aug 26, 2024
6f97a72
fix: making assert check actually positive instead
laserkelvin Aug 26, 2024
aae4a7a
feat: added first order terms
laserkelvin Aug 26, 2024
ea428ab
feat: added hacky zeroth order and udpate tests
laserkelvin Aug 26, 2024
ef174d9
deps: added training dependencies
laserkelvin Aug 27, 2024
7243fe8
git: ignoring aux files
laserkelvin Aug 27, 2024
a5d9293
notebook: added self-contained notebook for testing baseline architec…
laserkelvin Aug 27, 2024
2932285
notebook: updating latest direct evaluation notebook
laserkelvin Aug 27, 2024
4fda13e
baseline notebook: equivariance test?
migalkin Aug 27, 2024
2e7b3d2
notebook: added e3nn equivariance checks
laserkelvin Aug 28, 2024
975d782
feat: transcribed block implementations from notebook
laserkelvin Aug 28, 2024
f54f278
feat: defining __all__ for blocks
laserkelvin Aug 28, 2024
6709e09
feat: defining model namespace
laserkelvin Aug 28, 2024
fcb314f
chore: adding explicit dependencies
laserkelvin Aug 28, 2024
c4f054e
notebook: updating notebook with fix to subsequent layers needing irr…
laserkelvin Aug 28, 2024
a4c7829
chore: bumping version to 0.2.0
laserkelvin Aug 28, 2024
cc7ab79
chore: adding jsonargparse as dependency
laserkelvin Aug 28, 2024
a60ab78
feat: added CLI script to train QM9
laserkelvin Aug 28, 2024
fc5136b
fix: resolving e3nn nn import
laserkelvin Aug 28, 2024
783245f
fix: correcting data loader import to pyg
laserkelvin Aug 28, 2024
95edbd4
fix: addressing config overwrite and data loader spawning
laserkelvin Aug 29, 2024
4674a92
config: added config files for QM9 experiments
laserkelvin Aug 29, 2024
4a66d62
git: ignoring logging and dataset folders
laserkelvin Aug 29, 2024
c35864a
chore: added requirements file for experiments
laserkelvin Aug 29, 2024
013c9ce
refactor: moving lightning components to equitriton module
laserkelvin Aug 29, 2024
366eeac
refactor: using imported lightning modules in train script
laserkelvin Aug 29, 2024
fcdf0ef
script: adding SMILES output to phate analysis
laserkelvin Aug 30, 2024
399ba69
feat: adding embedding function to model
laserkelvin Aug 30, 2024
fad8e4f
feat: updating __all__ definition for utils
laserkelvin Aug 30, 2024
08cf6a3
feat: added utility function for number of projections
laserkelvin Aug 30, 2024
fd01042
feat: adding utility functions for grabbing triton implementations
laserkelvin Aug 30, 2024
facb0d8
refactor: changing spherical harmonic interface with preallocation
laserkelvin Sep 2, 2024
5ae8198
refactor: making first and third order accept offset values
laserkelvin Sep 2, 2024
636602b
refactor: second order fwd with col offset
laserkelvin Sep 2, 2024
f45c4a1
refactor: second order backward with col offset
laserkelvin Sep 2, 2024
4c340d0
refactor: updating autograd.Function for second order
laserkelvin Sep 2, 2024
b58794e
refactor: making third order autograd Function accept output stride
laserkelvin Sep 2, 2024
39a0871
refactor: making second order autograd.Function consistent with third…
laserkelvin Sep 2, 2024
a57ae4e
refactor: making fourth order kernels accept stride and col offset
laserkelvin Sep 2, 2024
bc07c01
refactor: updating fourth order interface for consistency
laserkelvin Sep 2, 2024
6385833
refactor: making fifth order accept col offset and stride
laserkelvin Sep 2, 2024
79d840b
refactor: updating fifth order autograd.Function interface
laserkelvin Sep 2, 2024
12fd185
refactor: allowing sixth order to accept col offset and output stride
laserkelvin Sep 2, 2024
ebf13b6
refactor: updating sixth order to accept col offset
laserkelvin Sep 2, 2024
ad1aae9
refactor: updating seventh order to accept col offset and output stride
laserkelvin Sep 2, 2024
f481663
refactor: updating seventh order autograd.Function interface
laserkelvin Sep 2, 2024
9a8ca9e
refactor: eighth order with col offset and output stride
laserkelvin Sep 2, 2024
ad84c00
refactor: updating eighth order Function interface
laserkelvin Sep 2, 2024
690cd43
refactor: ninth order with col offset and output stride
laserkelvin Sep 2, 2024
c0eca7d
refactor: updating ninth order Function with new interface
laserkelvin Sep 2, 2024
0839101
refactor: tenth order with col offset and output stride
laserkelvin Sep 2, 2024
6830d4c
refactor: updating tenth order Function interface
laserkelvin Sep 2, 2024
db8f09e
refactor: making zeroth order consistent
laserkelvin Sep 2, 2024
bdd9cb1
refactor: adding coord ptr back to zeroth order for consistency
laserkelvin Sep 2, 2024
5e08128
refactor: updating zeroth order Function interface for consistency
laserkelvin Sep 2, 2024
1f25519
refactor: correcting zeroth order bwd for consistency
laserkelvin Sep 2, 2024
eb699b2
fix: making second order add to gradients instead of overwrite
laserkelvin Sep 2, 2024
86ba0bc
refactor: making third order add to gradients
laserkelvin Sep 2, 2024
378dc60
refactor: making fourth order add to gradients
laserkelvin Sep 2, 2024
2cc5f11
refactor: making fifth order add to gradients
laserkelvin Sep 2, 2024
27f95a5
refactor: making sixth order add to gradients
laserkelvin Sep 2, 2024
ae07fc6
refactor: making seventh order add to gradients
laserkelvin Sep 2, 2024
8ada103
refactor: making eighth order add gradients
laserkelvin Sep 2, 2024
126c536
refactor: ninth order adds grads
laserkelvin Sep 2, 2024
477dd46
refactor: tenth order adds grads
laserkelvin Sep 2, 2024
1c1b141
refactor: changing direct module namespace
laserkelvin Sep 2, 2024
6e9732e
refactor: using preallocated interface for production
laserkelvin Sep 2, 2024
1ff79fa
refactor: supporting triton_spherical_harmonics_function for backward…
laserkelvin Sep 2, 2024
d2d25fd
feat: added model property to count irrep shapes
laserkelvin Sep 3, 2024
bc1bd7e
feat: added embedding separation utility
laserkelvin Sep 3, 2024
a3292c5
refactor: adding splitting utility into __all__
laserkelvin Sep 3, 2024
d10b0b5
refactor: letting embedding split function take numpy arrays
laserkelvin Sep 3, 2024
1f7d932
configs: updating configs used in experiments
laserkelvin Oct 8, 2024
3de00b2
git: ignoring artifact download from wandb
laserkelvin Oct 8, 2024
f29b190
script: committing phate script used for experiments
laserkelvin Oct 8, 2024
647bf3f
docs: added docstrings for phate embedding script
laserkelvin Oct 8, 2024
481bdd9
style: ran sort on generate phate imports
laserkelvin Oct 8, 2024
377d706
docs: updated citation
laserkelvin Oct 8, 2024
09acc2e
docs: added note on new interface
laserkelvin Oct 8, 2024
db06712
notebook: cleaned up notebook for pedagogy
laserkelvin Oct 8, 2024
c010843
notebook: cleaned up notebook for pedagogy
laserkelvin Oct 8, 2024
cd8c083
docs: linking new kernels to notebook
laserkelvin Oct 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,46 @@ nodes and masking in the forward pass. The script `scripts/dynamic_shapes.py` wi
let you test the performance over a range of shapes; we encourage you to test it
before performing full-scale training/inference.

## Decoupled spherical harmonics kernels

We recently published a paper at the AI4Mat workshop at NeurIPS 2024, which as part
of that work, we went back into ``sympy`` to refactor the spherical harmonics up to $l=10$,
such that computations of a particular order are _independent_ from others. This allows
arbitrary orders to be freely composed without incurring a performance penalty, in
the case that one wishes to calculate $l=8$, but not $l=7$, for example.

Functionally, these kernels are intended to behave in the same way as their original
implementation, i.e. they still provide equivariant properties when used to map
cartesian point clouds. However, because of the aggressive refactoring and heavy use
of hard-coded literals, they may (or will) differ numerically from even the initial _EquiTriton_
kernels, particularly at higher orders.

> [!IMPORTANT]
> For the above reason, while the kernels can be drop-in replacements, we do not recommend
> using them from already trained models, at least without some testing on the user's part,
> as the results may differ. We have also not yet attempted to use these kernels as part of
> simulation-based workflows (i.e. molecular dynamics), however our training experiments do
> show that training indeed does converge.

To use the new set of decoupled kernels, the main `torch.autograd` binding is through
the `equitriton.sph_harm.direct.TritonSphericalHarmonic`:

```python
import torch
from equitriton.sph_harm.direct import TritonSphericalHarmonic

coords = torch.rand(100, 3)
sph_harm = TritonSphericalHarmonic.apply(
l_values=[0, 1, 2, 6, 10],
coords=coords
)
```

The improvements to performance are expected to come from (1) decoupling of each spherical
harmonic order, and (2) pre-allocation of an output tensor as to avoid using `torch.cat`,
which calculates each order followed by copying. See the "Direct spherical harmonics evaluation"
notebook in the notebooks folder for derivation.

### Development and usage on Intel XPU

Development on Intel XPUs such as the Data Center GPU Max Series 1550 requires
Expand Down Expand Up @@ -131,7 +171,9 @@ contributions will be licensed under this license.

Citation
--------
If you find this repo useful, please consider citing the corresponding paper:
If you find this repo useful, please consider citing the respective papers.

For the original EquiTriton implementation, please use/read the following citation:

```bibtex
@inproceedings{lee2024scaling,
Expand All @@ -141,4 +183,16 @@ If you find this repo useful, please consider citing the corresponding paper:
year={2024},
url={https://openreview.net/forum?id=ftK00FO5wq}
}
```
```

For the refactored spherical harmonics up to $l=10$, and subsequent PHATE embedding analysis, see:

```bibtex
@inproceedings{lee2024deconstructing,
title={Deconstructing equivariant representations in molecular systems},
author={Kin Long Kelvin Lee and Mikhail Galkin and Santiago Miret},
booktitle={AI for Accelerated Materials Design - NeurIPS 2024},
year={2024},
url={https://openreview.net/forum?id=pshyLoyzRn}
}
```
2 changes: 2 additions & 0 deletions notebooks/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
qm9_data/
lightning_logs/
Loading