diff --git a/medchem/structural/_common.py b/medchem/structural/_common.py index 5a62305..1a5135a 100644 --- a/medchem/structural/_common.py +++ b/medchem/structural/_common.py @@ -165,8 +165,9 @@ def __call__( progress: bool = False, progress_leave: bool = False, scheduler: str = "auto", + batch_size: Optional[int] = None, keep_details: bool = False, - ): + ) -> pd.DataFrame: """Run alert evaluation on this list of molecule and return the full dataframe Args: @@ -175,6 +176,7 @@ def __call__( progress: whether to show progress or not. progress_leave: whether to leave the progress bar or not. scheduler: which scheduler to use. If "auto", will use "processes" if `len(mols) > 500` else "threads". + batch_size: batch size to use for parallelization. keep_details: whether to keep the details of the evaluation or not. """ @@ -183,18 +185,31 @@ def __call__( scheduler = "processes" # pragma: no cover else: scheduler = "threads" - - results = dm.parallelized( - functools.partial(self._evaluate, keep_details=keep_details), - mols, - progress=progress, - n_jobs=n_jobs, - scheduler=scheduler, - tqdm_kwargs=dict( - desc="Common alerts filtering", - leave=progress_leave, - ), - ) + if batch_size: + results = dm.parallelized_with_batches( + lambda batch: [functools.partial(self._evaluate, keep_details=keep_details)(mol) for mol in batch], + mols, + progress=progress, + n_jobs=n_jobs, + scheduler=scheduler, + batch_size=batch_size, + tqdm_kwargs=dict( + desc="Common alerts filtering", + leave=progress_leave, + ), + ) + else: + results = dm.parallelized( + functools.partial(self._evaluate, keep_details=keep_details), + mols, + progress=progress, + n_jobs=n_jobs, + scheduler=scheduler, + tqdm_kwargs=dict( + desc="Common alerts filtering", + leave=progress_leave, + ), + ) results = pd.DataFrame(results) return results diff --git a/pyproject.toml b/pyproject.toml index dc2cbcd..248aaa5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,17 +87,20 @@ omit = ["medchem/__init__.py", "medchem/_version.py"] output = "coverage.xml" [tool.ruff] +line-length = 110 +target-version = "py311" +extend-exclude = ["*.ipynb"] # Exclude Jupyter notebooks + +[tool.ruff.lint] ignore = [ "E501", # Never enforce `E501` (line length violations). ] -line-length = 110 -target-version = "py311" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "__init__.py" = [ "F401", # imported but unused "E402", # Module level import not at top of file ] -[tool.ruff.pycodestyle] +[tool.ruff.lint.pycodestyle] max-doc-length = 150