Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Free energy fitting #54

Draft
wants to merge 49 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
79d125f
Create run on openff-1.2 training dataset.ipynb
maxentile Oct 22, 2020
e3454ac
use offmol_indices, handle case of no propers / impropers, and adhere…
maxentile Oct 22, 2020
d3f236d
allow to .sum(dim=1) even when there are no impropers
maxentile Oct 22, 2020
10268d5
allow impropers to have length 0 in valencemodel
maxentile Oct 24, 2020
9c350ee
add gbsa port
maxentile Oct 25, 2020
cf14a20
oof, avoid pytorch in-place modification of input argument tensors
maxentile Oct 25, 2020
0c20732
gahh, careless variable-name typo
maxentile Oct 25, 2020
32f7d2d
wip demo notebook for fitting to a hydration free energy calculation …
maxentile Oct 25, 2020
dfce74c
port @proteneer's GBSA implementation instead
maxentile Oct 25, 2020
6a02086
repeat fitting-to-free-energies notebook with less-likely-to-be-buggy…
maxentile Oct 25, 2020
b5fa1e9
add reference implementation from bayes-implicit-solvent
maxentile Oct 29, 2020
b77794c
refactor gbsa_obc2_energy into a function in openmm unit system, and …
maxentile Oct 29, 2020
eb73857
increase descriptiveness in gbsa implementation, add thorough shape a…
maxentile Oct 29, 2020
745c6ac
import FreeSolv database v0.52
maxentile Oct 29, 2020
a88fe91
remove tensor-shape-printing statements
maxentile Oct 29, 2020
3773dd5
create pandas dataframe with serialized openmm systems for freesolv s…
maxentile Oct 29, 2020
22d314f
ooooof. fix silly mistake in fitting-to-free-energies notebook
maxentile Oct 29, 2020
dfdeac0
also handle cases like methane where len(propers) == 0
maxentile Oct 29, 2020
5866029
save also xyz coordinates from brief md
maxentile Oct 29, 2020
47931bd
remove **kwargs to try to play nice with torchscript jit
maxentile Oct 29, 2020
d812bdd
remove temporary assert statements in _gbsa_obc2_energy_omm
maxentile Oct 29, 2020
fc2fbe7
must be a remaining sign-flip error -- seems like it's unable to make…
maxentile Oct 29, 2020
b564fe6
add missing conversion from nm/(proton_charge**2) to kJ/mol
maxentile Oct 30, 2020
08a2bcf
update gbsa docstring
maxentile Oct 30, 2020
d5ae749
update freesolv-fitting notebook
maxentile Oct 30, 2020
bacb705
re-run demo notebook with increased stepsize and decreased network si…
maxentile Oct 30, 2020
6d4c831
ipynb --> py
maxentile Oct 30, 2020
9ffa304
refine fit_freesolv.py script
maxentile Oct 30, 2020
7115efb
add pdf figures from fit_freesolv
maxentile Oct 30, 2020
c63bc05
notebook reporting on element coverage in freesolv
maxentile Oct 31, 2020
6db4016
oops, forgot nitrogen!
maxentile Oct 31, 2020
9857f52
notebook fitting to {C, H, O} mini-freesolv
maxentile Oct 31, 2020
8e50eec
oodles o' vacuum samples
maxentile Oct 31, 2020
1f1f520
set openmm_cpu_threads to 1
maxentile Oct 31, 2020
38bbf03
Create fit to {C, H, O, N, Cl} subset of freesolv (n=529).ipynb
maxentile Oct 31, 2020
7de7b6c
oops fix plot labels
maxentile Oct 31, 2020
2643694
merge freesolv vacuum sample records
maxentile Oct 31, 2020
d24b5bf
git lfs track freesolv_vacuum_samples.npz (279MB)
maxentile Oct 31, 2020
e6d31d9
add xyz column to freesolv_with_samples.h5
maxentile Oct 31, 2020
98f424f
update {C, H, O} subset experiment to use thorough equilibrium sampling
maxentile Oct 31, 2020
a38812d
add experiment script for k-fold cv
maxentile Nov 2, 2020
ac80468
oops, don't indent all the relevant stuff out of the training loop!
maxentile Nov 2, 2020
7c76087
git lfs track each of the K=10-fold CV trajectories
maxentile Nov 2, 2020
f56999c
add notebook to plot k-fold cv results
maxentile Nov 2, 2020
e01ab88
add a horizontal line depicting RMSE of FreeSolv's explicit-solvent c…
maxentile Nov 2, 2020
0a06d43
Add todos
maxentile Sep 3, 2021
3c7eee4
Add SAGEConv
maxentile Sep 3, 2021
fad31a5
Address GraphSAGE todo
maxentile Sep 3, 2021
ccea6ba
Update PDF figures
maxentile Sep 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
645 changes: 645 additions & 0 deletions data/freesolv/database.txt

