Skip to content

Commit

Permalink
Merge pull request #206 from datamol-io/circle_grid
Browse files Browse the repository at this point in the history
Better control of the circular layout in `circle_grid`
  • Loading branch information
hadim authored Jun 22, 2023
2 parents 3ff8561 + e01067a commit 639e684
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 21 deletions.
6 changes: 4 additions & 2 deletions datamol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"ChemicalReaction": "datamol.types",
"Atom": "datamol.types",
"Bond": "datamol.types",
"ColorTuple": "datamol.types",
"DatamolColor": "datamol.types",
"RDKitColor": "datamol.types",
# utils
"parallelized": "datamol.utils",
"parallelized_with_batches": "datamol.utils",
Expand Down Expand Up @@ -209,7 +210,8 @@ def __dir__():
from .types import ChemicalReaction
from .types import Atom
from .types import Bond
from .types import ColorTuple
from .types import DatamolColor
from .types import RDKitColor

from . import utils

Expand Down
3 changes: 2 additions & 1 deletion datamol/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
Atom: TypeAlias = Chem.rdchem.Atom
Bond: TypeAlias = Chem.rdchem.Bond

ColorTuple = Union[Tuple[float, float, float, float], Tuple[float, float, float]]
RDKitColor = Union[Tuple[float, float, float, float], Tuple[float, float, float]]
DatamolColor = Union[RDKitColor, str]
43 changes: 36 additions & 7 deletions datamol/viz/_circle_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from .utils import drawer_to_image
from .utils import prepare_mol_for_drawing
from .utils import image_to_file
from datamol.types import ColorTuple
from .utils import to_rdkit_color

from datamol.types import DatamolColor
from datamol.types import Mol
from datamol.types import RDKitColor

import datamol as dm

