Skip to content

Commit

Permalink
Updates to VectorStore and Query Engine
Browse files Browse the repository at this point in the history
1. Upgraded the vectorstore from default json to Faiss.
Result: It now loads much faster from storage.

2. Implemented a custom query engine with the following:
	-A custom prompt template in-line with Rick's.
	-Prompt helper that will facilitate compact response synthesis.
	-Increased the retrieved documents to 30 per query.
	-Increased the similarity cutoff threshold to 0.75
	-Swithced the response mode from 'refine' to 'compact' for faster inference.
  • Loading branch information
ogkdmr committed Mar 10, 2024
1 parent f26cc55 commit eedd8ef
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 41 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ dependencies = [
"sentence_transformers >= 2.2.2",
"llama-index-embeddings-langchain",
"llama-index-embeddings-huggingface",
"llama-index-vector-stores-faiss"
"llama-index-vector-stores-faiss",
"faiss-gpu >= 1.7.0",

]

Expand Down Expand Up @@ -149,14 +150,14 @@ select = [
# ruff-specific
"RUF",
]
extend-ignore = []
extend-ignore = ["COM812", "ISC001"]

[tool.ruff.lint.flake8-pytest-style]
parametrize-values-type = "tuple"

[tool.ruff.lint.flake8-quotes]
inline-quotes = "single"
multiline-quotes = "single"
multiline-quotes = "double"

[tool.ruff.lint.isort]
force-single-line = true
Expand Down
170 changes: 132 additions & 38 deletions ragamp/pubmed_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,36 @@
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor

import faiss
import torch
from llama_index.core import get_response_synthesizer
from llama_index.core import PromptTemplate
from llama_index.core import SimpleDirectoryReader
from llama_index.core import VectorStoreIndex
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.indices.loading import load_index_from_storage
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.core.node_parser.interface import NodeParser
from llama_index.core.prompts.base import PromptTemplate
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
from llama_index.core.schema import Document
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.core.storage.storage_context import StorageContext
from llama_index.core.vector_stores import SimpleVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.vector_stores.faiss import FaissVectorStore
from transformers import BitsAndBytesConfig

logging.basicConfig(stream=sys.stdout, level=logging.INFO)


PERSIST_DIR = '/home/ac.ogokdemir/lucid_index'
PERSIST_DIR = '/rbstor/ac.ogokdemir/ragamp/all_lucid/faiss_index'
PAPERS_DIR = '/rbstor/ac.ogokdemir/md_outs'
QUERY_DIR = '/home/ogokdemir/ragamp/examples/lucid_queries.txt'
OUTPUT_DIR = '/rbstor/ac.ogokdemir/ragamp/output/lucid/'
QUERY_DIR = '/home/ac.ogokdemir/ragamp/examples/lucid_queries.txt'
OUTPUT_DIR = '/rbstor/ac.ogokdemir/ragamp/output/all_lucid'
NODE_INFO_PATH = osp.join(OUTPUT_DIR, 'node_info.jsonl')

os.makedirs(OUTPUT_DIR, exist_ok=True)
Expand All @@ -54,7 +59,8 @@
bnb_4bit_use_double_quant=True,
)

# TODO: Move the generator and encoder factories out.
# TODO: Move the generator and encoder creation to factory functions.
# Create BaseGenerator and BaseEncoder interfaces

