From 8f381d2f33b351486fe94365256b28efde5db9cd Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Wed, 1 Jan 2025 21:59:30 +0300 Subject: [PATCH 1/2] allow kwargs in init --- mteb/models/rerankers_custom.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mteb/models/rerankers_custom.py b/mteb/models/rerankers_custom.py index 40977f1e04..b1425fa313 100644 --- a/mteb/models/rerankers_custom.py +++ b/mteb/models/rerankers_custom.py @@ -22,6 +22,7 @@ def __init__( batch_size: int = 4, fp_options: bool = None, silent: bool = False, + **kwargs, ): self.model_name_or_path = model_name_or_path self.batch_size = batch_size @@ -34,7 +35,7 @@ def __init__( self.fp_options = torch.float32 elif self.fp_options == "bfloat16": self.fp_options = torch.bfloat16 - print(f"Using fp_options of {self.fp_options}") + logger.info(f"Using fp_options of {self.fp_options}") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.silent = silent self.first_print = True # for debugging From b95a9acf4241cdc100618dc4061229d47bb6b824 Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Wed, 1 Jan 2025 23:46:50 +0300 Subject: [PATCH 2/2] fix retrieval --- .../evaluators/RetrievalEvaluator.py | 12 ++++---- mteb/models/rerankers_custom.py | 29 +++++++++++++++---- mteb/models/rerankers_monot5_based.py | 15 ++++++++-- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/mteb/evaluation/evaluators/RetrievalEvaluator.py b/mteb/evaluation/evaluators/RetrievalEvaluator.py index 70f26e2236..0cb2311f98 100644 --- a/mteb/evaluation/evaluators/RetrievalEvaluator.py +++ b/mteb/evaluation/evaluators/RetrievalEvaluator.py @@ -43,7 +43,7 @@ def corpus_to_str( else corpus["text"][i].strip() for i in range(len(corpus["text"])) ] - elif isinstance(corpus, list) and isinstance(corpus[0], dict): + elif isinstance(corpus, (list, tuple)) and isinstance(corpus[0], dict): sentences = [ (doc["title"] + " " + doc["text"]).strip() if "title" in doc @@ -307,15 +307,17 @@ def search_cross_encoder( assert ( len(queries_in_pair) == len(corpus_in_pair) == len(instructions_in_pair) ) + corpus_in_pair = corpus_to_str(corpus_in_pair) if hasattr(self.model, "model") and isinstance( self.model.model, CrossEncoder ): # can't take instructions, so add them here - queries_in_pair = [ - f"{q} {i}".strip() - for i, q in zip(instructions_in_pair, queries_in_pair) - ] + if instructions_in_pair[0] is not None: + queries_in_pair = [ + f"{q} {i}".strip() + for i, q in zip(instructions_in_pair, queries_in_pair) + ] scores = self.model.predict(list(zip(queries_in_pair, corpus_in_pair))) # type: ignore else: # may use the instructions in a unique way, so give them also diff --git a/mteb/models/rerankers_custom.py b/mteb/models/rerankers_custom.py index b1425fa313..e8bb483a3d 100644 --- a/mteb/models/rerankers_custom.py +++ b/mteb/models/rerankers_custom.py @@ -71,7 +71,12 @@ def __init__( @torch.inference_mode() def predict(self, input_to_rerank, **kwargs): - queries, passages, instructions = list(zip(*input_to_rerank)) + inputs = list(zip(*input_to_rerank)) + if len(input_to_rerank[0]) == 2: + queries, passages = inputs + instructions = None + else: + queries, passages, instructions = inputs if instructions is not None and instructions[0] is not None: assert len(instructions) == len(queries) queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)] @@ -113,7 +118,13 @@ def __init__( @torch.inference_mode() def predict(self, input_to_rerank, **kwargs): - queries, passages, instructions = list(zip(*input_to_rerank)) + inputs = list(zip(*input_to_rerank)) + if len(input_to_rerank[0]) == 2: + queries, passages = inputs + instructions = None + else: + queries, passages, instructions = inputs + if instructions is not None and instructions[0] is not None: queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)] @@ -153,7 +164,13 @@ def __init__( ) def predict(self, input_to_rerank, **kwargs): - queries, passages, instructions = list(zip(*input_to_rerank)) + inputs = list(zip(*input_to_rerank)) + if len(input_to_rerank[0]) == 2: + queries, passages = inputs + instructions = None + else: + queries, passages, instructions = inputs + if instructions is not None and instructions[0] is not None: queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)] @@ -180,7 +197,7 @@ def loader_inner(**kwargs: Any) -> Encoder: _loader, wrapper=MonoBERTReranker, model_name_or_path="castorini/monobert-large-msmarco", - fp_options="float1616", + fp_options="float16", ), name="castorini/monobert-large-msmarco", languages=["eng_Latn"], @@ -195,7 +212,7 @@ def loader_inner(**kwargs: Any) -> Encoder: _loader, wrapper=JinaReranker, model_name_or_path="jinaai/jina-reranker-v2-base-multilingual", - fp_options="float1616", + fp_options="float16", ), name="jinaai/jina-reranker-v2-base-multilingual", languages=["eng_Latn"], @@ -209,7 +226,7 @@ def loader_inner(**kwargs: Any) -> Encoder: _loader, wrapper=BGEReranker, model_name_or_path="BAAI/bge-reranker-v2-m3", - fp_options="float1616", + fp_options="float16", ), name="BAAI/bge-reranker-v2-m3", languages=[ diff --git a/mteb/models/rerankers_monot5_based.py b/mteb/models/rerankers_monot5_based.py index 7ece40e3cf..d72a893406 100644 --- a/mteb/models/rerankers_monot5_based.py +++ b/mteb/models/rerankers_monot5_based.py @@ -105,7 +105,12 @@ def get_prediction_tokens( @torch.inference_mode() def predict(self, input_to_rerank, **kwargs): - queries, passages, instructions = list(zip(*input_to_rerank)) + inputs = list(zip(*input_to_rerank)) + if len(input_to_rerank[0]) == 2: + queries, passages = inputs + instructions = None + else: + queries, passages, instructions = inputs if instructions is not None and instructions[0] is not None: queries = [f"{q} {i}".strip() for i, q in zip(instructions, queries)] @@ -194,7 +199,13 @@ def __init__( @torch.inference_mode() def predict(self, input_to_rerank, **kwargs): - queries, passages, instructions = list(zip(*input_to_rerank)) + inputs = list(zip(*input_to_rerank)) + if len(input_to_rerank[0]) == 2: + queries, passages = inputs + instructions = None + else: + queries, passages, instructions = inputs + if instructions is not None and instructions[0] is not None: # logger.info(f"Adding instructions to LLAMA queries") queries = [