Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scripts to prepare denali dataset for training #118

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions scripts/denali-dataset/make_chembl_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import glob
import os
import numpy as np
from chembl_webresource_client.new_client import new_client


def run(path):
chembl_xyz_paths = glob.glob(os.path.join(path, "CHEMBL*"))
chembl_ids_denali = {}
for xyz_path in chembl_xyz_paths:
_, tail = os.path.split(xyz_path)
chembl_id = tail.split("_")[0]
chembl_ids_denali[tail] = {"path": xyz_path,
"chembl_id": chembl_id}

mols = new_client.molecule
chembl_ids = list(set([chembl_ids_denali[key]["chembl_id"] for key in chembl_ids_denali.keys()]))
m1 = mols.filter(molecule_chembl_id__in=chembl_ids).only(['molecule_chembl_id', 'molecule_structures'])
chemblid_smiles = {}
while True:
try:
result = next(m1)
except StopIteration:
break
except:
continue
chembl_id_result = result['molecule_chembl_id']
smiles = result['molecule_structures']['canonical_smiles']
chemblid_smiles[chembl_id_result] = smiles

for key in chembl_ids_denali.keys():
try:
chembl_ids_denali[key]["canonical_smiles"] = chemblid_smiles[chembl_ids_denali[key]["chembl_id"]]
except:
pass
xyz_path = chembl_ids_denali[key]["path"]
xyz_files = glob.glob(os.path.join(xyz_path, "*.xyz"))
coordinates = []
sample_ids = []
for xyz_file in xyz_files:
_, tail = os.path.split(xyz_file)
sample_ids.append(tail.rstrip(".xyz"))
with open(xyz_file, "r") as f:
next(f)
line = next(f)
multiplicity, charge = line.strip("\n").split()
species = []
coords = []
while True:
try:
line = next(f).strip("\n").split()
species.append(line[0])
coords.append([float(line[1]), float(line[2]), float(line[3])])
except StopIteration:
break
coordinates.append(np.array(coords))
coordinates = np.array(coordinates)
chembl_ids_denali[key]['species'] = species
chembl_ids_denali[key]['coordinates'] = coordinates
chembl_ids_denali[key]['charge'] = charge
chembl_ids_denali[key]['multiplicity'] = multiplicity
chembl_ids_denali[key]['sample_ids'] = sample_ids
chembl_ids_denali[key]['energies'] = np.zeros((len(sample_ids)))

with open("denali_labels.csv", "r") as f:
next(f)
while True:
try:
line = next(f).split(",")
if line[3] in chembl_ids_denali.keys():
if line[1] in chembl_ids_denali[line[3]]['sample_ids']:
sample_idx = chembl_ids_denali[line[3]]['sample_ids'].index(line[1])
chembl_ids_denali[line[3]]['energies'][sample_idx] = float(line[9])
except StopIteration:
break

import pickle

with open("denali_dataset_dict.pkl", "wb") as f:
pickle.dump(chembl_ids_denali, f)


if __name__ == "__main__":
import sys
run(sys.argv[1])
205 changes: 205 additions & 0 deletions scripts/denali-dataset/match_smiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import pickle
import torch
import espaloma as esp
import numpy as np
from openeye import oechem


def run(path, u_thres=0.1):
denali_data = pickle.load(open(path, "rb"))

for key in denali_data.keys():
try:
if key.split("_")[-1] == "conformers":
if 'canonical_smiles' not in denali_data[key].keys():
continue
smiles = denali_data[key]['canonical_smiles']
xs = denali_data[key]['coordinates']
us = denali_data[key]['energies']
species = denali_data[key]['species']

idxs = list(range(len(xs)))
idx_ref = us.argmin()
ok_idxs = [idx for idx in idxs if us[idx] <= us[idx_ref] + u_thres]


g = infer_mol_from_coordinates(xs[idx_ref], species, smiles_ref=smiles)

final_idxs = [idx_ref]
for idx in ok_idxs:
if idx == idx_ref:
continue
if check_offeq_graph(xs[idx], species, smiles_ref=smiles):
final_idxs.append(idx)

g.nodes['n1'].data['xyz'] = torch.tensor(xs[final_idxs, :, :]).transpose(1, 0)
g.nodes['g'].data['u_ref'] = torch.tensor(us[None, final_idxs])
g.save("denali/%s" % (key))
except Exception as ex:
print(ex)
with open("./denali_smiles_errors.dat", "a") as error_file:
error_file.write(f"{key}, {smiles}, {ex}\n")


def infer_mol_from_coordinates(
coordinates,
species,
smiles_ref=None,
coordinates_unit="angstrom",
):

# local import
from simtk import unit
from simtk.unit import Quantity

if isinstance(coordinates_unit, str):
coordinates_unit = getattr(unit, coordinates_unit)

# make sure we have the coordinates
# in the unit system
coordinates = Quantity(coordinates, coordinates_unit).value_in_unit(
unit.angstrom # to make openeye happy
)

# initialize molecule
mol = oechem.OEGraphMol()

if all(isinstance(symbol, str) for symbol in species):
[
mol.NewAtom(getattr(oechem, "OEElemNo_" + symbol))
for symbol in species
]