Expand All @@ -27,12 +30,15 @@ def circle_grid(
ring_scaler: float = 1.0,
align: Optional[Union[Mol, str, bool]] = None,
use_svg: bool = True,
ring_color: Optional[ColorTuple] = None,
ring_color: Optional[DatamolColor] = None,
ring_mol_start_angles_degrees: Optional[List[float]] = None,
center_mol_highlight_atoms: Optional[List[int]] = None,
center_mol_highlight_bonds: Optional[List[int]] = None,
ring_mol_highlight_atoms: Optional[List[List[int]]] = None,
ring_mol_highlight_bonds: Optional[List[List[int]]] = None,
outfile: Optional[str] = None,
kekulize: bool = True,
layout_random_seed: Optional[int] = 19,
**kwargs: Any,
):
"""Show molecules in concentric rings, with one molecule at the center
Expand All @@ -54,6 +60,10 @@ def circle_grid(
ring_mol_highlight_atoms: List of list of atom indices to highlight for molecules at each level of the concentric rings
ring_mol_highlight_bonds: List of list of bond indices to highlight for molecules at each level of the concentric rings
ring_color: Color of the concentric rings. Set to None to not draw any ring.
ring_mol_start_angles_degrees: List of angles in degrees to start drawing the molecules at each level of the concentric
rings. If None then a random position will be used.
kekulize: Whether to kekulize the molecules before drawing.
layout_random_seed: Random seed for the layout of the molecules. Set to None for no seed.
outfile: Optional path to the save the output file.
**kwargs: Additional arguments to pass to the drawing function. See RDKit
documentation related to `MolDrawOptions` for more details at
Expand All @@ -71,10 +81,13 @@ def circle_grid(
align=align,
use_svg=use_svg,
ring_color=ring_color,
ring_mol_start_angles_degrees=ring_mol_start_angles_degrees,
center_mol_highlight_atoms=center_mol_highlight_atoms,
center_mol_highlight_bonds=center_mol_highlight_bonds,
ring_mol_highlight_atoms=ring_mol_highlight_atoms,
ring_mol_highlight_bonds=ring_mol_highlight_bonds,
kekulize=kekulize,
layout_random_seed=layout_random_seed,
**kwargs,
)
return grid(outfile=outfile)
Expand All @@ -92,12 +105,14 @@ def __init__(
align: Optional[Union[Mol, str, bool]] = None,
use_svg: bool = True,
line_width: Optional[float] = None,
ring_color: Optional[ColorTuple] = None,
ring_color: Optional[DatamolColor] = None,
ring_mol_start_angles_degrees: Optional[List[float]] = None,
center_mol_highlight_atoms: Optional[List[int]] = None,
center_mol_highlight_bonds: Optional[List[int]] = None,
ring_mol_highlight_atoms: Optional[List[List[int]]] = None,
ring_mol_highlight_bonds: Optional[List[List[int]]] = None,
kekulize: bool = True,
layout_random_seed: Optional[int] = 19,
**kwargs: Any,
):
"""Show molecules in concentric rings, with one molecule at the center
Expand All @@ -121,7 +136,10 @@ def __init__(
ring_mol_highlight_atoms: List of list of atom indices to highlight for molecules at each level of the concentric rings
ring_mol_highlight_bonds: List of list of bond indices to highlight for molecules at each level of the concentric rings
ring_color: Color of the concentric rings. Set to None to not draw any ring.
kekulize: Whether to kekulize the molecules before drawing
ring_mol_start_angles_degrees: List of angles in degrees to start drawing the molecules at each level of the concentric
rings. If None then a random position will be used.
kekulize: Whether to kekulize the molecules before drawing.
layout_random_seed: Random seed for the layout of the molecules. Set to None for no seed.
**kwargs: Additional arguments to pass to the drawing function. See RDKit
documentation related to `MolDrawOptions` for more details at
https://www.rdkit.org/docs/source/rdkit.Chem.Draw.rdMolDraw2D.html.
Expand All @@ -141,11 +159,14 @@ def __init__(
self.use_svg = use_svg
self.line_width = line_width
self.ring_color = ring_color
self.ring_mol_start_angles_degrees = ring_mol_start_angles_degrees
self.ring_color_rdkit: Optional[RDKitColor] = to_rdkit_color(ring_color)
self.ring_mol_highlight_atoms = ring_mol_highlight_atoms
self.ring_mol_highlight_bonds = ring_mol_highlight_bonds
self.center_mol_highlight_atoms = center_mol_highlight_atoms
self.center_mol_highlight_bonds = center_mol_highlight_bonds
self.kekulize = kekulize
self.layout_random_seed = layout_random_seed
self._global_legend_size = 0
if self.legend is not None:
self._global_legend_size = max(25, self.margin)
Expand Down Expand Up @@ -276,12 +297,20 @@ def draw(self):
highlight_atom=self.center_mol_highlight_atoms,
highlight_bond=self.center_mol_highlight_bonds,
)

rng = random.Random(self.layout_random_seed)

# draw the ring mols
self.draw_options.scalingFactor *= self.ring_scaler
for i, mols in enumerate(self.ring_mols):
radius = radius_list[i]
ni = len(mols)
rand_unit = random.random() * 2 * math.pi

if self.ring_mol_start_angles_degrees is not None:
rand_unit = np.deg2rad(self.ring_mol_start_angles_degrees[i])
else:
rand_unit = rng.random() * 2 * math.pi

for k, mol in enumerate(mols):
center_x = radius * math.cos(2 * k * math.pi / ni + rand_unit) + self.midpoint.x
center_y = radius * math.sin(2 * k * math.pi / ni + rand_unit) + self.midpoint.y
Expand Down Expand Up @@ -323,8 +352,8 @@ def _draw_circles(self):
for _, radius in enumerate(full_range):
radius += self.margin // 2
if radius > self.margin:
if self.ring_color is not None:
self.canvas.SetColour(self.ring_color)
if self.ring_color_rdkit is not None:
self.canvas.SetColour(self.ring_color_rdkit)
self.canvas.DrawArc(self.midpoint, radius, 0, 360, rawCoords=True)
radius_list.append(radius + radius_step)
return radius_list
Expand Down
27 changes: 18 additions & 9 deletions datamol/viz/_lasso_highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
import numpy as np
import datamol as dm

from datamol.types import ColorTuple
from datamol.types import RDKitColor
from datamol.types import DatamolColor

from .utils import drawer_to_image
from .utils import prepare_mol_for_drawing
from .utils import to_rdkit_color


def _angle_to_coord(center: np.ndarray, angle: float, radius: float) -> np.ndarray:
Expand Down Expand Up @@ -185,7 +188,7 @@ def _draw_substructurematch(
rel_radius: float = 0.3,
rel_width: float = 0.5,
line_width: int = 2,
color: Optional[ColorTuple] = None,
color: Optional[RDKitColor] = None,
offset: Optional[Tuple[int, int]] = None,
) -> None:
"""Draws the substructure defined by (atom-) `indices`, as lasso-highlight onto `canvas`.
Expand All @@ -208,7 +211,7 @@ def _draw_substructurematch(
# # Default color is gray.
if not color:
color = (0.5, 0.5, 0.5, 1)
canvas.SetColour(color)
canvas.SetColour(tuple(color))

# Selects first conformer and calculates the mean bond length
conf = mol.GetConformer(0)
Expand Down Expand Up @@ -311,7 +314,7 @@ def _draw_multi_matches(
r_min: float = 0.3,
r_dist: float = 0.13,
relative_bond_width: float = 0.5,
color_list: Optional[List[ColorTuple]] = None,
color_list: Optional[List[DatamolColor]] = None,
line_width: int = 2,
offset: Optional[Tuple[int, int]] = None,
):
Expand All @@ -334,7 +337,6 @@ def _draw_multi_matches(
else:
_color_list = color_list

print(len(_color_list), len(indices_set_lists))
if len(_color_list) < len(indices_set_lists):
colors_to_add = []
for i in range(len(indices_set_lists) - len(_color_list)):
Expand Down Expand Up @@ -366,7 +368,7 @@ def _draw_multi_matches(
match_atoms,
rel_radius=ar,
rel_width=max(relative_bond_width, ar),
color=color,
color=to_rdkit_color(color),
line_width=line_width,
offset=offset,
)
Expand Down Expand Up @@ -394,7 +396,7 @@ def lasso_highlight_image(
r_min: float = 0.3,
r_dist: float = 0.13,
relative_bond_width: float = 0.5,
color_list: Optional[List[ColorTuple]] = None,
color_list: Optional[List[DatamolColor]] = None,
line_width: int = 2,
scale_padding: float = 1.0,
verbose: bool = False,
Expand Down Expand Up @@ -438,6 +440,7 @@ def lasso_highlight_image(
for i, search_mol in enumerate(search_molecules):
if isinstance(search_mol, str):
search_molecules[i] = dm.from_smarts(search_mol)

if search_molecules[i] is None or not isinstance(search_molecules[i], dm.Mol):
raise ValueError(
f"Please enter valid search molecules or smarts: {search_molecules[i]}"
Expand All @@ -462,11 +465,17 @@ def lasso_highlight_image(
## Step 1: setup drawer and canvas
if use_svg:
drawer = rdMolDraw2D.MolDraw2DSVG(
mol_size[0] * n_cols, mol_size[1] * n_rows, mol_size[0], mol_size[1]
mol_size[0] * n_cols,
mol_size[1] * n_rows,
mol_size[0],
mol_size[1],
)
else:
drawer = rdMolDraw2D.MolDraw2DCairo(
mol_size[0] * n_cols, mol_size[1] * n_rows, mol_size[0], mol_size[1]
mol_size[0] * n_cols,
mol_size[1] * n_rows,
mol_size[0],
mol_size[1],
)

# Setting the drawing options
Expand Down
29 changes: 27 additions & 2 deletions datamol/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
import fsspec

from rdkit.Chem import Draw
from matplotlib import colors as mcolors

import PIL.Image
import PIL.PngImagePlugin

import datamol as dm

from datamol.types import RDKitColor
from datamol.types import DatamolColor


def prepare_mol_for_drawing(mol: Optional[dm.Mol], kekulize: bool = True) -> Optional[dm.Mol]:
"""Prepare the molecule before drawing to avoid any error due to unsanitized molecule
Expand Down Expand Up @@ -94,7 +100,14 @@ def drawer_to_image(drawer: Draw.rdMolDraw2D.MolDraw2D):


def image_to_file(
image: Union[str, PIL.PngImagePlugin.PngImageFile, bytes], outfile, as_svg: bool = False
image: Union[
str,
PIL.PngImagePlugin.PngImageFile,
bytes,
PIL.Image.Image,
],
outfile,
as_svg: bool = False,
):
"""Save image to file. The image can be either a PNG or SVG depending
Expand All @@ -115,7 +128,19 @@ def image_to_file(
else:
if isinstance(image, PIL.PngImagePlugin.PngImageFile): # type: ignore
# in a terminal process
image.save(f)
image.save(f) # type: ignore
else:
# in a jupyter kernel process
f.write(image.data) # type: ignore


def to_rdkit_color(color: Optional[DatamolColor]) -> Optional[RDKitColor]:
"""If required convert a datamol color (rgb, rgba or hex string) to an RDKit
color (rgb or rgba).
Args:
color: A datamol color: hex, rgb, rgba or None.
"""
if isinstance(color, str):
return mcolors.to_rgba(color) # type: ignore
return color
38 changes: 38 additions & 0 deletions tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,44 @@ def test_circle_grid(tmp_path):
)


@pytest.mark.skipif(
not dm.is_greater_than_current_rdkit_version("2023.03"),
reason="Circle Grid requires rdkit>2022.09",
)
def test_circle_grid_with_hex_color(tmp_path):
mol = dm.to_mol("CC(=O)OC1=CC=CC=C1C(=O)O")
dm.viz.circle_grid(
mol,
[
[dm.to_mol("CCC"), dm.to_mol("CCCCCCC")],
[dm.to_mol("CCCO"), dm.to_mol("CCCCCCCO")],
],
ring_color="#ff1472",
layout_random_seed=None,
)


@pytest.mark.skipif(
not dm.is_greater_than_current_rdkit_version("2023.03"),
reason="Circle Grid requires rdkit>2022.09",
)
def test_circle_grid_with_angle_start(tmp_path):
mol = dm.to_mol("CC(=O)OC1=CC=CC=C1C(=O)O")
dm.viz.circle_grid(
mol,
[
[dm.to_mol("CCC"), dm.to_mol("CCCCCCC"), dm.to_mol("CCCCCO")],
[
dm.to_mol("CCCO"),
],
],
# ring_color=(0, 0, 0, 0.5),
ring_color="#ff1472aa",
layout_random_seed=19,
ring_mol_start_angles_degrees=[90, 90],
)


def test_to_image_align():
# Get a list of molecules
data = dm.data.freesolv()
Expand Down
9 changes: 9 additions & 0 deletions tests/test_viz_lasso_highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,12 @@ def test_atom_indices_list():
search_molecules=None,
atom_indices=[4, 5, 6],
)


def test_with_hex_color():
dm.viz.lasso_highlight_image(
"CC(N)Cc1c[nH]c2ccc3c(c12)CCCO3",
search_molecules=None,
atom_indices=[4, 5, 6],
color_list=["#ff1472"],
)

0 comments on commit 639e684

Please sign in to comment.