From b73dfc673ce0d3b945cda91236fdaa86992ebfe9 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 22 Apr 2023 08:10:13 +0100 Subject: [PATCH] pass platform through to pre-minimization --- .../openmm_rfe/_rfe_utils/multistate.py | 21 +++++++++++++------ .../protocols/openmm_rfe/equil_rfe_methods.py | 1 + 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index cb546a8d9..a269b1b5b 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -36,7 +36,8 @@ def __init__(self, *args, hybrid_factory=None, **kwargs): def setup(self, reporter, lambda_protocol, temperature=298.15 * unit.kelvin, n_replicas=None, - endstates=True, minimization_steps=100): + endstates=True, minimization_steps=100, + minimization_platform="CPU"): """ Setup MultistateSampler based on the input lambda protocol and number of replicas. @@ -58,7 +59,9 @@ class creation of LambdaProtocol. Whether or not to generate unsampled endstates (i.e. dispersion correction). minimization_steps : int - Number of steps to minimize states. + Number of steps to pre-minimize states. + minimization_platform : str + Platform to do the initial pre-minimization with. Attributes ---------- @@ -114,8 +117,10 @@ class creation of LambdaProtocol. # now generating a sampler_state for each thermodyanmic state, # with relaxed positions + # Note: remove once choderalab/openmmtools#672 is completed minimize(compound_thermostate_copy, sampler_state, - max_iterations=minimization_steps) + max_iterations=minimization_steps, + platform_name=minimization_platform) sampler_state_list.append(copy.deepcopy(sampler_state)) del compound_thermostate, sampler_state @@ -257,8 +262,10 @@ def create_endstates(first_thermostate, last_thermostate): return unsampled_endstates -def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: states.SamplerState, - max_iterations: int=100) -> states.SamplerState: +def minimize(thermodynamic_state: states.ThermodynamicState, + sampler_state: states.SamplerState, + max_iterations: int=100, + platform_name: str="CPU") -> states.SamplerState: """ Adapted from perses.dispersed.feptasks.minimize @@ -273,6 +280,8 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat The starting state at which to minimize the system. max_iterations : int, optional, default 100 The maximum number of minimization steps. Default is 100. + platform_name : str + The OpenMM platform name to carry out the minimization with. Returns ------- @@ -281,7 +290,7 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat """ # we won't take any steps, so use a simple integrator integrator = openmm.VerletIntegrator(1.0) - platform = openmm.Platform.getPlatformByName('CPU') + platform = openmm.Platform.getPlatformByName(platform_name) dummy_cache = cache.DummyContextCache(platform=platform) context, integrator = dummy_cache.get_context( thermodynamic_state, integrator diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 148ed422c..3ebf99251 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -628,6 +628,7 @@ def run(self, *, dry=False, verbose=True, lambda_protocol=lambdas, temperature=to_openmm(protocol_settings.thermo_settings.temperature), endstates=alchem_settings.unsampled_endstates, + minimization_platform=platform.getName(), ) try: