diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index efcf85dad..a269b1b5b 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -34,9 +34,10 @@ def __init__(self, *args, hybrid_factory=None, **kwargs): self._hybrid_factory = hybrid_factory super(HybridCompatibilityMixin, self).__init__(*args, **kwargs) - def setup(self, reporter, platform, lambda_protocol, + def setup(self, reporter, lambda_protocol, temperature=298.15 * unit.kelvin, n_replicas=None, - endstates=True, minimization_steps=100): + endstates=True, minimization_steps=100, + minimization_platform="CPU"): """ Setup MultistateSampler based on the input lambda protocol and number of replicas. @@ -45,8 +46,6 @@ def setup(self, reporter, platform, lambda_protocol, ---------- reporter : OpenMM reporter Simulation reporter to attach to each simulation replica. - platform : openmm.Platform - Platform to perform simulation on. lambda_protocol : LambdaProtocol The lambda protocol to be used for simulation. Default to a default class creation of LambdaProtocol. @@ -60,7 +59,9 @@ class creation of LambdaProtocol. Whether or not to generate unsampled endstates (i.e. dispersion correction). minimization_steps : int - Number of steps to minimize states. + Number of steps to pre-minimize states. + minimization_platform : str + Platform to do the initial pre-minimization with. Attributes ---------- @@ -84,8 +85,6 @@ class creation of LambdaProtocol. thermodynamic_state_list = [] sampler_state_list = [] - context_cache = cache.ContextCache(platform) - if n_replicas is None: msg = (f"setting number of replicas to number of states: {n_states}") warnings.warn(msg) @@ -118,12 +117,14 @@ class creation of LambdaProtocol. # now generating a sampler_state for each thermodyanmic state, # with relaxed positions - context, context_integrator = context_cache.get_context( - compound_thermostate_copy) + # Note: remove once choderalab/openmmtools#672 is completed minimize(compound_thermostate_copy, sampler_state, - max_iterations=minimization_steps) + max_iterations=minimization_steps, + platform_name=minimization_platform) sampler_state_list.append(copy.deepcopy(sampler_state)) + del compound_thermostate, sampler_state + # making sure number of sampler states equals n_replicas if len(sampler_state_list) != n_replicas: # picking roughly evenly spaced sampler states @@ -261,8 +262,10 @@ def create_endstates(first_thermostate, last_thermostate): return unsampled_endstates -def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: states.SamplerState, - max_iterations: int=100) -> states.SamplerState: +def minimize(thermodynamic_state: states.ThermodynamicState, + sampler_state: states.SamplerState, + max_iterations: int=100, + platform_name: str="CPU") -> states.SamplerState: """ Adapted from perses.dispersed.feptasks.minimize @@ -277,20 +280,28 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat The starting state at which to minimize the system. max_iterations : int, optional, default 100 The maximum number of minimization steps. Default is 100. + platform_name : str + The OpenMM platform name to carry out the minimization with. Returns ------- sampler_state : openmmtools.states.SamplerState The posititions and accompanying state following minimization """ - integrator = openmm.VerletIntegrator(1.0) #we won't take any steps, so use a simple integrator - context, integrator = cache.global_context_cache.get_context( + # we won't take any steps, so use a simple integrator + integrator = openmm.VerletIntegrator(1.0) + platform = openmm.Platform.getPlatformByName(platform_name) + dummy_cache = cache.DummyContextCache(platform=platform) + context, integrator = dummy_cache.get_context( thermodynamic_state, integrator ) - sampler_state.apply_to_context( - context, ignore_velocities=True - ) - openmm.LocalEnergyMinimizer.minimize( - context, maxIterations=max_iterations - ) - sampler_state.update_from_context(context) + try: + sampler_state.apply_to_context( + context, ignore_velocities=True + ) + openmm.LocalEnergyMinimizer.minimize( + context, maxIterations=max_iterations + ) + sampler_state.update_from_context(context) + finally: + del context, integrator, dummy_cache diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 2ddb5fce7..3ebf99251 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -323,20 +323,20 @@ def run(self, *, dry=False, verbose=True, # a. check timestep correctness + that # equilibration & production are divisible by n_steps - prototol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings'] + protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings'] stateA = self._inputs['stateA'] stateB = self._inputs['stateB'] mapping = self._inputs['ligandmapping'] - forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = prototol_settings.forcefield_settings - thermo_settings: settings.ThermoSettings = prototol_settings.thermo_settings - alchem_settings: AlchemicalSettings = prototol_settings.alchemical_settings - system_settings: SystemSettings = prototol_settings.system_settings - solvation_settings: SolvationSettings = prototol_settings.solvation_settings - sampler_settings: AlchemicalSamplerSettings = prototol_settings.alchemical_sampler_settings - sim_settings: SimulationSettings = prototol_settings.simulation_settings - timestep = prototol_settings.integrator_settings.timestep - mc_steps = prototol_settings.integrator_settings.n_steps.m + forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings + thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings + alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings + system_settings: SystemSettings = protocol_settings.system_settings + solvation_settings: SolvationSettings = protocol_settings.solvation_settings + sampler_settings: AlchemicalSamplerSettings = protocol_settings.alchemical_sampler_settings + sim_settings: SimulationSettings = protocol_settings.simulation_settings + timestep = protocol_settings.integrator_settings.timestep + mc_steps = protocol_settings.integrator_settings.n_steps.m # is the timestep good for the mass? if forcefield_settings.hydrogen_mass < 3.0: @@ -536,9 +536,9 @@ def run(self, *, dry=False, verbose=True, if 'solvent' in stateA.components: hybrid_factory.hybrid_system.addForce( openmm.MonteCarloBarostat( - prototol_settings.thermo_settings.pressure.to(unit.bar).m, - prototol_settings.thermo_settings.temperature.m, - prototol_settings.integrator_settings.barostat_frequency.m, + protocol_settings.thermo_settings.pressure.to(unit.bar).m, + protocol_settings.thermo_settings.temperature.m, + protocol_settings.integrator_settings.barostat_frequency.m, ) ) @@ -572,24 +572,14 @@ def run(self, *, dry=False, verbose=True, checkpoint_storage=shared_basepath / sim_settings.checkpoint_storage, ) - # 10. Get platform and context caches + # 10. Get platform platform = _rfe_utils.compute.get_openmm_platform( - prototol_settings.engine_settings.compute_platform - ) - - # a. Create context caches (energy + sampler) - # Note: these needs to exist on the compute node - energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, - ) - - sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, + protocol_settings.engine_settings.compute_platform ) # 11. Set the integrator # a. get integrator settings - integrator_settings = prototol_settings.integrator_settings + integrator_settings = protocol_settings.integrator_settings # b. create langevin integrator integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove( @@ -635,59 +625,91 @@ def run(self, *, dry=False, verbose=True, sampler.setup( n_replicas=sampler_settings.n_replicas, reporter=reporter, - platform=platform, lambda_protocol=lambdas, - temperature=to_openmm(prototol_settings.thermo_settings.temperature), + temperature=to_openmm(protocol_settings.thermo_settings.temperature), endstates=alchem_settings.unsampled_endstates, + minimization_platform=platform.getName(), ) - sampler.energy_context_cache = energy_context_cache - sampler.sampler_context_cache = sampler_context_cache + try: + # Create context caches (energy + sampler) + energy_context_cache = openmmtools.cache.ContextCache( + capacity=None, time_to_live=None, platform=platform, + ) - if not dry: # pragma: no-cover - # minimize - if verbose: - logger.info("minimizing systems") + sampler_context_cache = openmmtools.cache.ContextCache( + capacity=None, time_to_live=None, platform=platform, + ) - sampler.minimize(max_iterations=sim_settings.minimization_steps) + sampler.energy_context_cache = energy_context_cache + sampler.sampler_context_cache = sampler_context_cache - # equilibrate - if verbose: - logger.info("equilibrating systems") + if not dry: # pragma: no-cover + # minimize + if verbose: + logger.info("minimizing systems") - sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore + sampler.minimize(max_iterations=sim_settings.minimization_steps) - # production - if verbose: - logger.info("running production phase") + # equilibrate + if verbose: + logger.info("equilibrating systems") - sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore + sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore - # calculate estimate of results from this individual unit - ana = multistate.MultiStateSamplerAnalyzer(reporter) - est, _ = ana.get_free_energy() - est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole) - est = ensure_quantity(est, 'openff') + # production + if verbose: + logger.info("running production phase") - # close reporter when you're done + sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore + + # calculate estimate of results from this individual unit + ana = multistate.MultiStateSamplerAnalyzer(reporter) + est, _ = ana.get_free_energy() + est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole) + est = ensure_quantity(est, 'openff') + + nc = shared_basepath / sim_settings.output_filename + chk = shared_basepath / sim_settings.checkpoint_storage + else: + # clean up the reporter file + fns = [shared_basepath / sim_settings.output_filename, + shared_basepath / sim_settings.checkpoint_storage] + for fn in fns: + os.remove(fn) + finally: + # close reporter when you're done, prevent file handle clashes reporter.close() + del reporter + + # clean up the analyzer + if not dry: + ana.clear() + del ana + + # clear GPU contexts + # TODO: use cache.empty() calls when openmmtools #690 is resolved + # replace with above + for context in list(energy_context_cache._lru._data.keys()): + del energy_context_cache._lru._data[context] + for context in list(sampler_context_cache._lru._data.keys()): + del sampler_context_cache._lru._data[context] + # cautiously clear out the global context cache too + for context in list( + openmmtools.cache.global_context_cache._lru._data.keys()): + del openmmtools.cache.global_context_cache._lru._data[context] + + del sampler_context_cache, energy_context_cache + if not dry: + del integrator, sampler - nc = shared_basepath / sim_settings.output_filename - chk = shared_basepath / sim_settings.checkpoint_storage + if not dry: # pragma: no-cover return { 'nc': nc, 'last_checkpoint': chk, 'unit_estimate': est, } else: - # close reporter when you're done, prevent file handle clashes - reporter.close() - - # clean up the reporter file - fns = [shared_basepath / sim_settings.output_filename, - shared_basepath / sim_settings.checkpoint_storage] - for fn in fns: - os.remove(fn) return {'debug': {'sampler': sampler}} def _execute(