Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unifold on custom a3m MSA files #146

Open
FranceCosta opened this issue May 3, 2024 · 0 comments
Open

Unifold on custom a3m MSA files #146

FranceCosta opened this issue May 3, 2024 · 0 comments

Comments

@FranceCosta
Copy link

Hi!

I think UniFold symmetry is a great tool!

I am running an AlphaFold-UniFold sym pipeline that consists of the following steps:

  1. AlphaFold is run on a proteome to find homodimers;
  2. UniFold symmetry is run with different circular simmetries on the candidates that show homodimeric interactions;

To save computational resources and time, I have created a script that runs UniFold symmetry on MSA a3m files generated by AlphaFold. It is based on your Colab implementation. I would like to make sure what I am doing makes sense. Do you think this is valid?

#! /usr/env/python
# -*- coding: utf-8 -*-

"""
    Customised script to run UniFold with custom .a3m file 
    (from https://colab.research.google.com/github/dptech-corp/Uni-Fold/blob/main/notebooks/unifold.ipynb)
    Examble: python bin/run_custom_msa.py \
        --MSA output/UP000246259/MSA/A0A1B5FPA6/bfd_uniref_hits.a3m \
        --homooligomers_num 10 --output output/UP000246259/unifold/A0A1B5FPA6/C10
    Francesco Costa [email protected] 
    EMBL_EBI 01/04/2024
"""

import numpy as np
from typing import Dict, List, Union, Any
import argparse
from Bio import SeqIO
import os
import pickle
import gzip

from unifold.msa import templates
from unifold.msa import pipeline
from unifold.data.protein import PDB_CHAIN_IDS
from unifold.msa import parsers
from unifold.data.utils import compress_features
from unifold.colab.model import colab_inference

def createDir(dirPath: str) -> str:
    """Cretes directory if non existing"""
    os.makedirs(dirPath, exist_ok=True)
    return dirPath

parser = argparse.ArgumentParser(
    description="Create UniFold-like input files for unifold sym when provided with custom .a3m"

)

parser.add_argument(
    "--MSA", 
    required=True, 
    help="Path to MSA file in .a3m format", 
    type=str
)

parser.add_argument(
    "--homooligomers_num", 
    required=True, 
    help="Number of holigomeric chains for C sym", 
    type=int
)

parser.add_argument(
    "--output", 
    required=True, 
    help="Output dir", 
    type=createDir
)

parser.add_argument(
    "--params_dir", 
    required=False, 
    help="UniFold params dir", 
    type=str,
    default="/nfs/research/agb/research/francesco/software/Uni-Fold/"
)

parser.add_argument(
    "--max_recycling_iters", 
    required=False, 
    help="", 
    type=int,
    default=3
)

parser.add_argument(
    "--num_ensembles", 
    required=False, 
    help="", 
    type=int,
    default=2
)

parser.add_argument(
    "--times", 
    required=False, 
    help="", 
    type=int,
    default=1
)

parser.add_argument(
    "--manual_seed", 
    required=False, 
    help="", 
    type=int,
    default=42
)

def main():
    # Prepare
    args = parser.parse_args()
    unique_sequences = [[str(seq.seq) for seq in SeqIO.parse(args.MSA, "fasta")][0]]
    is_multimer = False
    symmetry_group = f"C{args.homooligomers_num}"
    target_id = "run"
    # Get features
    (
    unpaired_msa,
    paired_msa,
    template_results,
    ) = get_msa_and_templates(a3m_file=args.MSA, 
                            unique_sequences=unique_sequences, 
                            homooligomers_num=args.homooligomers_num)

    for idx, seq in enumerate(unique_sequences):
        chain_id = PDB_CHAIN_IDS[idx]
        sequence_features = pipeline.make_sequence_features(
                sequence=seq, description=f'>seq {chain_id}', num_res=len(seq)
            )
        monomer_msa = parsers.parse_a3m(unpaired_msa[idx])
        msa_features = pipeline.make_msa_features([monomer_msa])
        template_features = template_results[idx]
        feature_dict = {**sequence_features, **msa_features, **template_features}
        feature_dict = compress_features(feature_dict)
        features_output_path = os.path.join(
                args.output, "{}.feature.pkl.gz".format(chain_id)
            )
        #print(feature_dict.keys())
        pickle.dump(
            feature_dict, 
            gzip.GzipFile(features_output_path, "wb"), 
            protocol=4
            )
    
    # Create chain file
    with open(os.path.join(args.output, "chains.txt"), "w") as f:
        f.write("A")

    # Run prediction
    best_result = colab_inference(
        target_id=target_id,
        data_dir=args.output,
        param_dir=args.params_dir,
        output_dir=args.output,
        symmetry_group=symmetry_group,
        is_multimer=is_multimer,
        max_recycling_iters=args.max_recycling_iters,
        num_ensembles=args.num_ensembles,
        times=args.times,
        manual_seed=args.manual_seed,
        device="cuda:0",
    )

def get_msa_and_templates(a3m_file:str, unique_sequences:list, homooligomers_num:int):
    """
    
        Inspired to original unifold get_msa_and_templates()
    
    """

    n = 101
    # Keep naming of original function
    x = unique_sequences
    query_seqs_unique = unique_sequences

    # Parse MSA
    seqs = [x] if isinstance(x, str) else x
    seqs_unique = []
    [seqs_unique.append(x) for x in seqs if x not in seqs_unique]
    # gather a3m lines
    a3m_lines = []
    for line in open(a3m_file,"r"):
        if len(line) > 0:
            if "\x00" in line:
                line = line.replace("\x00","")
            a3m_lines.append(line)
    a3m_lines = ["".join(a3m_lines)]

    # Create paired_a3m_lines output
    paired_a3m_lines = []
    for i in range(0, homooligomers_num):
        paired_a3m_lines.append(
            ">" + str(n + i) + "\n" + query_seqs_unique[0] + "\n"
        )

    # Get null templates (for the moment)
    template_features = []
    for index in range(0, len(query_seqs_unique)):
            template_feature = get_null_template(query_seqs_unique[index])
            template_features.append(template_feature)
    
    # This corresponds to unpaired_msa, paired_msa, template_results
    return (a3m_lines, paired_a3m_lines, template_features)


def get_null_template(
    query_sequence: Union[List[str], str], num_temp: int = 1
) -> Dict[str, Any]:
    ln = (
        len(query_sequence)
        if isinstance(query_sequence, str)
        else sum(len(s) for s in query_sequence)
    )
    output_templates_sequence = "A" * ln
    output_confidence_scores = np.full(ln, 1.0)

    templates_all_atom_positions = np.zeros(
        (ln, templates.residue_constants.atom_type_num, 3)
    )
    templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num))
    templates_aatype = templates.residue_constants.sequence_to_onehot(
        output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID
    )
    template_features = {
        "template_all_atom_positions": np.tile(
            templates_all_atom_positions[None], [num_temp, 1, 1, 1]
        ),
        "template_all_atom_masks": np.tile(
            templates_all_atom_masks[None], [num_temp, 1, 1]
        ),
        "template_sequence": [f"none".encode()] * num_temp,
        "template_aatype": np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]),
        "template_domain_names": [f"none".encode()] * num_temp,
        "template_sum_probs": np.zeros([num_temp], dtype=np.float32),
    }
    return template_features

if __name__ == "__main__":
    main()

Thank you for your time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant