diff --git a/CITATION.cff b/CITATION.cff index d67648c..b2f6422 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -6,5 +6,5 @@ authors: orcid: https://orcid.org/0000-0001-5299-1983 license: MIT repository-code: https://github.com/ogkdmr/ragamp -title: RagAmp +title: RagAmpq url: https://ogkdmr.github.io/ragamp/ diff --git a/examples/antimicrobial_peptides.txt b/examples/antimicrobial_peptides.txt new file mode 100644 index 0000000..bba6e77 --- /dev/null +++ b/examples/antimicrobial_peptides.txt @@ -0,0 +1,47 @@ +Amoebapore A +BACTENECIN 5 +CCL20 +DEFB118 +Drosomycin +Eotaxin2 +Gm cecropin A +Human alphasynuclein +Human granulysin +Microcin B +Microcin S +NLP31 +Amoebapore B +BACTENECIN 7 +CXCL2 +DEFB24 +Drosomycin2 +Eotaxin3 +Gm cecropin B +Human beta defensin 2 +Human histatin 9 +Microcin C7 +Microcin V +Peptide 2 +Amoebapore C +CAP18 +CXCL3 +Defensin 1 +Drosophila cecropin B +EP2 +Gm cecropin C +Human beta defensin 3 +Human TC2 +Microcin L +NLP27 +Peptide 5 +Bactenecin +Cathepsin G +CXCL6 +Dermcidin +Elafin +FGG +Gm defensinlike peptide +Human beta defensin 4 +LL23 +Microcin M +NLP29 diff --git a/pyproject.toml b/pyproject.toml index b4057cc..0fe81c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,18 @@ classifiers = [ "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: Implementation :: CPython", ] + dependencies = [ - "requests", + "requests >= 2.31.0", + "transformers >= 4.37.1", + "tokenizers >= 0.15.1", + "llama-index >= 0.9.36", + "langchain >= 0.1.3", + "llama-index-llms-huggingface", + "llama-index-readers-file", + "bitsandbytes >= 0.42.0", + "sentence_transformers >= 2.2.2", + "llama-index-embeddings-langchain" ] [project.urls] diff --git a/ragamp/process_json_result.py b/ragamp/process_json_result.py new file mode 100644 index 0000000..7a6ee78 --- /dev/null +++ b/ragamp/process_json_result.py @@ -0,0 +1,13 @@ +"""Initial code for reading the content of the json formatted RAG response.""" + +from __future__ import annotations + +import json + +with open('data/query_responses_strains.json') as f: + q2r = json.load(f) + for k, v in q2r.items(): + print('Query AMP: ', k) + print() + print(' Response: ', v) + print() diff --git a/ragamp/pubmed_rag.py b/ragamp/pubmed_rag.py new file mode 100644 index 0000000..6e9b675 --- /dev/null +++ b/ragamp/pubmed_rag.py @@ -0,0 +1,104 @@ +"""Code for Building and querying a RAG vector store index using an LLM. + +This module contains code for querying a vector store index using a +language model and generating responses. It uses the HuggingFace library for +language model and tokenizer, and the llama_index library for vector store +index operations. The module also includes code for creating and loading the +index from storage, as well as saving the query responses to a JSON file. +""" + +from __future__ import annotations + +import json +import logging +import os +import os.path as osp +import sys + +import torch +from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings +from llama_index.core import SimpleDirectoryReader +from llama_index.core import VectorStoreIndex +from llama_index.core.indices.loading import load_index_from_storage +from llama_index.core.prompts.base import PromptTemplate +from llama_index.core.storage.storage_context import StorageContext +from llama_index.llms.huggingface import HuggingFaceLLM +from tqdm import tqdm + +# os.environ["HF_HOME"] = "/lus/eagle/projects/LUCID/ogokdemir/hf_cache" +from transformers import BitsAndBytesConfig + +logging.basicConfig(stream=sys.stdout, level=logging.INFO) +logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) + + +quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_quant_type='nf4', + bnb_4bit_use_double_quant=True, +) + +llm = HuggingFaceLLM( + model_name='mistralai/Mistral-7B-Instruct-v0.1', + tokenizer_name='mistralai/Mistral-7B-Instruct-v0.1', + query_wrapper_prompt=PromptTemplate( + '<s>[INST] {query_str} [/INST] </s>\n', + ), + context_window=3900, + max_new_tokens=256, + model_kwargs={'quantization_config': quantization_config}, + # tokenizer_kwargs={}, + generate_kwargs={ + 'temperature': 0.2, + 'top_k': 5, + 'top_p': 0.95, + 'do_sample': True, + }, + device_map='auto', +) + +encoder = HuggingFaceBgeEmbeddings( + model_name='pritamdeka/S-PubMedBert-MS-MARCO', +) + +PERSIST_DIR = '/lus/eagle/projects/LUCID/ogokdemir/ragamp/indexes/amp_index/' +AMP_PAPERS_DIR = '/lus/eagle/projects/candle_aesp/ogokdemir/pdfwf_runs/AmpParsedDocs/md_outs/' # noqa +QUERY_AMPS_DIR = '/home/ogokdemir/ragamp/examples/antimicrobial_peptides.txt' +OUTPUT_DIR = '/lus/eagle/projects/LUCID/ogokdemir/ragamp/outputs/amp_output/template_4.json' # noqa + +if not osp.exists(PERSIST_DIR): + logging.info('Creating index from scratch') + documents = SimpleDirectoryReader(AMP_PAPERS_DIR).load_data() + index = VectorStoreIndex.from_documents( + documents, + embed_model=encoder, + insert_batch_size=16384, + show_progress=True, + ) + index.storage_context.persist(PERSIST_DIR) +else: + logging.info('Loading index from storage') + storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) + index = load_index_from_storage(storage_context, embed_model=encoder) + +query_engine = index.as_query_engine(llm=llm) +logging.info('Query engine ready, running inference') + +with open(QUERY_AMPS_DIR) as f: + amps = f.read().splitlines() + +q2r = {} +for amp in tqdm(amps, desc='Querying', total=len(amps)): + query = f'What cellular processes does {amp} disrupt?' + response = query_engine.query(query) + q2r[amp] = str(response) + + +os.makedirs(osp.dirname(OUTPUT_DIR), exist_ok=True) + +with open(OUTPUT_DIR, 'w') as f: + json.dump(q2r, f) + +# TODO: Move dataloading and encoding to functions and parallelize them. +# TODO: Once that is done, build the index directly from the embeddings.