diff --git a/examples/lucid_queries.txt b/examples/lucid_queries.txt new file mode 100644 index 0000000..180b2c9 --- /dev/null +++ b/examples/lucid_queries.txt @@ -0,0 +1,10 @@ +How does the process of photosynthesis differ between C3, C4, and CAM plants? +Can you explain the role of mitochondria in apoptosis and describe the biochemical pathways involved? +What are the differences between classical and operant conditioning, and how do they relate to learning theory? +How does the Cori cycle work, and what is its role in maintaining energy balance in the body? +Can you explain the concept of dark matter and dark energy, and discuss their significance in the current understanding of the universe? +What are the key differences between eukaryotic and prokaryotic cells, and why are these distinctions important in cell biology? +How does the endocrine system regulate homeostasis in the human body, and what are some examples of hormones and their functions? +Can you describe the process of meiosis and discuss its role in sexual reproduction and genetic diversity? +What is the difference between a hypothesis, a theory, and a law in scientific terminology, and can you give an example of each? +How does the Doppler effect apply to sound and light waves, and what are some real-world examples where this phenomenon is observed? diff --git a/pyproject.toml b/pyproject.toml index 0fe81c3..a24a6d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,10 @@ dependencies = [ "llama-index-readers-file", "bitsandbytes >= 0.42.0", "sentence_transformers >= 2.2.2", - "llama-index-embeddings-langchain" + "llama-index-embeddings-langchain", + "llama-index-embeddings-huggingface", + "llama-index-vector-stores-faiss" + ] [project.urls] diff --git a/ragamp/pubmed_rag.py b/ragamp/pubmed_rag.py index 6e9b675..d272452 100644 --- a/ragamp/pubmed_rag.py +++ b/ragamp/pubmed_rag.py @@ -14,24 +14,39 @@ import os import os.path as osp import sys +import time +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor import torch -from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings from llama_index.core import SimpleDirectoryReader from llama_index.core import VectorStoreIndex +from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.indices.loading import load_index_from_storage +from llama_index.core.node_parser import SemanticSplitterNodeParser +from llama_index.core.node_parser.interface import NodeParser from llama_index.core.prompts.base import PromptTemplate +from llama_index.core.schema import BaseNode +from llama_index.core.schema import Document +from llama_index.core.storage.docstore import SimpleDocumentStore +from llama_index.core.storage.index_store import SimpleIndexStore from llama_index.core.storage.storage_context import StorageContext +from llama_index.core.vector_stores import SimpleVectorStore +from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.llms.huggingface import HuggingFaceLLM -from tqdm import tqdm - -# os.environ["HF_HOME"] = "/lus/eagle/projects/LUCID/ogokdemir/hf_cache" from transformers import BitsAndBytesConfig logging.basicConfig(stream=sys.stdout, level=logging.INFO) -logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) +PERSIST_DIR = '/home/ac.ogokdemir/lucid_index' +PAPERS_DIR = '/rbstor/ac.ogokdemir/md_outs' +QUERY_DIR = '/home/ogokdemir/ragamp/examples/lucid_queries.txt' +OUTPUT_DIR = '/rbstor/ac.ogokdemir/ragamp/output/lucid/' +NODE_INFO_PATH = osp.join(OUTPUT_DIR, 'node_info.jsonl') + +os.makedirs(OUTPUT_DIR, exist_ok=True) + quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, @@ -39,14 +54,35 @@ bnb_4bit_use_double_quant=True, ) -llm = HuggingFaceLLM( - model_name='mistralai/Mistral-7B-Instruct-v0.1', - tokenizer_name='mistralai/Mistral-7B-Instruct-v0.1', +# TODO: Move the generator and encoder factories out. + +# mistral7b = HuggingFaceLLM( +# model_name="mistralai/Mistral-7B-Instruct-v0.1", +# tokenizer_name="mistralai/Mistral-7B-Instruct-v0.1", +# query_wrapper_prompt=PromptTemplate( +# "[INST] {query_str} [/INST] \n", +# ), +# context_window=32000, +# max_new_tokens=1024, +# model_kwargs={"quantization_config": quantization_config}, +# # tokenizer_kwargs={}, +# generate_kwargs={ +# "temperature": 0.2, +# "top_k": 5, +# "top_p": 0.95, +# "do_sample": True, +# }, +# device_map="auto", +# ) + +mixtral8x7b = HuggingFaceLLM( + model_name='mistralai/Mixtral-8x7B-v0.1', + tokenizer_name='mistralai/Mixtral-8x7B-v0.1', query_wrapper_prompt=PromptTemplate( '[INST] {query_str} [/INST] \n', ), - context_window=3900, - max_new_tokens=256, + context_window=32000, + max_new_tokens=1024, model_kwargs={'quantization_config': quantization_config}, # tokenizer_kwargs={}, generate_kwargs={ @@ -58,47 +94,116 @@ device_map='auto', ) -encoder = HuggingFaceBgeEmbeddings( - model_name='pritamdeka/S-PubMedBert-MS-MARCO', -) -PERSIST_DIR = '/lus/eagle/projects/LUCID/ogokdemir/ragamp/indexes/amp_index/' -AMP_PAPERS_DIR = '/lus/eagle/projects/candle_aesp/ogokdemir/pdfwf_runs/AmpParsedDocs/md_outs/' # noqa -QUERY_AMPS_DIR = '/home/ogokdemir/ragamp/examples/antimicrobial_peptides.txt' -OUTPUT_DIR = '/lus/eagle/projects/LUCID/ogokdemir/ragamp/outputs/amp_output/template_4.json' # noqa +# indexer = encoder + chunker. +def get_encoder(device: int = 0) -> BaseEmbedding: + """Get the encoder for the vector store index.""" + return HuggingFaceEmbedding( + model_name='pritamdeka/S-PubMedBert-MS-MARCO', + tokenizer_name='pritamdeka/S-PubMedBert-MS-MARCO', + max_length=512, + embed_batch_size=64, + cache_folder=os.environ.get('HF_HOME'), + device=f'cuda:{device}', + ) + + +def get_splitter(encoder: BaseEmbedding) -> NodeParser: + """Get the splitter for the vector store index.""" + return SemanticSplitterNodeParser( + buffer_size=1, + include_metadata=True, + embed_model=encoder, + ) + + +def chunk_encode_unit(device: int, docs: list[Document]) -> list[BaseNode]: + """Encode documents using the given embedding model.""" + # create the encoder + encoder = get_encoder(device) + splitter = get_splitter(encoder) + return splitter.get_nodes_from_documents(docs) + + +def chunk_encode_parallel( + docs: list[Document], + num_workers: int = 8, +) -> list[BaseNode]: + """Encode documents in parallel using the given embedding model.""" + batches = [(i, docs[i::num_workers]) for i in range(num_workers)] + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit(chunk_encode_unit, device, docs) # + for device, docs in batches + ] + results = [future.result() for future in as_completed(futures)] + + return [node for result in results for node in result] + if not osp.exists(PERSIST_DIR): logging.info('Creating index from scratch') - documents = SimpleDirectoryReader(AMP_PAPERS_DIR).load_data() - index = VectorStoreIndex.from_documents( - documents, - embed_model=encoder, + + start = time.time() + documents = SimpleDirectoryReader(PAPERS_DIR).load_data() + end = time.time() + + logging.info(f'Loaded documents in {end - start} seconds.') + + nodes = chunk_encode_parallel(documents, num_workers=8) + + # Code for visually inspecting the success of semantic chunking. + with open(NODE_INFO_PATH, 'w') as f: + for rank, node in enumerate(nodes, 1): + node_info = { + 'rank': rank, + 'content': node.get_content(), + 'metadata': node.get_metadata_str(), + } + f.write(json.dumps(node_info) + '\n') + + index = VectorStoreIndex( + nodes, + embed_model=get_encoder(), # for now,this has to be serial insert_batch_size=16384, show_progress=True, + use_async=True, ) + + os.makedirs(PERSIST_DIR) index.storage_context.persist(PERSIST_DIR) + logging.info(f'Saved the new index to {PERSIST_DIR}') + else: logging.info('Loading index from storage') - storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) - index = load_index_from_storage(storage_context, embed_model=encoder) + storage_context = StorageContext.from_defaults( + docstore=SimpleDocumentStore.from_persist_dir(PERSIST_DIR), + vector_store=SimpleVectorStore.from_persist_dir( + PERSIST_DIR, + namespace='default', + ), + index_store=SimpleIndexStore.from_persist_dir(PERSIST_DIR), + ) + index = load_index_from_storage(storage_context, embed_model=get_encoder()) -query_engine = index.as_query_engine(llm=llm) -logging.info('Query engine ready, running inference') + logging.info(f'Loaded the index from {PERSIST_DIR}.') -with open(QUERY_AMPS_DIR) as f: - amps = f.read().splitlines() +# TODO: Add these query engine information in a config file. +query_engine = index.as_query_engine( + llm=mixtral8x7b, + similarity_top_k=10, + similarity_threshold=0.5, +) -q2r = {} -for amp in tqdm(amps, desc='Querying', total=len(amps)): - query = f'What cellular processes does {amp} disrupt?' - response = query_engine.query(query) - q2r[amp] = str(response) +logging.info('Query engine ready, running inference') +with open(QUERY_DIR) as f: + queries = f.read().splitlines() -os.makedirs(osp.dirname(OUTPUT_DIR), exist_ok=True) +query_2_response = {} -with open(OUTPUT_DIR, 'w') as f: - json.dump(q2r, f) +for query in queries: + query_2_response[query] = query_engine.query(query) -# TODO: Move dataloading and encoding to functions and parallelize them. -# TODO: Once that is done, build the index directly from the embeddings. +with open(osp.join(OUTPUT_DIR, 'query_responses.json'), 'w') as f: + json.dump(query_2_response, f)