Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hybrid AutoODE #189

Open
wants to merge 81 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
2bda02c
Add ODE solver
klane Jun 30, 2021
945321a
Added Lorenz63 data
amartyamukherjee Aug 18, 2021
b3cc643
Started working on Lorenz63 test
amartyamukherjee Aug 18, 2021
b7c8a6d
Set conda environment
amartyamukherjee Aug 18, 2021
b3a633c
Predicted Lorenz 63 data + Added Runge-Kutta 4
amartyamukherjee Aug 19, 2021
8ec55d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2021
0fcf33e
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
amartyamukherjee Aug 19, 2021
9580a7c
Cleaned up code
amartyamukherjee Aug 19, 2021
e48d98e
Added Rossler Attractor
amartyamukherjee Aug 20, 2021
3bdbd48
Added second order example
amartyamukherjee Aug 20, 2021
8ba7cba
Trained with Lorenz 63 data
amartyamukherjee Aug 20, 2021
8031085
Fixed bugs in Lorenz 63
amartyamukherjee Aug 20, 2021
0bbd5d0
Modified TensorDataset
amartyamukherjee Aug 20, 2021
934bdcc
Update Lorenz63Train.ipynb
amartyamukherjee Aug 20, 2021
3bbfc6b
Added a test for the Duffing Equation
amartyamukherjee Aug 20, 2021
d8e2e04
Improved the prediction of Lorenz63: random sample
amartyamukherjee Aug 20, 2021
3993762
Tested SEIR and Duffing with random samlpe
amartyamukherjee Aug 20, 2021
ab7c82f
Update SEIRTest_fitRandomSample.ipynb
amartyamukherjee Aug 20, 2021
ead6c7a
Added comments to ode.py
amartyamukherjee Aug 20, 2021
5e21e39
Update Lorenz63Train_fit.ipynb
amartyamukherjee Aug 20, 2021
adb65b3
Added README.md to examples/ode
amartyamukherjee Aug 20, 2021
b53ca5a
Increased prediction range to 10000
amartyamukherjee Aug 20, 2021
f428a79
Update README.md
amartyamukherjee Aug 20, 2021
3d89a9f
Update README.md
amartyamukherjee Aug 20, 2021
dd07a7f
Update README.md
amartyamukherjee Aug 20, 2021
b2a836a
Update README.md
amartyamukherjee Aug 20, 2021
dff1fd9
Added savemat
amartyamukherjee Aug 21, 2021
2343e46
Added .mat results
amartyamukherjee Aug 21, 2021
465349b
Added results and predictions as .mat files
amartyamukherjee Aug 21, 2021
388013e
Used RK4 to train the Lorenz63 model
amartyamukherjee Aug 23, 2021
69127b0
Used RK4 to train the SEIR model
amartyamukherjee Aug 23, 2021
4df4fbf
Duffing equation - Attempt to fix "w"
amartyamukherjee Aug 23, 2021
28105a8
Fixed error with "w" in Duffing Equation
amartyamukherjee Aug 23, 2021
a5757c4
Added tests for cosine
amartyamukherjee Aug 24, 2021
d4a88c4
FIxed comment
amartyamukherjee Aug 24, 2021
2c07f70
Trying more predictions for cos(wt)
amartyamukherjee Aug 24, 2021
1661cc4
Trained with Euler's method data
amartyamukherjee Aug 24, 2021
0108d71
Update Lorenz63_Euler_Train_fitRandomSample_RK4.ipynb
amartyamukherjee Aug 24, 2021
db86c56
Created datasets using sin
amartyamukherjee Aug 24, 2021
3b4b84a
Update a_Cos_wt_sin_dataset_fitRandomSample.ipynb
amartyamukherjee Aug 24, 2021
c42b1c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
7453737
Fixed linting issues
amartyamukherjee Aug 28, 2021
6be5f79
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
6407f54
Merge branch 'main' into ode-example
amartyamukherjee Aug 28, 2021
114c855
Fixed linting issues
amartyamukherjee Aug 28, 2021
dc37944
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
739cc25
Update ode.py
amartyamukherjee Aug 28, 2021
c8d82a9
Used a DNN to train the Duffing Equation
amartyamukherjee Sep 9, 2021
3035b4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
057d5d8
Merge branch 'main' into ode-example
amartyamukherjee Sep 9, 2021
8f1da0f
Merge branch 'main' into ode-example
amartyamukherjee Sep 26, 2021
fb547c2
Merge branch 'main' into ode-example
amartyamukherjee Oct 15, 2021
a0f3276
Used MSELoss to compare the predictions
amartyamukherjee Oct 22, 2021
a306193
Merge branch 'main' into ode-example
amartyamukherjee Dec 17, 2021
8371a8a
Add pre-commit fixes
pre-commit-ci[bot] Dec 17, 2021
01471ce
Modified ode.py to use pytorch_lightning.Trainer
amartyamukherjee Dec 18, 2021
6ab31c0
Add pre-commit fixes
pre-commit-ci[bot] Dec 18, 2021
5a082a5
Fixed pre-commit.ci issues
amartyamukherjee Dec 18, 2021
2e747eb
Fixed size error in ODESolver
amartyamukherjee Dec 18, 2021
00f2a5d
Add pre-commit fixes
pre-commit-ci[bot] Dec 18, 2021
56a8b07
Merge branch 'main' into ode-example
amartyamukherjee Dec 19, 2021
77f7a19
Included previous Duffing Equation builds
amartyamukherjee Dec 19, 2021
54397d1
Add pre-commit fixes
pre-commit-ci[bot] Dec 19, 2021
929c39b
Fixed linting issues
amartyamukherjee Dec 19, 2021
b79cc17
Add pre-commit fixes
pre-commit-ci[bot] Dec 19, 2021
4f13764
Fixed linting issues
amartyamukherjee Dec 19, 2021
3fd8444
Add pre-commit fixes
pre-commit-ci[bot] Dec 19, 2021
cccfa5a
Fixed linting issues
amartyamukherjee Dec 19, 2021
1d1cdf8
Fixed linting issues
amartyamukherjee Dec 19, 2021
581efa0
Changed ODESolver back to klane's model
amartyamukherjee Dec 20, 2021
499663e
Add pre-commit fixes
pre-commit-ci[bot] Dec 20, 2021
a3d95e9
Removed Lorenz63 and SEIR tests
amartyamukherjee Dec 20, 2021
65d0d80
Merge branch 'hybrid-ode-net' of https://github.com/amartyamukherjee/…
amartyamukherjee Dec 20, 2021
c6864fc
Changed hybridODENet to reflect original ODESolver
amartyamukherjee Dec 20, 2021
7f402ff
Create DuffingTest_HybridODE.ipynb
amartyamukherjee Dec 20, 2021
ca206fc
Attempt first example
amartyamukherjee Dec 20, 2021
5fdbacf
Add pre-commit fixes
pre-commit-ci[bot] Dec 20, 2021
bf14537
Update ode.py
amartyamukherjee May 11, 2022
52f1125
Merge branch 'main' into hybrid-ode-net
amartyamukherjee May 11, 2022
eb411c7
Update ode.py
amartyamukherjee May 11, 2022
cf9f9a4
Add pre-commit fixes
pre-commit-ci[bot] May 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ lightning_logs/
# pytest coverage files
.coverage*

