Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Dec 11, 2024
1 parent c19fe6b commit 7a70b54
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 58 deletions.
12 changes: 8 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def __getattribute__(self, name):
return ServerRecorder(server)


def create_client(request) -> weave_init.InitializedClient:
def create_client(
request, autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None
) -> weave_init.InitializedClient:
inited_client = None
weave_server_flag = request.config.getoption("--weave-server")
server: tsi.TraceServerInterface
Expand Down Expand Up @@ -513,7 +515,7 @@ def create_client(request) -> weave_init.InitializedClient:
entity, project, make_server_recorder(server)
)
inited_client = weave_init.InitializedClient(client)
autopatch.autopatch()
autopatch.autopatch(autopatch_settings)

return inited_client

Expand All @@ -527,19 +529,21 @@ def client(request):
yield inited_client.client
finally:
inited_client.reset()
autopatch.reset_autopatch()


@pytest.fixture()
def client_creator(request):
"""This fixture is useful for delaying the creation of the client (ex. when you want to set settings first)"""

@contextlib.contextmanager
def client():
inited_client = create_client(request)
def client(autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None):
inited_client = create_client(request, autopatch_settings)
try:
yield inited_client.client
finally:
inited_client.reset()
autopatch.reset_autopatch()

yield client

Expand Down
87 changes: 87 additions & 0 deletions tests/integrations/openai/test_autopatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# This is included here for convenience. Instead of creating a dummy API, we can test
# autopatching against the actual OpenAI API.

from typing import Any

import pytest
from openai import OpenAI

from weave.integrations.openai import openai_sdk
from weave.trace.autopatch import AutopatchSettings, IntegrationSettings, OpSettings


@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode
@pytest.mark.vcr(
filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
)
def test_disabled_integration_doesnt_patch(client_creator):
autopatch_settings = AutopatchSettings(
openai=IntegrationSettings(enabled=False),
)

with client_creator(autopatch_settings=autopatch_settings) as client:
oaiclient = OpenAI()
oaiclient.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": "tell me a joke"}],
)

calls = list(client.get_calls())
assert len(calls) == 0


@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode
@pytest.mark.vcr(
filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
)
def test_enabled_integration_patches(client_creator):
autopatch_settings = AutopatchSettings(
openai=IntegrationSettings(enabled=True),
)

with client_creator(autopatch_settings=autopatch_settings) as client:
oaiclient = OpenAI()
oaiclient.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": "tell me a joke"}],
)

calls = list(client.get_calls())
assert len(calls) == 1


@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode
@pytest.mark.vcr(
filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"]
)
def test_passthrough_op_kwargs(client_creator):
def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
print("CALLING THIS FUNC")
return dict.fromkeys(inputs, "REDACTED")

autopatch_settings = AutopatchSettings(
openai=IntegrationSettings(
op_settings=OpSettings(
postprocess_inputs=redact_inputs,
)
)
)

# Explicitly reset the patcher here to pretend like we're starting fresh. We need
# to do this because `_openai_patcher` is a global variable that is shared across
# tests. If we don't reset it, it will retain the state from the previous test,
# which can cause this test to fail.
openai_sdk._openai_patcher = None

with client_creator(autopatch_settings=autopatch_settings) as client:
oaiclient = OpenAI()
oaiclient.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": "tell me a joke"}],
)

calls = list(client.get_calls())
assert len(calls) == 1

