diff --git a/seamm_ase/seamm_ase.py b/seamm_ase/seamm_ase.py index 27adf7d..ed69206 100644 --- a/seamm_ase/seamm_ase.py +++ b/seamm_ase/seamm_ase.py @@ -2,11 +2,16 @@ __all__ = ["SEAMM_Calculator"] +import logging + from ase.calculators.calculator import ( Calculator as ASE_Calculator, all_changes as ASE_all_changes, + register_calculator_class, ) +logger = logging.getLogger(__name__) + class SEAMM_Calculator(ASE_Calculator): """Generic ASE calculator for SEAMM. @@ -60,7 +65,7 @@ def calculator( implemented_properties = ["energy", "forces"] nolabel = True - def __init__(self, step, **kwargs): + def __init__(self, step, calculator=None, name=None, configuration=None, **kwargs): """ Parameters ---------- @@ -71,8 +76,26 @@ def __init__(self, step, **kwargs): The keyword arguments are passed to the parent class. """ self.step = step + self.calculator = calculator # Method or function to call + self._name = name + self._configuration = configuration + super().__init__(**kwargs) + @property + def configuration(self): + """The configuration this calculator represents.""" + return self._configuration + + @property + def name(self): + """A name for this calculator.""" + if self._name is None and self.configuration is not None: + name = self.configuration.system.name + "/" + self.configuration.name + return name + else: + return self._name + def calculate( self, atoms=None, @@ -97,4 +120,25 @@ def calculate( """ super().calculate(atoms, properties, system_changes) - self.step.calculator(self, properties, system_changes) + logger.debug(f"SEAMM_Calculator.calculate {self.name} {properties=}") + logger.debug(f" {system_changes=}") + logger.debug(f" {atoms is None=}") + + if self.calculator is None: + self.step.calculate(self, properties, system_changes) + else: + self.calculator(self, properties, system_changes) + + def check_state(self, atoms, tol=1e-10): + """Check for any system changes since last calculation.""" + return super().check_state(atoms, tol=tol) + + def get_property(self, name, atoms=None, allow_calculation=True): + logger.debug(f"SEAMM_Calculator.get_property {self.name} {name=}") + + return super().get_property( + name, atoms=atoms, allow_calculation=allow_calculation + ) + + +register_calculator_class("seamm", SEAMM_Calculator)