Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 11, 2024
1 parent 3a3fb64 commit da6fb39
Show file tree
Hide file tree
Showing 15 changed files with 490 additions and 267 deletions.
8 changes: 5 additions & 3 deletions weave/integrations/cerebras/cerebras_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Optional

import weave
from weave.trace.autopatch import OpSettings
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.patcher import MultiPatcher, SymbolPatcher

_cerebras_patcher: Optional[MultiPatcher] = None
Expand Down Expand Up @@ -34,14 +34,16 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


def get_cerebras_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher:
def get_cerebras_patcher(
settings: Optional[IntegrationSettings] = None,
) -> MultiPatcher:
global _cerebras_patcher

if _cerebras_patcher is not None:
return _cerebras_patcher

if settings is None:
settings = OpSettings()
settings = IntegrationSettings()

base = settings.op_settings

Expand Down
6 changes: 3 additions & 3 deletions weave/integrations/cohere/cohere_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, Callable, Optional

import weave
from weave.trace.autopatch import OpSettings
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher

Expand Down Expand Up @@ -185,14 +185,14 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


def get_cohere_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher:
def get_cohere_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher:
global _cohere_patcher

if _cohere_patcher is not None:
return _cohere_patcher

if settings is None:
settings = OpSettings()
settings = IntegrationSettings()

base = settings.op_settings

Expand Down
8 changes: 5 additions & 3 deletions weave/integrations/google_ai_studio/google_ai_studio_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, Callable, Optional

import weave
from weave.trace.autopatch import OpSettings
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.serialize import dictify
Expand Down Expand Up @@ -135,14 +135,16 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


def get_google_genai_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher:
def get_google_genai_patcher(
settings: Optional[IntegrationSettings] = None,
) -> MultiPatcher:
global _google_genai_patcher

if _google_genai_patcher is not None:
return _google_genai_patcher

if settings is None:
settings = OpSettings()
settings = IntegrationSettings()

base = settings.op_settings

Expand Down
6 changes: 3 additions & 3 deletions weave/integrations/groq/groq_sdk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib
from typing import TYPE_CHECKING, Callable, Optional

from weave.trace.autopatch import OpSettings
from weave.trace.autopatch import IntegrationSettings, OpSettings

if TYPE_CHECKING:
from groq.types.chat import ChatCompletion, ChatCompletionChunk
Expand Down Expand Up @@ -100,14 +100,14 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


def get_groq_patcher(settings: Optional[OpSettings] = None) -> MultiPatcher:
def get_groq_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher:
global _groq_patcher

if _groq_patcher is not None:
return _groq_patcher

if settings is None:
settings = OpSettings()
settings = IntegrationSettings()

base = settings.op_settings

Expand Down
14 changes: 7 additions & 7 deletions weave/integrations/instructor/instructor_iterable_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import BaseModel

import weave
from weave.trace.autopatch import OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator


Expand All @@ -27,10 +28,10 @@ def should_accumulate_iterable(inputs: dict) -> bool:
return False


def instructor_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def instructor_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op(fn)
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: instructor_iterable_accumulator,
Expand All @@ -40,7 +41,7 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


def instructor_wrapper_async(name: str) -> Callable[[Callable], Callable]:
def instructor_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
Expand All @@ -49,9 +50,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:

return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op(_fn_wrapper(fn))
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(_fn_wrapper(fn), **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: instructor_iterable_accumulator,
Expand Down
7 changes: 4 additions & 3 deletions weave/integrations/instructor/instructor_partial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel

import weave
from weave.trace.autopatch import OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator


Expand All @@ -14,10 +15,10 @@ def instructor_partial_accumulator(
return acc


def instructor_wrapper_partial(name: str) -> Callable[[Callable], Callable]:
def instructor_wrapper_partial(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: instructor_partial_accumulator,
Expand Down
79 changes: 55 additions & 24 deletions weave/integrations/instructor/instructor_sdk.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,62 @@
import importlib
from typing import Optional

from weave.trace.autopatch import IntegrationSettings
from weave.trace.patcher import MultiPatcher, SymbolPatcher

from .instructor_iterable_utils import instructor_wrapper_async, instructor_wrapper_sync
from .instructor_partial_utils import instructor_wrapper_partial

instructor_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"Instructor.create",
instructor_wrapper_sync(name="Instructor.create"),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"AsyncInstructor.create",
instructor_wrapper_async(name="AsyncInstructor.create"),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"Instructor.create_partial",
instructor_wrapper_partial(name="Instructor.create_partial"),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"AsyncInstructor.create_partial",
instructor_wrapper_partial(name="AsyncInstructor.create_partial"),
),
]
)
_instructor_patcher: Optional[MultiPatcher] = None


