diff --git a/README.md b/README.md
index 62af505..603d938 100644
--- a/README.md
+++ b/README.md
@@ -5,6 +5,13 @@ Democratizing Biomolecular Interaction Modeling
![](docs/boltz1_pred_figure.png)
+Boltz-1 is an open-source model which predicts the 3D structure of proteins, RNA, DNA an
Boltz-1:
+
+Democratizing Biomolecular Interaction Modeling
+
+
+![](docs/boltz1_pred_figure.png)
+
Boltz-1 is an open-source model which predicts the 3D structure of proteins, RNA, DNA and small molecules; it handles modified residues, covalent ligands and glycans, as well as condition the generation on pocket residues.
For more information about the model, see our [technical report](https://gcorso.github.io/assets/boltz1.pdf).
@@ -24,6 +31,97 @@ cd boltz; pip install -e .
```
> Note: we recommend installing boltz in a fresh python environment
+## Inference on Colab
+
+You can run inference using the Boltz-1 model directly in Google Colab. Click the link below to open the notebook:
+
+[Google Colab](https://colab.research.google.com/drive/1B_xSLii8ZMvp4K4XL2zslQVbKIjWV-mp?usp=sharing)
+
+## Inference
+
+You can run inference using Boltz-1 with:
+
+```
+boltz predict input_path
+```
+
+Boltz currently accepts three input formats:
+
+1. Fasta file, for most use cases
+
+2. A comprehensive YAML schema, for more complex use cases
+
+3. A directory containing files of the above formats, for batched processing
+
+To see all available options: `boltz predict --help` and for more informaton on these input formats, see our [prediction instructions](docs/prediction.md).
+
+## Training
+
+If you're interested in retraining the model, see our [training instructions](docs/training.md).
+
+## Contributing
+
+We welcome external contributions and are eager to engage with the community. Connect with us on our [Slack channel](https://join.slack.com/t/boltz-community/shared_invite/zt-2uexwkemv-Tqt9E747hVkE0VOWlgOcIw) to discuss advancements, share insights, and foster collaboration around Boltz-1.
+
+## Coming very soon
+
+- [x] Auto-generated MSAs using MMseqs2
+- [x] More examples
+- [ ] Support for custom paired MSA
+- [ ] Confidence model checkpoint
+- [ ] Pocket conditioning support
+- [ ] Full data processing pipeline
+- [x] Colab notebook for inference
+- [ ] Kernel integration
+
+## License
+
+Our model and code are released under MIT License, and can be freely used for both academic and commercial purposes.
+
+
+## Cite
+
+If you use this code or the models in your research, please cite the following paper:
+
+```bibtex
+@article{wohlwend2024boltz1,
+ author = {Wohlwend, Jeremy and Corso, Gabriele and Passaro, Saro and Reveiz, Mateo and Leidal, Ken and Swiderski, Wojtek and Portnoi, Tally and Chinn, Itamar and Silterra, Jacob and Jaakkola, Tommi and Barzilay, Regina},
+ title = {Boltz-1: Democratizing Biomolecular Interaction Modeling},
+ year = {2024},
+ doi = {10.1101/2024.11.19.624167},
+ journal = {bioRxiv}
+}
+```
+
+In addition if you use the automatic MSA generation, please cite:
+
+```bibtex
+@article{mirdita2022colabfold,
+ title={ColabFold: making protein folding accessible to all},
+ author={Mirdita, Milot and Sch{\"u}tze, Konstantin and Moriwaki, Yoshitaka and Heo, Lim and Ovchinnikov, Sergey and Steinegger, Martin},
+ journal={Nature methods},
+ year={2022},
+}
+```
+d small molecules; it handles modified residues, covalent ligands and glycans, as well as condition the generation on pocket residues.
+
+For more information about the model, see our [technical report](https://gcorso.github.io/assets/boltz1.pdf).
+
+## Installation
+Install boltz with PyPI (recommended):
+
+```
+pip install boltz
+```
+
+or directly from GitHub for daily updates:
+
+```
+git clone https://github.com/jwohlwend/boltz.git
+cd boltz; pip install -e .
+```
+> Note: we recommend installing boltz in a fresh python environment
+
## Inference
You can run inference using Boltz-1 with:
@@ -58,7 +156,7 @@ We welcome external contributions and are eager to engage with the community. Co
- [ ] Confidence model checkpoint
- [ ] Pocket conditioning support
- [ ] Full data processing pipeline
-- [ ] Colab notebook for inference
+- [x] Colab notebook for inference
- [ ] Kernel integration
## License
diff --git a/inference.ipynb b/inference.ipynb
new file mode 100644
index 0000000..d5bb947
--- /dev/null
+++ b/inference.ipynb
@@ -0,0 +1,345 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "machine_shape": "hm",
+ "gpuType": "T4"
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Input protein sequence(s), then hit `Runtime` -> `Run all`\n",
+ "from google.colab import files\n",
+ "import os\n",
+ "import re\n",
+ "import hashlib\n",
+ "import random\n",
+ "import requests\n",
+ "from string import ascii_uppercase\n",
+ "\n",
+ "# Function to add a hash to the jobname\n",
+ "def add_hash(x, y):\n",
+ " return x + \"_\" + hashlib.sha1(y.encode()).hexdigest()[:5]\n",
+ "\n",
+ "# User inputs\n",
+ "query_sequence = 'PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASKK' #@param {type:\"string\"}\n",
+ "#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n",
+ "ligand_input = 'N[C@@H](Cc1ccc(O)cc1)C(=O)O' #@param {type:\"string\"}\n",
+ "#@markdown - Use `:` to specify multiple ligands as smile strings\n",
+ "ligand_input_ccd = 'SAH' #@param {type:\"string\"}\n",
+ "#@markdown - Use `:` to specify multiple ligands as CCD codes (three-letter codes)\n",
+ "ligand_input_common_name = '' #@param {type:\"string\"}\n",
+ "#@markdown - Use `:` to specify multiple ligands with their common name (e.g. Aspirin; SMILES fetched from [PubChem](https://pubchem.ncbi.nlm.nih.gov) API)\n",
+ "dna_input = '' #@param {type:\"string\"}\n",
+ "#@markdown - Use `:` to specify multiple DNA sequences\n",
+ "jobname = 'test' #@param {type:\"string\"}\n",
+ "\n",
+ "# Clean up the query sequence and jobname\n",
+ "query_sequence = \"\".join(query_sequence.split())\n",
+ "ligand_input = \"\".join(ligand_input.split())\n",
+ "ligand_input_ccd = \"\".join(ligand_input_ccd.split())\n",
+ "ligand_input_common_name = \"\".join(ligand_input_common_name.split())\n",
+ "dna_input = \"\".join(dna_input.split())\n",
+ "basejobname = \"\".join(jobname.split())\n",
+ "basejobname = re.sub(r'\\W+', '', basejobname)\n",
+ "jobname = add_hash(basejobname, query_sequence)\n",
+ "\n",
+ "# Check if a directory with jobname exists\n",
+ "def check(folder):\n",
+ " return not os.path.exists(folder)\n",
+ "\n",
+ "if not check(jobname):\n",
+ " n = 0\n",
+ " while not check(f\"{jobname}_{n}\"):\n",
+ " n += 1\n",
+ " jobname = f\"{jobname}_{n}\"\n",
+ "\n",
+ "# Make directory to save results\n",
+ "os.makedirs(jobname, exist_ok=True)\n",
+ "\n",
+ "from string import ascii_uppercase\n",
+ "\n",
+ "# Split sequences on chain breaks\n",
+ "protein_sequences = query_sequence.strip().split(':') if query_sequence.strip() else []\n",
+ "ligand_sequences = ligand_input.strip().split(':') if ligand_input.strip() else []\n",
+ "ligand_sequences_ccd = ligand_input_ccd.strip().split(':') if ligand_input_ccd.strip() else []\n",
+ "ligand_sequences_common_name = ligand_input_common_name.strip().split(':') if ligand_input_common_name.strip() else []\n",
+ "dna_sequences = dna_input.strip().split(':') if dna_input.strip() else []\n",
+ "\n",
+ "def get_smiles(compound_name):\n",
+ " autocomplete_url = f\"https://pubchem.ncbi.nlm.nih.gov/rest/autocomplete/compound/{compound_name}/json?limit=1\"\n",
+ " autocomplete_response = requests.get(autocomplete_url)\n",
+ " if autocomplete_response.status_code != 200:\n",
+ " return None\n",
+ "\n",
+ " autocomplete_data = autocomplete_response.json()\n",
+ " if autocomplete_data.get(\"status\", {}).get(\"code\") != 0 or autocomplete_data.get(\"total\", 0) == 0:\n",
+ " return None\n",
+ "\n",
+ " suggested_compound = autocomplete_data.get(\"dictionary_terms\", {}).get(\"compound\", [])\n",
+ " if not suggested_compound:\n",
+ " return None\n",
+ " suggested_compound_name = suggested_compound[0]\n",
+ "\n",
+ " smiles_url = f\"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{suggested_compound_name}/property/CanonicalSMILES/JSON\"\n",
+ " smiles_response = requests.get(smiles_url)\n",
+ " if smiles_response.status_code != 200:\n",
+ " return None\n",
+ "\n",
+ " smiles_data = smiles_response.json()\n",
+ " properties = smiles_data.get(\"PropertyTable\", {}).get(\"Properties\", [])\n",
+ " if len(properties) == 0:\n",
+ " return None\n",
+ "\n",
+ " return properties[0].get(\"CanonicalSMILES\")\n",
+ "\n",
+ "smiles_cache = {}\n",
+ "for name in ligand_sequences_common_name:\n",
+ " if name not in smiles_cache:\n",
+ " smiles_cache[name] = get_smiles(name)\n",
+ " if smiles_cache[name] is not None:\n",
+ " print(f\"Mapped compound {name} to {smiles_cache[name]}\")\n",
+ "\n",
+ " if smiles_cache[name] is not None:\n",
+ " ligand_sequences.append(smiles_cache[name])\n",
+ "\n",
+ "# Initialize chain labels starting from 'A'\n",
+ "chain_labels = iter(ascii_uppercase)\n",
+ "\n",
+ "fasta_entries = []\n",
+ "csv_entries = []\n",
+ "chain_label_to_seq_id = {}\n",
+ "seq_to_seq_id = {}\n",
+ "seq_id_counter = 0 # Counter for unique sequences\n",
+ "\n",
+ "# Process protein sequences\n",
+ "for seq in protein_sequences:\n",
+ " seq = seq.strip()\n",
+ " if not seq:\n",
+ " continue # Skip empty sequences\n",
+ " chain_label = next(chain_labels)\n",
+ " # Check if sequence has been seen before\n",
+ " if seq in seq_to_seq_id:\n",
+ " seq_id = seq_to_seq_id[seq]\n",
+ " else:\n",
+ " seq_id = f\"{jobname}_{seq_id_counter}\"\n",
+ " seq_to_seq_id[seq] = seq_id\n",
+ " seq_id_counter += 1\n",
+ " # For CSV file (for ColabFold), add only unique sequences\n",
+ " csv_entries.append((seq_id, seq))\n",
+ " chain_label_to_seq_id[chain_label] = seq_id\n",
+ " # For FASTA file\n",
+ " msa_path = os.path.join(jobname, f\"{seq_id}.a3m\")\n",
+ " header = f\">{chain_label}|protein|{msa_path}\"\n",
+ " sequence = seq\n",
+ " fasta_entries.append((header, sequence))\n",
+ "\n",
+ "# Process ligand sequences (assumed to be SMILES strings)\n",
+ "for lig in ligand_sequences:\n",
+ " lig = lig.strip()\n",
+ " if not lig:\n",
+ " continue # Skip empty ligands\n",
+ " chain_label = next(chain_labels)\n",
+ " lig_type = 'smiles'\n",
+ " header = f\">{chain_label}|{lig_type}\"\n",
+ " sequence = lig\n",
+ " fasta_entries.append((header, sequence))\n",
+ "\n",
+ "# Process DNA sequences (NO MSA is generated)\n",
+ "for seq in dna_sequences:\n",
+ " seq = seq.strip()\n",
+ " if not seq:\n",
+ " continue # Skip empty sequences\n",
+ " chain_label = next(chain_labels)\n",
+ " lig_type = 'DNA'\n",
+ " header = f\">{chain_label}|{lig_type}\"\n",
+ " sequence = seq\n",
+ " fasta_entries.append((header, sequence))\n",
+ "\n",
+ "# Process ligand sequences (CCD codes)\n",
+ "for lig in ligand_sequences_ccd:\n",
+ " lig = lig.strip()\n",
+ " if not lig:\n",
+ " continue # Skip empty ligands\n",
+ " chain_label = next(chain_labels)\n",
+ " lig_type = 'ccd'\n",
+ " header = f\">{chain_label}|{lig_type}\"\n",
+ " sequence = lig.upper() # Ensure CCD codes are uppercase\n",
+ " fasta_entries.append((header, sequence))\n",
+ "\n",
+ "# Write the CSV file for ColabFold\n",
+ "queries_path = os.path.join(jobname, f\"{jobname}.csv\")\n",
+ "with open(queries_path, \"w\") as text_file:\n",
+ " text_file.write(\"id,sequence\\n\")\n",
+ " for seq_id, seq in csv_entries:\n",
+ " text_file.write(f\"{seq_id},{seq}\\n\")\n",
+ "\n",
+ "# Write the FASTA file\n",
+ "queries_fasta = os.path.join(jobname, f\"{jobname}.fasta\")\n",
+ "with open(queries_fasta, 'w') as f:\n",
+ " for header, sequence in fasta_entries:\n",
+ " f.write(f\"{header}\\n{sequence}\\n\")\n",
+ "\n",
+ "# Optionally, print the output for verification\n",
+ "#print(f\"Generated FASTA file '{queries_fasta}':\\n\")\n",
+ "#for header, sequence in fasta_entries:\n",
+ "# print(f\"{header}\\n{sequence}\\n\")"
+ ],
+ "metadata": {
+ "id": "AcYvVeDESi2a"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4eXNO1JJHYrB",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "37bcda2e-8e7d-4f0e-8d20-f433b3fa324f"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "CPU times: user 235 ms, sys: 53.8 ms, total: 289 ms\n",
+ "Wall time: 1min 4s\n"
+ ]
+ }
+ ],
+ "source": [
+ "#@title Install dependencies\n",
+ "%%time\n",
+ "import os\n",
+ "\n",
+ "if not os.path.isfile(\"BOLZ_READY\"):\n",
+ " os.system(\"apt-get install -y aria2\")\n",
+ " os.system(\"pip install -q --no-warn-conflicts boltz\")\n",
+ " os.system(\"mkdir weights\")\n",
+ " os.system(\"aria2c -d weights -x8 -s8 https://colabfold.steineggerlab.workers.dev/boltz1.ckpt\")\n",
+ " os.system(\"aria2c -d weights -x8 -s8 https://colabfold.steineggerlab.workers.dev/ccd.pkl\")\n",
+ " os.system(\"touch BOLZ_READY\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Generate MSA with ColabFold"
+ ],
+ "metadata": {
+ "id": "2rtccHy3E0sg"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "if not os.path.isfile(\"COLABFOLD_READY\"):\n",
+ " print(\"installing colabfold...\")\n",
+ " os.system(\"pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'\")\n",
+ " if os.environ.get('TPU_NAME', False) != False:\n",
+ " os.system(\"pip uninstall -y jax jaxlib\")\n",
+ " os.system(\"pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\")\n",
+ " os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold\")\n",
+ " os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold\")\n",
+ " os.system(\"touch COLABFOLD_READY\")"
+ ],
+ "metadata": {
+ "id": "bMhGBWZUEm-8"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!colabfold_batch \"{queries_path}\" \"{jobname}\" --msa-only"
+ ],
+ "metadata": {
+ "id": "4aFDR4IhRe6y"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Predict structure"
+ ],
+ "metadata": {
+ "id": "UpznJl7JE-AS"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!boltz predict --out_dir \"{jobname}\" \"{jobname}/{jobname}.fasta\" --cache weights"
+ ],
+ "metadata": {
+ "id": "bgaBXxXtIAu9"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Download results\n",
+ "# Import necessary modules\n",
+ "import os\n",
+ "import zipfile\n",
+ "from google.colab import files\n",
+ "import glob\n",
+ "\n",
+ "# Ensure 'jobname' variable is defined\n",
+ "# jobname = 'test_abcde' # Uncomment and set if not already defined\n",
+ "\n",
+ "# Name of the zip file\n",
+ "zip_filename = f\"results_{jobname}.zip\"\n",
+ "\n",
+ "# Create a zip file and add the specified files without preserving directory structure\n",
+ "with zipfile.ZipFile(zip_filename, 'w') as zipf:\n",
+ " coverage_png_files = glob.glob(os.path.join(jobname, '*_coverage.png'))\n",
+ " a3m_files = glob.glob(os.path.join(jobname, '*.a3m'))\n",
+ " for file in coverage_png_files + a3m_files:\n",
+ " arcname = os.path.basename(file) # Use only the file name\n",
+ " zipf.write(file, arcname=arcname)\n",
+ "\n",
+ " cif_files = glob.glob(os.path.join(jobname, f'boltz_results_{jobname}', 'predictions', jobname, '*.cif'))\n",
+ " for file in cif_files:\n",
+ " arcname = os.path.basename(file) # Use only the file name\n",
+ " zipf.write(file, arcname=arcname)\n",
+ "\n",
+ " hparams_file = os.path.join(jobname, f'boltz_results_{jobname}', 'lightning_logs', 'version_0', 'hparams.yaml')\n",
+ " if os.path.exists(hparams_file):\n",
+ " arcname = os.path.basename(hparams_file) # Use only the file name\n",
+ " zipf.write(hparams_file, arcname=arcname)\n",
+ " else:\n",
+ " print(f\"Warning: {hparams_file} not found.\")\n",
+ "\n",
+ "# Download the zip file\n",
+ "files.download(zip_filename)"
+ ],
+ "metadata": {
+ "id": "jdSBSTOpaULF"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}