Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 12, 2024
1 parent 2e10809 commit e6423d2
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 39 deletions.
14 changes: 7 additions & 7 deletions weave/integrations/instructor/instructor_iterable_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import BaseModel

import weave
from weave.trace.autopatch import OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator


Expand All @@ -27,10 +28,10 @@ def should_accumulate_iterable(inputs: dict) -> bool:
return False


def instructor_wrapper_sync(name: str) -> Callable[[Callable], Callable]:
def instructor_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op(fn)
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: instructor_iterable_accumulator,
Expand All @@ -40,7 +41,7 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


def instructor_wrapper_async(name: str) -> Callable[[Callable], Callable]:
def instructor_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
def _fn_wrapper(fn: Callable) -> Callable:
@wraps(fn)
Expand All @@ -49,9 +50,8 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:

return _async_wrapper

"We need to do this so we can check if `stream` is used"
op = weave.op(_fn_wrapper(fn))
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(_fn_wrapper(fn), **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: instructor_iterable_accumulator,
Expand Down
7 changes: 4 additions & 3 deletions weave/integrations/instructor/instructor_partial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel

import weave
from weave.trace.autopatch import OpSettings
from weave.trace.op_extensions.accumulator import add_accumulator


Expand All @@ -14,10 +15,10 @@ def instructor_partial_accumulator(
return acc


def instructor_wrapper_partial(name: str) -> Callable[[Callable], Callable]:
def instructor_wrapper_partial(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
op = weave.op()(fn)
op.name = name # type: ignore
op_kwargs = settings.model_dump()
op = weave.op(fn, **op_kwargs)
return add_accumulator(
op, # type: ignore
make_accumulator=lambda inputs: instructor_partial_accumulator,
Expand Down
84 changes: 59 additions & 25 deletions weave/integrations/instructor/instructor_sdk.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,65 @@
from __future__ import annotations

import importlib

from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.autopatch import IntegrationSettings
from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher

from .instructor_iterable_utils import instructor_wrapper_async, instructor_wrapper_sync
from .instructor_partial_utils import instructor_wrapper_partial

instructor_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"Instructor.create",
instructor_wrapper_sync(name="Instructor.create"),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"AsyncInstructor.create",
instructor_wrapper_async(name="AsyncInstructor.create"),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"Instructor.create_partial",
instructor_wrapper_partial(name="Instructor.create_partial"),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"AsyncInstructor.create_partial",
instructor_wrapper_partial(name="AsyncInstructor.create_partial"),
),
]
)
_instructor_patcher: MultiPatcher | None = None


def get_instructor_patcher(
settings: IntegrationSettings | None = None,
) -> MultiPatcher | NoOpPatcher:
if settings is None:
settings = IntegrationSettings()

if not settings.enabled:
return NoOpPatcher()

global _instructor_patcher
if _instructor_patcher is not None:
return _instructor_patcher

base = settings.op_settings

create_settings = base.model_copy(update={"name": base.name or "Instructor.create"})
async_create_settings = base.model_copy(
update={"name": base.name or "AsyncInstructor.create"}
)
create_partial_settings = base.model_copy(
update={"name": base.name or "Instructor.create_partial"}
)
async_create_partial_settings = base.model_copy(
update={"name": base.name or "AsyncInstructor.create_partial"}
)

_instructor_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"Instructor.create",
instructor_wrapper_sync(create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"AsyncInstructor.create",
instructor_wrapper_async(async_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"Instructor.create_partial",
instructor_wrapper_partial(create_partial_settings),
),
SymbolPatcher(
lambda: importlib.import_module("instructor.client"),
"AsyncInstructor.create_partial",
instructor_wrapper_partial(async_create_partial_settings),
),
]
)

return _instructor_patcher
8 changes: 4 additions & 4 deletions weave/trace/autopatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
get_google_genai_patcher,
)
from weave.integrations.groq.groq_sdk import get_groq_patcher
from weave.integrations.instructor.instructor_sdk import instructor_patcher
from weave.integrations.instructor.instructor_sdk import get_instructor_patcher
from weave.integrations.langchain.langchain import langchain_patcher
from weave.integrations.litellm.litellm import get_litellm_patcher
from weave.integrations.llamaindex.llamaindex import llamaindex_patcher
Expand All @@ -75,7 +75,7 @@ def autopatch(settings: Optional[AutopatchSettings] = None) -> None:
get_litellm_patcher(settings.litellm).attempt_patch()
get_anthropic_patcher(settings.anthropic).attempt_patch()
get_groq_patcher(settings.groq).attempt_patch()
instructor_patcher.attempt_patch()
get_instructor_patcher(settings.instructor).attempt_patch()
get_dspy_patcher(settings.dspy).attempt_patch()
get_cerebras_patcher(settings.cerebras).attempt_patch()
get_cohere_patcher(settings.cohere).attempt_patch()
Expand All @@ -96,7 +96,7 @@ def reset_autopatch() -> None:
get_google_genai_patcher,
)
from weave.integrations.groq.groq_sdk import get_groq_patcher
from weave.integrations.instructor.instructor_sdk import instructor_patcher
from weave.integrations.instructor.instructor_sdk import get_instructor_patcher
from weave.integrations.langchain.langchain import langchain_patcher
from weave.integrations.litellm.litellm import get_litellm_patcher
from weave.integrations.llamaindex.llamaindex import llamaindex_patcher
Expand All @@ -110,7 +110,7 @@ def reset_autopatch() -> None:
get_litellm_patcher().undo_patch()
get_anthropic_patcher().undo_patch()
get_groq_patcher().undo_patch()
instructor_patcher.undo_patch()
get_instructor_patcher().undo_patch()
get_dspy_patcher().undo_patch()
get_cerebras_patcher().undo_patch()
get_cohere_patcher().undo_patch()
Expand Down

0 comments on commit e6423d2

Please sign in to comment.