Skip to content

Commit

Permalink
Merge pull request #26 from hongjin-su/main
Browse files Browse the repository at this point in the history
Update reranker
  • Loading branch information
Muennighoff authored Apr 12, 2024
2 parents 01b2ed4 + ccf8f67 commit 47b7fe6
Showing 1 changed file with 254 additions and 56 deletions.
310 changes: 254 additions & 56 deletions scripts/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""
To reproduce Reranking experiments with using GritLM for embedding and subsequent reranking replace the AbsTaskRetrieval file in MTEB with this file.
"""
import copy
import json
import logging
from time import time
from typing import Dict, List

import numpy as np
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, WordEmbeddings
import os
Expand All @@ -17,6 +16,191 @@

DRES_METHODS = ["encode_queries", "encode_corpus"]

TEMPLATES = {
"ArguAna": "<|user|>\n" \
"Provided two debate paragraphs, check if they are about the same topic, but contain counter-arguments.\n\n" \
"Paragraph 1: {query}\n" \
"Paragraph 2: {passage}\n\n" \
"Answer with yes if paragraph 1 and paragraph 2 are about the same topic, but contain counter-arguments; Answer with no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"SciFact": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"NFCorpus": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackAndroidRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"ClimateFEVER": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackEnglishRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackGamingRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackGisRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackMathematicaRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackPhysicsRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackProgrammersRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackStatsRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackTexRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackUnixRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackWebmastersRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"CQADupstackWordpressRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"DBPedia": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"FEVER": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"FiQA2018": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"HotpotQA": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"MSMARCO": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"NQ": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"QuoraRetrieval": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"SCIDOCS": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"TRECCOVID": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
"Touche2020": "<|user|>\n" \
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
"Query: {query}\n" \
"Passage: {passage}\n\n" \
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
"<|assistant|>\n" \
"Answer:",
}

class AbsTaskRetrieval(AbsTask):
"""
Abstract class for re-ranking experiments.
Expand Down Expand Up @@ -46,6 +230,7 @@ def evaluate(
score_function="cos_sim",
**kwargs
):
task_name = kwargs['task_name']
sgpt2_model = model
try:
from beir.retrieval.evaluation import EvaluateRetrieval
Expand All @@ -67,7 +252,7 @@ def evaluate(
corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000,
**kwargs,
)

