Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Free energy fitting #54

Draft
wants to merge 49 commits into
base: main
Choose a base branch
from
Draft

[WIP] Free energy fitting #54

wants to merge 49 commits into from

Conversation

maxentile
Copy link
Member

Translating numerical demonstrations from https://github.com/openforcefield/bayes-implicit-solvent#differentiable-atom-typing-experiments , upgrading to use message-passing rather than fingerprints + feedforward model.

Hiccup: porting the autodiff-friendly implementation of GBSA OBC energy from Jax to PyTorch wasn't as simple as replacing np. with torch. -- I need to track down a likely unit bug I introduced during the conversion, and pass an OpenMM consistency assertion, before merging.

maxentile and others added 22 commits October 22, 2020 15:03
… to 79 character lines

oops, my IDE had been on 120 character lines this whole time -- switching to 79 character lines so future black passes don't turn things into quadruply nested messes
…with a graph-net in the loop

it's fishy that the initial hydration free energy prediction is so poor

i suspect i may have made a unit mistake in my numpy/jax --> pytorch port
…another function that converts to espaloma unit system
…ssertions

Thanks to @yuanqing-wang for carefully stepping through this with me

Co-Authored-By: Yuanqing Wang <[email protected]>
the line to compute the reduced work was written as if "solv_energies" was "valence_energies + gbsa_energies" but of course it was just "gbsa_energies"...
in column `quick_xyz` -- will shortly replace this with a column `xyz` with more thorough parsley 1.2 vacuum sampling
f = torch.sqrt(r ** 2 + torch.ger(B, B) * torch.exp(
-r ** 2 / (4 * torch.ger(B, B))))
charge_products = torch.ger(charges, charges)
assert (f.shape == (N, N))
assert (charge_products.shape == (N, N))

