From eedd8efba56a22a9e89c6a40ff6fce43d1d0bac5 Mon Sep 17 00:00:00 2001 From: ogkdmr Date: Sun, 10 Mar 2024 16:52:53 -0500 Subject: [PATCH] Updates to VectorStore and Query Engine 1. Upgraded the vectorstore from default json to Faiss. Result: It now loads much faster from storage. 2. Implemented a custom query engine with the following: -A custom prompt template in-line with Rick's. -Prompt helper that will facilitate compact response synthesis. -Increased the retrieved documents to 30 per query. -Increased the similarity cutoff threshold to 0.75 -Swithced the response mode from 'refine' to 'compact' for faster inference. --- pyproject.toml | 7 +- ragamp/pubmed_rag.py | 170 +++++++++++++++++++++++++++++++++---------- 2 files changed, 136 insertions(+), 41 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a24a6d7..ae5f2b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,8 @@ dependencies = [ "sentence_transformers >= 2.2.2", "llama-index-embeddings-langchain", "llama-index-embeddings-huggingface", - "llama-index-vector-stores-faiss" + "llama-index-vector-stores-faiss", + "faiss-gpu >= 1.7.0", ] @@ -149,14 +150,14 @@ select = [ # ruff-specific "RUF", ] -extend-ignore = [] +extend-ignore = ["COM812", "ISC001"] [tool.ruff.lint.flake8-pytest-style] parametrize-values-type = "tuple" [tool.ruff.lint.flake8-quotes] inline-quotes = "single" -multiline-quotes = "single" +multiline-quotes = "double" [tool.ruff.lint.isort] force-single-line = true diff --git a/ragamp/pubmed_rag.py b/ragamp/pubmed_rag.py index d272452..2a25650 100644 --- a/ragamp/pubmed_rag.py +++ b/ragamp/pubmed_rag.py @@ -18,31 +18,36 @@ from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor +import faiss import torch +from llama_index.core import get_response_synthesizer +from llama_index.core import PromptTemplate from llama_index.core import SimpleDirectoryReader from llama_index.core import VectorStoreIndex from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.indices.loading import load_index_from_storage +from llama_index.core.indices.prompt_helper import PromptHelper from llama_index.core.node_parser import SemanticSplitterNodeParser from llama_index.core.node_parser.interface import NodeParser -from llama_index.core.prompts.base import PromptTemplate +from llama_index.core.postprocessor import SimilarityPostprocessor +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.response_synthesizers import ResponseMode +from llama_index.core.retrievers import VectorIndexRetriever from llama_index.core.schema import BaseNode from llama_index.core.schema import Document -from llama_index.core.storage.docstore import SimpleDocumentStore -from llama_index.core.storage.index_store import SimpleIndexStore from llama_index.core.storage.storage_context import StorageContext -from llama_index.core.vector_stores import SimpleVectorStore from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.llms.huggingface import HuggingFaceLLM +from llama_index.vector_stores.faiss import FaissVectorStore from transformers import BitsAndBytesConfig logging.basicConfig(stream=sys.stdout, level=logging.INFO) -PERSIST_DIR = '/home/ac.ogokdemir/lucid_index' +PERSIST_DIR = '/rbstor/ac.ogokdemir/ragamp/all_lucid/faiss_index' PAPERS_DIR = '/rbstor/ac.ogokdemir/md_outs' -QUERY_DIR = '/home/ogokdemir/ragamp/examples/lucid_queries.txt' -OUTPUT_DIR = '/rbstor/ac.ogokdemir/ragamp/output/lucid/' +QUERY_DIR = '/home/ac.ogokdemir/ragamp/examples/lucid_queries.txt' +OUTPUT_DIR = '/rbstor/ac.ogokdemir/ragamp/output/all_lucid' NODE_INFO_PATH = osp.join(OUTPUT_DIR, 'node_info.jsonl') os.makedirs(OUTPUT_DIR, exist_ok=True) @@ -54,7 +59,8 @@ bnb_4bit_use_double_quant=True, ) -# TODO: Move the generator and encoder factories out. +# TODO: Move the generator and encoder creation to factory functions. +# Create BaseGenerator and BaseEncoder interfaces # mistral7b = HuggingFaceLLM( # model_name="mistralai/Mistral-7B-Instruct-v0.1", @@ -82,7 +88,7 @@ '[INST] {query_str} [/INST] \n', ), context_window=32000, - max_new_tokens=1024, + max_new_tokens=2048, model_kwargs={'quantization_config': quantization_config}, # tokenizer_kwargs={}, generate_kwargs={ @@ -118,7 +124,15 @@ def get_splitter(encoder: BaseEmbedding) -> NodeParser: def chunk_encode_unit(device: int, docs: list[Document]) -> list[BaseNode]: - """Encode documents using the given embedding model.""" + """Takes list of documents and a GPU index and runs encoding on that GPU. + + Args: + device (int): the GPU that will run this unit of work. + docs (list[Document]): list of documents + + Returns: + list[BaseNode]: list of nodes + """ # create the encoder encoder = get_encoder(device) splitter = get_splitter(encoder) @@ -129,7 +143,16 @@ def chunk_encode_parallel( docs: list[Document], num_workers: int = 8, ) -> list[BaseNode]: - """Encode documents in parallel using the given embedding model.""" + """Encode documents in parallel using the given embedding model. + + Args: + docs (list[Document]): list of documents + num_workers (int, optional): Number of GPUs on your system. + + Returns: + list[BaseNode]: list of nodes + + """ batches = [(i, docs[i::num_workers]) for i in range(num_workers)] with ThreadPoolExecutor(max_workers=num_workers) as executor: futures = [ @@ -141,16 +164,44 @@ def chunk_encode_parallel( return [node for result in results for node in result] +# TODO: Put this in a create index from scratch function. if not osp.exists(PERSIST_DIR): logging.info('Creating index from scratch') - - start = time.time() - documents = SimpleDirectoryReader(PAPERS_DIR).load_data() - end = time.time() - - logging.info(f'Loaded documents in {end - start} seconds.') - + logging.info('Starting to load the documents.') + + load_start = time.time() + documents = SimpleDirectoryReader(PAPERS_DIR).load_data(show_progress=True) + load_end = time.time() + logging.info( + f"""Finished loading the documents in {load_end - load_start} seconds. + Starting to chunk and encode.""", + ) + chunk_start = time.time() nodes = chunk_encode_parallel(documents, num_workers=8) + chunk_end = time.time() + + logging.info( + f"""Finished encoding in {chunk_end - chunk_start} seconds, + creating Faiss index.""", + ) + embed_dim = 768 # for pubmedbert + faiss_index = faiss.IndexFlatL2(embed_dim) + vector_store = FaissVectorStore(faiss_index) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + pack_start = time.time() + index = VectorStoreIndex( + nodes, + embed_model=get_encoder(), + insert_batch_size=16384, + use_async=True, + storage_context=storage_context, + show_progress=True, + ) + pack_end = time.time() + logging.info( + f"""Finished packing the index in + {pack_end - pack_start} seconds.""", + ) # Code for visually inspecting the success of semantic chunking. with open(NODE_INFO_PATH, 'w') as f: @@ -162,39 +213,76 @@ def chunk_encode_parallel( } f.write(json.dumps(node_info) + '\n') - index = VectorStoreIndex( - nodes, - embed_model=get_encoder(), # for now,this has to be serial - insert_batch_size=16384, - show_progress=True, - use_async=True, - ) - os.makedirs(PERSIST_DIR) index.storage_context.persist(PERSIST_DIR) logging.info(f'Saved the new index to {PERSIST_DIR}') +# TODO: Put this in a "load_index_from_storage" function. else: logging.info('Loading index from storage') + + vector_store = FaissVectorStore.from_persist_dir(PERSIST_DIR) storage_context = StorageContext.from_defaults( - docstore=SimpleDocumentStore.from_persist_dir(PERSIST_DIR), - vector_store=SimpleVectorStore.from_persist_dir( - PERSIST_DIR, - namespace='default', - ), - index_store=SimpleIndexStore.from_persist_dir(PERSIST_DIR), + vector_store=vector_store, + persist_dir=PERSIST_DIR, ) + index = load_index_from_storage(storage_context, embed_model=get_encoder()) logging.info(f'Loaded the index from {PERSIST_DIR}.') -# TODO: Add these query engine information in a config file. -query_engine = index.as_query_engine( +# TODO: Refactor these into query_engine creation and inference. + +ldr_prompt_template_str = ( + """You are a super smart AI that knows about science. You follow + directions and you are always truthful and concise in your responses. + Below is an hypothesis submitted to your consideration.\n + """ + '---------------------\n' + 'Hypothesis: {query_str}\n' + '---------------------\n' + 'Below is some context provided to assist you in your analysis.' + '---------------------\n' + 'Context: {context_str}\n' + '---------------------\n' + """Based on your background knowledge and the context provided, please + determine if this hypothesis could have some connection to low dose + radiation biology. If the answer is yes, please generate one or more + specific conjectures of biological mechanisms, that could relate low + dose radiation to the effects detailed in the hypothesis. Please be + as specific as possible, by naming specific biological pathways, genes + or proteins and their interactions. It is okay to speculate as long as + you give reasons for your conjectures. Finally, please estimate to the + best of your knowledge the likelihood of the hypothesis being true, and + please give step by step reasoning for your answers. \n""" + 'Answer: ' +) + +ldr_prompt_template = PromptTemplate(ldr_prompt_template_str) + +# Creating the query engine. +retriever = VectorIndexRetriever(index, similarity_top_k=30) + +prompt_helper = PromptHelper( + context_window=32000, + num_output=2048, + chunk_overlap_ratio=0, +) + +response_synthesizer = get_response_synthesizer( llm=mixtral8x7b, - similarity_top_k=10, - similarity_threshold=0.5, + response_mode=ResponseMode.COMPACT, + prompt_helper=prompt_helper, + text_qa_template=ldr_prompt_template, ) +query_engine = RetrieverQueryEngine( + retriever=retriever, + node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)], + response_synthesizer=response_synthesizer, +) + + logging.info('Query engine ready, running inference') with open(QUERY_DIR) as f: @@ -202,8 +290,14 @@ def chunk_encode_parallel( query_2_response = {} -for query in queries: - query_2_response[query] = query_engine.query(query) +for query in queries[:2]: + response = query_engine.query(query) + query_2_response[query] = { + 'response': str(response.response), + 'metadata': str(response.metadata), + 'source_nodes': str(response.source_nodes), + } + with open(osp.join(OUTPUT_DIR, 'query_responses.json'), 'w') as f: json.dump(query_2_response, f)