Large diffs are not rendered by default.

119 changes: 119 additions & 0 deletions espaloma/mm/implicit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch

torch.set_default_dtype(torch.float64)

from espaloma.units import DISTANCE_UNIT, ENERGY_UNIT

from simtk import unit

distance_to_nm = (1.0 * DISTANCE_UNIT).value_in_unit(unit.nanometer)
energy_from_kjmol = (1.0 * unit.kilojoule_per_mole).value_in_unit(ENERGY_UNIT)


def step(x):
"""return (x > 0)"""
return 1.0 * (x >= 0)


def _gbsa_obc2_energy_omm(
distance_matrix,
radii, scales, charges,
alpha=0.8, beta=0.0, gamma=2.909125,
dielectric_offset=0.009,
surface_tension=28.3919551,
solute_dielectric=1.0,
solvent_dielectric=78.5,
probe_radius=0.14
):
"""
Assume everything is given in OpenMM units
ported from jax/numpy implementation here:
https://github.com/openforcefield/bayes-implicit-solvent/blob/067239fcbb8af28eb6310d702804887662692ec2/bayes_implicit_solvent/gb_models/jax_gb_models.py#L13-L60

with corrections and refinements by Yutong Zhao here
https://github.com/proteneer/timemachine/blob/417f4b0b1181b638935518532c78c380b03d7d19/timemachine/potentials/gbsa.py#L1-L111
"""

N = len(charges)
eye = torch.eye(N, dtype=distance_matrix.dtype)

r = distance_matrix + eye
or1 = radii.reshape((N, 1)) - dielectric_offset
or2 = radii.reshape((1, N)) - dielectric_offset
sr2 = scales.reshape((1, N)) * or2

L = torch.max(or1, abs(r - sr2))
U = r + sr2

I = 1 / L - 1 / U + 0.25 * (r - sr2 ** 2 / r) * (
1 / (U ** 2) - 1 / (L ** 2)) + 0.5 * torch.log(
L / U) / r
# handle the interior case
I = torch.where(or1 < (sr2 - r), I + 2 * (1 / or1 - 1 / L), I)
I = step(r + sr2 - or1) * 0.5 * I # note the extra 0.5 here
I -= torch.diag(torch.diag(I))
I = torch.sum(I, dim=1)

# okay, next compute born radii
offset_radius = radii - dielectric_offset

psi = I * offset_radius

psi_coefficient = alpha
psi2_coefficient = beta
psi3_coefficient = gamma

psi_term = (psi_coefficient * psi) - (psi2_coefficient * psi ** 2) + (
psi3_coefficient * psi ** 3)

B = 1 / (1 / offset_radius - torch.tanh(psi_term) / radii)

E = 0.0
# single particle
# ACE
E += torch.sum(
surface_tension * (radii + probe_radius) ** 2 * (radii / B) ** 6)

# on-diagonal
E += torch.sum(-0.5 * (
1 / solute_dielectric - 1 / solvent_dielectric) * charges ** 2 / B)

# particle pair
# note: np.outer --> torch.ger
f = torch.sqrt(r ** 2 + torch.ger(B, B) * torch.exp(
-r ** 2 / (4 * torch.ger(B, B))))
charge_products = torch.ger(charges, charges)