call = calls[0]
assert all(v == "REDACTED" for v in call.inputs.values())
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ export const FeedbackSidebar = ({
<div className="text-lg font-semibold">Feedback</div>
<div className="flex-grow" />
</div>
<div className="min-h-1 mb-8 h-1 flex-grow overflow-auto bg-moon-300" />
<div className="min-h-1 mb-8 h-1 overflow-auto bg-moon-300" />
{humanAnnotationSpecs.length > 0 ? (
<>
<div className="ml-6 h-full flex-grow overflow-auto">
Expand Down
115 changes: 74 additions & 41 deletions weave/integrations/openai/openai_sdk.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import dataclasses
import importlib
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional

import weave
from weave.trace.autopatch import IntegrationSettings, OpSettings
from weave.trace.op import Op, ProcessedInputs
from weave.trace.op_extensions.accumulator import add_accumulator
from weave.trace.patcher import MultiPatcher, SymbolPatcher
from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher

if TYPE_CHECKING:
from openai.types.chat import ChatCompletionChunk

_openai_patcher: Optional[MultiPatcher] = None


def maybe_unwrap_api_response(value: Any) -> Any:
"""If the caller requests a raw response, we unwrap the APIResponse object.
Expand Down Expand Up @@ -305,20 +309,16 @@ def openai_on_input_handler(
return None


def create_wrapper_sync(
name: str,
) -> Callable[[Callable], Callable]:
def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
"We need to do this so we can check if `stream` is used"

def _add_stream_options(fn: Callable) -> Callable:
@wraps(fn)
def _wrapper(*args: Any, **kwargs: Any) -> Any:
if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None:
if kwargs.get("stream") and kwargs.get("stream_options") is None:
kwargs["stream_options"] = {"include_usage": True}
return fn(
*args, **kwargs
) # This is where the final execution of fn is happening.
return fn(*args, **kwargs)

return _wrapper

Expand All @@ -327,8 +327,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool:
return True
return False

op = weave.op()(_add_stream_options(fn))
op.name = name # type: ignore
op_kwargs = dataclasses.asdict(settings)
op = weave.op(_add_stream_options(fn), **op_kwargs)
op._set_on_input_handler(openai_on_input_handler)
return add_accumulator(
op, # type: ignore
Expand All @@ -345,16 +345,14 @@ def _openai_stream_options_is_set(inputs: dict) -> bool:
# Surprisingly, the async `client.chat.completions.create` does not pass
# `inspect.iscoroutinefunction`, so we can't dispatch on it and must write
# it manually here...
def create_wrapper_async(
name: str,
) -> Callable[[Callable], Callable]:
def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]:
def wrapper(fn: Callable) -> Callable:
"We need to do this so we can check if `stream` is used"

def _add_stream_options(fn: Callable) -> Callable:
@wraps(fn)
async def _wrapper(*args: Any, **kwargs: Any) -> Any:
if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None:
if kwargs.get("stream") and kwargs.get("stream_options") is None:
kwargs["stream_options"] = {"include_usage": True}
return await fn(*args, **kwargs)

Expand All @@ -365,8 +363,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool:
return True
return False

op = weave.op()(_add_stream_options(fn))
op.name = name # type: ignore
op_kwargs = dataclasses.asdict(settings)
op = weave.op(_add_stream_options(fn), **op_kwargs)
op._set_on_input_handler(openai_on_input_handler)
return add_accumulator(
op, # type: ignore
Expand All @@ -380,28 +378,63 @@ def _openai_stream_options_is_set(inputs: dict) -> bool:
return wrapper


symbol_patchers = [
# Patch the Completions.create method
SymbolPatcher(
lambda: importlib.import_module("openai.resources.chat.completions"),
"Completions.create",
create_wrapper_sync(name="openai.chat.completions.create"),
),
SymbolPatcher(
lambda: importlib.import_module("openai.resources.chat.completions"),
"AsyncCompletions.create",
create_wrapper_async(name="openai.chat.completions.create"),
),
SymbolPatcher(
lambda: importlib.import_module("openai.resources.beta.chat.completions"),
"Completions.parse",
create_wrapper_sync(name="openai.beta.chat.completions.parse"),
),
SymbolPatcher(
lambda: importlib.import_module("openai.resources.beta.chat.completions"),
"AsyncCompletions.parse",
create_wrapper_async(name="openai.beta.chat.completions.parse"),
),
]

openai_patcher = MultiPatcher(symbol_patchers) # type: ignore
def get_openai_patcher(settings: Optional[IntegrationSettings] = None) -> MultiPatcher:
if settings is None:
settings = IntegrationSettings()

if not settings.enabled:
return NoOpPatcher()

global _openai_patcher
if _openai_patcher is not None:
return _openai_patcher

base = settings.op_settings

completions_create_settings = dataclasses.replace(
base,
name=base.name or "openai.chat.completions.create",
)
async_completions_create_settings = dataclasses.replace(
base,
name=base.name or "openai.chat.completions.create",
)
completions_parse_settings = dataclasses.replace(
base,
name=base.name or "openai.beta.chat.completions.parse",
)
async_completions_parse_settings = dataclasses.replace(
base,
name=base.name or "openai.beta.chat.completions.parse",
)

_openai_patcher = MultiPatcher(
[
SymbolPatcher(
lambda: importlib.import_module("openai.resources.chat.completions"),
"Completions.create",
create_wrapper_sync(settings=completions_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module("openai.resources.chat.completions"),
"AsyncCompletions.create",
create_wrapper_async(settings=async_completions_create_settings),
),
SymbolPatcher(
lambda: importlib.import_module(
"openai.resources.beta.chat.completions"
),
"Completions.parse",
create_wrapper_sync(settings=completions_parse_settings),
),
SymbolPatcher(
lambda: importlib.import_module(
"openai.resources.beta.chat.completions"
),
"AsyncCompletions.parse",
create_wrapper_async(settings=async_completions_parse_settings),
),
]
)

return _openai_patcher
4 changes: 0 additions & 4 deletions weave/scorers/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@

from typing import TYPE_CHECKING, Any, Union

from weave.trace.autopatch import autopatch

autopatch() # ensure both weave patching and instructor patching are applied

OPENAI_DEFAULT_MODEL = "gpt-4o"
OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
OPENAI_DEFAULT_MODERATION_MODEL = "text-moderation-latest"
Expand Down
9 changes: 8 additions & 1 deletion weave/trace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# There is probably a better place for this, but including here for now to get the fix in.
from weave import type_handlers # noqa: F401
from weave.trace import urls, util, weave_client, weave_init
from weave.trace.autopatch import AutopatchSettings
from weave.trace.constants import TRACE_OBJECT_EMOJI
from weave.trace.context import call_context
from weave.trace.context import weave_client_context as weave_client_context
Expand All @@ -32,6 +33,7 @@ def init(
project_name: str,
*,
settings: UserSettings | dict[str, Any] | None = None,
autopatch_settings: AutopatchSettings | None = None,
) -> weave_client.WeaveClient:
"""Initialize weave tracking, logging to a wandb project.
Expand All @@ -52,7 +54,12 @@ def init(
if should_disable_weave():
return weave_init.init_weave_disabled().client

return weave_init.init_weave(project_name).client
initialized_client = weave_init.init_weave(
project_name,
autopatch_settings=autopatch_settings,
)

return initialized_client.client


@contextlib.contextmanager
Expand Down
Loading

0 comments on commit 7a70b54

Please sign in to comment.