Skip to content

Commit

Permalink
Merge pull request #1311 from pengjunfeng11/master
Browse files Browse the repository at this point in the history
Update model_mapping.py
  • Loading branch information
545999961 authored Jan 7, 2025
2 parents 0498828 + 3c2d321 commit d4735c3
Showing 1 changed file with 40 additions and 9 deletions.
49 changes: 40 additions & 9 deletions FlagEmbedding/inference/embedder/model_mapping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Type
from typing import Type, List
from dataclasses import dataclass
from collections import OrderedDict

Expand Down Expand Up @@ -36,8 +36,8 @@ class EmbedderConfig:
query_instruction_format: str = "{}{}"


AUTO_EMBEDDER_MAPPING = OrderedDict([
# ============================== BGE ==============================
# BGE models mapping
BGE_MAPPING = OrderedDict([
(
"bge-en-icl",
EmbedderConfig(FlagICLModel, PoolingMethod.LAST_TOKEN, query_instruction_format="<instruct>{}\n<query>{}")
Expand Down Expand Up @@ -98,7 +98,10 @@ class EmbedderConfig:
"bge-small-zh",
EmbedderConfig(FlagModel, PoolingMethod.CLS)
),
# ============================== E5 ==============================
])

# E5 models mapping
E5_MAPPING = OrderedDict([
(
"e5-mistral-7b-instruct",
EmbedderConfig(FlagLLMModel, PoolingMethod.LAST_TOKEN, query_instruction_format="Instruct: {}\nQuery: {}")
Expand Down Expand Up @@ -143,7 +146,10 @@ class EmbedderConfig:
"e5-small",
EmbedderConfig(FlagModel, PoolingMethod.MEAN)
),
# ============================== GTE ==============================
])

# GTE models mapping
GTE_MAPPING = OrderedDict([
(
"gte-Qwen2-7B-instruct",
EmbedderConfig(FlagLLMModel, PoolingMethod.LAST_TOKEN, trust_remote_code=True, query_instruction_format="Instruct: {}\nQuery: {}")
Expand Down Expand Up @@ -192,7 +198,10 @@ class EmbedderConfig:
'gte-small-zh',
EmbedderConfig(FlagModel, PoolingMethod.CLS)
),
# ============================== SFR ==============================
])

# SFR models mapping
SFR_MAPPING = OrderedDict([
(
'SFR-Embedding-2_R',
EmbedderConfig(FlagLLMModel, PoolingMethod.LAST_TOKEN, query_instruction_format="Instruct: {}\nQuery: {}")
Expand All @@ -201,15 +210,37 @@ class EmbedderConfig:
'SFR-Embedding-Mistral',
EmbedderConfig(FlagLLMModel, PoolingMethod.LAST_TOKEN, query_instruction_format="Instruct: {}\nQuery: {}")
),
# ============================== Linq ==============================
])

# Linq models mapping
LINQ_MAPPING = OrderedDict([
(
'Linq-Embed-Mistral',
EmbedderConfig(FlagLLMModel, PoolingMethod.LAST_TOKEN, query_instruction_format="Instruct: {}\nQuery: {}")
),
# ============================== BCE ==============================
])

# BCE models mapping
BCE_MAPPING = OrderedDict([
(
'bce-embedding-base_v1',
EmbedderConfig(FlagModel, PoolingMethod.CLS)
),
# TODO: Add more models, such as Jina, Stella_v5, NV-Embed, etc.
])

# Combine all mappings
AUTO_EMBEDDER_MAPPING = OrderedDict()
AUTO_EMBEDDER_MAPPING.update(BGE_MAPPING)
AUTO_EMBEDDER_MAPPING.update(E5_MAPPING)
AUTO_EMBEDDER_MAPPING.update(GTE_MAPPING)
AUTO_EMBEDDER_MAPPING.update(SFR_MAPPING)
AUTO_EMBEDDER_MAPPING.update(LINQ_MAPPING)
AUTO_EMBEDDER_MAPPING.update(BCE_MAPPING)

# TODO: Add more models, such as Jina, Stella_v5, NV-Embed, etc.

def support_native_bge_model_list()->List[str]:
return list(BGE_MAPPING.keys())

def support_model_list()->List[str]:
return list(AUTO_EMBEDDER_MAPPING.keys())

0 comments on commit d4735c3

Please sign in to comment.