diff --git a/tests/conftest.py b/tests/conftest.py
index b28187a3833..85e9b53c36b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -477,7 +477,9 @@ def __getattribute__(self, name):
return ServerRecorder(server)
-def create_client(request) -> weave_init.InitializedClient:
+def create_client(
+ request, autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None
+) -> weave_init.InitializedClient:
inited_client = None
weave_server_flag = request.config.getoption("--weave-server")
server: tsi.TraceServerInterface
@@ -513,7 +515,7 @@ def create_client(request) -> weave_init.InitializedClient:
entity, project, make_server_recorder(server)
)
inited_client = weave_init.InitializedClient(client)
- autopatch.autopatch()
+ autopatch.autopatch(autopatch_settings)
return inited_client
@@ -527,6 +529,7 @@ def client(request):
yield inited_client.client
finally:
inited_client.reset()
+ autopatch.reset_autopatch()
@pytest.fixture()
@@ -534,12 +537,13 @@ def client_creator(request):
"""This fixture is useful for delaying the creation of the client (ex. when you want to set settings first)"""
@contextlib.contextmanager
- def client():
- inited_client = create_client(request)
+ def client(autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None):
+ inited_client = create_client(request, autopatch_settings)
try:
yield inited_client.client
finally:
inited_client.reset()
+ autopatch.reset_autopatch()
yield client
diff --git a/tests/integrations/openai/test_autopatch.py b/tests/integrations/openai/test_autopatch.py
new file mode 100644
index 00000000000..f6a537fbf19
--- /dev/null
+++ b/tests/integrations/openai/test_autopatch.py
@@ -0,0 +1,87 @@
+# This is included here for convenience. Instead of creating a dummy API, we can test
+# autopatching against the actual OpenAI API.
+
+from typing import Any
+
+import pytest
+from openai import OpenAI
+
+from weave.integrations.openai import openai_sdk
+from weave.trace.autopatch import AutopatchSettings, IntegrationSettings, OpSettings
+
+
+@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode
+@pytest.mark.vcr(
+ filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
+)
+def test_disabled_integration_doesnt_patch(client_creator):
+ autopatch_settings = AutopatchSettings(
+ openai=IntegrationSettings(enabled=False),
+ )
+
+ with client_creator(autopatch_settings=autopatch_settings) as client:
+ oaiclient = OpenAI()
+ oaiclient.chat.completions.create(
+ model="gpt-4o",
+ messages=[{"role": "user", "content": "tell me a joke"}],
+ )
+
+ calls = list(client.get_calls())
+ assert len(calls) == 0
+
+
+@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode
+@pytest.mark.vcr(
+ filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
+)
+def test_enabled_integration_patches(client_creator):
+ autopatch_settings = AutopatchSettings(
+ openai=IntegrationSettings(enabled=True),
+ )
+
+ with client_creator(autopatch_settings=autopatch_settings) as client:
+ oaiclient = OpenAI()
+ oaiclient.chat.completions.create(
+ model="gpt-4o",
+ messages=[{"role": "user", "content": "tell me a joke"}],
+ )
+
+ calls = list(client.get_calls())
+ assert len(calls) == 1
+
+
+@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode
+@pytest.mark.vcr(
+ filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
+)
+def test_passthrough_op_kwargs(client_creator):
+ def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
+ print("CALLING THIS FUNC")
+ return dict.fromkeys(inputs, "REDACTED")
+
+ autopatch_settings = AutopatchSettings(
+ openai=IntegrationSettings(
+ op_settings=OpSettings(
+ postprocess_inputs=redact_inputs,
+ )
+ )
+ )
+
+ # Explicitly reset the patcher here to pretend like we're starting fresh. We need
+ # to do this because `_openai_patcher` is a global variable that is shared across
+ # tests. If we don't reset it, it will retain the state from the previous test,
+ # which can cause this test to fail.
+ openai_sdk._openai_patcher = None
+
+ with client_creator(autopatch_settings=autopatch_settings) as client:
+ oaiclient = OpenAI()
+ oaiclient.chat.completions.create(
+ model="gpt-4o",
+ messages=[{"role": "user", "content": "tell me a joke"}],
+ )
+
+ calls = list(client.get_calls())
+ assert len(calls) == 1
+
+ call = calls[0]
+ assert all(v == "REDACTED" for v in call.inputs.values())
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx
index ef6bcbd69ff..0b3c9603fef 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx
@@ -96,7 +96,7 @@ export const FeedbackSidebar = ({
Feedback
-
+
{humanAnnotationSpecs.length > 0 ? (
<>
diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py
index 7814700d4d3..8f4aca93805 100644
--- a/weave/integrations/openai/openai_sdk.py
+++ b/weave/integrations/openai/openai_sdk.py
@@ -1,15 +1,19 @@
+import dataclasses
import importlib
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional
import weave
+from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op import Op, ProcessedInputs
from weave.trace.op_extensions.accumulator import add_accumulator
-from weave.trace.patcher import MultiPatcher, SymbolPatcher
+from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionChunk
+_openai_patcher: Optional[MultiPatcher] = None
+
def maybe_unwrap_api_response(value: Any) -> Any:
"""If the caller requests a raw response, we unwrap the APIResponse object.
@@ -305,20 +309,16 @@ def openai_on_input_handler(
return None
-def create_wrapper_sync(
- name: str,
-) -> Callable[[Callable], Callable]:
+def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
"We need to do this so we can check if `stream` is used"
def _add_stream_options(fn: Callable) -> Callable:
@wraps(fn)
def _wrapper(*args: Any, **kwargs: Any) -> Any:
- if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None:
+ if kwargs.get("stream") and kwargs.get("stream_options") is None:
kwargs["stream_options"] = {"include_usage": True}
- return fn(
- *args, **kwargs
- ) # This is where the final execution of fn is happening.
+ return fn(*args, **kwargs)
return _wrapper
@@ -327,8 +327,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool:
return True
return False
- op = weave.op()(_add_stream_options(fn))
- op.name = name # type: ignore
+ op_kwargs = dataclasses.asdict(settings)
+ op = weave.op(_add_stream_options(fn), **op_kwargs)
op._set_on_input_handler(openai_on_input_handler)
return add_accumulator(
op, # type: ignore
@@ -345,16 +345,14 @@ def _openai_stream_options_is_set(inputs: dict) -> bool:
# Surprisingly, the async `client.chat.completions.create` does not pass
# `inspect.iscoroutinefunction`, so we can't dispatch on it and must write
# it manually here...
-def create_wrapper_async(
- name: str,
-) -> Callable[[Callable], Callable]:
+def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
"We need to do this so we can check if `stream` is used"
def _add_stream_options(fn: Callable) -> Callable:
@wraps(fn)
async def _wrapper(*args: Any, **kwargs: Any) -> Any:
- if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None:
+ if kwargs.get("stream") and kwargs.get("stream_options") is None:
kwargs["stream_options"] = {"include_usage": True}
return await fn(*args, **kwargs)
@@ -365,8 +363,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool:
return True
return False
- op = weave.op()(_add_stream_options(fn))
- op.name = name # type: ignore
+ op_kwargs = dataclasses.asdict(settings)
+ op = weave.op(_add_stream_options(fn), **op_kwargs)
op._set_on_input_handler(openai_on_input_handler)
return add_accumulator(
op, # type: ignore
@@ -380,28 +378,63 @@ def _openai_stream_options_is_set(inputs: dict) -> bool:
return wrapper
-symbol_patchers = [
- # Patch the Completions.create method
- SymbolPatcher(
- lambda: importlib.import_module("openai.resources.chat.completions"),
- "Completions.create",
- create_wrapper_sync(name="openai.chat.completions.create"),
- ),
- SymbolPatcher(
- lambda: importlib.import_module("openai.resources.chat.completions"),
- "AsyncCompletions.create",
- create_wrapper_async(name="openai.chat.completions.create"),
- ),
- SymbolPatcher(
- lambda: importlib.import_module("openai.resources.beta.chat.completions"),
- "Completions.parse",
- create_wrapper_sync(name="openai.beta.chat.completions.parse"),
- ),
- SymbolPatcher(
- lambda: importlib.import_module("openai.resources.beta.chat.completions"),
- "AsyncCompletions.parse",
- create_wrapper_async(name="openai.beta.chat.completions.parse"),
- ),
-]
-
-openai_patcher = MultiPatcher(symbol_patchers) # type: ignore
+def get_openai_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher:
+ if settings is None:
+ settings = IntegrationSettings()
+
+ if not settings.enabled:
+ return NoOpPatcher()
+
+ global _openai_patcher
+ if _openai_patcher is not None:
+ return _openai_patcher
+
+ base = settings.op_settings
+
+ completions_create_settings = dataclasses.replace(
+ base,
+ name=base.name or "openai.chat.completions.create",
+ )
+ async_completions_create_settings = dataclasses.replace(
+ base,
+ name=base.name or "openai.chat.completions.create",
+ )
+ completions_parse_settings = dataclasses.replace(
+ base,
+ name=base.name or "openai.beta.chat.completions.parse",
+ )
+ async_completions_parse_settings = dataclasses.replace(
+ base,
+ name=base.name or "openai.beta.chat.completions.parse",
+ )
+
+ _openai_patcher = MultiPatcher(
+ [
+ SymbolPatcher(
+ lambda: importlib.import_module("openai.resources.chat.completions"),
+ "Completions.create",
+ create_wrapper_sync(settings=completions_create_settings),
+ ),
+ SymbolPatcher(
+ lambda: importlib.import_module("openai.resources.chat.completions"),
+ "AsyncCompletions.create",
+ create_wrapper_async(settings=async_completions_create_settings),
+ ),
+ SymbolPatcher(
+ lambda: importlib.import_module(
+ "openai.resources.beta.chat.completions"
+ ),
+ "Completions.parse",
+ create_wrapper_sync(settings=completions_parse_settings),
+ ),
+ SymbolPatcher(
+ lambda: importlib.import_module(
+ "openai.resources.beta.chat.completions"
+ ),
+ "AsyncCompletions.parse",
+ create_wrapper_async(settings=async_completions_parse_settings),
+ ),
+ ]
+ )
+
+ return _openai_patcher
diff --git a/weave/scorers/llm_utils.py b/weave/scorers/llm_utils.py
index 68ae2ccb366..eef6f018b0f 100644
--- a/weave/scorers/llm_utils.py
+++ b/weave/scorers/llm_utils.py
@@ -2,10 +2,6 @@
from typing import TYPE_CHECKING, Any, Union
-from weave.trace.autopatch import autopatch
-
-autopatch() # ensure both weave patching and instructor patching are applied
-
OPENAI_DEFAULT_MODEL = "gpt-4o"
OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
OPENAI_DEFAULT_MODERATION_MODEL = "text-moderation-latest"
diff --git a/weave/trace/api.py b/weave/trace/api.py
index ee8131b0875..294308cbb67 100644
--- a/weave/trace/api.py
+++ b/weave/trace/api.py
@@ -13,6 +13,7 @@
# There is probably a better place for this, but including here for now to get the fix in.
from weave import type_handlers # noqa: F401
from weave.trace import urls, util, weave_client, weave_init
+from weave.trace.autopatch import AutopatchSettings
from weave.trace.constants import TRACE_OBJECT_EMOJI
from weave.trace.context import call_context
from weave.trace.context import weave_client_context as weave_client_context
@@ -32,6 +33,7 @@ def init(
project_name: str,
*,
settings: UserSettings | dict[str, Any] | None = None,
+ autopatch_settings: AutopatchSettings | None = None,
) -> weave_client.WeaveClient:
"""Initialize weave tracking, logging to a wandb project.
@@ -52,7 +54,12 @@ def init(
if should_disable_weave():
return weave_init.init_weave_disabled().client
- return weave_init.init_weave(project_name).client
+ initialized_client = weave_init.init_weave(
+ project_name,
+ autopatch_settings=autopatch_settings,
+ )
+
+ return initialized_client.client
@contextlib.contextmanager
diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py
index 3a5dca14556..f6873f5cf47 100644
--- a/weave/trace/autopatch.py
+++ b/weave/trace/autopatch.py
@@ -4,8 +4,15 @@
check if libraries are installed and imported and patch in the case that they are.
"""
+from __future__ import annotations
-def autopatch() -> None:
+from dataclasses import dataclass, field
+from typing import Any, Callable
+
+from weave.trace.weave_client import Call
+
+
+def autopatch(settings: AutopatchSettings | None = None) -> None:
from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher
from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher
from weave.integrations.cohere.cohere_sdk import cohere_patcher
@@ -20,10 +27,10 @@ def autopatch() -> None:
from weave.integrations.llamaindex.llamaindex import llamaindex_patcher
from weave.integrations.mistral import mistral_patcher
from weave.integrations.notdiamond.tracing import notdiamond_patcher
- from weave.integrations.openai.openai_sdk import openai_patcher
+ from weave.integrations.openai.openai_sdk import get_openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher
- openai_patcher.attempt_patch()
+ get_openai_patcher(settings.openai).attempt_patch()
mistral_patcher.attempt_patch()
litellm_patcher.attempt_patch()
llamaindex_patcher.attempt_patch()
@@ -54,10 +61,10 @@ def reset_autopatch() -> None:
from weave.integrations.llamaindex.llamaindex import llamaindex_patcher
from weave.integrations.mistral import mistral_patcher
from weave.integrations.notdiamond.tracing import notdiamond_patcher
- from weave.integrations.openai.openai_sdk import openai_patcher
+ from weave.integrations.openai.openai_sdk import get_openai_patcher
from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher
- openai_patcher.undo_patch()
+ get_openai_patcher().undo_patch()
mistral_patcher.undo_patch()
litellm_patcher.undo_patch()
llamaindex_patcher.undo_patch()
@@ -71,3 +78,43 @@ def reset_autopatch() -> None:
google_genai_patcher.undo_patch()
notdiamond_patcher.undo_patch()
vertexai_patcher.undo_patch()
+
+
+@dataclass
+class OpSettings:
+ """Op settings for a specific integration.
+ These currently subset the `op` decorator args to provide a consistent interface
+ when working with auto-patched functions. See the `op` decorator for more details."""
+
+ name: str | None = None
+ call_display_name: str | Callable[[Call], str] | None = None
+ postprocess_inputs: Callable[[dict[str, Any]], dict[str, Any]] | None = None
+ postprocess_output: Callable[[Any], Any] | None = None
+
+
+@dataclass
+class IntegrationSettings:
+ """Configuration for a specific integration."""
+
+ enabled: bool = True
+ op_settings: OpSettings = field(default_factory=OpSettings)
+
+
+@dataclass
+class AutopatchSettings:
+ """Settings for auto-patching integrations."""
+
+ anthropic: IntegrationSettings = field(default_factory=IntegrationSettings)
+ cerebras: IntegrationSettings = field(default_factory=IntegrationSettings)
+ cohere: IntegrationSettings = field(default_factory=IntegrationSettings)
+ dspy: IntegrationSettings = field(default_factory=IntegrationSettings)
+ google_ai_studio: IntegrationSettings = field(default_factory=IntegrationSettings)
+ groq: IntegrationSettings = field(default_factory=IntegrationSettings)
+ instructor: IntegrationSettings = field(default_factory=IntegrationSettings)
+ langchain: IntegrationSettings = field(default_factory=IntegrationSettings)
+ litellm: IntegrationSettings = field(default_factory=IntegrationSettings)
+ llamaindex: IntegrationSettings = field(default_factory=IntegrationSettings)
+ mistral: IntegrationSettings = field(default_factory=IntegrationSettings)
+ notdiamond: IntegrationSettings = field(default_factory=IntegrationSettings)
+ openai: IntegrationSettings = field(default_factory=IntegrationSettings)
+ vertexai: IntegrationSettings = field(default_factory=IntegrationSettings)
diff --git a/weave/trace/patcher.py b/weave/trace/patcher.py
index 1567c4e2bb9..c1d0d653ffa 100644
--- a/weave/trace/patcher.py
+++ b/weave/trace/patcher.py
@@ -17,6 +17,14 @@ def undo_patch(self) -> bool:
raise NotImplementedError()
+class NoOpPatcher(Patcher):
+ def attempt_patch(self) -> bool:
+ return True
+
+ def undo_patch(self) -> bool:
+ return True
+
+
class MultiPatcher(Patcher):
def __init__(self, patchers: Sequence[Patcher]) -> None:
self.patchers = patchers
diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py
index 563dcbdaed4..f51d42d5018 100644
--- a/weave/trace/weave_init.py
+++ b/weave/trace/weave_init.py
@@ -63,7 +63,9 @@ def get_entity_project_from_project_name(project_name: str) -> tuple[str, str]:
def init_weave(
- project_name: str, ensure_project_exists: bool = True
+ project_name: str,
+ ensure_project_exists: bool = True,
+ autopatch_settings: autopatch.AutopatchSettings | None = None,
) -> InitializedClient:
global _current_inited_client
if _current_inited_client is not None:
@@ -120,7 +122,7 @@ def init_weave(
# autopatching is only supported for the wandb client, because OpenAI calls are not
# logged in local mode currently. When that's fixed, this autopatch call can be
# moved to InitializedClient.__init__
- autopatch.autopatch()
+ autopatch.autopatch(autopatch_settings)
username = get_username()
try: