diff --git a/scripts/AbsTaskRetrieval.py b/scripts/AbsTaskRetrieval.py index 6966643..f39b975 100644 --- a/scripts/AbsTaskRetrieval.py +++ b/scripts/AbsTaskRetrieval.py @@ -1,11 +1,10 @@ -""" -To reproduce Reranking experiments with using GritLM for embedding and subsequent reranking replace the AbsTaskRetrieval file in MTEB with this file. -""" +import copy import json import logging from time import time from typing import Dict, List +import numpy as np from sentence_transformers import SentenceTransformer from sentence_transformers.models import Transformer, WordEmbeddings import os @@ -17,6 +16,191 @@ DRES_METHODS = ["encode_queries", "encode_corpus"] +TEMPLATES = { + "ArguAna": "<|user|>\n" \ + "Provided two debate paragraphs, check if they are about the same topic, but contain counter-arguments.\n\n" \ + "Paragraph 1: {query}\n" \ + "Paragraph 2: {passage}\n\n" \ + "Answer with yes if paragraph 1 and paragraph 2 are about the same topic, but contain counter-arguments; Answer with no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "SciFact": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "NFCorpus": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackAndroidRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "ClimateFEVER": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackEnglishRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackGamingRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackGisRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackMathematicaRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackPhysicsRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackProgrammersRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackStatsRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackTexRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackUnixRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackWebmastersRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "CQADupstackWordpressRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "DBPedia": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "FEVER": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "FiQA2018": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "HotpotQA": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "MSMARCO": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "NQ": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "QuoraRetrieval": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "SCIDOCS": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "TRECCOVID": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", + "Touche2020": "<|user|>\n" \ + "Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \ + "Query: {query}\n" \ + "Passage: {passage}\n\n" \ + "Answer with yes if the passage is relevant to the query, and no otherwise.\n" \ + "<|assistant|>\n" \ + "Answer:", +} + class AbsTaskRetrieval(AbsTask): """ Abstract class for re-ranking experiments. @@ -46,6 +230,7 @@ def evaluate( score_function="cos_sim", **kwargs ): + task_name = kwargs['task_name'] sgpt2_model = model try: from beir.retrieval.evaluation import EvaluateRetrieval @@ -67,7 +252,7 @@ def evaluate( corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000, **kwargs, ) - + else: # Distributed (multi-GPU) from beir.retrieval.search.dense import ( @@ -82,67 +267,80 @@ def evaluate( retriever = EvaluateRetrieval(model, score_function=score_function) # or "cos_sim" or "dot" start_time = time() - # if os.path.isfile('SciFact/SciFact.json'): - # with open('SciFact/SciFact.json') as f: - # results = json.load(f) - # else: + # with open(f'qrels/{task_name}.json') as f: + # results = json.load(f) results = retriever.retrieve(corpus, queries) - # with open('SciFact/SciFact.json','w') as f: - # json.dump(results,f,indent=2) end_time = time() sgpt2_model = sgpt2_model.to('cpu') logger.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time)) model_rerank = kwargs.get('model_rerank', None) - template = "<|user|>\nI will provide you with {num} passages, each indicated by a numerical identifier []. " \ - "Rank the passages based on their relevance to the search query {query}.\n\n{passages}\n\n" \ - "Search Query: {query}.\n\n" \ - "Rank the {num} passages above based on their relevance to the search query. All the passages " \ - "should be included and listed using identifiers, in descending order of relevance. " \ - "The output format should be [] > [] > ..., e.g., [4] > [2] > ... " \ - "Only respond with the ranking results, do not say any word or explain.\n<|assistant|>\n" + template = TEMPLATES[task_name] if model_rerank is not None: + if not os.path.isdir(f"rank_cache_new/{task_name}"): + os.makedirs(f"rank_cache_new/{task_name}",exist_ok=True) + model_rerank.tokenizer.pad_token_id = model_rerank.tokenizer.eos_token_id model_rerank = model_rerank.cuda() + top_k = kwargs.get('tok_k',kwargs['top_k']) + # step_size = kwargs.get('step_size', 2) + # window_size = kwargs.get('window_size', -1) os.environ.pop("BIDIRECTIONAL_ATTN") print("BIDIRECTIONAL_ATTN", os.getenv("BIDIRECTIONAL_ATTN", False)) - for qid, doc_ids in tqdm(results.items(), desc='reranking'): - # if os.path.isfile(f"SciFact/{qid}.json"): - # with open(f"SciFact/{qid}.json") as f: - # rerank_orders = json.load(f) - # else: - doc_ids = sorted(doc_ids.items(),key=lambda x:x[1],reverse=True) - cur_query = queries[qid] - num = 0 - passages = '' - cur_prompt = None - all_ids = {} - scores = [] - old_orders = [] - while len(model_rerank.tokenizer(template.format(num=num, query=cur_query, passages=passages), return_tensors="pt")["input_ids"][0])<1900: - cur_prompt = template.format(num=num, query=cur_query, passages=passages) - passages += f"[{num}] {corpus[doc_ids[num][0]]['title'] + ' ' + corpus[doc_ids[num][0]]['text']}\n" - old_orders.append(doc_ids[num][0]) - all_ids[num] = doc_ids[num][0] - scores.append(doc_ids[num][1]) - num += 1 - inputs = model_rerank.tokenizer(cur_prompt, return_tensors="pt")["input_ids"].to(model_rerank.device) - generation_output = model_rerank.generate(inputs, max_new_tokens=100, temperature=0.7, do_sample=True) - outputs = model_rerank.tokenizer.batch_decode(generation_output[:, inputs.shape[-1]:])[0].strip('').strip() - components = outputs.split('>') - new_orders = [] - for idx,c in enumerate(components): - try: - new_orders.append(all_ids[int(c.strip().strip('[').strip(']').strip())]) - except: - print(len(old_orders),outputs) - pass - rerank_orders = {'old_orders':old_orders,'new_orders':new_orders} - # with open(f"SciFact/{qid}.json",'w') as f: - # json.dump(rerank_orders,f,indent=2) - cur_scores = [] - for i in rerank_orders['old_orders']: - cur_scores.append(results[qid][i]) - for i,s in zip(rerank_orders['new_orders'],cur_scores): - results[qid][i] = s + all_qids = [] + for k in results: + all_qids.append(k) + for qid in all_qids: + doc_ids = sorted(results[qid].items(), key=lambda x: x[1], reverse=True) + # remove_doc_ids = [d[0] for d in doc_ids[top_k:]] + # for a_doc_id in remove_doc_ids: + # results[qid].pop(a_doc_id) + all_qids = [] + for k in results: + all_qids.append(k) + bar = tqdm(range(len(all_qids)*top_k),desc='reranking') + def print_orders(l,tag): + order_to_print = [] + for local_i,o in enumerate(l): + order_to_print.append([local_i,o]) + print(order_to_print,tag) + for qid in all_qids: + flag = False + rerank_orders = {} + if os.path.isfile(f"rank_cache_new/{task_name}/{qid}.json"): + # continue + with open(f"rank_cache_new/{task_name}/{qid}.json") as f: + rerank_orders = json.load(f) + if 'old_orders' in rerank_orders and 'new_orders' in rerank_orders: + flag = True + if not flag: + with open(f"rank_cache_new/{task_name}/{qid}.json",'w') as f: + json.dump({},f,indent=2) + doc_ids = sorted(results[qid].items(),key=lambda x:x[1],reverse=True) + orders = [d[0] for d in doc_ids] + old_orders = copy.deepcopy(orders) + new_orders = [] + for a_doc_id in orders[:top_k]: + # cut to both query and foc to 600 for ArguAna + cur_prompt = template.format(query=queries[qid][:600],passage=corpus[a_doc_id]['title']+' '+corpus[a_doc_id]['text'][:600]) + inputs = model_rerank.tokenizer(cur_prompt, return_tensors="pt")["input_ids"].to(model_rerank.device) + generation_output = model_rerank.generate(inputs, max_new_tokens=1, temperature=0, + do_sample=False, return_dict_in_generate=True, + output_scores=True) + scores = generation_output.scores[0][0].cpu() + new_orders.append([a_doc_id,scores[5081]]) # 708 for no, 5081 for yes + bar.update(1) + new_orders_raw = sorted(new_orders,key=lambda x:x[1],reverse=True) + new_orders = [i[0] for i in new_orders_raw] + rerank_orders = {'old_orders':old_orders,'new_orders':new_orders} + with open(f"rank_cache_new/{task_name}/{qid}.json",'w') as f: + json.dump(rerank_orders,f,indent=2) + # assert set(rerank_orders['new_orders'])==set(rerank_orders['old_orders']) + # assert set(rerank_orders['new_orders'])==set(list(results[qid].keys())) + # selected_scores = [] + # for rank_id,o in enumerate(rerank_orders['new_orders']): + # selected_scores.append(results[qid][o]) + # selected_scores = sorted(selected_scores,reverse=True) + # for rank_id,o in enumerate(rerank_orders['new_orders']): + # results[qid][o] += (10-rank_id)/kwargs['divisor'] os.environ["BIDIRECTIONAL_ATTN"] = 'true' print("BIDIRECTIONAL_ATTN", os.getenv("BIDIRECTIONAL_ATTN", False))