elif all(isinstance(symbol, int) for symbol in species):
[
mol.NewAtom(
getattr(
oechem, "OEElemNo_" + oechem.OEGetAtomicSymbol(symbol)
)
)
for symbol in species
]

else:
raise RuntimeError(
"The species can only be all strings or all integers."
)

ims = oechem.oemolistream()
ims.SetFormat(oechem.OEFormat_SMI)
ims.openstring(smiles_ref)
ref_mol = next(ims.GetOEMols())

mol.SetCoords(coordinates.reshape([-1]))
mol.SetDimension(3)
oechem.OEDetermineConnectivity(mol)
oechem.OEFindRingAtomsAndBonds(mol)
oechem.OEPerceiveBondOrders(mol)

smiles_can = oechem.OEMolToSmiles(mol)
smiles_ref = oechem.OEMolToSmiles(mol)
if smiles_ref != smiles_can:
print([atom.GetAtomicNum() for atom in mol.GetAtoms()])
print([atom.GetAtomicNum() for atom in ref_mol.GetAtoms()])
tmp_mol = oechem.OEGraphMol()
if all(isinstance(symbol, str) for symbol in species):
[
tmp_mol.NewAtom(getattr(oechem, "OEElemNo_" + symbol))
for symbol in species
]

elif all(isinstance(symbol, int) for symbol in species):
[
tmp_mol.NewAtom(
getattr(
oechem, "OEElemNo_" + oechem.OEGetAtomicSymbol(symbol)
)
)
for symbol in species
]
tmp_mol.SetCoords(coordinates.reshape([-1]))
tmp_mol.SetDimension(3)
print([(atom.GetAtomicNum(), atom.GetValence()) for atom in tmp_mol.GetAtoms()])
oechem.OEDetermineConnectivity(tmp_mol)
print([(atom.GetAtomicNum(), atom.GetValence()) for atom in tmp_mol.GetAtoms()])
oechem.OEFindRingAtomsAndBonds(tmp_mol)
print([(atom.GetAtomicNum(), atom.GetValence()) for atom in tmp_mol.GetAtoms()])
oechem.OEPerceiveBondOrders(tmp_mol)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to explicitly perceive aromaticity after this?

print([(atom.GetAtomicNum(), atom.GetValence()) for atom in tmp_mol.GetAtoms()])
print([oechem.OECheckAtomValence(atom) for atom in tmp_mol.GetAtoms()])

assert (
smiles_ref == smiles_can
), "SMILES different. Input is %s, ref is %s" % (
smiles_can,
smiles_ref,
)

from openff.toolkit.topology import Molecule

_mol = Molecule.from_openeye(mol, allow_undefined_stereo=True)
g = esp.Graph(_mol)

return g


def check_offeq_graph(
coordinates,
species,
smiles_ref=None,
coordinates_unit="angstrom",
):

# local import
from simtk import unit
from simtk.unit import Quantity

if isinstance(coordinates_unit, str):
coordinates_unit = getattr(unit, coordinates_unit)

# make sure we have the coordinates
# in the unit system
coordinates = Quantity(coordinates, coordinates_unit).value_in_unit(
unit.angstrom # to make openeye happy
)

# initialize molecule
mol = oechem.OEGraphMol()

if all(isinstance(symbol, str) for symbol in species):
[
mol.NewAtom(getattr(oechem, "OEElemNo_" + symbol))
for symbol in species
]

elif all(isinstance(symbol, int) for symbol in species):
[
mol.NewAtom(
getattr(
oechem, "OEElemNo_" + oechem.OEGetAtomicSymbol(symbol)
)
)
for symbol in species
]

else:
raise RuntimeError(
"The species can only be all strings or all integers."
)

mol.SetCoords(coordinates.reshape([-1]))
mol.SetDimension(3)
oechem.OEDetermineConnectivity(mol)
oechem.OEFindRingAtomsAndBonds(mol)
oechem.OEPerceiveBondOrders(mol)

smiles_can = oechem.OEMolToSmiles(mol)
smiles_ref = oechem.OEMolToSmiles(mol)
if smiles_ref != smiles_can:
return False
return True


if __name__ == "__main__":
import sys
run(sys.argv[1])
32 changes: 32 additions & 0 deletions scripts/denali-dataset/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import espaloma as esp

def run(in_path, out_path, u_threshold=0.1):
g = esp.Graph.load(in_path)
from espaloma.data.md import subtract_nonbonded_force
g = subtract_nonbonded_force(g, subtract_charges=True)

# get number of snapshots
n_data = g.nodes['n1'].data['xyz'].shape[1]
u_min = g.nodes['g'].data['u_ref'].min().item()

print(n_data)

# original indicies
idxs = list(range(n_data))
idxs = [idx for idx in idxs if g.nodes['g'].data['u_ref'][:, idx].item() < u_min + u_threshold]

g.nodes['n1'].data['xyz'] = g.nodes['n1'].data['xyz'][:, idxs, :]
g.nodes['g'].data['u_ref'] = g.nodes['g'].data['u_ref'][:, idxs]

n_data = len(idxs)

print(n_data)
if n_data > 1:
g.save(out_path)

if __name__ == "__main__":
import sys
run(sys.argv[1], sys.argv[2])