Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow kwargs in init for RerankingWrapper #1676

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions mteb/evaluation/evaluators/RetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def corpus_to_str(
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
]
elif isinstance(corpus, list) and isinstance(corpus[0], dict):
elif isinstance(corpus, (list, tuple)) and isinstance(corpus[0], dict):
sentences = [
(doc["title"] + " " + doc["text"]).strip()
if "title" in doc
Expand Down Expand Up @@ -307,15 +307,17 @@ def search_cross_encoder(
assert (
len(queries_in_pair) == len(corpus_in_pair) == len(instructions_in_pair)
)
corpus_in_pair = corpus_to_str(corpus_in_pair)

if hasattr(self.model, "model") and isinstance(
self.model.model, CrossEncoder
):
# can't take instructions, so add them here
queries_in_pair = [
f"{q} {i}".strip()
for i, q in zip(instructions_in_pair, queries_in_pair)
]
if instructions_in_pair[0] is not None:
queries_in_pair = [
f"{q} {i}".strip()
for i, q in zip(instructions_in_pair, queries_in_pair)
]
scores = self.model.predict(list(zip(queries_in_pair, corpus_in_pair))) # type: ignore
else:
# may use the instructions in a unique way, so give them also
Expand Down
32 changes: 25 additions & 7 deletions mteb/models/rerankers_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
batch_size: int = 4,
fp_options: bool = None,
silent: bool = False,
**kwargs,
Samoed marked this conversation as resolved.
Show resolved Hide resolved
):
self.model_name_or_path = model_name_or_path
self.batch_size = batch_size
Expand All @@ -34,7 +35,7 @@ def __init__(
self.fp_options = torch.float32
elif self.fp_options == "bfloat16":
self.fp_options = torch.bfloat16
print(f"Using fp_options of {self.fp_options}")
logger.info(f"Using fp_options of {self.fp_options}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.silent = silent
self.first_print = True # for debugging
Expand Down Expand Up @@ -70,7 +71,12 @@ def __init__(

@torch.inference_mode()
def predict(self, input_to_rerank, **kwargs):
queries, passages, instructions = list(zip(*input_to_rerank))
inputs = list(zip(*input_to_rerank))
if len(input_to_rerank[0]) == 2:
queries, passages = inputs
instructions = None
else:
queries, passages, instructions = inputs
if instructions is not None and instructions[0] is not None:
assert len(instructions) == len(queries)
queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)]
Expand Down Expand Up @@ -112,7 +118,13 @@ def __init__(

@torch.inference_mode()
def predict(self, input_to_rerank, **kwargs):
queries, passages, instructions = list(zip(*input_to_rerank))
inputs = list(zip(*input_to_rerank))
if len(input_to_rerank[0]) == 2:
queries, passages = inputs
instructions = None
else:
queries, passages, instructions = inputs

if instructions is not None and instructions[0] is not None:
queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)]

Expand Down Expand Up @@ -152,7 +164,13 @@ def __init__(
)

def predict(self, input_to_rerank, **kwargs):
queries, passages, instructions = list(zip(*input_to_rerank))
inputs = list(zip(*input_to_rerank))
if len(input_to_rerank[0]) == 2:
queries, passages = inputs
instructions = None
else:
queries, passages, instructions = inputs

if instructions is not None and instructions[0] is not None:
queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)]

Expand All @@ -179,7 +197,7 @@ def loader_inner(**kwargs: Any) -> Encoder:
_loader,
wrapper=MonoBERTReranker,
model_name_or_path="castorini/monobert-large-msmarco",
fp_options="float1616",
fp_options="float16",
),
name="castorini/monobert-large-msmarco",
languages=["eng_Latn"],
Expand All @@ -194,7 +212,7 @@ def loader_inner(**kwargs: Any) -> Encoder:
_loader,
wrapper=JinaReranker,
model_name_or_path="jinaai/jina-reranker-v2-base-multilingual",
fp_options="float1616",
fp_options="float16",
),
name="jinaai/jina-reranker-v2-base-multilingual",
languages=["eng_Latn"],
Expand All @@ -208,7 +226,7 @@ def loader_inner(**kwargs: Any) -> Encoder:
_loader,
wrapper=BGEReranker,
model_name_or_path="BAAI/bge-reranker-v2-m3",
fp_options="float1616",
fp_options="float16",
),
name="BAAI/bge-reranker-v2-m3",
languages=[
Expand Down
15 changes: 13 additions & 2 deletions mteb/models/rerankers_monot5_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ def get_prediction_tokens(

@torch.inference_mode()
def predict(self, input_to_rerank, **kwargs):
queries, passages, instructions = list(zip(*input_to_rerank))
inputs = list(zip(*input_to_rerank))
if len(input_to_rerank[0]) == 2:
queries, passages = inputs
instructions = None
else:
queries, passages, instructions = inputs

if instructions is not None and instructions[0] is not None:
queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)]
Expand Down Expand Up @@ -194,7 +199,13 @@ def __init__(

@torch.inference_mode()
def predict(self, input_to_rerank, **kwargs):
queries, passages, instructions = list(zip(*input_to_rerank))
inputs = list(zip(*input_to_rerank))
if len(input_to_rerank[0]) == 2:
queries, passages = inputs
instructions = None
else:
queries, passages, instructions = inputs

if instructions is not None and instructions[0] is not None:
# logger.info(f"Adding instructions to LLAMA queries")
queries = [
Expand Down
Loading