Skip to content

Commit

Permalink
feat: support streaming and improve docs
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Feb 28, 2024
1 parent 7c04b6d commit 0fc7237
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 39 deletions.
64 changes: 57 additions & 7 deletions clients/python/text_generation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator
from typing import Dict, Optional, List, AsyncIterator, Iterator, Union

from text_generation.types import (
StreamResponse,
Expand All @@ -12,6 +12,7 @@
Parameters,
Grammar,
ChatRequest,
ChatCompletionChunk,
ChatComplete,
Message,
Tool,
Expand Down Expand Up @@ -134,18 +135,42 @@ def chat(
tools=tools,
tool_choice=tool_choice,
)

if not stream:
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return ChatComplete(**payload)
else:
return self._chat_stream_response(request)

def _chat_stream_response(self, request):
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return ChatComplete(**payload)
# iterate and print stream
for byte_payload in resp.iter_lines():
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = ChatCompletionChunk(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)

def generate(
self,
Expand Down Expand Up @@ -417,7 +442,7 @@ async def chat(
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
tool_choice: Optional[str] = None,
):
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
"""
Given a list of messages, generate a response asynchronously
Expand Down Expand Up @@ -472,6 +497,12 @@ async def chat(
tools=tools,
tool_choice=tool_choice,
)
if not stream:
return await self._chat_single_response(request)
else:
return self._chat_stream_response(request)

async def _chat_single_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
Expand All @@ -483,6 +514,25 @@ async def chat(
raise parse_error(resp.status, payload)
return ChatComplete(**payload)

async def _chat_stream_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/chat/completions", json=request.dict()
) as resp:
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = ChatCompletionChunk(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)

async def generate(
self,
prompt: str,
Expand Down
34 changes: 34 additions & 0 deletions clients/python/text_generation/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,40 @@ class ChatCompletionComplete(BaseModel):
usage: Any


class Function(BaseModel):
name: Optional[str]
arguments: str


class ChoiceDeltaToolCall(BaseModel):
index: int
id: str
type: str
function: Function


class ChoiceDelta(BaseModel):
role: str
content: Optional[str]
tool_calls: Optional[ChoiceDeltaToolCall]


class Choice(BaseModel):
index: int
delta: ChoiceDelta
logprobs: Optional[dict] = None
finish_reason: Optional[str] = None


class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choice]


class ChatComplete(BaseModel):
# Chat completion details
id: str
Expand Down
5 changes: 2 additions & 3 deletions docs/source/guidance.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Guidance

Text Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version `1.4.3`. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility.
Text Generation Inference (TGI) now supports [Grammar and Constraints](#grammar-and-constraints) and [Tools and Functions](#tools-and-functions) to help developer guide the LLM's responses and enhance its capabilities.

Whether you're a developer, a data scientist, or just a curious mind, we've made it super easy (and fun!) to start integrating advanced text generation capabilities into your applications.
These feature is available starting from version `1.4.3`. These features are accessible via the text-generation-client library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!

### Quick Start

Expand Down Expand Up @@ -214,7 +214,6 @@ if __name__ == "__main__":

```


## Tools and Functions 🛠️

### The Tools Parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
Expand All @@ -25,14 +25,14 @@
"usage": null
}
],
"created": 1708957016,
"created": 1709079417,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 313,
"total_tokens": 349
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
Expand All @@ -25,14 +25,14 @@
"usage": null
}
],
"created": 1708957016,
"created": 1709079492,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 313,
"total_tokens": 349
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
"usage": null
}
],
"created": 1708957017,
"created": 1709079493,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 21,
"prompt_tokens": 184,
"total_tokens": 205
"prompt_tokens": 187,
"total_tokens": 208
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "</s>",
"name": null
},
"id": "",
"index": 20,
"type": "function"
}
},
"finish_reason": "eos_token",
"index": 20,
"logprobs": null
}
],
"created": 1709087088,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native"
}
41 changes: 37 additions & 4 deletions integration-tests/models/test_tools_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2,
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0,
Expand Down Expand Up @@ -163,8 +163,8 @@ async def test_flash_llama_grammar_tools_auto(
"name": "tools",
"parameters": {
"format": "celsius",
"location": "San Francisco",
"num_days": 2,
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0,
Expand Down Expand Up @@ -206,3 +206,36 @@ async def test_flash_llama_grammar_tools_choice(
},
}
assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Paris, France?",
},
],
stream=True,
)

count = 0
async for response in responses:
print(response)
count += 1

assert count == 20
assert response == response_snapshot
Loading

0 comments on commit 0fc7237

Please sign in to comment.