Skip to content

Commit

Permalink
Merge pull request #34 from jepler/mistral
Browse files Browse the repository at this point in the history
Add mistral.ai backend
  • Loading branch information
jepler authored Mar 10, 2024
2 parents fac2dfb + ddc5214 commit 3682753
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/chap/backends/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class Parameters:
url: str = "http://localhost:8080/completion"
"""The URL of a llama.cpp server's completion endpoint."""

start_prompt: str = """<s>[INST] <<SYS>>\n"""
after_system: str = "\n<</SYS>>\n\n"
after_user: str = """ [/INST] """
after_assistant: str = """ </s><s>[INST] """
start_prompt: str = "<s>"
system_format: str = "<<SYS>>{}<</SYS>>"
user_format: str = " [INST] {} [/INST]"
assistant_format: str = " {}</s>"

def __init__(self) -> None:
super().__init__()
Expand All @@ -34,18 +34,18 @@ def __init__(self) -> None:
def make_full_query(self, messages: Session, max_query_size: int) -> str:
del messages[1:-max_query_size]
result = [self.parameters.start_prompt]
formats = {
Role.SYSTEM: self.parameters.system_format,
Role.USER: self.parameters.user_format,
Role.ASSISTANT: self.parameters.assistant_format,
}
for m in messages:
content = (m.content or "").strip()
if not content:
continue
result.append(content)
if m.role == Role.SYSTEM:
result.append(self.parameters.after_system)
elif m.role == Role.ASSISTANT:
result.append(self.parameters.after_assistant)
elif m.role == Role.USER:
result.append(self.parameters.after_user)
result.append(formats[m.role].format(content))
full_query = "".join(result)
print("fq", full_query)
return full_query

async def aask(
Expand Down
99 changes: 99 additions & 0 deletions src/chap/backends/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: 2024 Jeff Epler <[email protected]>
#
# SPDX-License-Identifier: MIT

import json
from dataclasses import dataclass
from typing import AsyncGenerator, Any

import httpx

from ..core import AutoAskMixin
from ..key import get_key
from ..session import Assistant, Session, User


class Mistral(AutoAskMixin):
@dataclass
class Parameters:
url: str = "https://api.mistral.ai"
model: str = "open-mistral-7b"
max_new_tokens: int = 1000

def __init__(self) -> None:
super().__init__()
self.parameters = self.Parameters()

system_message = """\
Answer each question accurately and thoroughly.
"""

def make_full_query(self, messages: Session, max_query_size: int) -> dict[str, Any]:
messages = [m for m in messages if m.content]
del messages[1:-max_query_size]
result = dict(
model=self.parameters.model,
max_tokens=self.parameters.max_new_tokens,
messages=[dict(role=str(m.role), content=m.content) for m in messages],
stream=True,
)
return result

async def aask(
self,
session: Session,
query: str,
*,
max_query_size: int = 5,
timeout: float = 180,
) -> AsyncGenerator[str, None]:
new_content: list[str] = []
params = self.make_full_query(session + [User(query)], max_query_size)
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST",
f"{self.parameters.url}/v1/chat/completions",
json=params,
headers={
"Authorization": f"Bearer {self.get_key()}",
"content-type": "application/json",
"accept": "application/json",
"model": "application/json",
},
) as response:
if response.status_code == 200:
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line.removeprefix("data:").strip()
if data == "[DONE]":
break
j = json.loads(data)
content = (
j.get("choices", [{}])[0]
.get("delta", {})
.get("content", "")
)
if content:
new_content.append(content)
yield content
else:
content = f"\nFailed with {response=!r}"
new_content.append(content)
yield content
async for line in response.aiter_lines():
new_content.append(line)
yield line
except httpx.HTTPError as e:
content = f"\nException: {e!r}"
new_content.append(content)
yield content

session.extend([User(query), Assistant("".join(new_content))])

@classmethod
def get_key(cls) -> str:
return get_key("mistral_api_key")


factory = Mistral

0 comments on commit 3682753

Please sign in to comment.