Skip to content

Commit

Permalink
Fix AudioQnA gateway and orchestrator logics (#233)
Browse files Browse the repository at this point in the history
* add asr/tts component for xeon and hpu

Signed-off-by: Spycsh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix ffmpeg JSONDecode error on HPU

* add tests

* trigger

* try

* add gateway

* import

* add asr check

* fix

* lower max_token

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Spycsh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Spycsh and pre-commit-ci[bot] authored Jun 21, 2024
1 parent 9001783 commit 48ed5d8
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 0 deletions.
1 change: 1 addition & 0 deletions comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
CodeTransGateway,
DocSumGateway,
TranslationGateway,
AudioQnAGateway,
)

# Telemetry
Expand Down
35 changes: 35 additions & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi.responses import StreamingResponse

from ..proto.api_protocol import (
AudioChatCompletionRequest,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
Expand Down Expand Up @@ -315,3 +316,37 @@ async def handle_request(self, request: Request):
)
)
return ChatCompletionResponse(model="docsum", choices=choices, usage=usage)


class AudioQnAGateway(Gateway):
def __init__(self, megaservice, host="0.0.0.0", port=8888):
super().__init__(
megaservice,
host,
port,
str(MegaServiceEndpoint.AUDIO_QNA),
AudioChatCompletionRequest,
ChatCompletionResponse,
)

async def handle_request(self, request: Request):
data = await request.json()

chat_request = AudioChatCompletionRequest.parse_obj(data)
parameters = LLMParams(
# relatively lower max_tokens for audio conversation
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=False, # TODO add streaming LLM output as input to TTS
)
result_dict = await self.megaservice.schedule(
initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters
)

last_node = self.megaservice.all_leaves()[-1]
response = result_dict[last_node]["byte_str"]

return response
8 changes: 8 additions & 0 deletions comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ async def execute(
for field, value in llm_parameters_dict.items():
if inputs.get(field) != value:
inputs[field] = value

if self.services[cur_node].service_type == ServiceType.LLM and llm_parameters.streaming:
# Still leave to sync requests.post for StreamingResponse
response = requests.post(
Expand All @@ -93,6 +94,13 @@ def generate():

return StreamingResponse(generate(), media_type="text/event-stream"), cur_node
else:
if (
self.services[cur_node].service_type == ServiceType.LLM
and self.predecessors(cur_node)
and "asr" in self.predecessors(cur_node)[0]
):
inputs["query"] = inputs["text"]
del inputs["text"]
async with session.post(endpoint, json=inputs) as response:
print(response.status)
return await response.json(), cur_node
Expand Down
22 changes: 22 additions & 0 deletions comps/cores/proto/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,28 @@ class ChatCompletionRequest(BaseModel):
user: Optional[str] = None


class AudioChatCompletionRequest(BaseModel):
audio: str
messages: Optional[
Union[
str,
List[Dict[str, str]],
List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]],
]
] = None
model: Optional[str] = "Intel/neural-chat-7b-v3-3"
temperature: Optional[float] = 0.01
top_p: Optional[float] = 0.95
top_k: Optional[int] = 10
n: Optional[int] = 1
max_tokens: Optional[int] = 1024
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
presence_penalty: Optional[float] = 1.03
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None


class ChatMessage(BaseModel):
role: str
content: str
Expand Down

0 comments on commit 48ed5d8

Please sign in to comment.