def get_instructor_patcher(
settings: Optional[IntegrationSettings] = None,
) -> MultiPatcher:
global _instructor_patcher

if _instructor_patcher is not None:
return _instructor_patcher

if settings is None:
settings = IntegrationSettings()

base = settings.op_settings

create_settings = base.model_copy(update={"name": base.name or "Instructor.create"})
async_create_settings = base.model_copy(
update={"name": base.name or "AsyncInstructor.create"}
)
create_partial_settings = base.model_copy(
update={"name": base.name or "Instructor.create_partial"}
)
async_create_partial_settings = base.model_copy(
update={"name": base.name or "AsyncInstructor.create_partial"}
)

_instructor_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"Instructor.create",
instructor_wrapper_sync(create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"AsyncInstructor.create",
instructor_wrapper_async(async_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"Instructor.create_partial",
instructor_wrapper_partial(create_partial_settings),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"AsyncInstructor.create_partial",
instructor_wrapper_partial(async_create_partial_settings),
),
]
)

return _instructor_patcher
57 changes: 40 additions & 17 deletions weave/integrations/litellm/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
from typing import TYPE_CHECKING, Any, Callable, Optional

import weave
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher

if TYPE_CHECKING:
from litellm.utils import ModelResponse

_litellm_patcher: Optional[MultiPatcher] = None


# This accumulator is nearly identical to the mistral accumulator, just with different types.
def litellm_accumulator(
Expand Down Expand Up @@ -82,10 +85,10 @@ def should_use_accumulator(inputs: dict) -> bool:
return isinstance(inputs, dict) and bool(inputs.get("stream"))


def make_wrapper(name: str) -> Callable:
def make_wrapper(settings: OpSettings) -> Callable:
def litellm_wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: litellm_accumulator,
Expand All @@ -96,17 +99,37 @@ def litellm_wrapper(fn: Callable) -> Callable:
return litellm_wrapper


litellm_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("litellm"),
"completion",
make_wrapper("litellm.completion"),
),
SymbolPatcher(
lambda: importlib.import_module("litellm"),
"acompletion",
make_wrapper("litellm.acompletion"),
),
]
)
def get_litellm_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher:
global _litellm_patcher

if _litellm_patcher is not None:
return _litellm_patcher

if settings is None:
settings = IntegrationSettings()

base = settings.op_settings

completion_settings = base.model_copy(
update={"name": base.name or "litellm.completion"}
)
acompletion_settings = base.model_copy(
update={"name": base.name or "litellm.acompletion"}
)

_litellm_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("litellm"),
"completion",
make_wrapper(completion_settings),
),
SymbolPatcher(
lambda: importlib.import_module("litellm"),
"acompletion",
make_wrapper(acompletion_settings),
),
]
)

return _litellm_patcher
4 changes: 2 additions & 2 deletions weave/integrations/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
mistral_version = "1.0" # we need to return a patching function

if version.parse(mistral_version) < version.parse("1.0.0"):
from .v0.mistral import mistral_patcher
from .v0.mistral import get_mistral_patcher # noqa: F401

print(
f"Using MistralAI version {mistral_version}. Please consider upgrading to version 1.0.0 or later."
)
else:
from .v1.mistral import mistral_patcher # noqa: F401
from .v1.mistral import get_mistral_patcher # noqa: F401
Loading

0 comments on commit da6fb39

Please sign in to comment.