Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

murcko scaffold split #18

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/beignet.subsets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# beignet subsets

::: beignet.subsets.murcko_scaffold_split
5 changes: 5 additions & 0 deletions src/beignet/subsets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._murcko_scaffold_split import murcko_scaffold_split

__all__ = [
"murcko_scaffold_split",
]
122 changes: 122 additions & 0 deletions src/beignet/subsets/_murcko_scaffold_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import math
import random
from collections import defaultdict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make explicit.

from typing import Sequence

from torch.utils.data import Dataset, Subset

try:
from rdkit import Chem
from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol

_RDKit_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
_RDKit_AVAILABLE = False
Chem, MurckoScaffoldSmiles = None, None


def murcko_scaffold_split(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add generator: Generator arg.

dataset: Dataset,
smiles: Sequence[str],
test_size: float | int,
*,
seed: int = 0xDEADBEEF,
shuffle: bool = True,
include_chirality: bool = False,
) -> tuple[Subset, Subset]:
Copy link
Collaborator

@0x00b1 0x00b1 May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better design is returning List[Subset] based on a lengths parameter, e.g., see torch.utils.data.random_split.

"""
Creates datasets subsets with disjoint Murcko scaffolds based
on provided SMILES strings.

Note that for datasets that are small or not highly diverse,
the final test set may be smaller than the specified test_size.

Parameters
----------
dataset : Dataset
The dataset to split.
smiles : Sequence[str]
A list of SMILES strings.
test_size : float | int
The size of the test set. If float, should be between 0.0 and 1.0.
If int, should be between 0 and len(smiles).
seed : int, optional
The random seed to use for shuffling, by default 0xDEADBEEF
shuffle : bool, optional
Whether to shuffle the indices, by default True
include_chirality : bool, optional
Whether to include chirality in the scaffold, by default False

Returns
-------
tuple[Subset, Subset]
The train and test subsets.

References
----------
- Bemis, G. W., & Murcko, M. A. (1996). The properties of known drugs.
1. Molecular frameworks. Journal of medicinal chemistry, 39(15), 2887–2893.
https://doi.org/10.1021/jm9602928
- "RDKit: Open-source cheminformatics. https://www.rdkit.org"
"""
train_idx, test_idx = _murcko_scaffold_split_indices(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline?

smiles,
test_size,
seed=seed,
shuffle=shuffle,
include_chirality=include_chirality,
)
return Subset(dataset, train_idx), Subset(dataset, test_idx)


def _murcko_scaffold_split_indices(
smiles: list[str],
test_size: float | int,
*,
seed: int = 0xDEADBEEF,
shuffle: bool = True,
include_chirality: bool = False,
) -> tuple[list[int], list[int]]:
"""
Get train and test indices based on Murcko scaffolds."""
if not _RDKit_AVAILABLE:
raise ImportError(
"This function requires RDKit to be installed (pip install rdkit)"
)

if (
isinstance(test_size, int) and (test_size <= 0 or test_size >= len(smiles))
) or (isinstance(test_size, float) and (test_size <= 0 or test_size >= 1)):
raise ValueError(
f"Test_size should be a float in (0, 1) or and int < {len(smiles)}."
)

if isinstance(test_size, float):
test_size = math.ceil(len(smiles) * test_size)

scaffolds = defaultdict(list)

for ind, s in enumerate(smiles):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid abbreviations, e.g., use index and sequence.

mol = Chem.MolFromSmiles(s)
if mol is not None:
scaffold = Chem.MolToSmiles(
GetScaffoldForMol(mol), isomericSmiles=include_chirality
)
scaffolds[scaffold].append(ind)

train_idx = []
test_idx = []

if shuffle:
if seed is not None:
random.Random(seed).shuffle(scaffolds)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use generator.

else:
random.shuffle(scaffolds)

for index_list in scaffolds.values():
if len(test_idx) + len(index_list) <= test_size:
test_idx = [*test_idx, *index_list]
else:
train_idx.extend(index_list)

return train_idx, test_idx
71 changes: 71 additions & 0 deletions tests/beignet/subsets/test__murcko_scaffold_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from importlib.util import find_spec
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make explicit.

from unittest.mock import MagicMock, patch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from unittest.mock import MagicMock
import unittest.mock

Use unittest.mock.patch explicitly.


import pytest
from beignet.subsets._murcko_scaffold_split import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from beignet.subsets import (...)

_murcko_scaffold_split_indices,
murcko_scaffold_split,
)
from torch.utils.data import Dataset, Subset

_RDKit_AVAILABLE = find_spec("rdkit") is not None


@pytest.mark.skipif(not _RDKit_AVAILABLE, reason="RDKit is not available")
@patch("beignet.subsets._murcko_scaffold_split._murcko_scaffold_split_indices")
def test_murcko_scaffold_split(mock__murcko_scaffold_split_indices):
mock__murcko_scaffold_split_indices.return_value = ([0], [1])

mock_dataset = MagicMock(spec=Dataset)

train_dataset, test_dataset = murcko_scaffold_split(
dataset=mock_dataset,
smiles=["C", "C"],
test_size=0.5,
shuffle=False,
seed=0,
)

assert isinstance(train_dataset, Subset)
assert isinstance(test_dataset, Subset)
assert train_dataset.indices == [0]
assert test_dataset.indices == [1]


@pytest.mark.skipif(not _RDKit_AVAILABLE, reason="RDKit is not available")
@pytest.mark.parametrize(
"test_size, expected_train_idx, expected_test_idx",
[
pytest.param(0.5, [2, 3], [0, 1], id="test_size is float"),
pytest.param(2, [2, 3], [0, 1], id="test_size is int"),
],
)
def test__murcko_scaffold_split_indices(
test_size, expected_train_idx, expected_test_idx
):
smiles = ["C1CCCCC1", "C1CCCCC1", "CCO", "CCO"]

train_idx, test_idx = _murcko_scaffold_split_indices(
smiles,
test_size=test_size,
)
assert train_idx == expected_train_idx
assert test_idx == expected_test_idx


@pytest.mark.skipif(not _RDKit_AVAILABLE, reason="RDKit is not available")
@pytest.mark.parametrize(
"smiles, test_size",
[
pytest.param(["CCO"], 1.2, id="test_size is float > 1"),
pytest.param(["CCO"], -1, id="test_size is negative"),
pytest.param(["CCO"], 0, id="test_size is 0"),
pytest.param(["CCO"], 5, id="test_size > len(smiles)"),
],
)
def test__murcko_scaffold_split_indices_invalid_inputs(smiles, test_size):
with pytest.raises(ValueError):
_murcko_scaffold_split_indices(
smiles,
test_size=test_size,
)