Skip to content

Commit

Permalink
Merge pull request #14 from lm-sys/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
icowan authored Nov 7, 2024
2 parents 8959971 + 185e1a9 commit ade9ba8
Show file tree
Hide file tree
Showing 24 changed files with 1,776 additions and 580 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# FastChat
| [**Demo**](https://lmarena.ai/) | [**Discord**](https://discord.gg/HSWAKCrnFx) | [**X**](https://x.com/lmsysorg) |
| [**Demo**](https://lmarena.ai/) | [**Discord**](https://discord.gg/6GXcFg3TH8) | [**X**](https://x.com/lmsysorg) |

FastChat is an open platform for training, serving, and evaluating large language model based chatbots.
- FastChat powers Chatbot Arena ([lmarena.ai](https://lmarena.ai)), serving over 10 million chat requests for 70+ LLMs.
Expand Down
13 changes: 8 additions & 5 deletions fastchat/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

REPO_PATH = os.path.dirname(os.path.dirname(__file__))

# Survey Link URL (to be removed)
SURVEY_LINK = """<div style='text-align: center; margin: 20px 0;'>
<div style='display: inline-block; border: 2px solid #DE3163; padding: 10px; border-radius: 5px;'>
<span style='color: #DE3163; font-weight: bold;'>We would love your feedback! Fill out <a href='https://docs.google.com/forms/d/e/1FAIpQLSfKSxwFOW6qD05phh4fwYjk8q0YV1VQe_bmK0_qOVTbC66_MA/viewform?usp=sf_link' style='color: #DE3163; text-decoration: underline;'>this short survey</a> to tell us what you like about the arena, what you don't like, and what you want to see in the future.</span>
# Survey Link URL (to be removed) #00729c
SURVEY_LINK = """<div style='text-align: left; margin: 20px 0;'>
<div style='display: inline-block; border: 2px solid #C41E3A; padding: 20px; padding-bottom: 10px; padding-top: 10px; border-radius: 5px;'>
<span style='color: #C41E3A; font-weight: bold;'>New Launch! Jailbreak models at <a href='https://redarena.ai' style='color: #C41E3A; text-decoration: underline;'>RedTeam Arena</a>. </span>
</div>
</div>"""
# SURVEY_LINK = ""

##### For the gradio web server
SERVER_ERROR_MSG = (
Expand All @@ -27,7 +28,9 @@
MODERATION_MSG = "$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES."
CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
SLOW_MODEL_MSG = "⚠️ Both models will show the responses all at once. Please stay patient as it may take over 30 seconds."
SLOW_MODEL_MSG = (
"⚠️ Models are thinking. Please stay patient as it may take over a minute."
)
RATE_LIMIT_MSG = "**RATE LIMIT OF THIS MODEL IS REACHED. PLEASE COME BACK LATER OR USE <span style='color: red; font-weight: bold;'>[BATTLE MODE](https://lmarena.ai)</span> (the 1st tab).**"
# Maximum input length
INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 12000))
Expand Down
242 changes: 221 additions & 21 deletions fastchat/conversation.py

Large diffs are not rendered by default.

91 changes: 74 additions & 17 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,13 @@
"gpt2-chatbot",
"im-also-a-good-gpt2-chatbot",
"im-a-good-gpt2-chatbot",
"gpt-4o-mini-2024-07-18",
"gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"chatgpt-4o-latest-20240903",
"chatgpt-4o-latest",
"o1-preview",
"o1-mini",
)


Expand Down Expand Up @@ -1118,8 +1124,20 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("gpt-4-turbo-2024-04-09")
if "gpt2-chatbot" in model_path:
return get_conv_template("gpt-4-turbo-2024-04-09")
if "gpt-4o" in model_path:
if "gpt-4o-2024-05-13" in model_path:
return get_conv_template("gpt-4-turbo-2024-04-09")
if "gpt-4o-2024-08-06" in model_path:
return get_conv_template("gpt-mini")
if "anonymous-chatbot" in model_path:
return get_conv_template("gpt-4-turbo-2024-04-09")
if "chatgpt-4o-latest" in model_path:
return get_conv_template("gpt-4-turbo-2024-04-09")
if "gpt-mini" in model_path:
return get_conv_template("gpt-mini")
if "gpt-4o-mini-2024-07-18" in model_path:
return get_conv_template("gpt-mini")
if "o1" in model_path:
return get_conv_template("api_based_default")
return get_conv_template("chatgpt")


Expand Down Expand Up @@ -1167,7 +1185,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
if "claude-3-sonnet" in model_path:
return get_conv_template("claude-3-sonnet-20240229")
if "claude-3-5-sonnet" in model_path:
return get_conv_template("claude-3-5-sonnet-20240620")
return get_conv_template("claude-3-5-sonnet-20240620-v2")
if "claude-3-opus" in model_path:
return get_conv_template("claude-3-opus-20240229")
return get_conv_template("claude")
Expand Down Expand Up @@ -1212,19 +1230,6 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("gemini")


class GeminiDevAdapter(BaseModelAdapter):
"""The model adapter for Gemini 1.5 Pro"""

def match(self, model_path: str):
return "gemini-1.5-pro" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
raise NotImplementedError()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("gemini-dev")


class BiLLaAdapter(BaseModelAdapter):
"""The model adapter for Neutralzz/BiLLa-7B-SFT"""

Expand Down Expand Up @@ -1589,7 +1594,7 @@ class Llama3Adapter(BaseModelAdapter):
"""The model adapter for Llama-3 (e.g., meta-llama/Meta-Llama-3-8B-Instruct)"""

def match(self, model_path: str):
return "llama-3" in model_path.lower()
return "llama-3-" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
Expand All @@ -1601,6 +1606,43 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-3")


class Llama31Adapter(BaseModelAdapter):
"""The model adapter for Llama-3 (e.g., meta-llama/Meta-Llama-3-8B-Instruct)"""

def match(self, model_path: str):
keywords = [
"llama-3.1",
]
for keyword in keywords:
if keyword in model_path.lower():
return True

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
if model_path.lower() in [
"llama-3.1-8b-instruct",
"llama-3.1-70b-instruct",
"the-real-chatbot-v2",
]:
return get_conv_template("meta-llama-3.1-sp")
return get_conv_template("meta-llama-3.1")


class GrokAdapter(BaseModelAdapter):
def match(self, model_path: str):
return "grok" in model_path.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
if "mini" in model_path.lower():
return get_conv_template("grok-2-mini")
return get_conv_template("grok-2")


class CuteGPTAdapter(BaseModelAdapter):
"""The model adapter for CuteGPT"""

Expand Down Expand Up @@ -2459,6 +2501,19 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("api_based_default")


class NoSystemAdapter(BaseModelAdapter):
def match(self, model_path: str):
keyword_list = ["athene-70b"]

for keyword in keyword_list:
if keyword == model_path.lower():
return True
return False

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("api_based_default")


# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(PeftModelAdapter)
Expand All @@ -2484,7 +2539,6 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(BardAdapter)
register_model_adapter(PaLM2Adapter)
register_model_adapter(GeminiAdapter)
register_model_adapter(GeminiDevAdapter)
register_model_adapter(GemmaAdapter)
register_model_adapter(ChatGPTAdapter)
register_model_adapter(AzureOpenAIAdapter)
Expand Down Expand Up @@ -2559,6 +2613,9 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(RekaAdapter)
register_model_adapter(SmaugChatAdapter)
register_model_adapter(Llama3Adapter)
register_model_adapter(Llama31Adapter)
register_model_adapter(GrokAdapter)
register_model_adapter(NoSystemAdapter)

# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)
Loading

0 comments on commit ade9ba8

Please sign in to comment.