Skip to content

Commit

Permalink
Merge pull request #354 from dotsdl/repex-context-cleanup
Browse files Browse the repository at this point in the history
Cleanup contexts upon calculation completion, failure
  • Loading branch information
IAlibay authored Apr 22, 2023
2 parents 48a3214 + b73dfc6 commit 793b817
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 80 deletions.
53 changes: 32 additions & 21 deletions openfe/protocols/openmm_rfe/_rfe_utils/multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ def __init__(self, *args, hybrid_factory=None, **kwargs):
self._hybrid_factory = hybrid_factory
super(HybridCompatibilityMixin, self).__init__(*args, **kwargs)

def setup(self, reporter, platform, lambda_protocol,
def setup(self, reporter, lambda_protocol,
temperature=298.15 * unit.kelvin, n_replicas=None,
endstates=True, minimization_steps=100):
endstates=True, minimization_steps=100,
minimization_platform="CPU"):
"""
Setup MultistateSampler based on the input lambda protocol and number
of replicas.
Expand All @@ -45,8 +46,6 @@ def setup(self, reporter, platform, lambda_protocol,
----------
reporter : OpenMM reporter
Simulation reporter to attach to each simulation replica.
platform : openmm.Platform
Platform to perform simulation on.
lambda_protocol : LambdaProtocol
The lambda protocol to be used for simulation. Default to a default
class creation of LambdaProtocol.
Expand All @@ -60,7 +59,9 @@ class creation of LambdaProtocol.
Whether or not to generate unsampled endstates (i.e. dispersion
correction).
minimization_steps : int
Number of steps to minimize states.
Number of steps to pre-minimize states.
minimization_platform : str
Platform to do the initial pre-minimization with.
Attributes
----------
Expand All @@ -84,8 +85,6 @@ class creation of LambdaProtocol.
thermodynamic_state_list = []
sampler_state_list = []

context_cache = cache.ContextCache(platform)

if n_replicas is None:
msg = (f"setting number of replicas to number of states: {n_states}")
warnings.warn(msg)
Expand Down Expand Up @@ -118,12 +117,14 @@ class creation of LambdaProtocol.

# now generating a sampler_state for each thermodyanmic state,
# with relaxed positions
context, context_integrator = context_cache.get_context(
compound_thermostate_copy)
# Note: remove once choderalab/openmmtools#672 is completed
minimize(compound_thermostate_copy, sampler_state,
max_iterations=minimization_steps)
max_iterations=minimization_steps,
platform_name=minimization_platform)
sampler_state_list.append(copy.deepcopy(sampler_state))

del compound_thermostate, sampler_state

# making sure number of sampler states equals n_replicas
if len(sampler_state_list) != n_replicas:
# picking roughly evenly spaced sampler states
Expand Down Expand Up @@ -261,8 +262,10 @@ def create_endstates(first_thermostate, last_thermostate):
return unsampled_endstates


def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: states.SamplerState,
max_iterations: int=100) -> states.SamplerState:
def minimize(thermodynamic_state: states.ThermodynamicState,
sampler_state: states.SamplerState,
max_iterations: int=100,
platform_name: str="CPU") -> states.SamplerState:
"""
Adapted from perses.dispersed.feptasks.minimize
Expand All @@ -277,20 +280,28 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat
The starting state at which to minimize the system.
max_iterations : int, optional, default 100
The maximum number of minimization steps. Default is 100.
platform_name : str
The OpenMM platform name to carry out the minimization with.
Returns
-------
sampler_state : openmmtools.states.SamplerState
The posititions and accompanying state following minimization
"""
integrator = openmm.VerletIntegrator(1.0) #we won't take any steps, so use a simple integrator
context, integrator = cache.global_context_cache.get_context(
# we won't take any steps, so use a simple integrator
integrator = openmm.VerletIntegrator(1.0)
platform = openmm.Platform.getPlatformByName(platform_name)
dummy_cache = cache.DummyContextCache(platform=platform)
context, integrator = dummy_cache.get_context(
thermodynamic_state, integrator
)
sampler_state.apply_to_context(
context, ignore_velocities=True
)
openmm.LocalEnergyMinimizer.minimize(
context, maxIterations=max_iterations
)
sampler_state.update_from_context(context)
try:
sampler_state.apply_to_context(
context, ignore_velocities=True
)
openmm.LocalEnergyMinimizer.minimize(
context, maxIterations=max_iterations
)
sampler_state.update_from_context(context)
finally:
del context, integrator, dummy_cache
140 changes: 81 additions & 59 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,20 +323,20 @@ def run(self, *, dry=False, verbose=True,

# a. check timestep correctness + that
# equilibration & production are divisible by n_steps
prototol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings']
protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings']
stateA = self._inputs['stateA']
stateB = self._inputs['stateB']
mapping = self._inputs['ligandmapping']

forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = prototol_settings.forcefield_settings
thermo_settings: settings.ThermoSettings = prototol_settings.thermo_settings
alchem_settings: AlchemicalSettings = prototol_settings.alchemical_settings
system_settings: SystemSettings = prototol_settings.system_settings
solvation_settings: SolvationSettings = prototol_settings.solvation_settings
sampler_settings: AlchemicalSamplerSettings = prototol_settings.alchemical_sampler_settings
sim_settings: SimulationSettings = prototol_settings.simulation_settings
timestep = prototol_settings.integrator_settings.timestep
mc_steps = prototol_settings.integrator_settings.n_steps.m
forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings
thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings
alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings
system_settings: SystemSettings = protocol_settings.system_settings
solvation_settings: SolvationSettings = protocol_settings.solvation_settings
sampler_settings: AlchemicalSamplerSettings = protocol_settings.alchemical_sampler_settings
sim_settings: SimulationSettings = protocol_settings.simulation_settings
timestep = protocol_settings.integrator_settings.timestep
mc_steps = protocol_settings.integrator_settings.n_steps.m

# is the timestep good for the mass?
if forcefield_settings.hydrogen_mass < 3.0:
Expand Down Expand Up @@ -536,9 +536,9 @@ def run(self, *, dry=False, verbose=True,
if 'solvent' in stateA.components:
hybrid_factory.hybrid_system.addForce(
openmm.MonteCarloBarostat(
prototol_settings.thermo_settings.pressure.to(unit.bar).m,
prototol_settings.thermo_settings.temperature.m,
prototol_settings.integrator_settings.barostat_frequency.m,
protocol_settings.thermo_settings.pressure.to(unit.bar).m,
protocol_settings.thermo_settings.temperature.m,
protocol_settings.integrator_settings.barostat_frequency.m,
)
)

Expand Down Expand Up @@ -572,24 +572,14 @@ def run(self, *, dry=False, verbose=True,
checkpoint_storage=shared_basepath / sim_settings.checkpoint_storage,
)

# 10. Get platform and context caches
# 10. Get platform
platform = _rfe_utils.compute.get_openmm_platform(
prototol_settings.engine_settings.compute_platform
)

# a. Create context caches (energy + sampler)
# Note: these needs to exist on the compute node
energy_context_cache = openmmtools.cache.ContextCache(
capacity=None, time_to_live=None, platform=platform,
)

sampler_context_cache = openmmtools.cache.ContextCache(
capacity=None, time_to_live=None, platform=platform,
protocol_settings.engine_settings.compute_platform
)

# 11. Set the integrator
# a. get integrator settings
integrator_settings = prototol_settings.integrator_settings
integrator_settings = protocol_settings.integrator_settings

# b. create langevin integrator
integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove(
Expand Down Expand Up @@ -635,59 +625,91 @@ def run(self, *, dry=False, verbose=True,
sampler.setup(
n_replicas=sampler_settings.n_replicas,
reporter=reporter,
platform=platform,
lambda_protocol=lambdas,
temperature=to_openmm(prototol_settings.thermo_settings.temperature),
temperature=to_openmm(protocol_settings.thermo_settings.temperature),
endstates=alchem_settings.unsampled_endstates,
minimization_platform=platform.getName(),
)

sampler.energy_context_cache = energy_context_cache
sampler.sampler_context_cache = sampler_context_cache
try:
# Create context caches (energy + sampler)
energy_context_cache = openmmtools.cache.ContextCache(
capacity=None, time_to_live=None, platform=platform,
)

if not dry: # pragma: no-cover
# minimize
if verbose:
logger.info("minimizing systems")
sampler_context_cache = openmmtools.cache.ContextCache(
capacity=None, time_to_live=None, platform=platform,
)

sampler.minimize(max_iterations=sim_settings.minimization_steps)
sampler.energy_context_cache = energy_context_cache
sampler.sampler_context_cache = sampler_context_cache

# equilibrate
if verbose:
logger.info("equilibrating systems")
if not dry: # pragma: no-cover
# minimize
if verbose:
logger.info("minimizing systems")

sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore
sampler.minimize(max_iterations=sim_settings.minimization_steps)

# production
if verbose:
logger.info("running production phase")
# equilibrate
if verbose:
logger.info("equilibrating systems")

sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore
sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore

# calculate estimate of results from this individual unit
ana = multistate.MultiStateSamplerAnalyzer(reporter)
est, _ = ana.get_free_energy()
est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole)
est = ensure_quantity(est, 'openff')
# production
if verbose:
logger.info("running production phase")

# close reporter when you're done
sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore

# calculate estimate of results from this individual unit
ana = multistate.MultiStateSamplerAnalyzer(reporter)
est, _ = ana.get_free_energy()
est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole)
est = ensure_quantity(est, 'openff')

nc = shared_basepath / sim_settings.output_filename
chk = shared_basepath / sim_settings.checkpoint_storage
else:
# clean up the reporter file
fns = [shared_basepath / sim_settings.output_filename,
shared_basepath / sim_settings.checkpoint_storage]
for fn in fns:
os.remove(fn)
finally:
# close reporter when you're done, prevent file handle clashes
reporter.close()
del reporter

# clean up the analyzer
if not dry:
ana.clear()
del ana

# clear GPU contexts
# TODO: use cache.empty() calls when openmmtools #690 is resolved
# replace with above
for context in list(energy_context_cache._lru._data.keys()):
del energy_context_cache._lru._data[context]
for context in list(sampler_context_cache._lru._data.keys()):
del sampler_context_cache._lru._data[context]
# cautiously clear out the global context cache too
for context in list(
openmmtools.cache.global_context_cache._lru._data.keys()):
del openmmtools.cache.global_context_cache._lru._data[context]

del sampler_context_cache, energy_context_cache
if not dry:
del integrator, sampler

nc = shared_basepath / sim_settings.output_filename
chk = shared_basepath / sim_settings.checkpoint_storage
if not dry: # pragma: no-cover
return {
'nc': nc,
'last_checkpoint': chk,
'unit_estimate': est,
}
else:
# close reporter when you're done, prevent file handle clashes
reporter.close()

# clean up the reporter file
fns = [shared_basepath / sim_settings.output_filename,
shared_basepath / sim_settings.checkpoint_storage]
for fn in fns:
os.remove(fn)
return {'debug': {'sampler': sampler}}

def _execute(
Expand Down

0 comments on commit 793b817

Please sign in to comment.