Skip to content

Commit

Permalink
Allow passing mol object to bust (#12)
Browse files Browse the repository at this point in the history
CLI:
- `bust` method now also accepts a single path or a single RDKit molecule object.
- Bug fix for #11

Energy ratio check and distance geometry check:
- Instead of raising an error, check runs and returns NA for molecule objects that do not contain a conformation.
  • Loading branch information
maabuu authored Sep 6, 2023
1 parent 1b265d3 commit fde8e89
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 17 deletions.
4 changes: 2 additions & 2 deletions docs/source/api_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
"metadata": {},
"outputs": [],
"source": [
"pred_file = Path(\"inputs/generated_molecules.sdf\") # generated molecules\n",
"true_file = Path(\"inputs/crystal_ligand.sdf\") # generated molecules\n",
"pred_file = Path(\"inputs/generated_molecules.sdf\") # predicted or generated molecules\n",
"true_file = Path(\"inputs/crystal_ligand.sdf\") # \"ground truth\" molecules\n",
"cond_file = Path(\"inputs/protein.pdb\") # conditioning molecule"
]
},
Expand Down
2 changes: 1 addition & 1 deletion posebusters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
"check_volume_overlap",
]

__version__ = "0.2.4"
__version__ = "0.2.5"
8 changes: 5 additions & 3 deletions posebusters/modules/distance_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def check_geometry(
"""Use RDKit distance geometry bounds to check the geometry of a molecule.
Args:
mol_pred: Predicted molecule (docked ligand) with exactly one conformer.
mol_pred: Predicted molecule (docked ligand). Only the first conformer will be checked.
threshold_bad_bond_length: Bond length threshold in relative percentage. 0.2 means that bonds may be up to 20%
longer than DG bounds. Defaults to 0.2.
threshold_clash: Threshold for how much overlap constitutes a clash. 0.2 means that the two atoms may be up to
Expand All @@ -91,10 +91,12 @@ def check_geometry(
PoseBusters results dictionary.
"""
mol = deepcopy(mol_pred)
assert mol.GetNumConformers() == 1, "Molecule must have exactly one conformer."

results = _empty_results.copy()

if mol.GetNumConformers() == 0:
logger.warning("Molecule does not have a conformer.")
return {"results": results}

if mol.GetNumAtoms() == 1:
logger.warning(f"Molecule has only {mol.GetNumAtoms()} atoms.")
results[col_angles_result] = True
Expand Down
2 changes: 2 additions & 0 deletions posebusters/modules/energy_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def check_energy_ratio(
PoseBusters results dictionary.
"""
mol_pred = deepcopy(mol_pred)

try:
assert mol_pred.GetNumConformers() > 0, "Molecule does not have a conformer."
SanitizeMol(mol_pred)
AddHs(mol_pred, addCoords=True)
except Exception:
Expand Down
1 change: 1 addition & 0 deletions posebusters/modules/flatness.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def check_flatness(
mol = deepcopy(mol_pred)
# if mol cannot be sanitized, then rdkit may not find substructures
try:
assert mol_pred.GetNumConformers() > 0, "Molecule does not have a conformer."
SanitizeMol(mol)
except Exception:
return _empty_results
Expand Down
18 changes: 13 additions & 5 deletions posebusters/posebusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def __init__(self, config: str | dict[str, Any] = "redock", top_n: int | None =

def bust(
self,
mol_pred: Iterable[Mol | Path],
mol_pred: Iterable[Mol | Path] | Mol | Path,
mol_true: Mol | Path | None = None,
mol_cond: Mol | Path | None = None,
full_report: bool = False,
) -> pd.DataFrame:
"""Run all tests on one molecule.
"""Run tests on one or more molecules.
Args:
mol_pred: Generated molecule, e.g. docked ligand, with one or more poses.
mol_pred: Generated molecule(s), e.g. de-novo generated molecule or docked ligand, with one or more poses.
mol_true: True molecule, e.g. crystal ligand, with one or more poses.
mol_cond: Conditioning molecule, e.g. protein.
full_report: Whether to include all columns in the output or only the boolean ones specified in the config.
Expand All @@ -83,16 +83,21 @@ def bust(
Returns:
Pandas dataframe with results.
"""
mol_pred = [mol_pred] if isinstance(mol_pred, (Mol, Path, str)) else mol_pred

columns = ["mol_pred", "mol_true", "mol_cond"]
self.file_paths = pd.DataFrame([[mol_pred, mol_true, mol_cond] for mol_pred in mol_pred], columns=columns)

results_gen = self._run()

df = pd.concat([_dataframe_from_output(d, self.config, full_report=full_report) for d in results_gen])
df.index.names = ["file", "molecule"]
df.columns = [c.lower().replace(" ", "_") for c in df.columns]

return df

def bust_table(self, mol_table: pd.DataFrame, full_report: bool = False) -> pd.DataFrame:
"""Run all tests on multiple molecules provided in pandas dataframe as paths or rdkit molecule objects.
"""Run tests on molecules provided in pandas dataframe as paths or rdkit molecule objects.
Args:
mol_table: Pandas dataframe with columns "mol_pred", "mol_true", "mol_cond" containing paths to molecules.
Expand All @@ -102,10 +107,13 @@ def bust_table(self, mol_table: pd.DataFrame, full_report: bool = False) -> pd.D
Pandas dataframe with results.
"""
self.file_paths = mol_table

results_gen = self._run()

df = pd.concat([_dataframe_from_output(d, self.config, full_report=full_report) for d in results_gen])
df.index.names = ["file", "molecule"]
df.columns = [c.lower().replace(" ", "_") for c in df.columns]

return df

def _run(self) -> Generator[dict, None, None]:
Expand Down Expand Up @@ -173,7 +181,7 @@ def _initialize_modules(self) -> None:
def _get_name(mol: Mol, i: int) -> str:
if mol is None:
return f"invalid_mol_at_pos_{i}"
elif mol.GetProp("_Name") == "":
elif not mol.HasProp("_Name") or mol.GetProp("_Name") == "":
return f"mol_at_pos_{i}"
else:
return mol.GetProp("_Name")
Expand Down
6 changes: 5 additions & 1 deletion posebusters/tools/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,16 @@ def safe_supply_mols(path: Path, load_all=True, sanitize=True, **load_params) ->
Returns:
Molecule object or None if loading failed.
"""
path = Path(path)
if isinstance(path, Mol):
yield path
return None

path = Path(path)
if path.suffix == ".sdf":
pass
elif path.suffix in {".mol", ".mol2"}:
yield safe_load_mol(path, sanitize=True, **load_params)
return None
else:
raise ValueError(f"Molecule file {path} has unknown format. Only .sdf, .mol and .mol2 are supported.")

Expand Down
36 changes: 31 additions & 5 deletions tests/test_posebusters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math

from rdkit.Chem.rdmolfiles import MolFromSmiles

from posebusters import PoseBusters

mols_table = "tests/conftest/sample_bust_docks_table.csv"
Expand All @@ -24,30 +26,54 @@

def test_bust_redocks_1ia1() -> None:
posebusters = PoseBusters("redock")
list(posebusters.bust([mol_pred_1ia1], mol_true_1ia1, mol_cond_1ia1))
df = posebusters.bust([mol_pred_1ia1], mol_true_1ia1, mol_cond_1ia1)
assert df.all(axis=1).values[0]


def test_bust_redocks_1w1p() -> None:
posebusters = PoseBusters("redock")
list(posebusters.bust([mol_pred_1w1p], mol_true_1w1p, mol_cond_1w1p))
df = posebusters.bust([mol_pred_1w1p], mol_true_1w1p, mol_cond_1w1p)
assert df.all(axis=1).values[0]


def test_bust_docks() -> None:
posebusters = PoseBusters("dock")
list(posebusters.bust([mol_pred_1ia1], mol_cond=mol_cond_1w1p))
df = posebusters.bust([mol_pred_1ia1], mol_cond=mol_cond_1ia1)
assert df.all(axis=1).values[0]


def test_bust_mols() -> None:
posebusters = PoseBusters("mol")
list(posebusters.bust([mol_pred_1ia1]))

# pass one not in list
df = posebusters.bust(mol_pred_1ia1)
assert df.all(axis=1).values[0]

# pass list
df = posebusters.bust([mol_pred_1ia1])
assert df.all(axis=1).values[0]


def test_bust_mol_rdkit() -> None:
posebusters = PoseBusters(config="mol")
mol = MolFromSmiles("C")

df = posebusters.bust(mol)
assert df.all(axis=1).values[0]

df = posebusters.bust([mol])
assert df.all(axis=1).values[0]


def test_bust_mols_hydrogen() -> None:
posebusters = PoseBusters("mol")
list(posebusters.bust([mol_single_h]))
df = posebusters.bust([mol_single_h])
assert df.sum(axis=1).values[0] >= 8 # energy ratio test fails


def test_bust_mols_consistency() -> None:
# check that running the same molecule twice gives the same result

posebusters = PoseBusters("mol")
result_2 = posebusters.bust([mol_conf_2])

Expand Down

0 comments on commit fde8e89

Please sign in to comment.