# mistral7b = HuggingFaceLLM(
# model_name="mistralai/Mistral-7B-Instruct-v0.1",
Expand Down Expand Up @@ -82,7 +88,7 @@
'<s>[INST] {query_str} [/INST] </s>\n',
),
context_window=32000,
max_new_tokens=1024,
max_new_tokens=2048,
model_kwargs={'quantization_config': quantization_config},
# tokenizer_kwargs={},
generate_kwargs={
Expand Down Expand Up @@ -118,7 +124,15 @@ def get_splitter(encoder: BaseEmbedding) -> NodeParser:


def chunk_encode_unit(device: int, docs: list[Document]) -> list[BaseNode]:
"""Encode documents using the given embedding model."""
"""Takes list of documents and a GPU index and runs encoding on that GPU.
Args:
device (int): the GPU that will run this unit of work.
docs (list[Document]): list of documents
Returns:
list[BaseNode]: list of nodes
"""
# create the encoder
encoder = get_encoder(device)
splitter = get_splitter(encoder)
Expand All @@ -129,7 +143,16 @@ def chunk_encode_parallel(
docs: list[Document],
num_workers: int = 8,
) -> list[BaseNode]:
"""Encode documents in parallel using the given embedding model."""
"""Encode documents in parallel using the given embedding model.
Args:
docs (list[Document]): list of documents
num_workers (int, optional): Number of GPUs on your system.
Returns:
list[BaseNode]: list of nodes
"""
batches = [(i, docs[i::num_workers]) for i in range(num_workers)]
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [
Expand All @@ -141,16 +164,44 @@ def chunk_encode_parallel(
return [node for result in results for node in result]


# TODO: Put this in a create index from scratch function.
if not osp.exists(PERSIST_DIR):
logging.info('Creating index from scratch')

start = time.time()
documents = SimpleDirectoryReader(PAPERS_DIR).load_data()
end = time.time()

logging.info(f'Loaded documents in {end - start} seconds.')

logging.info('Starting to load the documents.')

load_start = time.time()
documents = SimpleDirectoryReader(PAPERS_DIR).load_data(show_progress=True)
load_end = time.time()
logging.info(
f"""Finished loading the documents in {load_end - load_start} seconds.
Starting to chunk and encode.""",
)
chunk_start = time.time()
nodes = chunk_encode_parallel(documents, num_workers=8)
chunk_end = time.time()

logging.info(
f"""Finished encoding in {chunk_end - chunk_start} seconds,
creating Faiss index.""",
)
embed_dim = 768 # for pubmedbert
faiss_index = faiss.IndexFlatL2(embed_dim)
vector_store = FaissVectorStore(faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
pack_start = time.time()
index = VectorStoreIndex(
nodes,
embed_model=get_encoder(),
insert_batch_size=16384,
use_async=True,
storage_context=storage_context,
show_progress=True,
)
pack_end = time.time()
logging.info(
f"""Finished packing the index in
{pack_end - pack_start} seconds.""",
)

# Code for visually inspecting the success of semantic chunking.
with open(NODE_INFO_PATH, 'w') as f:
Expand All @@ -162,48 +213,91 @@ def chunk_encode_parallel(
}
f.write(json.dumps(node_info) + '\n')

index = VectorStoreIndex(
nodes,
embed_model=get_encoder(), # for now,this has to be serial
insert_batch_size=16384,
show_progress=True,
use_async=True,
)

os.makedirs(PERSIST_DIR)
index.storage_context.persist(PERSIST_DIR)
logging.info(f'Saved the new index to {PERSIST_DIR}')

# TODO: Put this in a "load_index_from_storage" function.
else:
logging.info('Loading index from storage')

vector_store = FaissVectorStore.from_persist_dir(PERSIST_DIR)
storage_context = StorageContext.from_defaults(
docstore=SimpleDocumentStore.from_persist_dir(PERSIST_DIR),
vector_store=SimpleVectorStore.from_persist_dir(
PERSIST_DIR,
namespace='default',
),
index_store=SimpleIndexStore.from_persist_dir(PERSIST_DIR),
vector_store=vector_store,
persist_dir=PERSIST_DIR,
)

index = load_index_from_storage(storage_context, embed_model=get_encoder())

logging.info(f'Loaded the index from {PERSIST_DIR}.')

# TODO: Add these query engine information in a config file.
query_engine = index.as_query_engine(
# TODO: Refactor these into query_engine creation and inference.

ldr_prompt_template_str = (
"""You are a super smart AI that knows about science. You follow
directions and you are always truthful and concise in your responses.
Below is an hypothesis submitted to your consideration.\n
"""
'---------------------\n'
'Hypothesis: {query_str}\n'
'---------------------\n'
'Below is some context provided to assist you in your analysis.'
'---------------------\n'
'Context: {context_str}\n'
'---------------------\n'
"""Based on your background knowledge and the context provided, please
determine if this hypothesis could have some connection to low dose
radiation biology. If the answer is yes, please generate one or more
specific conjectures of biological mechanisms, that could relate low
dose radiation to the effects detailed in the hypothesis. Please be
as specific as possible, by naming specific biological pathways, genes
or proteins and their interactions. It is okay to speculate as long as
you give reasons for your conjectures. Finally, please estimate to the
best of your knowledge the likelihood of the hypothesis being true, and
please give step by step reasoning for your answers. \n"""
'Answer: '
)

ldr_prompt_template = PromptTemplate(ldr_prompt_template_str)

# Creating the query engine.
retriever = VectorIndexRetriever(index, similarity_top_k=30)

prompt_helper = PromptHelper(
context_window=32000,
num_output=2048,
chunk_overlap_ratio=0,
)

response_synthesizer = get_response_synthesizer(
llm=mixtral8x7b,
similarity_top_k=10,
similarity_threshold=0.5,
response_mode=ResponseMode.COMPACT,
prompt_helper=prompt_helper,
text_qa_template=ldr_prompt_template,
)

query_engine = RetrieverQueryEngine(
retriever=retriever,
node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
response_synthesizer=response_synthesizer,
)


logging.info('Query engine ready, running inference')

with open(QUERY_DIR) as f:
queries = f.read().splitlines()

query_2_response = {}

for query in queries:
query_2_response[query] = query_engine.query(query)
for query in queries[:2]:
response = query_engine.query(query)
query_2_response[query] = {
'response': str(response.response),
'metadata': str(response.metadata),
'source_nodes': str(response.source_nodes),
}


with open(osp.join(OUTPUT_DIR, 'query_responses.json'), 'w') as f:
json.dump(query_2_response, f)

0 comments on commit eedd8ef

Please sign in to comment.