From c2506f5e40625eb826fb70cb49fb6b13b207c66a Mon Sep 17 00:00:00 2001 From: pengjunfeng11 <34857167+pengjunfeng11@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:21:06 +0800 Subject: [PATCH 1/2] Update model_mapping.py REF: Refactor the code supporting the model list and provided two methods to return all models supported by the framework and the beg series models supported by the framework. --- .../inference/embedder/model_mapping.py | 47 +++++++++++++++---- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/FlagEmbedding/inference/embedder/model_mapping.py b/FlagEmbedding/inference/embedder/model_mapping.py index c17436bc..f9d727d7 100644 --- a/FlagEmbedding/inference/embedder/model_mapping.py +++ b/FlagEmbedding/inference/embedder/model_mapping.py @@ -36,8 +36,8 @@ class EmbedderConfig: query_instruction_format: str = "{}{}" -AUTO_EMBEDDER_MAPPING = OrderedDict([ - # ============================== BGE ============================== +# BGE models mapping +BGE_MAPPING = OrderedDict([ ( "bge-en-icl", EmbedderConfig(FlagICLModel, PoolingMethod.LAST_TOKEN, query_instruction_format="{}\n{}") @@ -98,7 +98,10 @@ class EmbedderConfig: "bge-small-zh", EmbedderConfig(FlagModel, PoolingMethod.CLS) ), - # ============================== E5 ============================== +]) + +# E5 models mapping +E5_MAPPING = OrderedDict([ ( "e5-mistral-7b-instruct", EmbedderConfig(FlagLLMModel, PoolingMethod.LAST_TOKEN, query_instruction_format="Instruct: {}\nQuery: {}") @@ -143,7 +146,10 @@ class EmbedderConfig: "e5-small", EmbedderConfig(FlagModel, PoolingMethod.MEAN) ), - # ============================== GTE ============================== +]) + +# GTE models mapping +GTE_MAPPING = OrderedDict([ ( "gte-Qwen2-7B-instruct", EmbedderConfig(FlagLLMModel, PoolingMethod.LAST_TOKEN, trust_remote_code=True, query_instruction_format="Instruct: {}\nQuery: {}") @@ -192,7 +198,10 @@ class EmbedderConfig: 'gte-small-zh', EmbedderConfig(FlagModel, PoolingMethod.CLS) ), - # ============================== SFR ============================== +]) + +# SFR models mapping +SFR_MAPPING = OrderedDict([ ( 'SFR-Embedding-2_R', EmbedderConfig(FlagLLMModel, PoolingMethod.LAST_TOKEN, query_instruction_format="Instruct: {}\nQuery: {}") @@ -201,15 +210,37 @@ class EmbedderConfig: 'SFR-Embedding-Mistral', EmbedderConfig(FlagLLMModel, PoolingMethod.LAST_TOKEN, query_instruction_format="Instruct: {}\nQuery: {}") ), - # ============================== Linq ============================== +]) + +# Linq models mapping +LINQ_MAPPING = OrderedDict([ ( 'Linq-Embed-Mistral', EmbedderConfig(FlagLLMModel, PoolingMethod.LAST_TOKEN, query_instruction_format="Instruct: {}\nQuery: {}") ), - # ============================== BCE ============================== +]) + +# BCE models mapping +BCE_MAPPING = OrderedDict([ ( 'bce-embedding-base_v1', EmbedderConfig(FlagModel, PoolingMethod.CLS) ), - # TODO: Add more models, such as Jina, Stella_v5, NV-Embed, etc. ]) + +# Combine all mappings +AUTO_EMBEDDER_MAPPING = OrderedDict() +AUTO_EMBEDDER_MAPPING.update(BGE_MAPPING) +AUTO_EMBEDDER_MAPPING.update(E5_MAPPING) +AUTO_EMBEDDER_MAPPING.update(GTE_MAPPING) +AUTO_EMBEDDER_MAPPING.update(SFR_MAPPING) +AUTO_EMBEDDER_MAPPING.update(LINQ_MAPPING) +AUTO_EMBEDDER_MAPPING.update(BCE_MAPPING) + +# TODO: Add more models, such as Jina, Stella_v5, NV-Embed, etc. + +def support_native_bge_model_list(): + return list(BGE_MAPPING.keys()) + +def support_model_list(): + return (AUTO_EMBEDDER_MAPPING.keys()) From 3c2d3215ea5528a7888b2f99849da7bb56d68cea Mon Sep 17 00:00:00 2001 From: pengjunfeng11 <34857167+pengjunfeng11@users.noreply.github.com> Date: Sat, 4 Jan 2025 16:11:31 +0800 Subject: [PATCH 2/2] Update model_mapping.py --- FlagEmbedding/inference/embedder/model_mapping.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/FlagEmbedding/inference/embedder/model_mapping.py b/FlagEmbedding/inference/embedder/model_mapping.py index f9d727d7..78c9253b 100644 --- a/FlagEmbedding/inference/embedder/model_mapping.py +++ b/FlagEmbedding/inference/embedder/model_mapping.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Type +from typing import Type, List from dataclasses import dataclass from collections import OrderedDict @@ -239,8 +239,8 @@ class EmbedderConfig: # TODO: Add more models, such as Jina, Stella_v5, NV-Embed, etc. -def support_native_bge_model_list(): +def support_native_bge_model_list()->List[str]: return list(BGE_MAPPING.keys()) -def support_model_list(): - return (AUTO_EMBEDDER_MAPPING.keys()) +def support_model_list()->List[str]: + return list(AUTO_EMBEDDER_MAPPING.keys())