else:
# Distributed (multi-GPU)
from beir.retrieval.search.dense import (
Expand All @@ -82,67 +267,80 @@ def evaluate(

retriever = EvaluateRetrieval(model, score_function=score_function) # or "cos_sim" or "dot"
start_time = time()
# if os.path.isfile('SciFact/SciFact.json'):
# with open('SciFact/SciFact.json') as f:
# results = json.load(f)
# else:
# with open(f'qrels/{task_name}.json') as f:
# results = json.load(f)
results = retriever.retrieve(corpus, queries)
# with open('SciFact/SciFact.json','w') as f:
# json.dump(results,f,indent=2)
end_time = time()
sgpt2_model = sgpt2_model.to('cpu')
logger.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
model_rerank = kwargs.get('model_rerank', None)
template = "<|user|>\nI will provide you with {num} passages, each indicated by a numerical identifier []. " \
"Rank the passages based on their relevance to the search query {query}.\n\n{passages}\n\n" \
"Search Query: {query}.\n\n" \
"Rank the {num} passages above based on their relevance to the search query. All the passages " \
"should be included and listed using identifiers, in descending order of relevance. " \
"The output format should be [] > [] > ..., e.g., [4] > [2] > ... " \
"Only respond with the ranking results, do not say any word or explain.\n<|assistant|>\n"
template = TEMPLATES[task_name]
if model_rerank is not None:
if not os.path.isdir(f"rank_cache_new/{task_name}"):
os.makedirs(f"rank_cache_new/{task_name}",exist_ok=True)
model_rerank.tokenizer.pad_token_id = model_rerank.tokenizer.eos_token_id
model_rerank = model_rerank.cuda()
top_k = kwargs.get('tok_k',kwargs['top_k'])
# step_size = kwargs.get('step_size', 2)
# window_size = kwargs.get('window_size', -1)
os.environ.pop("BIDIRECTIONAL_ATTN")
print("BIDIRECTIONAL_ATTN", os.getenv("BIDIRECTIONAL_ATTN", False))
for qid, doc_ids in tqdm(results.items(), desc='reranking'):
# if os.path.isfile(f"SciFact/{qid}.json"):
# with open(f"SciFact/{qid}.json") as f:
# rerank_orders = json.load(f)
# else:
doc_ids = sorted(doc_ids.items(),key=lambda x:x[1],reverse=True)
cur_query = queries[qid]
num = 0
passages = ''
cur_prompt = None
all_ids = {}
scores = []
old_orders = []
while len(model_rerank.tokenizer(template.format(num=num, query=cur_query, passages=passages), return_tensors="pt")["input_ids"][0])<1900:
cur_prompt = template.format(num=num, query=cur_query, passages=passages)
passages += f"[{num}] {corpus[doc_ids[num][0]]['title'] + ' ' + corpus[doc_ids[num][0]]['text']}\n"
old_orders.append(doc_ids[num][0])
all_ids[num] = doc_ids[num][0]
scores.append(doc_ids[num][1])
num += 1
inputs = model_rerank.tokenizer(cur_prompt, return_tensors="pt")["input_ids"].to(model_rerank.device)
generation_output = model_rerank.generate(inputs, max_new_tokens=100, temperature=0.7, do_sample=True)
outputs = model_rerank.tokenizer.batch_decode(generation_output[:, inputs.shape[-1]:])[0].strip('</s>').strip()
components = outputs.split('>')
new_orders = []
for idx,c in enumerate(components):
try:
new_orders.append(all_ids[int(c.strip().strip('[').strip(']').strip())])
except:
print(len(old_orders),outputs)
pass
rerank_orders = {'old_orders':old_orders,'new_orders':new_orders}
# with open(f"SciFact/{qid}.json",'w') as f:
# json.dump(rerank_orders,f,indent=2)
cur_scores = []
for i in rerank_orders['old_orders']:
cur_scores.append(results[qid][i])
for i,s in zip(rerank_orders['new_orders'],cur_scores):
results[qid][i] = s
all_qids = []
for k in results:
all_qids.append(k)
for qid in all_qids:
doc_ids = sorted(results[qid].items(), key=lambda x: x[1], reverse=True)
# remove_doc_ids = [d[0] for d in doc_ids[top_k:]]
# for a_doc_id in remove_doc_ids:
# results[qid].pop(a_doc_id)
all_qids = []
for k in results:
all_qids.append(k)
bar = tqdm(range(len(all_qids)*top_k),desc='reranking')
def print_orders(l,tag):
order_to_print = []
for local_i,o in enumerate(l):
order_to_print.append([local_i,o])
print(order_to_print,tag)
for qid in all_qids:
flag = False
rerank_orders = {}
if os.path.isfile(f"rank_cache_new/{task_name}/{qid}.json"):
# continue
with open(f"rank_cache_new/{task_name}/{qid}.json") as f:
rerank_orders = json.load(f)
if 'old_orders' in rerank_orders and 'new_orders' in rerank_orders:
flag = True
if not flag:
with open(f"rank_cache_new/{task_name}/{qid}.json",'w') as f:
json.dump({},f,indent=2)
doc_ids = sorted(results[qid].items(),key=lambda x:x[1],reverse=True)
orders = [d[0] for d in doc_ids]
old_orders = copy.deepcopy(orders)
new_orders = []
for a_doc_id in orders[:top_k]:
# cut to both query and foc to 600 for ArguAna
cur_prompt = template.format(query=queries[qid][:600],passage=corpus[a_doc_id]['title']+' '+corpus[a_doc_id]['text'][:600])
inputs = model_rerank.tokenizer(cur_prompt, return_tensors="pt")["input_ids"].to(model_rerank.device)
generation_output = model_rerank.generate(inputs, max_new_tokens=1, temperature=0,
do_sample=False, return_dict_in_generate=True,
output_scores=True)
scores = generation_output.scores[0][0].cpu()
new_orders.append([a_doc_id,scores[5081]]) # 708 for no, 5081 for yes
bar.update(1)
new_orders_raw = sorted(new_orders,key=lambda x:x[1],reverse=True)
new_orders = [i[0] for i in new_orders_raw]
rerank_orders = {'old_orders':old_orders,'new_orders':new_orders}
with open(f"rank_cache_new/{task_name}/{qid}.json",'w') as f:
json.dump(rerank_orders,f,indent=2)
# assert set(rerank_orders['new_orders'])==set(rerank_orders['old_orders'])
# assert set(rerank_orders['new_orders'])==set(list(results[qid].keys()))
# selected_scores = []
# for rank_id,o in enumerate(rerank_orders['new_orders']):
# selected_scores.append(results[qid][o])
# selected_scores = sorted(selected_scores,reverse=True)
# for rank_id,o in enumerate(rerank_orders['new_orders']):
# results[qid][o] += (10-rank_id)/kwargs['divisor']
os.environ["BIDIRECTIONAL_ATTN"] = 'true'
print("BIDIRECTIONAL_ATTN", os.getenv("BIDIRECTIONAL_ATTN", False))

Expand Down

0 comments on commit 47b7fe6

Please sign in to comment.