Skip to content

Commit

Permalink
feat: avoid skip tool test and avoid empty tool prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Aug 26, 2024
1 parent 5b34898 commit e7442fe
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 40 deletions.
7 changes: 6 additions & 1 deletion clients/python/text_generation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,12 @@ async def _chat_stream_response(self, request):
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
payload_data = (
payload.lstrip("data:").rstrip("\n").removeprefix(" ")
)
if payload_data == "[DONE]":
break
json_payload = json.loads(payload_data)
try:
response = ChatCompletionChunk(**json_payload)
yield response
Expand Down
50 changes: 18 additions & 32 deletions integration-tests/models/test_tools_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
},
},
"required": ["location", "format"],
"additionalProperties": False,
},
},
},
Expand All @@ -62,21 +63,21 @@ async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
},
},
"required": ["location", "format", "num_days"],
"additionalProperties": False,
},
},
},
]


@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
presence_penalty=-1.1,
temperature=0.0,
messages=[
{
"role": "system",
Expand All @@ -91,19 +92,18 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]
assert response == response_snapshot


@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
Expand All @@ -113,8 +113,8 @@ async def test_flash_llama_grammar_tools_auto(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="auto",
presence_penalty=-1.1,
messages=[
{
"role": "system",
Expand All @@ -129,20 +129,19 @@ async def test_flash_llama_grammar_tools_auto(
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]

assert response == response_snapshot


@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
Expand All @@ -152,8 +151,8 @@ async def test_flash_llama_grammar_tools_choice(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
Expand All @@ -168,20 +167,19 @@ async def test_flash_llama_grammar_tools_choice(
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"id": "0",
"type": "function",
"function": {
"description": None,
"name": "get_current_weather",
"arguments": {"format": "celsius", "location": "New York, NY"},
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
},
}
]

assert response == response_snapshot


@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
Expand All @@ -191,8 +189,8 @@ async def test_flash_llama_grammar_tools_stream(
max_tokens=100,
seed=1,
tools=tools,
temperature=0.0,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
Expand All @@ -210,25 +208,24 @@ async def test_flash_llama_grammar_tools_stream(
async for response in responses:
count += 1

assert count == 38
assert count == 48
assert response == response_snapshot


@pytest.mark.skip(reason="Takes too long to run")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=8,
seed=24,
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
Expand All @@ -239,18 +236,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
)

assert responses.choices[0].message.content is None
assert responses.choices[0].message.tool_calls == [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": None,
"name": "notify_error",
},
"id": 0,
"type": "function",
}
]

assert (
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
)
assert responses == response_snapshot
8 changes: 3 additions & 5 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ pub(crate) struct ChatRequest {
pub tools: Option<Vec<Tool>>,

/// A prompt to be appended before the tools
#[serde(default = "default_tool_prompt")]
#[serde(default)]
#[schema(
nullable = true,
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
Expand All @@ -865,10 +865,8 @@ pub(crate) struct ChatRequest {
pub guideline: Option<String>,
}

fn default_tool_prompt() -> Option<String> {
Some(
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string(),
)
pub fn default_tool_prompt() -> String {
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
}

#[derive(Clone, Debug, Deserialize, PartialEq, Serialize, ToSchema)]
Expand Down
6 changes: 4 additions & 2 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError;
use crate::ChatTokenizeResponse;
use crate::{default_tool_prompt, ChatTokenizeResponse};
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
Expand Down Expand Up @@ -1158,7 +1158,9 @@ async fn chat_completions(
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt.unwrap_or_default();
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Expand Down

0 comments on commit e7442fe

Please sign in to comment.