Skip to content

Commit

Permalink
Merge pull request #222 from kkovary/221-max_mols-fix
Browse files Browse the repository at this point in the history
pass max_mols to Draw.MolsToGridImage
  • Loading branch information
maclandrol authored Feb 2, 2024
2 parents 4fbb047 + ec2d34d commit cc3cc36
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/code-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:

- name: Install black
run: |
pip install black>=23
pip install black>=24
- name: Lint
run: black --check .
Expand Down
2 changes: 1 addition & 1 deletion binder/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies:
- pytest >=6.0
- pytest-cov
- pytest-xdist
- black >=23
- black >=24
- jupyterlab
- mypy
- codecov
Expand Down
41 changes: 15 additions & 26 deletions datamol/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,15 @@ def open_datamol_data_file(


@overload
def freesolv(as_df: Literal[True] = True) -> pd.DataFrame:
...
def freesolv(as_df: Literal[True] = True) -> pd.DataFrame: ...


@overload
def freesolv(as_df: Literal[False] = False) -> List[Mol]:
...
def freesolv(as_df: Literal[False] = False) -> List[Mol]: ...


@overload
def freesolv(as_df: bool = True) -> Union[List[Mol], pd.DataFrame]:
...
def freesolv(as_df: bool = True) -> Union[List[Mol], pd.DataFrame]: ...


def freesolv(as_df: bool = True) -> Union[List[Mol], pd.DataFrame]:
Expand All @@ -102,18 +99,17 @@ def freesolv(as_df: bool = True) -> Union[List[Mol], pd.DataFrame]:


@overload
def cdk2(as_df: Literal[True] = True, mol_column: Optional[str] = "mol") -> pd.DataFrame:
...
def cdk2(as_df: Literal[True] = True, mol_column: Optional[str] = "mol") -> pd.DataFrame: ...


@overload
def cdk2(as_df: Literal[False] = False, mol_column: Optional[str] = "mol") -> List[Mol]:
...
def cdk2(as_df: Literal[False] = False, mol_column: Optional[str] = "mol") -> List[Mol]: ...


@overload
def cdk2(as_df: bool = True, mol_column: Optional[str] = "mol") -> Union[List[Mol], pd.DataFrame]:
...
def cdk2(
as_df: bool = True, mol_column: Optional[str] = "mol"
) -> Union[List[Mol], pd.DataFrame]: ...


def cdk2(as_df: bool = True, mol_column: Optional[str] = "mol"):
Expand All @@ -130,20 +126,17 @@ def cdk2(as_df: bool = True, mol_column: Optional[str] = "mol"):


@overload
def solubility(as_df: Literal[True] = True, mol_column: Optional[str] = "mol") -> pd.DataFrame:
...
def solubility(as_df: Literal[True] = True, mol_column: Optional[str] = "mol") -> pd.DataFrame: ...


@overload
def solubility(as_df: Literal[False] = False, mol_column: Optional[str] = "mol") -> List[Mol]:
...
def solubility(as_df: Literal[False] = False, mol_column: Optional[str] = "mol") -> List[Mol]: ...


@overload
def solubility(
as_df: bool = True, mol_column: Optional[str] = "mol"
) -> Union[List[Mol], pd.DataFrame]:
...
) -> Union[List[Mol], pd.DataFrame]: ...


def solubility(as_df: bool = True, mol_column: Optional[str] = "mol"):
Expand Down Expand Up @@ -184,13 +177,11 @@ def solubility(as_df: bool = True, mol_column: Optional[str] = "mol"):


@overload
def chembl_drugs(as_df: Literal[True] = True) -> pd.DataFrame:
...
def chembl_drugs(as_df: Literal[True] = True) -> pd.DataFrame: ...


@overload
def chembl_drugs(as_df: Literal[False] = False) -> List[Mol]:
...
def chembl_drugs(as_df: Literal[False] = False) -> List[Mol]: ...


def chembl_drugs(as_df: bool = True) -> Union[List[Mol], pd.DataFrame]:
Expand All @@ -210,13 +201,11 @@ def chembl_drugs(as_df: bool = True) -> Union[List[Mol], pd.DataFrame]:


@overload
def chembl_samples(as_df: Literal[True] = True) -> pd.DataFrame:
...
def chembl_samples(as_df: Literal[True] = True) -> pd.DataFrame: ...


@overload
def chembl_samples(as_df: Literal[False] = False) -> List[Mol]:
...
def chembl_samples(as_df: Literal[False] = False) -> List[Mol]: ...


def chembl_samples(as_df: bool = True) -> Union[List[Mol], pd.DataFrame]:
Expand Down
9 changes: 3 additions & 6 deletions datamol/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def read_sdf(
max_num_mols: Optional[int] = ...,
discard_invalid: bool = ...,
n_jobs: Optional[int] = ...,
) -> List[Mol]:
...
) -> List[Mol]: ...


@overload
Expand All @@ -134,8 +133,7 @@ def read_sdf(
max_num_mols: Optional[int] = ...,
discard_invalid: bool = ...,
n_jobs: Optional[int] = ...,
) -> pd.DataFrame:
...
) -> pd.DataFrame: ...


@overload
Expand All @@ -152,8 +150,7 @@ def read_sdf(
max_num_mols: Optional[int] = ...,
discard_invalid: bool = ...,
n_jobs: Optional[int] = ...,
) -> Union[List[Mol], pd.DataFrame]:
...
) -> Union[List[Mol], pd.DataFrame]: ...


def read_sdf(
Expand Down
15 changes: 15 additions & 0 deletions datamol/viz/_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple
from typing import Optional
from typing import Any
from loguru import logger

from rdkit.Chem import Draw

Expand All @@ -22,6 +23,7 @@ def to_image(
highlight_bond: Optional[List[List[int]]] = None,
outfile: Optional[str] = None,
max_mols: int = 32,
max_mols_ipython: int = 50,
copy: bool = True,
indices: bool = False,
bond_indices: bool = False,
Expand All @@ -44,6 +46,7 @@ def to_image(
highlight_bond: The bonds to highlight.
outfile: Path where to save the image (local or remote path).
max_mols: The maximum number of molecules to display.
max_mols_ipython: The maximum number of molecules to display when running within an IPython environment.
copy: Whether to copy the molecules or not.
indices: Whether to draw the atom indices.
bond_indices: Whether to draw the bond indices.
Expand Down Expand Up @@ -120,6 +123,18 @@ def to_image(
else:
_kwargs[k] = v

# Check if we are in a Jupyter notebook or IPython display context
# If so, conditionally add the maxMols argument
in_notebook = dm.viz.utils.is_ipython_session()

if in_notebook:
_kwargs["maxMols"] = max_mols_ipython
if max_mols > max_mols_ipython:
logger.warning(
f"You have set max_mols to {max_mols}, which is higher than max_mols_ipython ({max_mols_ipython}). "
"Consider increasing max_mols_ipython if you want to display all molecules in an IPython environment."
)

image = Draw.MolsToGridImage(
mols,
legends=legends,
Expand Down
2 changes: 1 addition & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies:
- pytest >=6.0
- pytest-cov
- pytest-xdist
- black >=23
- black >=24
- ruff
- jupyterlab
- mypy
Expand Down

0 comments on commit cc3cc36

Please sign in to comment.