ixns = - (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this missing a -138.935485 conversion from nm/(proton_charge**2) to kJ/mol? The docstring says "everything is in OpenMM native units".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you might pre-multiply the charges by sqrt(138.935485)? If so, you should probably document that in the docstring.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gahh -- you're right -- I had dropped this in the current conversion! Thank you for catching this. Charges are not assumed to be premultiplied by sqrt(138.935485) , will clarify docstring...

(This conversion was present but poorly labeled in the numpy/jax implementation in bayes-implicit-solvent.)

maxentile and others added 4 commits October 30, 2020 11:55
* confirmed that it can overfit to a small subset of FreeSolv! 🎉
* RMSE on whole of freesolv hasn't yet matched the quality of OBC2 🙏
…ze: validation set RMSE 1.8kcal.mol

* increased stepsize 1e-3 rather than 1e-4
* decreased layer and node dimensions from 128 to 32
@maxentile
Copy link
Member Author

Incorporating the missing unit conversion John identified, this now appears to be passing integration checks in the demo notebook. A graph-net is used to emit (n_atoms, 2) per-particle GBSA parameters given an input molecular graph. These per-particle parameters are then passed into a PyTorch GBSA implementation (along with cached vacuum samples) to produce one-sided EXP estimates of hydration free energy. A loss function is defined in terms of the estimated vs. experimental hydration free energies, differentiated w.r.t. the graph-net parameters, and optimized using Adam.

Can this procedure overfit a GBSA-parameter-emitting graph-net to a small random subsample of FreeSolv (N=10)? Yes:
image

Can this procedure fit a graph-net to a random half of FreeSolv (N=321) and generalize to the other half (N=321)? Tentatively yes:
image

A few more important refinements and unit tests are needed before this is ready for final review and merge, but this is now significantly less fishy than it was yesterday.

+ a few formatting and documentation enhancements
* define random seed once at top of file rather than before each step
* remove verbose flag
* use same learning rate, n_iterations, n_mols_per_batch, n_snapshots_per_mol for both trajectories
element    # of molecules containing it
C          639
H          629
O          344
N          169
Cl         114
S          40
F          35
Br         25
P          14
I          12
* add training / validation curves for early-stopping
* add bootstrapped rmses to final scatterplots
@maxentile
Copy link
Member Author

To address concern about elements that appear only a handful of times in FreeSolv, see this notebook counting the number of molecules in FreeSolv containing each element.

element    # of molecules containing it
C          639
H          629
O          344
N          169
Cl         114
S          40
F          35
Br         25
P          14
I          12

A related question is: if we filter the molecules to retain only certain subsets of elements, how many molecules do we retain?

Enumerating one sequence of element subsets (including elements in descending order of "popularity"):

elements                                 # molecules     coverage
{C, H}                                   103             16.0%
{C, H, O}                                300             46.7%
{C, H, N, O}                             431             67.1%
{C, Cl, H, N, O}                         529             82.4%
{C, Cl, H, N, O, S}                      559             87.1%
{C, Cl, F, H, N, O, S}                   591             92.1%
{Br, C, Cl, F, H, N, O, S}               616             96.0%
{Br, C, Cl, F, H, N, O, P, S}            630             98.1%
{Br, C, Cl, F, H, I, N, O, P, S}         642             100.0%

A not-so-challenging subset of FreeSolv -- that should be free of the infrequently-occurring-element concern -- is the collection of molecules containing only {C, H, O}. This demo notebook fits a GB-parameter-emitting graph-net on this set in about 40 CPU minutes.

Training and validation RMSE are reported every epoch for this "mini-Freesolv" subset:
image

The same plot, zoomed in on the y range 0.5-2.5 kcal/mol
image

In this run, the lowest validation-set RMSE happened to be encountered at the very last epoch, but that wouldn't be expected in general due to noise in gradient estimates (and especially if run longer).

Plotting predicted vs. reference scatter plots for training and validation sets at that last epoch (labeled with RMSE +/- 95% bootstrapped CI):
image

Similar plots could easily be generated for every other "mini-Freesolv" enumerated above. If there's an apparent difference between the more restricted vs. the more complete "mini-Freesolv"s, that might be suggestive of difficulty arising from sparsely sampled elements / chemical environments.

run 10ns of md per molecule (rather than the measly 0.01ns per molecule in 5866029 )
@jchodera
Copy link
Member

Looks great!

How about we run with this for the next bioRxiv update (and thesis) and revisit compound splitting on a larger set (maybe including N and Cl) in the next iteration (after thesis submission)?

@maxentile
Copy link
Member Author

Looks great!

Thanks!

How about we run with this for the next bioRxiv update (and thesis) and revisit compound splitting on a larger set (maybe including N and Cl) in the next iteration (after thesis submission)?

Sounds good -- compound splitting is subtle and not the primary focus of this demonstration.

Because it was convenient (change one line, wait 30 minutes), I re-ran the notebook on the {C, H, O, N, Cl} FreeSolv subset (n=529) to get a preview

Noting one observation for when we return to this:

In this run, the validation loss increased for ~10 epochs before decreasing again. Early-stopping requires the user to pre-specify a "patience" parameter (how many iterations without improvement to tolerate before stopping), and this example suggests it might be better to choose a "patience" >= 10 epochs. Will sync with @yuanqing-wang about how this patience parameter is currently selected.
image

using more thorough vacuum md, specified here 8e50eec
anecdotally, this appears to increase the training-set vs. validation-set error gap, suggesting that insufficient equilibrium sampling might have made the validation-set performance reported in #54 (comment) look more favorable than it should!
@maxentile
Copy link
Member Author

To hone in on the version of these results that will be reported in the biorxiv update (and thesis), I re-ran the notebook from #54 (comment) , on the updated equilibrium snapshots cached from more thorough vacuum MD.

These updated results should supersede the earlier results.

Anecdotally (based on one run with snapshots from short MD vs. one run with snapshots from thorough MD), this update appears to have increased the training-set vs. validation-set error gap, suggesting that insufficient equilibrium sampling might have made the validation-set performance reported in #54 (comment) look more favorable than it should.

image

@jchodera
Copy link
Member

Interesting finding! But I agree that behavior is much closer to what I would have expected from training vs validation error.
Let's run with this for the bioRxiv/thesis update!

@maxentile
Copy link
Member Author

Repeated this, but with 10x longer optimization trajectories, and with KFold 90%/10% splitting rather than a single 50%/50% split.

Reporting final training and validation set performance for each of the 10 splits.

image

image

@jchodera
Copy link
Member

jchodera commented Nov 2, 2020

Is this train/validate/test with early stopping, or is it just train/validate with 10% of the dataset split out and no early stopping (with cross-validation over the 10% held-out sets intended to be representative of the test set error)?

Are we concerned at all with the experimental strategies being vastly different between the different experiments in the paper for no particular reason?

@maxentile
Copy link
Member Author

Is this train/validate/test with early stopping

No early stopping.

is it just train/validate with 10% of the dataset split out and no early stopping (with cross-validation over the 10% held-out sets intended to be representative of the test set error)?

Correct. I'm not aiming to do any hyperparameter selection informed by this experiment, just aiming to report on the repeatability / variability of the training procedure if the dataset were slightly different, and to report an estimate of the generalization error of this specific procedure on the chemistry represented by this specific dataset.

In the previous plot, I showed just a single 50/50 split. Would that plot look different if the random seed were different? The way to measure that is to repeat multiple times with different random splits and report all results. The ideal would be to approach leave-one-out (run the procedure 300 times on each of the n=299-size subsets of the data). K-fold is a common compromise.

Are we concerned at all with the experimental strategies being vastly different between the different experiments in the paper for no particular reason?

John, I think the different approaches in progress partly reflect differing goals -- here I'm picking a single hyperparameter choice, and aiming to report on the variability / repeatability of the training procedure.

The valence-fitting experiments I think are still highly sensitive to various hyperparameter choices, and the goal of those ongoing experiments is still I think to select good hyperparameters.

Experiments constructed to simultaneously select hyperparameters and estimate the generalization error once hyperparameters are selected must take care to do nested cross-validation or use a held-out test set that is only ever consulted once.

@jchodera
Copy link
Member

jchodera commented Nov 2, 2020

Thanks for the clear explanations! Let's make sure the experimental section describes the motivation and conception of this design, both in the presentation of results and Detailed Methods! Those subtleties will be lost on the reader unless we make them explicit.

…alculations vs. experiment on the {C, H, O} subset
@maxentile
Copy link
Member Author

Noting here a few more to-do's (of undecided priority level), that I think would help shore up and contextualize these results:

  • Compare to the baseline of optimizing only the 6 continuous parameters in this model with a fixed (elemental) atom-typing scheme, to measure how much improvement is due to learned "chemical perception" vs. just learned parameters. So far, the results in this PR do not address this question, but only demonstrate feasibility of performing the optimization with learned "chemical perception" in the loop.
  • Compare to the baseline of using the graph model to predict the hydration free energy directly from the chemical graph (rather than predicting physical simulation parameters that in turn imply a hydration free energy), as @yuanqing-wang suggested to me on Monday. So far, the results in this PR do not address the question of whether incorporating a physical model in this task improves predictive performance relative to the fully "black-box" version of the approach.
  • Incorporate MBAR reweighting rather than forward-EXP reweighting (as in https://gist.github.com/maxentile/1568531f2f39b5a84e263a1ab8d963b5#file-sample_parameters_using_autograd_and_pymbar-py-L197-L260 , neutromeratio, and openff-evaluator) to reduce concern about reweighting estimator reliability. Further, store and report asymptotic uncertainty of the reweighting estimator to confirm its reliability. So far, the results in this PR assume (plausibly, based on prior experience) that the forward-EXP estimator is reliable for this specific task, but this assumption should either be avoided or its validity quantified. (Additionally, using MBAR in place of EXP should reduce the gradient estimator variance, which may have an impact on the behavior of the stochastic optimizer.)

Observations:
* n=10 overfitting seems to achieve a lower error than previously
* 50/50 train/validate seems to initialize and optimize at a higher error than in first version
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants