Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 12, 2024
1 parent 3571f93 commit f4d6c6d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 31 deletions.
73 changes: 51 additions & 22 deletions weave/integrations/litellm/litellm.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from __future__ import annotations

import importlib
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable

import weave
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher

if TYPE_CHECKING:
from litellm.utils import ModelResponse

_litellm_patcher: MultiPatcher | None = None


# This accumulator is nearly identical to the mistral accumulator, just with different types.
def litellm_accumulator(
acc: Optional["ModelResponse"],
value: "ModelResponse",
) -> "ModelResponse":
acc: ModelResponse | None,
value: ModelResponse,
) -> ModelResponse:
# This import should be safe at this point
from litellm.utils import Choices, Message, ModelResponse, Usage

Expand Down Expand Up @@ -82,10 +87,10 @@ def should_use_accumulator(inputs: dict) -> bool:
return isinstance(inputs, dict) and bool(inputs.get("stream"))


def make_wrapper(name: str) -> Callable:
def make_wrapper(settings: OpSettings) -> Callable:
def litellm_wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: litellm_accumulator,
Expand All @@ -96,17 +101,41 @@ def litellm_wrapper(fn: Callable) -> Callable:
return litellm_wrapper


litellm_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("litellm"),
"completion",
make_wrapper("litellm.completion"),
),
SymbolPatcher(
lambda: importlib.import_module("litellm"),
"acompletion",
make_wrapper("litellm.acompletion"),
),
]
)
def get_litellm_patcher(
settings: IntegrationSettings | None = None,
) -> MultiPatcher | NoOpPatcher:
if settings is None:
settings = IntegrationSettings()

if not settings.enabled:
return NoOpPatcher()

global _litellm_patcher
if _litellm_patcher is not None:
return _litellm_patcher

base = settings.op_settings

completion_settings = base.model_copy(
update={"name": base.name or "litellm.completion"}
)
acompletion_settings = base.model_copy(
update={"name": base.name or "litellm.acompletion"}
)

_litellm_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("litellm"),
"completion",
make_wrapper(completion_settings),
),
SymbolPatcher(
lambda: importlib.import_module("litellm"),
"acompletion",
make_wrapper(acompletion_settings),
),
]
)

return _litellm_patcher
20 changes: 11 additions & 9 deletions weave/trace/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AutopatchSettings(BaseModel):
google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings)
groq: IntegrationSettings = Field(default_factory=IntegrationSettings)
# instructor: IntegrationSettings = Field(default_factory=IntegrationSettings)
# litellm: IntegrationSettings = Field(default_factory=IntegrationSettings)
litellm: IntegrationSettings = Field(default_factory=IntegrationSettings)
mistral: IntegrationSettings = Field(default_factory=IntegrationSettings)
notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings)
openai: IntegrationSettings = Field(default_factory=IntegrationSettings)
Expand All @@ -60,7 +60,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
from weave.integrations.groq.groq_sdk import get_groq_patcher
from weave.integrations.instructor.instructor_sdk import instructor_patcher
from weave.integrations.langchain.langchain import langchain_patcher
from weave.integrations.litellm.litellm import litellm_patcher
from weave.integrations.litellm.litellm import get_litellm_patcher
from weave.integrations.llamaindex.llamaindex import llamaindex_patcher
from weave.integrations.mistral import get_mistral_patcher
from weave.integrations.notdiamond.tracing import get_notdiamond_patcher
Expand All @@ -72,9 +72,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:

get_openai_patcher(settings.openai).attempt_patch()
get_mistral_patcher(settings.mistral).attempt_patch()
litellm_patcher.attempt_patch()
llamaindex_patcher.attempt_patch()
langchain_patcher.attempt_patch()
get_litellm_patcher(settings.litellm).attempt_patch()
get_anthropic_patcher(settings.anthropic).attempt_patch()
get_groq_patcher(settings.groq).attempt_patch()
instructor_patcher.attempt_patch()
Expand All @@ -85,6 +83,9 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
get_notdiamond_patcher(settings.notdiamond).attempt_patch()
get_vertexai_patcher(settings.vertexai).attempt_patch()

llamaindex_patcher.attempt_patch()
langchain_patcher.attempt_patch()


def reset_autopatch() -> None:
from weave.integrations.anthropic.anthropic_sdk import get_anthropic_patcher
Expand All @@ -97,7 +98,7 @@ def reset_autopatch() -> None:
from weave.integrations.groq.groq_sdk import get_groq_patcher
from weave.integrations.instructor.instructor_sdk import instructor_patcher
from weave.integrations.langchain.langchain import langchain_patcher
from weave.integrations.litellm.litellm import litellm_patcher
from weave.integrations.litellm.litellm import get_litellm_patcher
from weave.integrations.llamaindex.llamaindex import llamaindex_patcher
from weave.integrations.mistral import get_mistral_patcher
from weave.integrations.notdiamond.tracing import get_notdiamond_patcher
Expand All @@ -106,9 +107,7 @@ def reset_autopatch() -> None:

get_openai_patcher().undo_patch()
get_mistral_patcher().undo_patch()
litellm_patcher.undo_patch()
llamaindex_patcher.undo_patch()
langchain_patcher.undo_patch()
get_litellm_patcher().undo_patch()
get_anthropic_patcher().undo_patch()
get_groq_patcher().undo_patch()
instructor_patcher.undo_patch()
Expand All @@ -118,3 +117,6 @@ def reset_autopatch() -> None:
get_google_genai_patcher().undo_patch()
get_notdiamond_patcher().undo_patch()
get_vertexai_patcher().undo_patch()

llamaindex_patcher.undo_patch()
langchain_patcher.undo_patch()

0 comments on commit f4d6c6d

Please sign in to comment.