Skip to content

Commit

Permalink
Denserank - add support for RepBERT and FAISS
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Martin Jose committed Aug 22, 2021
1 parent 1767d5a commit 7a27a76
Show file tree
Hide file tree
Showing 46 changed files with 613,314 additions and 293 deletions.
2 changes: 0 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ include capreolus/data/antique.json
include capreolus/data/dummy/data/dummy_trec_doc
include capreolus/data/dummy_folds.json
include capreolus/data/dummy.yaml
include capreolus/data/msmarcopassage.folds.json
include capreolus/data/qrels.antique.txt
include capreolus/data/qrels.antique_test.txt
include capreolus/data/qrels.dummy.txt
include capreolus/data/qrels.msmarcopassage.txt
include capreolus/data/qrels.robust2004.txt
include capreolus/data/rob04_cedr_folds.json
include capreolus/data/rob04_yang19_folds.json
Expand Down
2 changes: 0 additions & 2 deletions capreolus/.gitignore

This file was deleted.

9 changes: 7 additions & 2 deletions capreolus/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,12 @@ class Benchmark(ModuleBase):
relevance_level = 1
""" Documents with a relevance label >= relevance_level will be considered relevant.
This corresponds to trec_eval's --level_for_rel (and is passed to pytrec_eval as relevance_level). """
use_train_as_dev = True
""" Whether to use training set as validate set when there is no training needed,
use_train_as_dev = False
""" Whether to use training set as validate set when there is no training needed,
e.g. for traditional IR algorithms like BM25 """
need_pooling = False
"""Some benchmarks consists of documents that are really passages, and the final score will have to aggregate
scores from all passages belonging to the same document. This property indicates if such pooling is required"""

@property
def qrels(self):
Expand Down Expand Up @@ -188,6 +191,8 @@ def get_topics_file(self, query_sets=None):
with cached_file(fn) as tmp_fn:
with open(tmp_fn, "wt") as outf:
for qid, query in self.topics[self.query_type].items():
if not query.strip():
continue
if query_sets == "all" or qid in valid_qids:
print(f"{qid}\t{query}", file=outf)
except TargetFileExists as e:
Expand Down
25 changes: 24 additions & 1 deletion capreolus/benchmark/gov2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from capreolus import Dependency, constants
from capreolus import Dependency, constants, ConfigOption

from . import Benchmark, IRDBenchmark

Expand All @@ -14,7 +14,30 @@ class Gov2(IRDBenchmark):
fold_file = PACKAGE_PATH / "data" / "gov2_maxp_folds.json"


@Benchmark.register
class Gov2Passages(IRDBenchmark):
module_name = "gov2passages"
query_type = "title"
ird_dataset_names = ["gov2/trec-tb-2004", "gov2/trec-tb-2005", "gov2/trec-tb-2006"]
dependencies = [Dependency(key="collection", module="collection", name="gov2passages")]
config_spec = [ConfigOption("pool", "max", "Strategy used to aggregate passage level scores")]
fold_file = PACKAGE_PATH / "data" / "gov2_maxp_folds.json"
need_pooling = True


@Benchmark.register
class Gov2Desc(Gov2):
module_name = "gov2.desc"
query_type = "desc"


@Benchmark.register
class MQ2007Passages(Benchmark):
module_name = "mq2007passages"
dependencies = [Dependency(key="collection", module="collection", name="gov2passages")]
config_spec = [ConfigOption("pool", "max", "Strategy used to aggregate passage level scores")]
qrel_file = PACKAGE_PATH / "data" / "qrels.mq2007.txt"
topic_file = PACKAGE_PATH / "data" / "topics.mq2007.txt"
fold_file = PACKAGE_PATH / "data" / "mq2007.folds.json"
query_type = "title"
need_pooling = True
14 changes: 0 additions & 14 deletions capreolus/benchmark/msmarco.py

This file was deleted.

27 changes: 26 additions & 1 deletion capreolus/benchmark/robust04.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from capreolus import Dependency, constants
from capreolus import Dependency, constants, ConfigOption

from . import Benchmark

Expand Down Expand Up @@ -44,6 +44,30 @@ class Robust04Yang19Desc(Robust04Yang19, Benchmark):
query_type = "desc"


@Benchmark.register
class Robust04Passages(Benchmark):
"""
Split robust04 into passages
"""
module_name = "robust04passages"
dependencies = [Dependency(key="collection", module="collection", name="robust04passages")]
config_spec = [ConfigOption("pool", "max", "Strategy used to aggregate passage level scores")]
qrel_file = PACKAGE_PATH / "data" / "qrels.robust2004.txt"
topic_file = PACKAGE_PATH / "data" / "topics.robust04.301-450.601-700.txt"
fold_file = PACKAGE_PATH / "data" / "rob04_cedr_folds.json"
query_type = "title"
need_pooling = True


@Benchmark.register
class Robust04PassagesDesc(Robust04Passages, Benchmark):
"""
Split robust04 into passages
"""
module_name = "robust04passagesdesc"
query_type = "desc"


@Benchmark.register
class Robust04Huston14(Benchmark):
module_name = "robust04.huston14.title"
Expand All @@ -59,3 +83,4 @@ class Robust04Huston14Desc(Robust04Huston14, Benchmark):
module_name = "robust04.huston14.desc"
fold_file = PACKAGE_PATH / "data" / "rob04_huston14_desc_folds.json"
query_type = "desc"

47 changes: 46 additions & 1 deletion capreolus/collection/gov2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from capreolus import constants
import os
import shutil
import tarfile

from capreolus import ConfigOption, constants, Dependency
from capreolus.utils.common import download_file
from capreolus.utils.loginit import get_logger
from capreolus.utils.trec import anserini_index_to_trec_docs

from . import Collection, IRDCollection

Expand All @@ -13,3 +19,42 @@ class Gov2(IRDCollection):
module_name = "gov2"
ird_dataset_name = "gov2"
collection_type = "TrecwebCollection"


@Collection.register
class Gov2Passages(Collection):
""" TREC Robust04 (TREC disks 4 and 5 without the Congressional Record documents) """

module_name = "gov2passages"
collection_type = "TrecCollection"
generator_type = "DefaultLuceneDocumentGenerator"
config_keys_not_in_path = ["path"]
config_spec = [ConfigOption("path", "/GW/NeuralIR/nobackup/GOV2/GOV2_data", "path to corpus")]
dependencies = [Dependency(key="task", module="task", name="gov2passages")]

def download_if_missing(self):
target_dir = os.path.join(self.task.get_cache_path(), "generated")
if os.path.isdir(target_dir):
return target_dir

return self.download_index()

def _validate_document_path(self, path):
"""
Validate that the document path appears to contain robust04's documents (Aquaint-TREC-3-4).
"""

if not os.path.isdir(path):
return False

contents = {fn.lower(): fn for fn in os.listdir(path)}

if "generated" in contents:
return True

return False

def download_index(self):
self.task.generate()

return os.path.join(self.task.get_cache_path(), "generated")
7 changes: 0 additions & 7 deletions capreolus/collection/msmarco.py

This file was deleted.

52 changes: 51 additions & 1 deletion capreolus/collection/robust04.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil
import tarfile

from capreolus import ConfigOption, constants
from capreolus import ConfigOption, constants, Dependency
from capreolus.utils.common import download_file
from capreolus.utils.loginit import get_logger
from capreolus.utils.trec import anserini_index_to_trec_docs
Expand Down Expand Up @@ -100,3 +100,53 @@ def download_index(
print("", file=outf)

return document_dir


@Collection.register
class Robust04Passages(Collection):
"""
TREC Robust04 (TREC disks 4 and 5 without the Congressional Record documents).
Splits each document in Robust04 into passages and indexes them separately.
"""

module_name = "robust04passages"
collection_type = "TrecCollection"
generator_type = "DefaultLuceneDocumentGenerator"
config_keys_not_in_path = ["path"]
config_spec = [ConfigOption("path", "Aquaint-TREC-3-4", "path to corpus")]
dependencies = [Dependency(key="task", module="task", name="robust04passages")]

def download_if_missing(self):
target_dir = os.path.join(self.task.get_cache_path(), "generated")
if os.path.isdir(target_dir):
return target_dir

return self.download_index()

def _validate_document_path(self, path):
"""Validate that the document path appears to contain robust04's documents (Aquaint-TREC-3-4).
Validation is performed by looking for four directories (case-insensitive): `FBIS`, `FR94`, `FT`, and `LATIMES`.
These directories may either be at the root of `path` or they may be in `path/NEWS_data` (case-insensitive).
Returns:
True if the Aquaint-TREC-3-4 document directories are found or False if not
"""

if not os.path.isdir(path):
return False

contents = {fn.lower(): fn for fn in os.listdir(path)}

if "generated" in contents:
return True

return False

def download_index(self):
self.task.generate()

return os.path.join(self.task.get_cache_path(), "generated")

# Download the collection from URL and extract into a path in the cache directory.
# To avoid re-downloading every call, we create an empty '/done' file in this directory on success.
Loading

0 comments on commit 7a27a76

Please sign in to comment.