ixns = - (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this missing a -138.935485 conversion from nm/(proton_charge**2) to kJ/mol? The docstring says "everything is in OpenMM native units".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you might pre-multiply the charges by sqrt(138.935485)? If so, you should probably document that in the docstring.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gahh -- you're right -- I had dropped this in the current conversion! Thank you for catching this. Charges are not assumed to be premultiplied by sqrt(138.935485) , will clarify docstring...

(This conversion was present but poorly labeled in the numpy/jax implementation in bayes-implicit-solvent.)

1 / solute_dielectric - 1 / solvent_dielectric) * charge_products / f

E += torch.sum(torch.triu(ixns, diagonal=1))
return E # E is in kJ/mol at this point


def gbsa_obc2_energy(
distance_matrix_in_bohr,
radii_in_bohr, scales, charges,
alpha=0.8, beta=0.0, gamma=2.909125,
dielectric_offset=0.009,
surface_tension=28.3919551,
solute_dielectric=1.0,
solvent_dielectric=78.5,
probe_radius=0.14
):
# convert distances and radii into units of nanometers before proceeding
distance_matrix = distance_matrix_in_bohr * distance_to_nm
radii = radii_in_bohr * distance_to_nm

E = _gbsa_obc2_energy_omm(
distance_matrix,
radii, scales, charges,
alpha, beta, gamma,
dielectric_offset=dielectric_offset,
surface_tension=surface_tension,
solute_dielectric=solute_dielectric,
solvent_dielectric=solvent_dielectric,
probe_radius=probe_radius,
)

return E * energy_from_kjmol # return E in espaloma energy unit
69 changes: 69 additions & 0 deletions espaloma/mm/tests/test_implicit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@


import jax.numpy as np


def step(x):
# return (x > 0)
return 1.0 * (x >= 0)


def compute_OBC_energy_bayes_implicit(distance_matrix, radii, scales, charges,
offset=0.009, screening=138.935484,
surface_tension=28.3919551,
solvent_dielectric=78.5,
solute_dielectric=1.0,
):
"""From https://github.com/openforcefield/bayes-implicit-solvent/blob/46936da65ed93ed33f0f97362a1dea12f9170758/bayes_implicit_solvent/gb_models/jax_gb_models.py

in turn based on https://github.com/openmm/openmm/blob/master/platforms/reference/src/SimTKReference/ReferenceObc.cpp
"""
N = len(radii)
# print(type(distance_matrix))
eye = np.eye(N, dtype=distance_matrix.dtype)
# print(type(eye))
r = distance_matrix + eye # so I don't have divide-by-zero nonsense
or1 = radii.reshape((N, 1)) - offset
or2 = radii.reshape((1, N)) - offset
sr2 = scales.reshape((1, N)) * or2

L = np.maximum(or1, abs(r - sr2))
U = r + sr2
I = step(r + sr2 - or1) * 0.5 * (
1 / L - 1 / U + 0.25 * (r - sr2 ** 2 / r) * (
1 / (U ** 2) - 1 / (L ** 2)) + 0.5 * np.log(
L / U) / r)

I -= np.diag(np.diag(I))
I = np.sum(I, axis=1)

# okay, next compute born radii
offset_radius = radii - offset
psi = I * offset_radius
psi_coefficient = 0.8
psi2_coefficient = 0
psi3_coefficient = 2.909125

psi_term = (psi_coefficient * psi) + (psi2_coefficient * psi ** 2) + (
psi3_coefficient * psi ** 3)

B = 1 / (1 / offset_radius - np.tanh(psi_term) / radii)

# finally, compute the three energy terms
E = 0.0

# single particle
E += np.sum(surface_tension * (radii + 0.14) ** 2 * (radii / B) ** 6)
E += np.sum(-0.5 * screening * (
1 / solute_dielectric - 1 / solvent_dielectric) * charges ** 2 / B)

