diff --git a/scripts/benchmark.py b/scripts/benchmark.py index d3568f4..0193409 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -8,6 +8,7 @@ import torch import numpy as np import pandas as pd +import e3nn from e3nn.o3._spherical_harmonics import _spherical_harmonics from equitriton.sph_harm.bindings import * @@ -82,11 +83,14 @@ def e3nn_benchmark(tensor_shape: list[int], device: str | torch.device, l_max: i joint_tensor[..., 1].contiguous(), joint_tensor[..., 2].contiguous(), ) - output = _spherical_harmonics(l_max, x, y, z) + e3nn.set_optimization_defaults(jit_script_fx=False) + output = torch.compile(_spherical_harmonics, fullgraph=True, mode="max-autotune")(l_max, x, y, z) output.backward(gradient=torch.ones_like(output)) # delete references to ensure memory gets cleared del output del joint_tensor + e3nn.set_optimization_defaults(jit_script_fx=True) # Turn it back on to avoid any issues + @benchmark(num_steps=args.num_steps, warmup_fraction=args.warmup_fraction) @@ -131,4 +135,4 @@ def triton_benchmark(tensor_shape: list[int], device: str | torch.device, l_max: all_data.append(joint_results) df = pd.DataFrame(all_data) -df.to_csv(f"{args.device}_lmax{args.l_max}_results.csv", index=False) +df.to_csv(f"{args.device}_lmax{args.l_max}_results.csv", index=False) \ No newline at end of file diff --git a/scripts/measure_numerical_error.py b/scripts/measure_numerical_error.py index 91d8929..03f69f9 100644 --- a/scripts/measure_numerical_error.py +++ b/scripts/measure_numerical_error.py @@ -6,6 +6,7 @@ import torch import numpy as np +import e3nn from e3nn.o3._spherical_harmonics import _spherical_harmonics from equitriton.sph_harm.bindings import * @@ -73,7 +74,8 @@ def compare_e3nn_triton( joint_tensor[..., 1].contiguous(), joint_tensor[..., 2].contiguous(), ) - e3nn_output = _spherical_harmonics(l_max, x, y, z) + e3nn.set_optimization_defaults(jit_script_fx=False) + e3nn_output = torch.compile(_spherical_harmonics, fullgraph=True, mode="max-autotune")(l_max, x, y, z) e3nn_output.backward(gradient=torch.ones_like(e3nn_output)) e3nn_grad = joint_tensor.grad.detach().clone() joint_tensor.grad = None @@ -95,6 +97,7 @@ def compare_e3nn_triton( # delete intermediate tensors to make sure we don't leak del e3nn_output del triton_output + e3nn.set_optimization_defaults(jit_script_fx=True) # Turn it back on to avoid any issues return (signed_fwd_error, signed_bwd_error) diff --git a/scripts/profile_script.py b/scripts/profile_script.py index 9404177..1cdd0ae 100644 --- a/scripts/profile_script.py +++ b/scripts/profile_script.py @@ -7,6 +7,7 @@ import torch from torch.profiler import record_function +import e3nn from e3nn.o3._spherical_harmonics import _spherical_harmonics from equitriton.sph_harm.bindings import * @@ -74,13 +75,15 @@ def e3nn_benchmark(tensor_shape: list[int], device: str | torch.device, l_max: i joint_tensor[..., 1].contiguous(), joint_tensor[..., 2].contiguous(), ) + e3nn.set_optimization_defaults(jit_script_fx=False) with record_function("forward"): - output = _spherical_harmonics(l_max, x, y, z) + output = torch.compile(_spherical_harmonics, fullgraph=True, mode="max-autotune")(l_max, x, y, z) with record_function("backward"): output.backward(gradient=torch.ones_like(output)) # delete references to ensure memory gets cleared del output del joint_tensor + e3nn.set_optimization_defaults(jit_script_fx=True) # Turn it back on to avoid any issues @profile(