diff --git a/docs/tutorials/load-from-wandb.ipynb b/docs/tutorials/load-from-wandb.ipynb new file mode 100644 index 0000000..1aefd8f --- /dev/null +++ b/docs/tutorials/load-from-wandb.ipynb @@ -0,0 +1,322 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from safe.sample import SAFEDesign" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/emmanuel.noutahi/miniconda3/envs/safe/lib/python3.12/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "model = SAFEDesign.load_default()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Upload models to wandb\n", + "\n", + "SAFE models can be uploaded to wandb with the `upload_to_wandb` function. You can define a general \"SAFE_WANDB_PROJECT\" env variable to save all of your models to that project. \n", + "\n", + "Make sure that you are login into your wandb account:\n", + "\n", + "```bash\n", + "wandb login --relogin $WANDB_API_KEY\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from safe.io import upload_to_wandb" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: WANDB_SILENT=False\n", + "env: SAFE_WANDB_PROJECT=safe-models\n", + "[2024-09-10 13:42:46,004] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to mps (auto detect)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0910 13:42:46.257000 8343047168 torch/distributed/elastic/multiprocessing/redirects.py:28] NOTE: Redirects are currently not supported in Windows or MacOs.\n", + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmaclandrol\u001b[0m (\u001b[33mvalencelabs\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "wandb version 0.17.9 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.6" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /Users/emmanuel.noutahi/Code/safe/nb/wandb/run-20240910_134247-72wmn5st" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run absurd-disco-1 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/valencelabs/safe-models" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/valencelabs/safe-models/runs/72wmn5st" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (/var/folders/rl/wwcfdj4x0pg293bfqszl970r0000gq/T/tmpy8r3mrzg)... Done. 0.6s\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a9466f1c60dc43d18aceca8b74f753f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='333.221 MB of 333.221 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run absurd-disco-1 at: https://wandb.ai/valencelabs/safe-models/runs/72wmn5st
View project at: https://wandb.ai/valencelabs/safe-models
Synced 6 W&B file(s), 0 media file(s), 5 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240910_134247-72wmn5st/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%env WANDB_SILENT=False\n", + "%env SAFE_WANDB_PROJECT=safe-models\n", + "\n", + "upload_to_wandb(model.model, model.tokenizer, artifact_name=\"default-safe-zinc\", slicer=\"BRICS/Partition\", aliases=[\"paper\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading models from wandb" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: SAFE_MODEL_ROOT=/Users/emmanuel.noutahi/.cache/wandb/safe/\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact default-safe-zinc:latest, 333.22MB. 5 files... \n", + "\u001b[34m\u001b[1mwandb\u001b[0m: 5 of 5 files downloaded. \n", + "Done. 0:0:0.7\n" + ] + } + ], + "source": [ + "%env SAFE_MODEL_ROOT=/Users/emmanuel.noutahi/.cache/wandb/safe/\n", + "designer = SAFEDesign.load_from_wandb(\"safe-models/default-safe-zinc\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "96de1483b3894961a8ec2690df4f8ace", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:001` or unset `early_stopping`.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "['C[C@]1(C(=O)N2CCC[C@@H](NC(=O)C#CC3CC3)CC2)CCNC1=O',\n", + " 'CN(C(=O)CN1CC[NH+](C[C@@H](O)Cn2cc([N+](=O)[O-])cn2)CC1)c1ccccc1',\n", + " 'CC[C@@H](C)C[C@@H]([NH3+])C(=O)N(CC)C[C@@H]1CCOC1',\n", + " 'Cc1nnc2n1C[C@H](CNC(=O)Nc1cc(Cl)ccc1Cl)CC2',\n", + " 'CCc1cccc(CC)c1NC(=O)[C@H](C)OC(=O)CCc1nc2ccccc2o1',\n", + " 'Cc1cc(OC[C@H](O)C[NH2+]C[C@@H]2C[C@H](O)CN2Cc2ccccc2)ccc1F',\n", + " 'Cc1c(Cl)cccc1N=C(O)CN=C(O)COC(=O)c1csc(-c2ccccc2)n1',\n", + " 'CCc1nc(CCNC(=O)N[C@@H]2CCc3nnnn3CC2)cs1',\n", + " 'C[C@@]1(C(=O)N[C@H]2CCCCCN(C(=O)c3cc(C4CC4)no3)C2)C=CCC1',\n", + " 'Cc1cc(-c2cc(-c3cnn(C)c3)c3c(N)ncnc3n2)ccc1F']" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "designer.de_novo_generation(10)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "safe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/safe/__init__.py b/safe/__init__.py index 4bbdbc3..1ad2f92 100644 --- a/safe/__init__.py +++ b/safe/__init__.py @@ -4,3 +4,4 @@ from .sample import SAFEDesign from .tokenizer import SAFETokenizer, split from .viz import to_image +from .io import upload_to_wandb diff --git a/safe/io.py b/safe/io.py new file mode 100644 index 0000000..4986008 --- /dev/null +++ b/safe/io.py @@ -0,0 +1,82 @@ +from typing import Optional, List + +import tempfile +import os +import contextlib +import torch +import wandb +import fsspec + +from transformers import PreTrainedModel, is_torch_available +from transformers.processing_utils import PushToHubMixin + + +def upload_to_wandb( + model: PreTrainedModel, + tokenizer, + artifact_name: str, + wandb_project_name: Optional[str] = "safe-models", + artifact_type: str = "model", + slicer: Optional[str] = None, + aliases: Optional[List[str]] = None, + **init_args, +): + """ + Uploads a model and tokenizer to a specified Weights and Biases (wandb) project. + + Args: + model (PreTrainedModel): The model to be uploaded (instance of PreTrainedModel). + tokenizer: The tokenizer associated with the model. + artifact_name (str): The name of the wandb artifact to create. + wandb_project_name (Optional[str]): The name of the wandb project. Defaults to 'safe-model'. + artifact_type (str): The type of artifact (e.g., 'model'). Defaults to 'model'. + slicer (Optional[str]): Optional metadata field that can store a slicing method. + aliases (Optional[List[str]]): List of aliases to assign to this artifact version. + **init_args: Additional arguments to pass into `wandb.init()`. + """ + + with tempfile.TemporaryDirectory() as tmpdirname: + # Paths to save model and tokenizer + model_path = tokenizer_path = tmpdirname + architecture_file = os.path.join(tmpdirname, "architecture.txt") + with fsspec.open(architecture_file, "w+") as f: + f.write(str(model)) + + model.save_pretrained(model_path) + with contextlib.suppress(Exception): + tokenizer.save_pretrained(tokenizer_path) + tokenizer.save(os.path.join(tokenizer_path, "tokenizer.json")) + + info_dict = {"slicer": slicer} + model_config = None + if hasattr(model, "config") and model.config is not None: + model_config = ( + model.config.to_dict() if not isinstance(model.config, dict) else model.config + ) + info_dict.update(model_config) + + if hasattr(model, "peft_config") and model.peft_config is not None: + info_dict.update({"peft_config": model.peft_config}) + + with contextlib.suppress(Exception): + info_dict["model/num_parameters"] = model.num_parameters() + + init_args.setdefault("config", info_dict) + run = wandb.init(project=os.getenv("SAFE_WANDB_PROJECT", wandb_project_name), **init_args) + + artifact = wandb.Artifact( + name=artifact_name, + type=artifact_type, + metadata={ + "model_config": model_config, + "num_parameters": info_dict.get("model/num_parameters"), + "initial_model": True, + }, + ) + + # Add model and tokenizer directories to the artifact + artifact.add_dir(tmpdirname) + run.log_artifact(artifact, aliases=aliases) + + # Finish the wandb run + run.finish() diff --git a/safe/sample.py b/safe/sample.py index a44698c..e25ff09 100644 --- a/safe/sample.py +++ b/safe/sample.py @@ -9,7 +9,9 @@ import datamol as dm import torch +import tempfile from loguru import logger +from pathlib import Path from tqdm.auto import tqdm from transformers import GenerationConfig from transformers.generation import DisjunctiveConstraint, PhrasalConstraint @@ -18,6 +20,8 @@ from safe.tokenizer import SAFETokenizer from safe.trainer.model import SAFEDoubleHeadsModel +# modify this code to allow loading the model from wandb + class SAFEDesign: """Molecular generation using SAFE pretrained model""" @@ -59,6 +63,7 @@ def __init__( safe_encoder: custom safe encoder to use verbose: whether to print out logging information during generation """ + if isinstance(model, (str, os.PathLike)): model = SAFEDoubleHeadsModel.from_pretrained(model) @@ -82,9 +87,78 @@ def __init__( self.verbose = verbose self.safe_encoder = safe_encoder or sf.SAFEConverter() + @classmethod + def load_from_wandb( + cls, artifact_path: str, device: Optional[str] = None, verbose: bool = True, **kwargs: Any + ) -> "SAFEDesign": + """ + Load SAFE model and tokenizer from a Weights and Biases (wandb) artifact. By default, the model will be downloaded into SAFE_MODEL_ROOT. + + Args: + artifact_path: The path to the wandb artifact in the format `entity/project/artifact:version`. + device: The device where the model should be loaded ('cpu' or 'cuda'). If None, it defaults to the available device. + verbose: Whether to print out logging information during generation. + + Returns: + SAFEDesign: An instance of SAFEDesign class with the model, tokenizer, and generation config loaded from wandb. + """ + # EN: potentially remove wandb scheme + import wandb + + artifact_path = artifact_path.replace("wandb://", "") + + # Parse the artifact path to extract project and artifact name + parts = artifact_path.split("/", 1) + if len(parts) > 1: + project_name, artifact_name = parts + else: + project_name = os.getenv("SAFE_WANDB_PROJECT", "safe-models") + artifact_name = artifact_path + + if ":" not in artifact_name: + artifact_name += ":latest" + + artifact_path = f"{project_name}/{artifact_name}" + + # Check if SAFE_MODEL_ROOT environment variable is defined + cache_path = os.getenv("SAFE_MODEL_ROOT", None) + if cache_path is not None: + # Ensure the cache path exists + cache_path = Path(cache_path) + cache_path.mkdir(parents=True, exist_ok=True) + artifact_subfolder = artifact_path.replace("/", "_").replace(":", "_") + cache_dir = cache_path / artifact_subfolder + cache_path = cache_dir.as_posix() + + api = wandb.Api() + # Download the artifact from wandb to the cache directory + artifact = api.artifact(artifact_path, type="model") + artifact_dir = artifact.download(root=cache_path) + + # Load the model, tokenizer, and generation config from the artifact directory + model = SAFEDoubleHeadsModel.from_pretrained(artifact_dir) + tokenizer = SAFETokenizer.from_pretrained(artifact_dir) + gen_config = GenerationConfig.from_pretrained(artifact_dir) + + # Move model to the specified device if provided + if device is not None: + model = model.to(device) + + return cls( + model=model, + tokenizer=tokenizer, + generation_config=gen_config, + verbose=verbose, + **kwargs, + ) + @classmethod def load_default( - cls, verbose: bool = False, model_dir: Optional[str] = None, device: str = None + cls, + model_dir: Optional[str] = None, + device: str = None, + verbose: bool = False, + **kwargs: Any, ) -> "SAFEDesign": """Load default SAFEGenerator model @@ -93,6 +167,7 @@ def load_default( model_dir: Optional path to model folder to use instead of the default one. If provided the tokenizer should be in the model_dir named as `tokenizer.json` device: optional device where to move the model + kwargs: any additional argument to pass to the init function """ if model_dir is None or not model_dir: model_dir = cls._DEFAULT_MODEL_PATH @@ -101,7 +176,13 @@ def load_default( gen_config = GenerationConfig.from_pretrained(model_dir) if device is not None: model = model.to(device) - return cls(model=model, tokenizer=tokenizer, generation_config=gen_config, verbose=verbose) + return cls( + model=model, + tokenizer=tokenizer, + generation_config=gen_config, + verbose=verbose, + **kwargs, + ) def linker_generation( self, diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index 802135d..355c795 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -12,7 +12,7 @@ NOTEBOOK_PATHS = list(filter(lambda x: x.name not in DISABLE_NOTEBOOKS, NOTEBOOK_PATHS)) # Discard some notebooks -NOTEBOOKS_TO_DISCARD = ["extracting-representation-molfeat.ipynb"] +NOTEBOOKS_TO_DISCARD = ["extracting-representation-molfeat.ipynb", "load-from-wandb.ipynb"] NOTEBOOK_PATHS = list(filter(lambda x: x.name not in NOTEBOOKS_TO_DISCARD, NOTEBOOK_PATHS))