Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 12, 2024
1 parent 435ca80 commit 5999fb0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 29 deletions.
77 changes: 53 additions & 24 deletions weave/integrations/groq/groq_sdk.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from __future__ import annotations

import importlib
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable

import weave
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher

if TYPE_CHECKING:
from groq.types.chat import ChatCompletion, ChatCompletionChunk

import weave
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher

_groq_patcher: MultiPatcher | None = None


def groq_accumulator(
acc: Optional["ChatCompletion"], value: "ChatCompletionChunk"
) -> "ChatCompletion":
acc: ChatCompletion | None, value: ChatCompletionChunk
) -> ChatCompletion:
from groq.types.chat import ChatCompletion, ChatCompletionMessage
from groq.types.chat.chat_completion import Choice
from groq.types.chat.chat_completion_chunk import Choice as ChoiceChunk
Expand Down Expand Up @@ -83,11 +89,10 @@ def should_use_accumulator(inputs: dict) -> bool:
return isinstance(inputs, dict) and bool(inputs.get("stream"))


def groq_wrapper(name: str) -> Callable[[Callable], Callable]:
def groq_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
# return op
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: groq_accumulator,
Expand All @@ -97,17 +102,41 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


groq_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("groq.resources.chat.completions"),
"Completions.create",
groq_wrapper(name="groq.chat.completions.create"),
),
SymbolPatcher(
lambda: importlib.import_module("groq.resources.chat.completions"),
"AsyncCompletions.create",
groq_wrapper(name="groq.async.chat.completions.create"),
),
]
)
def get_groq_patcher(
settings: IntegrationSettings | None = None,
) -> MultiPatcher | NoOpPatcher:
if settings is None:
settings = IntegrationSettings()

if not settings.enabled:
return NoOpPatcher()

global _groq_patcher
if _groq_patcher is not None:
return _groq_patcher

base = settings.op_settings

chat_completions_settings = base.model_copy(
update={"name": base.name or "groq.chat.completions.create"}
)
async_chat_completions_settings = base.model_copy(
update={"name": base.name or "groq.async.chat.completions.create"}
)

_groq_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("groq.resources.chat.completions"),
"Completions.create",
groq_wrapper(chat_completions_settings),
),
SymbolPatcher(
lambda: importlib.import_module("groq.resources.chat.completions"),
"AsyncCompletions.create",
groq_wrapper(async_chat_completions_settings),
),
]
)

return _groq_patcher
10 changes: 5 additions & 5 deletions weave/trace/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class AutopatchSettings(BaseModel):
cohere: IntegrationSettings = Field(default_factory=IntegrationSettings)
dspy: IntegrationSettings = Field(default_factory=IntegrationSettings)
google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings)
# groq: IntegrationSettings = Field(default_factory=IntegrationSettings)
groq: IntegrationSettings = Field(default_factory=IntegrationSettings)
# instructor: IntegrationSettings = Field(default_factory=IntegrationSettings)
# langchain: IntegrationSettings = Field(default_factory=IntegrationSettings)
# litellm: IntegrationSettings = Field(default_factory=IntegrationSettings)
Expand All @@ -59,7 +59,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
from weave.integrations.google_ai_studio.google_ai_studio_sdk import (
get_google_genai_patcher,
)
from weave.integrations.groq.groq_sdk import groq_patcher
from weave.integrations.groq.groq_sdk import get_groq_patcher
from weave.integrations.instructor.instructor_sdk import instructor_patcher
from weave.integrations.langchain.langchain import langchain_patcher
from weave.integrations.litellm.litellm import litellm_patcher
Expand All @@ -78,7 +78,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
llamaindex_patcher.attempt_patch()
langchain_patcher.attempt_patch()
get_anthropic_patcher(settings.anthropic).attempt_patch()
groq_patcher.attempt_patch()
get_groq_patcher(settings.groq).attempt_patch()
instructor_patcher.attempt_patch()
get_dspy_patcher(settings.dspy).attempt_patch()
get_cerebras_patcher(settings.cerebras).attempt_patch()
Expand All @@ -96,7 +96,7 @@ def reset_autopatch() -> None:
from weave.integrations.google_ai_studio.google_ai_studio_sdk import (
get_google_genai_patcher,
)
from weave.integrations.groq.groq_sdk import groq_patcher
from weave.integrations.groq.groq_sdk import get_groq_patcher
from weave.integrations.instructor.instructor_sdk import instructor_patcher
from weave.integrations.langchain.langchain import langchain_patcher
from weave.integrations.litellm.litellm import litellm_patcher
Expand All @@ -112,7 +112,7 @@ def reset_autopatch() -> None:
llamaindex_patcher.undo_patch()
langchain_patcher.undo_patch()
get_anthropic_patcher().undo_patch()
groq_patcher.undo_patch()
get_groq_patcher().undo_patch()
instructor_patcher.undo_patch()
get_dspy_patcher().undo_patch()
get_cerebras_patcher().undo_patch()
Expand Down

0 comments on commit 5999fb0

Please sign in to comment.