Skip to content

Commit

Permalink
pass platform through to pre-minimization
Browse files Browse the repository at this point in the history
  • Loading branch information
IAlibay committed Apr 22, 2023
1 parent 9abe3c9 commit b73dfc6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
21 changes: 15 additions & 6 deletions openfe/protocols/openmm_rfe/_rfe_utils/multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self, *args, hybrid_factory=None, **kwargs):

def setup(self, reporter, lambda_protocol,
temperature=298.15 * unit.kelvin, n_replicas=None,
endstates=True, minimization_steps=100):
endstates=True, minimization_steps=100,
minimization_platform="CPU"):
"""
Setup MultistateSampler based on the input lambda protocol and number
of replicas.
Expand All @@ -58,7 +59,9 @@ class creation of LambdaProtocol.
Whether or not to generate unsampled endstates (i.e. dispersion
correction).
minimization_steps : int
Number of steps to minimize states.
Number of steps to pre-minimize states.
minimization_platform : str
Platform to do the initial pre-minimization with.
Attributes
----------
Expand Down Expand Up @@ -114,8 +117,10 @@ class creation of LambdaProtocol.

# now generating a sampler_state for each thermodyanmic state,
# with relaxed positions
# Note: remove once choderalab/openmmtools#672 is completed
minimize(compound_thermostate_copy, sampler_state,
max_iterations=minimization_steps)
max_iterations=minimization_steps,
platform_name=minimization_platform)
sampler_state_list.append(copy.deepcopy(sampler_state))

del compound_thermostate, sampler_state
Expand Down Expand Up @@ -257,8 +262,10 @@ def create_endstates(first_thermostate, last_thermostate):
return unsampled_endstates


def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: states.SamplerState,
max_iterations: int=100) -> states.SamplerState:
def minimize(thermodynamic_state: states.ThermodynamicState,
sampler_state: states.SamplerState,
max_iterations: int=100,
platform_name: str="CPU") -> states.SamplerState:
"""
Adapted from perses.dispersed.feptasks.minimize
Expand All @@ -273,6 +280,8 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat
The starting state at which to minimize the system.
max_iterations : int, optional, default 100
The maximum number of minimization steps. Default is 100.
platform_name : str
The OpenMM platform name to carry out the minimization with.
Returns
-------
Expand All @@ -281,7 +290,7 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat
"""
# we won't take any steps, so use a simple integrator
integrator = openmm.VerletIntegrator(1.0)
platform = openmm.Platform.getPlatformByName('CPU')
platform = openmm.Platform.getPlatformByName(platform_name)
dummy_cache = cache.DummyContextCache(platform=platform)
context, integrator = dummy_cache.get_context(
thermodynamic_state, integrator
Expand Down
1 change: 1 addition & 0 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def run(self, *, dry=False, verbose=True,
lambda_protocol=lambdas,
temperature=to_openmm(protocol_settings.thermo_settings.temperature),
endstates=alchem_settings.unsampled_endstates,
minimization_platform=platform.getName(),
)

try:
Expand Down

0 comments on commit b73dfc6

Please sign in to comment.