Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/train sampler cuda memory #15

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
38 changes: 30 additions & 8 deletions espfit/app/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def output_directory_path(self, value):
os.makedirs(value, exist_ok=True)


def load_traj(self, reference_pdb='solvated.pdb', trajectory_netcdf='traj.nc', stride=1, input_directory_path=None):
def load_traj(self, reference_pdb='solvated.pdb', trajectory_netcdf='traj.nc', stride=1, exclude_n_frames=0, input_directory_path=None):
"""Load MD trajectory.

Parameters
Expand All @@ -71,6 +71,11 @@ def load_traj(self, reference_pdb='solvated.pdb', trajectory_netcdf='traj.nc', s

stride : int, optional
Stride to load the trajectory. Default is 1.

exclude_n_frames : float or int, optional
Exclude the first n frames from the trajectory. Default is 0.
If float, the first n percentange of frames will be excluded.
If int, the first n frames will be excluded.

input_directory_path : str, optional
Input directory path. Default is None.
Expand All @@ -79,6 +84,10 @@ def load_traj(self, reference_pdb='solvated.pdb', trajectory_netcdf='traj.nc', s
Returns
-------
None

Notes
-----
* Should `stride` and `exclude_n_frames` be set as class attributes?
"""
import mdtraj

Expand Down Expand Up @@ -114,6 +123,17 @@ def load_traj(self, reference_pdb='solvated.pdb', trajectory_netcdf='traj.nc', s
else:
self.traj = traj

# Exclude the first n frames from the trajectory
if isinstance(exclude_n_frames, int):
if 0 < exclude_n_frames <= traj.n_frames:
self.traj = self.traj[exclude_n_frames:]
elif isinstance(exclude_n_frames, float):
if 0 < exclude_n_frames < 1:
exclude_n_frames = int(exclude_n_frames * traj.n_frames)
self.traj = self.traj[exclude_n_frames:]
else:
raise ValueError(f"Invalid exclude_n_frames: {exclude_n_frames}. Expected int or float.")


class RNASystem(BaseDataLoader):
"""RNA system class to compute experimental observables from MD simulations.
Expand Down Expand Up @@ -175,7 +195,7 @@ def radian_to_degree(self, a):
return a


def compute_jcouplings(self, weights=None, couplings=None, residues=None):
def compute_jcouplings(self, weights=None, couplings=['H1H2', 'H2H3', 'H3H4'], residues=None):
"""Compute J-couplings from MD trajectory.

Parameters
Expand All @@ -184,7 +204,7 @@ def compute_jcouplings(self, weights=None, couplings=None, residues=None):
Weights to compute the J-couplings. Default is None.

couplings : str, optional
Name of the couplings to compute. Default is None.
Name of the couplings to compute. Default is ['H1H2', 'H2H3', 'H3H4'].
If a list of couplings to be chosen from [H1H2,H2H3,H3H4,1H5P,2H5P,C4Pb,1H5H4,2H5H4,H3P,C4Pe,H1C2/4,H1C6/8]
is provided, only the selected couplings will be computed. Otherwise, all couplings will be computed.

Expand All @@ -208,7 +228,7 @@ def compute_jcouplings(self, weights=None, couplings=None, residues=None):
"""
import barnaba as bb

_logger.info("Compute J-couplings from MD trajectory")
_logger.debug("Compute J-couplings from MD trajectory")

if couplings is not None:
# Check if the provided coupling names are valid
Expand All @@ -228,23 +248,25 @@ def compute_jcouplings(self, weights=None, couplings=None, residues=None):
for i, resname in enumerate(resname_list):
_values = values[:,i,:] # Coupling values of i-th residue
values_by_names = dict()
for j, coupling_name in enumerate(couplings):
for j, coupling_name in enumerate(couplings):
avg_raw = np.round(_values[:,j].mean(), 5) # e.g. mean value of H1H2 coupling of i-th residue
std_raw = np.round(_values[:,j].std(), 5) # e.g. standard deviation of H1H2 coupling of i-th residue
avg_raw = replace_nan_with_none(avg_raw)
std_raw = replace_nan_with_none(std_raw)
if weights is not None:
arr = _values[:,j] * weights
#_logger.info(f'non-weighted: {_values[:,j]}')
#_logger.info(f'weights: {weights}')
#_logger.info(f'weighted: {arr}')
_logger.info(f'weights:\n{weights}')
_logger.info(f'non-weighted observalble:\n{_values[:,j]}')
_logger.info(f'weighted observable:\n{arr}')
avg = np.round(arr.mean(), 5)
std = np.round(arr.std(), 5)
avg = replace_nan_with_none(avg)
std = replace_nan_with_none(std)
else:
_logger.info(f'No weights are provided. Using predicted values without reweighting.')
avg = avg_raw
std = std_raw
_logger.info(f'J-scalar ({coupling_name}-{resname}): avg={avg}, std={std}')
values_by_names[coupling_name] = {'avg': avg, 'std': std, 'avg_raw': avg_raw, 'std_raw': std_raw}
coupling_dict[resname] = values_by_names

Expand Down
Loading
Loading