Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 10, 2024
1 parent 9f641c9 commit 98ab709
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 36 deletions.
4 changes: 2 additions & 2 deletions weave/integrations/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
mistral_version = "1.0" # we need to return a patching function

if version.parse(mistral_version) < version.parse("1.0.0"):
from .v0.mistral import mistral_patcher
from .v0.mistral import get_mistral_patcher

print(
f"Using MistralAI version {mistral_version}. Please consider upgrading to version 1.0.0 or later."
)
else:
from .v1.mistral import mistral_patcher # noqa: F401
from .v1.mistral import get_mistral_patcher # noqa: F401
107 changes: 73 additions & 34 deletions weave/integrations/mistral/v0/mistral.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
import importlib
from typing import TYPE_CHECKING, Callable, Optional

import weave
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher

Expand All @@ -11,6 +13,8 @@
ChatCompletionStreamResponse,
)

_mistral_patcher: Optional[MultiPatcher] = None


def mistral_accumulator(
acc: Optional["ChatCompletionResponse"],
Expand Down Expand Up @@ -72,37 +76,72 @@ def mistral_accumulator(
return acc


def mistral_stream_wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore
return acc_op


mistral_patcher = MultiPatcher(
[
# Patch the sync, non-streaming chat method
SymbolPatcher(
lambda: importlib.import_module("mistralai.client"),
"MistralClient.chat",
weave.op(),
),
# Patch the sync, streaming chat method
SymbolPatcher(
lambda: importlib.import_module("mistralai.client"),
"MistralClient.chat_stream",
mistral_stream_wrapper,
),
# Patch the async, non-streaming chat method
SymbolPatcher(
lambda: importlib.import_module("mistralai.async_client"),
"MistralAsyncClient.chat",
weave.op(),
),
# Patch the async, streaming chat method
SymbolPatcher(
lambda: importlib.import_module("mistralai.async_client"),
"MistralAsyncClient.chat_stream",
mistral_stream_wrapper,
),
]
)
def mistral_stream_wrapper(settings: OpSettings) -> Callable:
def wrapper(fn: Callable) -> Callable:
op_kwargs = dataclasses.asdict(settings)
op = weave.op(fn, **op_kwargs)
acc_op = add_accumulator(op, lambda inputs: mistral_accumulator) # type: ignore
return acc_op

return wrapper


def mistral_wrapper(settings: OpSettings) -> Callable:
def wrapper(fn: Callable) -> Callable:
op_kwargs = dataclasses.asdict(settings)
op = weave.op(fn, **op_kwargs)
return op

return wrapper


def get_mistral_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher:
global _mistral_patcher

if _mistral_patcher is not None:
return _mistral_patcher

if settings is None:
settings = IntegrationSettings()

base = settings.op_settings

chat_complete_settings = dataclasses.replace(
base,
name=base.name or "mistralai.chat.complete",
)
chat_stream_settings = dataclasses.replace(
base,
name=base.name or "mistralai.chat.stream",
)

_mistral_patcher = MultiPatcher(
[
# Patch the sync, non-streaming chat method
SymbolPatcher(
lambda: importlib.import_module("mistralai.client"),
"MistralClient.chat",
mistral_wrapper(chat_complete_settings),
),
# Patch the sync, streaming chat method
SymbolPatcher(
lambda: importlib.import_module("mistralai.client"),
"MistralClient.chat_stream",
mistral_stream_wrapper(chat_stream_settings),
),
# Patch the async, non-streaming chat method
SymbolPatcher(
lambda: importlib.import_module("mistralai.async_client"),
"MistralAsyncClient.chat",
mistral_wrapper(chat_complete_settings),
),
# Patch the async, streaming chat method
SymbolPatcher(
lambda: importlib.import_module("mistralai.async_client"),
"MistralAsyncClient.chat_stream",
mistral_stream_wrapper(chat_stream_settings),
),
]
)

return _mistral_patcher

0 comments on commit 98ab709

Please sign in to comment.