Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
hadim committed Oct 27, 2023
1 parent 7af6cfd commit af88fcc
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 374 deletions.
90 changes: 46 additions & 44 deletions docs/tutorials/design-with-safe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,49 @@
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"!!! info\n",
"\n",
" Pretrained checkpoint for autoregressive only models are available at : `gs://valence-experiments/safe/generative-model/`"
"%autoreload 2\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n",
"\n",
"import safe as sf\n",
"import datamol as dm\n",
"from safe.trainer.data_utils import get_dataset\n",
"from safe import SAFEDesign"
"import datamol as dm\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "HFValidationError",
"evalue": "Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/home/hadim/.cache/safe/default_model'. Use `repo_type` argument if needed.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mHFValidationError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/hadim/Code/valence/Libs/safe/docs/tutorials/design-with-safe.ipynb Cell 3\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/hadim/Code/valence/Libs/safe/docs/tutorials/design-with-safe.ipynb#W4sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m designer \u001b[39m=\u001b[39m sf\u001b[39m.\u001b[39;49mSAFEDesign\u001b[39m.\u001b[39;49mload_default(verbose\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
"File \u001b[0;32m~/Code/valence/Libs/safe/safe/sample.py:89\u001b[0m, in \u001b[0;36mSAFEDesign.load_default\u001b[0;34m(cls, verbose, model_dir, device)\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[39mif\u001b[39;00m model_dir \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mor\u001b[39;00m \u001b[39mnot\u001b[39;00m model_dir:\n\u001b[1;32m 88\u001b[0m model_dir \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m_DEFAULT_MODEL_PATH\n\u001b[0;32m---> 89\u001b[0m model \u001b[39m=\u001b[39m SAFEDoubleHeadsModel\u001b[39m.\u001b[39;49mfrom_pretrained(model_dir)\n\u001b[1;32m 90\u001b[0m tokenizer \u001b[39m=\u001b[39m SAFETokenizer\u001b[39m.\u001b[39mload(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(model_dir, \u001b[39m\"\u001b[39m\u001b[39mtokenizer.json\u001b[39m\u001b[39m\"\u001b[39m))\n\u001b[1;32m 91\u001b[0m gen_config \u001b[39m=\u001b[39m GenerationConfig\u001b[39m.\u001b[39mfrom_pretrained(model_dir)\n",
"File \u001b[0;32m~/local/micromamba/envs/safe/lib/python3.11/site-packages/transformers/modeling_utils.py:2507\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 2504\u001b[0m \u001b[39mif\u001b[39;00m commit_hash \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 2505\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 2506\u001b[0m \u001b[39m# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible\u001b[39;00m\n\u001b[0;32m-> 2507\u001b[0m resolved_config_file \u001b[39m=\u001b[39m cached_file(\n\u001b[1;32m 2508\u001b[0m pretrained_model_name_or_path,\n\u001b[1;32m 2509\u001b[0m CONFIG_NAME,\n\u001b[1;32m 2510\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 2511\u001b[0m force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m 2512\u001b[0m resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m 2513\u001b[0m proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m 2514\u001b[0m local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m 2515\u001b[0m token\u001b[39m=\u001b[39;49mtoken,\n\u001b[1;32m 2516\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 2517\u001b[0m subfolder\u001b[39m=\u001b[39;49msubfolder,\n\u001b[1;32m 2518\u001b[0m _raise_exceptions_for_missing_entries\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[1;32m 2519\u001b[0m _raise_exceptions_for_connection_errors\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[1;32m 2520\u001b[0m )\n\u001b[1;32m 2521\u001b[0m commit_hash \u001b[39m=\u001b[39m extract_commit_hash(resolved_config_file, commit_hash)\n\u001b[1;32m 2522\u001b[0m \u001b[39melse\u001b[39;00m:\n",
"File \u001b[0;32m~/local/micromamba/envs/safe/lib/python3.11/site-packages/transformers/utils/hub.py:429\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 426\u001b[0m user_agent \u001b[39m=\u001b[39m http_user_agent(user_agent)\n\u001b[1;32m 427\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 428\u001b[0m \u001b[39m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 429\u001b[0m resolved_file \u001b[39m=\u001b[39m hf_hub_download(\n\u001b[1;32m 430\u001b[0m path_or_repo_id,\n\u001b[1;32m 431\u001b[0m filename,\n\u001b[1;32m 432\u001b[0m subfolder\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m \u001b[39mif\u001b[39;49;00m \u001b[39mlen\u001b[39;49m(subfolder) \u001b[39m==\u001b[39;49m \u001b[39m0\u001b[39;49m \u001b[39melse\u001b[39;49;00m subfolder,\n\u001b[1;32m 433\u001b[0m repo_type\u001b[39m=\u001b[39;49mrepo_type,\n\u001b[1;32m 434\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 435\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 436\u001b[0m user_agent\u001b[39m=\u001b[39;49muser_agent,\n\u001b[1;32m 437\u001b[0m force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m 438\u001b[0m proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m 439\u001b[0m resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m 440\u001b[0m token\u001b[39m=\u001b[39;49mtoken,\n\u001b[1;32m 441\u001b[0m local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m 442\u001b[0m )\n\u001b[1;32m 443\u001b[0m \u001b[39mexcept\u001b[39;00m GatedRepoError \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 444\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 445\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mYou are trying to access a gated repo.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mMake sure to request access at \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 446\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mhttps://huggingface.co/\u001b[39m\u001b[39m{\u001b[39;00mpath_or_repo_id\u001b[39m}\u001b[39;00m\u001b[39m and pass a token having permission to this repo either \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 447\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mby logging in with `huggingface-cli login` or by passing `token=<your_token>`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 448\u001b[0m ) \u001b[39mfrom\u001b[39;00m \u001b[39me\u001b[39;00m\n",
"File \u001b[0;32m~/local/micromamba/envs/safe/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:110\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[39mfor\u001b[39;00m arg_name, arg_value \u001b[39min\u001b[39;00m chain(\n\u001b[1;32m 106\u001b[0m \u001b[39mzip\u001b[39m(signature\u001b[39m.\u001b[39mparameters, args), \u001b[39m# Args values\u001b[39;00m\n\u001b[1;32m 107\u001b[0m kwargs\u001b[39m.\u001b[39mitems(), \u001b[39m# Kwargs values\u001b[39;00m\n\u001b[1;32m 108\u001b[0m ):\n\u001b[1;32m 109\u001b[0m \u001b[39mif\u001b[39;00m arg_name \u001b[39min\u001b[39;00m [\u001b[39m\"\u001b[39m\u001b[39mrepo_id\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mfrom_id\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mto_id\u001b[39m\u001b[39m\"\u001b[39m]:\n\u001b[0;32m--> 110\u001b[0m validate_repo_id(arg_value)\n\u001b[1;32m 112\u001b[0m \u001b[39melif\u001b[39;00m arg_name \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mtoken\u001b[39m\u001b[39m\"\u001b[39m \u001b[39mand\u001b[39;00m arg_value \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 113\u001b[0m has_token \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
"File \u001b[0;32m~/local/micromamba/envs/safe/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:158\u001b[0m, in \u001b[0;36mvalidate_repo_id\u001b[0;34m(repo_id)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[39mraise\u001b[39;00m HFValidationError(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mRepo id must be a string, not \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mtype\u001b[39m(repo_id)\u001b[39m}\u001b[39;00m\u001b[39m: \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mrepo_id\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 157\u001b[0m \u001b[39mif\u001b[39;00m repo_id\u001b[39m.\u001b[39mcount(\u001b[39m\"\u001b[39m\u001b[39m/\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[0;32m--> 158\u001b[0m \u001b[39mraise\u001b[39;00m HFValidationError(\n\u001b[1;32m 159\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mRepo id must be in the form \u001b[39m\u001b[39m'\u001b[39m\u001b[39mrepo_name\u001b[39m\u001b[39m'\u001b[39m\u001b[39m or \u001b[39m\u001b[39m'\u001b[39m\u001b[39mnamespace/repo_name\u001b[39m\u001b[39m'\u001b[39m\u001b[39m:\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 160\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mrepo_id\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m. Use `repo_type` argument if needed.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 161\u001b[0m )\n\u001b[1;32m 163\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m REPO_ID_REGEX\u001b[39m.\u001b[39mmatch(repo_id):\n\u001b[1;32m 164\u001b[0m \u001b[39mraise\u001b[39;00m HFValidationError(\n\u001b[1;32m 165\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mRepo id must use alphanumeric chars or \u001b[39m\u001b[39m'\u001b[39m\u001b[39m-\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m_\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m--\u001b[39m\u001b[39m'\u001b[39m\u001b[39m and \u001b[39m\u001b[39m'\u001b[39m\u001b[39m..\u001b[39m\u001b[39m'\u001b[39m\u001b[39m are\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 166\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m forbidden, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m-\u001b[39m\u001b[39m'\u001b[39m\u001b[39m and \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m'\u001b[39m\u001b[39m cannot start or end the name, max length is 96:\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 167\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mrepo_id\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 168\u001b[0m )\n",
"\u001b[0;31mHFValidationError\u001b[0m: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/home/hadim/.cache/safe/default_model'. Use `repo_type` argument if needed."
]
}
],
"source": [
"designer = SAFEDesign.load_default(verbose=True)"
"designer = sf.SAFEDesign.load_default(verbose=True)\n"
]
},
{
Expand Down Expand Up @@ -182,7 +184,7 @@
}
],
"source": [
"dm.to_image(dm.to_mol(candidate_mol))"
"dm.to_image(dm.to_mol(candidate_mol))\n"
]
},
{
Expand All @@ -191,7 +193,7 @@
"metadata": {},
"outputs": [],
"source": [
"N_SAMPLES = 100"
"N_SAMPLES = 100\n"
]
},
{
Expand All @@ -217,7 +219,7 @@
}
],
"source": [
"generated = designer.de_novo_generation(sanitize=True, n_samples_per_trial=N_SAMPLES)"
"generated = designer.de_novo_generation(sanitize=True, n_samples_per_trial=N_SAMPLES)\n"
]
},
{
Expand Down Expand Up @@ -1631,7 +1633,7 @@
}
],
"source": [
"dm.to_image(generated[:20])"
"dm.to_image(generated[:20])\n"
]
},
{
Expand Down Expand Up @@ -1708,7 +1710,7 @@
}
],
"source": [
"dm.to_image(scaffold)"
"dm.to_image(scaffold)\n"
]
},
{
Expand All @@ -1725,7 +1727,7 @@
}
],
"source": [
"generated = designer.scaffold_decoration(scaffold=scaffold, n_samples_per_trial=N_SAMPLES, n_trials=2, sanitize=True, do_not_fragment_further=True)"
"generated = designer.scaffold_decoration(scaffold=scaffold, n_samples_per_trial=N_SAMPLES, n_trials=2, sanitize=True, do_not_fragment_further=True)\n"
]
},
{
Expand Down Expand Up @@ -5994,7 +5996,7 @@
}
],
"source": [
"dm.viz.lasso_highlight_image([dm.to_mol(x) for x in generated[:20]], dm.from_smarts(scaffold))"
"dm.viz.lasso_highlight_image([dm.to_mol(x) for x in generated[:20]], dm.from_smarts(scaffold))\n"
]
},
{
Expand Down Expand Up @@ -6059,7 +6061,7 @@
}
],
"source": [
"dm.to_image(superstructure)"
"dm.to_image(superstructure)\n"
]
},
{
Expand All @@ -6077,7 +6079,7 @@
],
"source": [
"generated = designer.super_structure(core=superstructure, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, attachment_point_depth=3)\n",
"#generated"
"#generated\n"
]
},
{
Expand Down Expand Up @@ -7225,7 +7227,7 @@
}
],
"source": [
"dm.to_image(generated[:20])"
"dm.to_image(generated[:20])\n"
]
},
{
Expand Down Expand Up @@ -7277,7 +7279,7 @@
}
],
"source": [
"dm.to_image(motif)"
"dm.to_image(motif)\n"
]
},
{
Expand All @@ -7295,7 +7297,7 @@
],
"source": [
"# let's make some long sequence\n",
"generated = designer.motif_extension(motif=motif, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, min_length=25, max_length=80)"
"generated = designer.motif_extension(motif=motif, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, min_length=25, max_length=80)\n"
]
},
{
Expand Down Expand Up @@ -8451,7 +8453,7 @@
}
],
"source": [
"dm.to_image(generated[:20])"
"dm.to_image(generated[:20])\n"
]
},
{
Expand Down Expand Up @@ -8541,7 +8543,7 @@
}
],
"source": [
"dm.to_image(side_chains)"
"dm.to_image(side_chains)\n"
]
},
{
Expand Down Expand Up @@ -10006,7 +10008,7 @@
],
"source": [
"generated = designer.scaffold_morphing(side_chains=side_chains, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, random_seed=100)\n",
"dm.to_image(generated[:20])"
"dm.to_image(generated[:20])\n"
]
},
{
Expand Down Expand Up @@ -10112,7 +10114,7 @@
}
],
"source": [
"dm.to_image(linker_generation)"
"dm.to_image(linker_generation)\n"
]
},
{
Expand Down Expand Up @@ -12320,7 +12322,7 @@
],
"source": [
"generated = designer.linker_generation(*linker_generation, n_samples_per_trial=N_SAMPLES, n_trials=1, sanitize=True, do_not_fragment_further=False, random_seed=100)\n",
"dm.to_image(generated[:20])"
"dm.to_image(generated[:20])\n"
]
}
],
Expand All @@ -12340,7 +12342,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.6"
},
"orig_nbformat": 4
},
Expand Down
Loading

0 comments on commit af88fcc

Please sign in to comment.