From 5999fb08e8aa12d6d6a8ef2d7545660fb238b6b4 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Thu, 12 Dec 2024 00:41:20 -0500 Subject: [PATCH] test --- weave/integrations/groq/groq_sdk.py | 77 ++++++++++++++++++++--------- weave/trace/autopatch.py | 10 ++-- 2 files changed, 58 insertions(+), 29 deletions(-) diff --git a/weave/integrations/groq/groq_sdk.py b/weave/integrations/groq/groq_sdk.py index 4f470e6d743..c5c07fd705f 100644 --- a/weave/integrations/groq/groq_sdk.py +++ b/weave/integrations/groq/groq_sdk.py @@ -1,17 +1,23 @@ +from __future__ import annotations + import importlib -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable + +import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings +from weave.trace.op_extensions.accumulator import add_accumulator +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from groq.types.chat import ChatCompletion, ChatCompletionChunk -import weave -from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher + +_groq_patcher: MultiPatcher | None = None def groq_accumulator( - acc: Optional["ChatCompletion"], value: "ChatCompletionChunk" -) -> "ChatCompletion": + acc: ChatCompletion | None, value: ChatCompletionChunk +) -> ChatCompletion: from groq.types.chat import ChatCompletion, ChatCompletionMessage from groq.types.chat.chat_completion import Choice from groq.types.chat.chat_completion_chunk import Choice as ChoiceChunk @@ -83,11 +89,10 @@ def should_use_accumulator(inputs: dict) -> bool: return isinstance(inputs, dict) and bool(inputs.get("stream")) -def groq_wrapper(name: str) -> Callable[[Callable], Callable]: +def groq_wrapper(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore - # return op + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: groq_accumulator, @@ -97,17 +102,41 @@ def wrapper(fn: Callable) -> Callable: return wrapper -groq_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("groq.resources.chat.completions"), - "Completions.create", - groq_wrapper(name="groq.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("groq.resources.chat.completions"), - "AsyncCompletions.create", - groq_wrapper(name="groq.async.chat.completions.create"), - ), - ] -) +def get_groq_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _groq_patcher + if _groq_patcher is not None: + return _groq_patcher + + base = settings.op_settings + + chat_completions_settings = base.model_copy( + update={"name": base.name or "groq.chat.completions.create"} + ) + async_chat_completions_settings = base.model_copy( + update={"name": base.name or "groq.async.chat.completions.create"} + ) + + _groq_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("groq.resources.chat.completions"), + "Completions.create", + groq_wrapper(chat_completions_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("groq.resources.chat.completions"), + "AsyncCompletions.create", + groq_wrapper(async_chat_completions_settings), + ), + ] + ) + + return _groq_patcher diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index d7e8e041931..dc8fbf369c1 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -39,7 +39,7 @@ class AutopatchSettings(BaseModel): cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) - # groq: IntegrationSettings = Field(default_factory=IntegrationSettings) + groq: IntegrationSettings = Field(default_factory=IntegrationSettings) # instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) # langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) # litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) @@ -59,7 +59,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None: from weave.integrations.google_ai_studio.google_ai_studio_sdk import ( get_google_genai_patcher, ) - from weave.integrations.groq.groq_sdk import groq_patcher + from weave.integrations.groq.groq_sdk import get_groq_patcher from weave.integrations.instructor.instructor_sdk import instructor_patcher from weave.integrations.langchain.langchain import langchain_patcher from weave.integrations.litellm.litellm import litellm_patcher @@ -78,7 +78,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None: llamaindex_patcher.attempt_patch() langchain_patcher.attempt_patch() get_anthropic_patcher(settings.anthropic).attempt_patch() - groq_patcher.attempt_patch() + get_groq_patcher(settings.groq).attempt_patch() instructor_patcher.attempt_patch() get_dspy_patcher(settings.dspy).attempt_patch() get_cerebras_patcher(settings.cerebras).attempt_patch() @@ -96,7 +96,7 @@ def reset_autopatch() -> None: from weave.integrations.google_ai_studio.google_ai_studio_sdk import ( get_google_genai_patcher, ) - from weave.integrations.groq.groq_sdk import groq_patcher + from weave.integrations.groq.groq_sdk import get_groq_patcher from weave.integrations.instructor.instructor_sdk import instructor_patcher from weave.integrations.langchain.langchain import langchain_patcher from weave.integrations.litellm.litellm import litellm_patcher @@ -112,7 +112,7 @@ def reset_autopatch() -> None: llamaindex_patcher.undo_patch() langchain_patcher.undo_patch() get_anthropic_patcher().undo_patch() - groq_patcher.undo_patch() + get_groq_patcher().undo_patch() instructor_patcher.undo_patch() get_dspy_patcher().undo_patch() get_cerebras_patcher().undo_patch()