From f4d6c6d9c0a9001fea5dade4f8f1f5b5e19c79f2 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Thu, 12 Dec 2024 00:49:25 -0500 Subject: [PATCH] test --- weave/integrations/litellm/litellm.py | 73 +++++++++++++++++++-------- weave/trace/autopatch.py | 20 ++++---- 2 files changed, 62 insertions(+), 31 deletions(-) diff --git a/weave/integrations/litellm/litellm.py b/weave/integrations/litellm/litellm.py index c3bf1bf114a..9ae6e492c84 100644 --- a/weave/integrations/litellm/litellm.py +++ b/weave/integrations/litellm/litellm.py @@ -1,19 +1,24 @@ +from __future__ import annotations + import importlib -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from litellm.utils import ModelResponse +_litellm_patcher: MultiPatcher | None = None + # This accumulator is nearly identical to the mistral accumulator, just with different types. def litellm_accumulator( - acc: Optional["ModelResponse"], - value: "ModelResponse", -) -> "ModelResponse": + acc: ModelResponse | None, + value: ModelResponse, +) -> ModelResponse: # This import should be safe at this point from litellm.utils import Choices, Message, ModelResponse, Usage @@ -82,10 +87,10 @@ def should_use_accumulator(inputs: dict) -> bool: return isinstance(inputs, dict) and bool(inputs.get("stream")) -def make_wrapper(name: str) -> Callable: +def make_wrapper(settings: OpSettings) -> Callable: def litellm_wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(fn, **op_kwargs) return add_accumulator( op, # type: ignore make_accumulator=lambda inputs: litellm_accumulator, @@ -96,17 +101,41 @@ def litellm_wrapper(fn: Callable) -> Callable: return litellm_wrapper -litellm_patcher = MultiPatcher( - [ - SymbolPatcher( - lambda: importlib.import_module("litellm"), - "completion", - make_wrapper("litellm.completion"), - ), - SymbolPatcher( - lambda: importlib.import_module("litellm"), - "acompletion", - make_wrapper("litellm.acompletion"), - ), - ] -) +def get_litellm_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _litellm_patcher + if _litellm_patcher is not None: + return _litellm_patcher + + base = settings.op_settings + + completion_settings = base.model_copy( + update={"name": base.name or "litellm.completion"} + ) + acompletion_settings = base.model_copy( + update={"name": base.name or "litellm.acompletion"} + ) + + _litellm_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("litellm"), + "completion", + make_wrapper(completion_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("litellm"), + "acompletion", + make_wrapper(acompletion_settings), + ), + ] + ) + + return _litellm_patcher diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index e24c0a36c81..16a365ea02c 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -41,7 +41,7 @@ class AutopatchSettings(BaseModel): google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) groq: IntegrationSettings = Field(default_factory=IntegrationSettings) # instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) - # litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) + litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) openai: IntegrationSettings = Field(default_factory=IntegrationSettings) @@ -60,7 +60,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None: from weave.integrations.groq.groq_sdk import get_groq_patcher from weave.integrations.instructor.instructor_sdk import instructor_patcher from weave.integrations.langchain.langchain import langchain_patcher - from weave.integrations.litellm.litellm import litellm_patcher + from weave.integrations.litellm.litellm import get_litellm_patcher from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import get_mistral_patcher from weave.integrations.notdiamond.tracing import get_notdiamond_patcher @@ -72,9 +72,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None: get_openai_patcher(settings.openai).attempt_patch() get_mistral_patcher(settings.mistral).attempt_patch() - litellm_patcher.attempt_patch() - llamaindex_patcher.attempt_patch() - langchain_patcher.attempt_patch() + get_litellm_patcher(settings.litellm).attempt_patch() get_anthropic_patcher(settings.anthropic).attempt_patch() get_groq_patcher(settings.groq).attempt_patch() instructor_patcher.attempt_patch() @@ -85,6 +83,9 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None: get_notdiamond_patcher(settings.notdiamond).attempt_patch() get_vertexai_patcher(settings.vertexai).attempt_patch() + llamaindex_patcher.attempt_patch() + langchain_patcher.attempt_patch() + def reset_autopatch() -> None: from weave.integrations.anthropic.anthropic_sdk import get_anthropic_patcher @@ -97,7 +98,7 @@ def reset_autopatch() -> None: from weave.integrations.groq.groq_sdk import get_groq_patcher from weave.integrations.instructor.instructor_sdk import instructor_patcher from weave.integrations.langchain.langchain import langchain_patcher - from weave.integrations.litellm.litellm import litellm_patcher + from weave.integrations.litellm.litellm import get_litellm_patcher from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import get_mistral_patcher from weave.integrations.notdiamond.tracing import get_notdiamond_patcher @@ -106,9 +107,7 @@ def reset_autopatch() -> None: get_openai_patcher().undo_patch() get_mistral_patcher().undo_patch() - litellm_patcher.undo_patch() - llamaindex_patcher.undo_patch() - langchain_patcher.undo_patch() + get_litellm_patcher().undo_patch() get_anthropic_patcher().undo_patch() get_groq_patcher().undo_patch() instructor_patcher.undo_patch() @@ -118,3 +117,6 @@ def reset_autopatch() -> None: get_google_genai_patcher().undo_patch() get_notdiamond_patcher().undo_patch() get_vertexai_patcher().undo_patch() + + llamaindex_patcher.undo_patch() + langchain_patcher.undo_patch()