Skip to content

Commit

Permalink
Working version, ready to refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
ogkdmr committed Mar 11, 2024
1 parent d54563f commit 0d3348d
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 47 deletions.
3 changes: 2 additions & 1 deletion ragamp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""foobar package."""

# It is recommended to not write code in the __init__.py because it is easy
# to introduce import cycles and code becomes harder to search for.
from __future__ import annotations

import importlib.metadata as importlib_metadata
import sys

__version__ = importlib_metadata.version('foobar')
__version__ = importlib_metadata.version('ragamp')
132 changes: 86 additions & 46 deletions ragamp/pubmed_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@
import faiss
import torch
from llama_index.core import get_response_synthesizer
from llama_index.core import PromptTemplate
from llama_index.core import SimpleDirectoryReader
from llama_index.core import VectorStoreIndex
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.indices.loading import load_index_from_storage
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.core.node_parser.interface import NodeParser
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.prompts import PromptTemplate
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
Expand All @@ -44,10 +42,10 @@
logging.basicConfig(stream=sys.stdout, level=logging.INFO)


PERSIST_DIR = '/rbstor/ac.ogokdemir/ragamp/all_lucid/faiss_index'
PERSIST_DIR = '/rbstor/ac.ogokdemir/ragamp/indexes/lucid/faiss_index'
PAPERS_DIR = '/rbstor/ac.ogokdemir/md_outs'
QUERY_DIR = '/home/ac.ogokdemir/ragamp/examples/lucid_queries.txt'
OUTPUT_DIR = '/rbstor/ac.ogokdemir/ragamp/output/all_lucid'
QUERY_DIR = '/home/ac.ogokdemir/ragamp/examples/hypo_ol_parsed.jsonl'
OUTPUT_DIR = '/rbstor/ac.ogokdemir/ragamp/output/lucid_hypo_eval'
NODE_INFO_PATH = osp.join(OUTPUT_DIR, 'node_info.jsonl')

os.makedirs(OUTPUT_DIR, exist_ok=True)
Expand All @@ -60,14 +58,14 @@
)

# TODO: Move the generator and encoder creation to factory functions.
# Create BaseGenerator and BaseEncoder interfaces
# Create BaseGenerator and BaseEncoder interfaces make them take config.

# mistral7b = HuggingFaceLLM(
# model_name="mistralai/Mistral-7B-Instruct-v0.1",
# tokenizer_name="mistralai/Mistral-7B-Instruct-v0.1",
# query_wrapper_prompt=PromptTemplate(
# "<s>[INST] {query_str} [/INST] </s>\n",
# ),
# # query_wrapper_prompt=PromptTemplate(
# # "<s>[INST] {query_str} [/INST] </s>\n",
# # ),
# context_window=32000,
# max_new_tokens=1024,
# model_kwargs={"quantization_config": quantization_config},
Expand All @@ -81,21 +79,21 @@
# device_map="auto",
# )


mixtral8x7b = HuggingFaceLLM(
model_name='mistralai/Mixtral-8x7B-v0.1',
tokenizer_name='mistralai/Mixtral-8x7B-v0.1',
model_name='mistralai/Mixtral-8x7B-Instruct-v0.1',
tokenizer_name='mistralai/Mixtral-8x7B-Instruct-v0.1',
context_window=32000,
max_new_tokens=4096,
query_wrapper_prompt=PromptTemplate(
'<s>[INST] {query_str} [/INST] </s>\n',
),
context_window=32000,
max_new_tokens=2048,
model_kwargs={'quantization_config': quantization_config},
# tokenizer_kwargs={},
generate_kwargs={
'temperature': 0.2,
'top_k': 5,
'top_p': 0.95,
'do_sample': True,
'do_sample': False,
},
device_map='auto',
)
Expand Down Expand Up @@ -233,15 +231,20 @@ def chunk_encode_parallel(

# TODO: Refactor these into query_engine creation and inference.


# Creating the query engine.
retriever = VectorIndexRetriever(
index, embed_model=get_encoder(), similarity_top_k=30
)

ldr_prompt_template_str = (
"""You are a super smart AI that knows about science. You follow
"""<s> [INST] You are a super smart AI that knows about science. You follow
directions and you are always truthful and concise in your responses.
Below is an hypothesis submitted to your consideration.\n
"""
Below is an hypothesis submitted to your consideration.\n"""
'---------------------\n'
'Hypothesis: {query_str}\n'
'---------------------\n'
'Below is some context provided to assist you in your analysis.'
'Below is some context provided to assist you in your analysis.\n'
'---------------------\n'
'Context: {context_str}\n'
'---------------------\n'
Expand All @@ -254,14 +257,13 @@ def chunk_encode_parallel(
or proteins and their interactions. It is okay to speculate as long as
you give reasons for your conjectures. Finally, please estimate to the
best of your knowledge the likelihood of the hypothesis being true, and
please give step by step reasoning for your answers. \n"""
'Answer: '
please give step by step reasoning for your answers.
Answer: [/INST] </s>"""
''
)

ldr_prompt_template = PromptTemplate(ldr_prompt_template_str)

# Creating the query engine.
retriever = VectorIndexRetriever(index, similarity_top_k=30)

prompt_helper = PromptHelper(
context_window=32000,
Expand All @@ -272,32 +274,70 @@ def chunk_encode_parallel(
response_synthesizer = get_response_synthesizer(
llm=mixtral8x7b,
response_mode=ResponseMode.COMPACT,
prompt_helper=prompt_helper,
# prompt_helper=prompt_helper,
text_qa_template=ldr_prompt_template,
)

query_engine = RetrieverQueryEngine(
retriever=retriever,
node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
response_synthesizer=response_synthesizer,
)
# query_engine = RetrieverQueryEngine(
# retriever=retriever,
# node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
# response_synthesizer=response_synthesizer,
# )

# print(query_engine.query("What is the meaning of life?"))
qe = index.as_query_engine(
llm=mixtral8x7b,
similarity_top_k=5,
response_mode=ResponseMode.COMPACT,
# prompt_helper=prompt_helper,
text_qa_template=ldr_prompt_template,
# response_synthesizer=response_synthesizer,
# node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.2)],
)

logging.info('Query engine ready, running inference')
response = qe.query(
"""The hypothesis in this study is that low doses of X-rays can induce
a repair mechanism in human lymphocytes. This mechanism reduces
the number of broken chromosome ends that can take part in aberration
formation,even in the case of high-LET radiation from radon. This
adaptive response is suggested to be a general phenomenon, not
specific to X-rays, and potentially applicable to other types of
radiation. The hypothesis is tested through experiments involving
different types of radiation, varying doses and timing of low-dose
X-ray exposure, and examining different cell types."""
)

with open(QUERY_DIR) as f:
queries = f.read().splitlines()

query_2_response = {}

for query in queries[:2]:
response = query_engine.query(query)
query_2_response[query] = {
'response': str(response.response),
'metadata': str(response.metadata),
'source_nodes': str(response.source_nodes),
}


with open(osp.join(OUTPUT_DIR, 'query_responses.json'), 'w') as f:
json.dump(query_2_response, f)
print(response.response)
print('-' * 20)
print(response.metadata)
print('-' * 20)
print(response.source_nodes)
print('-' * 20)

sys.exit()

# # TODO: Make this query processing a generic function.

# with open(QUERY_DIR) as f:
# queries = [json.loads(line) for line in f.readlines()]
# for query in queries:
# question = query['output'] # e.g., hypothesis
# query_source_file = query['source']
# response_dict = query_engine.query(question)
# answer = response_dict.response # type: ignore
# metadata = response_dict.metadata
# context_sources = response_dict.source_nodes

# out = {
# 'query': question,
# 'query_source_file': query_source_file,
# 'response': answer,
# 'metadata': metadata,
# 'source_nodes': context_sources,
# }

# with open(
# osp.join(OUTPUT_DIR, f'{query_source_file}_hypo_ol_rag.json'), 'w'
# ) as f:
# json.dump(out, f)

0 comments on commit 0d3348d

Please sign in to comment.