Skip to content

Commit

Permalink
Validate benchmark topics, qrels, and folds (#178)
Browse files Browse the repository at this point in the history
* validation for benchmark (topic; qrel; and files)
  • Loading branch information
crystina-z authored Aug 17, 2021
1 parent d6b5d10 commit 1ef78b2
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 10 deletions.
122 changes: 120 additions & 2 deletions capreolus/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,114 @@
import json
import os
import json
from copy import deepcopy
from collections import defaultdict

import ir_datasets

from capreolus import ModuleBase
from capreolus.utils.caching import cached_file, TargetFileExists
from capreolus.utils.trec import load_qrels, load_trec_topics
from capreolus.utils.trec import write_qrels, load_qrels, load_trec_topics
from capreolus.utils.loginit import get_logger


logger = get_logger(__name__)


def validate(build_f):
def validate_folds_file(self):
if not hasattr(self, "fold_file"):
logger.warning(f"Folds file is not found for Module {self.module_name}")
return

if self.fold_file.suffix != ".json":
raise ValueError(f"Expect folds file to be in .json format.")

raw_folds = json.load(open(self.fold_file))
# we actually don't need to verify the name of folds right?

for fold_name, fold_sets in raw_folds.items():
if set(fold_sets) != {"train_qids", "predict"}:
raise ValueError(f"Expect each fold to contain ['train_qids', 'predict'] fields.")

if set(fold_sets["predict"]) != {"dev", "test"}:
raise ValueError(f"Expect each fold to contain ['dev', 'test'] fields under 'predict'.")
logger.info("Folds file validation finishes.")

def validate_qrels_file(self):
if not hasattr(self, "qrel_file"):
logger.warning(f"Qrel file is not found for Module {self.module_name}")
return

n_dup, qrels = 0, defaultdict(dict)
with open(self.qrel_file) as f:
for line in f:
qid, _, docid, label = line.strip().split()
if docid in qrels[qid]:
n_dup += 1
if int(label) != qrels[qid][docid]:
raise ValueError(f"Found conflicting label in {self.qrel_file} for query {qid} and document {docid}.")
qrels[qid][docid] = int(label)

if n_dup > 0:
qrel_file_no_ext, ext = os.path.splitext(self.qrel_file)
dup_qrel_file = qrel_file_no_ext + "-contain-dup-entries" + ext
os.rename(self.qrel_file, dup_qrel_file)
write_qrels(qrels, self.qrel_file)
logger.warning(
f"Removed {n_dup} entries from the file {self.qrel_file}. The original version could be found in {dup_qrel_file}."
)

logger.info("Qrel file validation finishes.")

def validate_query_alignment(self):
topic_qids = set(self.topics[self.query_type])
qrels_qids = set(self.qrels)

for fold_name, fold_sets in self.folds.items():
# check if there are overlap between training, dev, and test set
train_qids, dev_qids, test_qids = (
set(fold_sets["train_qids"]),
set(fold_sets["predict"]["dev"]),
set(fold_sets["predict"]["test"]),
)
if len(train_qids & dev_qids) > 0:
logger.warning(
f"Found {len(train_qids & dev_qids)} overlap queries between training and dev set in fold {fold_name}."
)
if len(train_qids & test_qids) > 0:
logger.warning(
f"Found {len(train_qids & dev_qids)} overlap queries between training and dev set in fold {fold_name}."
)
if len(dev_qids & test_qids) > 0:
logger.warning(
f"Found {len(train_qids & dev_qids)} overlap queries between training and dev set in fold {fold_name}."
)

# check if the topics, qrels, and folds file share a reasonable set (if not all) of queries
folds_qids = train_qids | dev_qids | test_qids
n_overlap = len(set(topic_qids) & set(qrels_qids) & set(folds_qids))
if not len(topic_qids) == len(qrels_qids) == len(folds_qids) == n_overlap:
logger.warning(
f"Number of queries are not aligned across topics, qrels and folds in fold {fold_name}: {len(topic_qids)} queries in topics file, {len(qrels_qids)} queries in qrels file, {len(folds_qids)} queries in folds file; {n_overlap} overlap queries found among the three."
)

# check if any topic in folds cannot be found in topics file
for set_name, set_qids in zip(["training", "dev", "test"], [train_qids, dev_qids, test_qids]):
if len(set_qids - topic_qids) > 0:
raise ValueError(
f"{len(set_qids - topic_qids)} queries in {set_name} set of fold {fold_name} cannot be found in topic file."
)

logger.info("Query Alignment validation finishes.")

def _validate(self):
"""Rewrite the files that contain invalid (duplicate) entries, and remove the currently loaded variables"""
build_f(self)
validate_folds_file(self)
validate_qrels_file(self)
validate_query_alignment(self)

return _validate


class Benchmark(ModuleBase):
Expand All @@ -26,6 +129,9 @@ class Benchmark(ModuleBase):
relevance_level = 1
""" Documents with a relevance label >= relevance_level will be considered relevant.
This corresponds to trec_eval's --level_for_rel (and is passed to pytrec_eval as relevance_level). """
use_train_as_dev = True
""" Whether to use training set as validate set when there is no training needed,
e.g. for traditional IR algorithms like BM25 """

@property
def qrels(self):
Expand All @@ -45,6 +151,14 @@ def folds(self):
self._folds = json.load(open(self.fold_file, "rt"), parse_int=str)
return self._folds

@property
def non_nn_dev(self):
dev_per_fold = {fold_name: deepcopy(folds["predict"]["dev"]) for fold_name, folds in self.folds.items()}
if self.use_train_as_dev:
for fold_name, folds in self.folds.items():
dev_per_fold[fold_name].extend(folds["train_qids"])
return dev_per_fold

def get_topics_file(self, query_sets=None):
"""Returns path to a topics file in TSV format containing queries from query_sets.
query_sets may contain any combination of 'train', 'dev', and 'test'.
Expand Down Expand Up @@ -81,6 +195,10 @@ def get_topics_file(self, query_sets=None):

return fn

@validate
def build(self):
return


class IRDBenchmark(Benchmark):
ird_dataset_names = []
Expand Down
3 changes: 2 additions & 1 deletion capreolus/benchmark/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from capreolus import Dependency, constants

from . import Benchmark, IRDBenchmark
from . import Benchmark, IRDBenchmark, validate

PACKAGE_PATH = constants["PACKAGE_PATH"]

Expand All @@ -16,6 +16,7 @@ class CDS(IRDBenchmark):
query_type = "summary"
query_types = {} # diagnosis, treatment, or test

@validate
def build(self):
self.topics

Expand Down
3 changes: 2 additions & 1 deletion capreolus/benchmark/codesearchnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from capreolus.utils.loginit import get_logger
from capreolus.utils.trec import topic_to_trectxt

from . import Benchmark
from . import Benchmark, validate

logger = get_logger(__name__)
PACKAGE_PATH = constants["PACKAGE_PATH"]
Expand Down Expand Up @@ -41,6 +41,7 @@ class CodeSearchNetCorpus(Benchmark):

config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")]

@validate
def build(self):
lang = self.config["lang"]

Expand Down
3 changes: 2 additions & 1 deletion capreolus/benchmark/covid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from capreolus.utils.loginit import get_logger
from capreolus.utils.trec import load_qrels, topic_to_trectxt

from . import Benchmark
from . import Benchmark, validate

logger = get_logger(__name__)
PACKAGE_PATH = constants["PACKAGE_PATH"]
Expand All @@ -30,6 +30,7 @@ class COVID(Benchmark):

config_spec = [ConfigOption("udelqexpand", False), ConfigOption("useprevqrels", True)]

@validate
def build(self):
if self.collection.config["round"] == self.lastest_round and not self.config["useprevqrels"]:
logger.warning(f"No evaluation can be done for the lastest round without using previous qrels")
Expand Down
3 changes: 2 additions & 1 deletion capreolus/benchmark/nf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from capreolus.utils.loginit import get_logger
from capreolus.utils.trec import topic_to_trectxt

from . import Benchmark
from . import Benchmark, validate

logger = get_logger(__name__)
PACKAGE_PATH = constants["PACKAGE_PATH"]
Expand Down Expand Up @@ -34,6 +34,7 @@ class NF(Benchmark):

query_type = "title"

@validate
def build(self):
fields, label_range = self.config["fields"], self.config["labelrange"]
self.field2kws = {
Expand Down
14 changes: 10 additions & 4 deletions capreolus/searcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,23 @@ def load_trec_run(fn):
run = OrderedDefaultDict()

with open(fn, "rt") as f:
for line in f:
for i, line in enumerate(f):
line = line.strip()
if len(line) > 0:
qid, _, docid, rank, score, desc = line.split(" ")
try:
qid, _, docid, rank, score, desc = line.split()
except ValueError as e:
logger.error(
f"Encountered malformated line when reading {fn} [Line #{i}], possibly because the writing to runfile was interruptded."
)
raise e
run[qid][docid] = float(score)
return run

@staticmethod
def write_trec_run(preds, outfn):
def write_trec_run(preds, outfn, mode="wt"):
count = 0
with open(outfn, "wt") as outf:
with open(outfn, mode) as outf:
qids = sorted(preds.keys(), key=lambda k: int(k))
for qid in qids:
rank = 1
Expand Down

0 comments on commit 1ef78b2

Please sign in to comment.