# Sphinx
docs/build/
.vscode/
examples/ode/lorenz63/.ipynb_checkpoints/

# build artifacts
build/
dist/
Expand Down
732 changes: 732 additions & 0 deletions examples/ode/DuffingEquation/DuffingTest_DNN.ipynb

Large diffs are not rendered by default.

398 changes: 398 additions & 0 deletions examples/ode/DuffingEquation/DuffingTest_HybridODE.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions examples/ode/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
In this folder, we train the ODESolver model on three different ODEs. Our goal is to estimate all the parameters here.

Lorenz63:

![alt text](https://wikimedia.org/api/rest_v1/media/math/render/svg/7928004d58943529a7be774575a62ca436a82a7f)

Parameters to estimate: $\sigma, \rho, \beta$

<!-- Duffing equation:

![alt text](https://wikimedia.org/api/rest_v1/media/math/render/svg/4881d84893e137772068573bb1218fc1e2b295cd)

Parameters to estimate: $\alpha, \beta, \gamma, \delta, \omega$ -->

SEIR:

![alt text](https://miro.medium.com/max/1056/1*dXCHv_pSYiMG90efXiFNPQ.png)

Parameters to estimate: $\alpha, \beta, \gamma$
58 changes: 58 additions & 0 deletions torchts/nn/models/hybridode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
from torch import nn

from torchts.nn.models.ode import ODESolver


class HybridODENet(ODESolver):
def __init__(
self,
ode,
dnns,
init_vars,
init_coeffs,
dt,
solver="euler",
outvar=None,
**kwargs,
):
super().__init__(ode, init_vars, init_coeffs, dt, solver, outvar, **kwargs)

if ode.keys() != init_vars.keys():
raise ValueError("Inconsistent keys in ode and init_vars")

if solver == "euler":
self.solver = self.euler
else:
raise ValueError(f"Unrecognized solver {solver}")

for name, value in init_coeffs.items():
self.register_parameter(name, nn.Parameter(torch.tensor(value)))

self.ode = ode
self.dnns = dnns
self.var_names = ode.keys()
self.init_vars = {
name: torch.tensor(value, device=self.device)
for name, value in init_vars.items()
}
self.coeffs = {name: param for name, param in self.named_parameters()}
self.outvar = self.var_names if outvar is None else outvar
self.dt = dt

def euler(self, nt):
pred = {name: value.unsqueeze(0) for name, value in self.init_vars.items()}

for n in range(nt - 1):
# create dictionary containing values from previous time step
prev_val = {var: pred[var][[n]] for var in self.var_names}

for var in self.var_names:
new_val = (
prev_val[var]
+ self.ode[var](prev_val, self.coeffs, self.dnns) * self.dt
)
pred[var] = torch.cat([pred[var], new_val])

# reformat output to contain desired (observed) variables
return torch.stack([pred[var] for var in self.outvar], dim=1)
10 changes: 10 additions & 0 deletions torchts/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,13 @@ def sliding_window(tensor, lags, horizon=1, dim=0, step=1):
x, y = data[:, [lag - 1 for lag in lags]], data[:, -1]

return x, y


def generate_ode_dataset(data, num_steps):
n = data.shape[0]
y_2d = data[:num_steps, :]
y = y_2d.view(1, y_2d.shape[0], y_2d.shape[1])
for i in range(1, n - num_steps):
y_2d = data[i : i + num_steps, :]
y = torch.cat((y, y_2d.view(1, y_2d.shape[0], y_2d.shape[1])), dim=0)
return data[: n - num_steps, :], y