forked from IntelLabs/fastRAG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nq-plaid-fid.py
159 lines (119 loc) · 4.48 KB
/
nq-plaid-fid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import logging
import os
import pathlib
import kilt.eval_retrieval as retrieval_metrics
import pandas as pd
import torch
import tqdm
from datasets import load_dataset
from haystack import Pipeline
from haystack.document_stores import ElasticsearchDocumentStore
from haystack.nodes import BM25Retriever, FARMReader, PromptModel, SentenceTransformersRanker
from haystack.nodes.prompt import AnswerParser, PromptNode
from haystack.nodes.prompt.prompt_template import PromptTemplate
from kilt.eval_downstream import _calculate_metrics, validate_input
from tqdm import tqdm
from fastrag.prompters.invocation_layers import fid
from fastrag.retrievers.colbert import ColBERTRetriever
from fastrag.stores import PLAIDDocumentStore
from fastrag.utils import get_timing_from_pipeline
def evaluate(gold_records, guess_records):
# 0. validate input
gold_records, guess_records = validate_input(gold_records, guess_records)
# 1. downstream + kilt
result = _calculate_metrics(gold_records, guess_records)
# 2. retrieval performance
retrieval_results = retrieval_metrics.compute(
gold_records, guess_records, ks=[1, 5], rank_keys=["wikipedia_id"]
)
result["retrieval"] = {
"Rprec": retrieval_results["Rprec"],
"recall@5": retrieval_results["recall@5"],
}
return result
def create_json_entry(jid, input_text, answer, documents):
return {
"id": jid,
"input": input_text,
"output": [{"answer": answer, "provenance": [{"wikipedia_id": d.id} for d in documents]}],
}
def create_records(test_dataset, result_collection):
guess_records = []
for i in range(len(test_dataset)):
example = test_dataset[i]
results = result_collection[i]
guess_records.append(
create_json_entry(
example["id"], example["input"], results["answers"][0].answer, results["documents"]
)
)
return guess_records
def evaluate_from_answers(gold_records, result_collection):
guess_records = create_records(gold_records, result_collection)
return evaluate(gold_records, guess_records)
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logging.info("Loading PLAID index...")
document_store = PLAIDDocumentStore(
collection_path="collection_path", checkpoint_path="checkpoint_path", index_path="index_path"
)
# Create Components
retriever = ColBERTRetriever(document_store=document_store)
reranker = SentenceTransformersRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2")
PrompterModel = PromptModel(
model_name_or_path="Intel/fid_flan_t5_base_nq",
use_gpu=True,
invocation_layer_class=fid.FiDHFLocalInvocationLayer,
model_kwargs=dict(
model_kwargs=dict(device_map={"": 0}, torch_dtype=torch.bfloat16, do_sample=False),
generation_kwargs=dict(max_length=10),
),
)
reader = PromptNode(
model_name_or_path=PrompterModel,
default_prompt_template=PromptTemplate("{query}", output_parser=AnswerParser()),
)
# Build Pipeline
p = Pipeline()
p.add_node(component=retriever, name="Retriever", inputs=["Query"])
p.add_node(component=reranker, name="Reranker", inputs=["Retriever"])
p.add_node(component=reader, name="Reader", inputs=["Reranker"])
# Load Dataset
data = load_dataset("kilt_tasks", "nq")
validation_data = data["validation"]
# Run Pipeline
retriever_top_k = 100
reranker_top_k = 50
all_results = []
efficiency_metrics = []
for example in tqdm(validation_data):
results = p.run(
query=example["input"],
params={"Retriever": {"top_k": retriever_top_k}, "Reranker": {"top_k": reranker_top_k}},
)
pipeline_latency_report = get_timing_from_pipeline(p)
efficiency_metrics.append(
{
component_name: component_time[1]
for component_name, component_time in pipeline_latency_report.items()
}
)
all_results.append(results)
kilt_metrics = evaluate_from_answers(validation_data, all_results)
# Show Results
efficiency_metrics_df = pd.DataFrame(efficiency_metrics)
efficiency_metrics_df_mean = efficiency_metrics_df.mean()
for metric in efficiency_metrics_df.columns:
logging.info(f"Mean Latency for {metric} examples: {efficiency_metrics_df_mean[metric]} sec")
logging.info(
f"""
Accuracy: {kilt_metrics['downstream']['accuracy']}
EM: {kilt_metrics['downstream']['em']}
F1: {kilt_metrics['downstream']['f1']}
ROUGE-L: {kilt_metrics['downstream']['rougel']}
"""
)