# particle pair
f = np.sqrt(
r ** 2 + np.outer(B, B) * np.exp(-r ** 2 / (4 * np.outer(B, B))))
charge_products = np.outer(charges, charges)

E += np.sum(np.triu(-screening * (
1 / solute_dielectric - 1 / solvent_dielectric) * charge_products / f,
k=1))

return E
58 changes: 36 additions & 22 deletions espaloma/redux/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,37 @@


def compute_bonds(
xyz: torch.Tensor, params: ParameterizedSystem, indices: Indices
) -> torch.Tensor:
xyz: torch.Tensor,
params: ParameterizedSystem,
indices: Indices) -> torch.Tensor:
a, b = xyz[:, indices.bonds[:, 0]], xyz[:, indices.bonds[:, 1]]
distance = esp.mm.geometry.distance(a, b)
k, eq = params.bonds[:, 0], params.bonds[:, 1]
return esp.mm.bond.harmonic_bond(distance, k, eq)


def compute_angles(
xyz: torch.Tensor, params: ParameterizedSystem, indices: Indices
) -> torch.Tensor:
a, b, c = (
xyz[:, indices.angles[:, 0]],
xyz[:, indices.angles[:, 1]],
xyz[:, indices.angles[:, 2]],
)
xyz: torch.Tensor,
params: ParameterizedSystem,
indices: Indices) -> torch.Tensor:
# TODO; be less verbose about this
a = xyz[:, indices.angles[:, 0]]
b = xyz[:, indices.angles[:, 1]]
c = xyz[:, indices.angles[:, 2]]

angles = esp.mm.geometry.angle(a, b, c)
k, eq = params.angles[:, 0], params.angles[:, 1]
return esp.mm.angle.harmonic_angle(angles, k, eq)


def compute_propers(
xyz: torch.Tensor, params: ParameterizedSystem, indices: Indices
) -> torch.Tensor:
xyz: torch.Tensor,
params: ParameterizedSystem,
indices: Indices) -> torch.Tensor:
# it's possible there are no proper torsions in the system (e.g. h2o)
if len(indices.propers) == 0:
return torch.tensor([[0.0]])

# TODO: reduce code duplication with compute_impropers
a, b = xyz[:, indices.propers[:, 0]], xyz[:, indices.propers[:, 1]]
c, d = xyz[:, indices.propers[:, 2]], xyz[:, indices.propers[:, 3]]
Expand All @@ -39,8 +46,13 @@ def compute_propers(


def compute_impropers(
xyz: torch.Tensor, params: ParameterizedSystem, indices: Indices
) -> torch.Tensor:
xyz: torch.Tensor,
params: ParameterizedSystem,
indices: Indices) -> torch.Tensor:
# it's possible there are no iproper torsions in the system (e.g. nh4)
if len(indices.impropers) == 0:
return torch.tensor([[0.0]])

# TODO: reduce code duplication with compute_propers
a, b = xyz[:, indices.impropers[:, 0]], xyz[:, indices.impropers[:, 1]]
c, d = xyz[:, indices.impropers[:, 2]], xyz[:, indices.impropers[:, 3]]
Expand All @@ -50,13 +62,15 @@ def compute_impropers(


def compute_valence_energy(
offmol: Molecule, xyz: torch.Tensor, params: ParameterizedSystem
) -> torch.Tensor:
offmol: Molecule,
xyz: torch.Tensor,
params: ParameterizedSystem) -> torch.Tensor:
indices = Indices(offmol)
harmonic_terms = compute_bonds(xyz, params, indices).sum(
1
) + compute_angles(xyz, params, indices).sum(1)
torsion_terms = compute_propers(xyz, params, indices).sum(
1
) + compute_impropers(xyz, params, indices).sum(1)
return harmonic_terms + torsion_terms

bonds = compute_bonds(xyz, params, indices).sum(1)
angles = compute_angles(xyz, params, indices).sum(1)
propers = compute_propers(xyz, params, indices).sum(1)
impropers = compute_impropers(xyz, params, indices).sum(1)
valence_energy = bonds + angles + propers + impropers

return valence_energy
Loading