diff --git a/docs/docs/guides/evaluation/scorers.md b/docs/docs/guides/evaluation/scorers.md index e313dcb5582..233d162b15c 100644 --- a/docs/docs/guides/evaluation/scorers.md +++ b/docs/docs/guides/evaluation/scorers.md @@ -455,7 +455,7 @@ In Weave, Scorers are used to evaluate AI outputs and return evaluation metrics. from weave.scorers import OpenAIModerationScorer from openai import OpenAI - oai_client = OpenAI(api_key=...) # initialize your LLM client here + oai_client = OpenAI() # initialize your LLM client here scorer = OpenAIModerationScorer( client=oai_client, diff --git a/docs/docs/guides/integrations/local_models.md b/docs/docs/guides/integrations/local_models.md index 090d22a3f76..2597ad19dd1 100644 --- a/docs/docs/guides/integrations/local_models.md +++ b/docs/docs/guides/integrations/local_models.md @@ -14,7 +14,6 @@ First and most important, is the `base_url` change during the `openai.OpenAI()` ```python client = openai.OpenAI( - api_key='fake', base_url="http://localhost:1234", ) ``` diff --git a/docs/docs/guides/integrations/notdiamond.md b/docs/docs/guides/integrations/notdiamond.md index 3106ef11f1b..e98a23d1334 100644 --- a/docs/docs/guides/integrations/notdiamond.md +++ b/docs/docs/guides/integrations/notdiamond.md @@ -68,7 +68,6 @@ preference_id = train_router( response_column="actual", language="en", maximize=True, - api_key=api_key, ) ``` diff --git a/docs/docs/quickstart.md b/docs/docs/quickstart.md index 59ac16ef236..65bc813af03 100644 --- a/docs/docs/quickstart.md +++ b/docs/docs/quickstart.md @@ -50,7 +50,7 @@ _In this example, we're using openai so you will need to add an OpenAI [API key] import weave from openai import OpenAI - client = OpenAI(api_key="...") + client = OpenAI() # Weave will track the inputs, outputs and code of this function # highlight-next-line diff --git a/docs/docs/tutorial-tracing_2.md b/docs/docs/tutorial-tracing_2.md index 719ee99e012..d6e392d9c56 100644 --- a/docs/docs/tutorial-tracing_2.md +++ b/docs/docs/tutorial-tracing_2.md @@ -24,7 +24,7 @@ Building on our [basic tracing example](/quickstart), we will now add additional import json from openai import OpenAI - client = OpenAI(api_key="...") + client = OpenAI() # highlight-next-line @weave.op() diff --git a/pyproject.toml b/pyproject.toml index d392d527b60..f34757b315e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -226,7 +226,7 @@ module = "weave_query.*" ignore_errors = true [tool.bumpversion] -current_version = "0.51.24-dev0" +current_version = "0.51.25-dev0" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/tests/conftest.py b/tests/conftest.py index b28187a3833..85e9b53c36b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -477,7 +477,9 @@ def __getattribute__(self, name): return ServerRecorder(server) -def create_client(request) -> weave_init.InitializedClient: +def create_client( + request, autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None +) -> weave_init.InitializedClient: inited_client = None weave_server_flag = request.config.getoption("--weave-server") server: tsi.TraceServerInterface @@ -513,7 +515,7 @@ def create_client(request) -> weave_init.InitializedClient: entity, project, make_server_recorder(server) ) inited_client = weave_init.InitializedClient(client) - autopatch.autopatch() + autopatch.autopatch(autopatch_settings) return inited_client @@ -527,6 +529,7 @@ def client(request): yield inited_client.client finally: inited_client.reset() + autopatch.reset_autopatch() @pytest.fixture() @@ -534,12 +537,13 @@ def client_creator(request): """This fixture is useful for delaying the creation of the client (ex. when you want to set settings first)""" @contextlib.contextmanager - def client(): - inited_client = create_client(request) + def client(autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None): + inited_client = create_client(request, autopatch_settings) try: yield inited_client.client finally: inited_client.reset() + autopatch.reset_autopatch() yield client diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml new file mode 100644 index 00000000000..7245829a0b3 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFJNa9wwEL37V0x1yWVd7P3KspcSSKE5thvooSlGK40tJbJGSOOSNOx/ + L/Z+2KEp9KLDe/Me743mNQMQVostCGUkqza4/Eav1f1O/v4aVvvbL/Pdt7tVHUp1s/vsnhsx6xW0 + f0TFZ9VHRW1wyJb8kVYRJWPvWl4vFpvNYl2UA9GSRtfLmsD5kvJ5MV/mxSYv1iehIaswiS38yAAA + Xoe3j+g1PostFLMz0mJKskGxvQwBiEiuR4RMySaWnsVsJBV5Rj+k/m5eQJO/YkhP6JDJJ6htYxhQ + KgPEBuOnB//g7w2eJ438hcAGoek4fZgaR6y7JPtevnPuhB8uSR01IdI+nfgLXltvk6kiykS+T5WY + ghjYQwbwc9hI96akCJHawBXTE/resCyPdmL8ggm5PJFMLN2Iz1ezd9wqjSytS5ONCiWVQT0qx/XL + TluaENmk899h3vM+9ra++R/7kVAKA6OuQkRt1dvC41jE/kD/NXbZ8RBYpJfE2Fa19Q3GEO3xRupQ + qWslC9xLJUV2yP4AAAD//wMA4O+DUSwDAAA= + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f01fe3aabd037cf-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 02:20:01 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=xqe_jHZdTV5LijJQYQ3GMY5MjtVrCyxbFO4glgLvgD0-1733883601-1.0.1.1-p.DDUca_cHppJu2hXzzA0CXU1mtalxHUNfBWVgPIQj.UkU603pbNscCvSIi4_Zjlz9Zuc3.hjlvoyZxcDBJTsw; + path=/; expires=Wed, 11-Dec-24 02:50:01 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=WEjxXqkGswaEDhllTROGX_go9tgaWNJcUJ3cCd50xDI-1733883601764-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '607' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_8592a74b531c806f65c63c7471101cb6 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml new file mode 100644 index 00000000000..1895cdcd5f2 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLJbtswEL3rK6a85GIV8lYvl6KX9NQFrYEckkKgyZHImOII5KiJEfjf + C8qLHDQFeuHhbXgzw5cMQFgt1iCUkaya1uWf9Hyuv5H29ebu89Pt80Z9//H0RS2nX/e3LEbJQdtH + VHx2vVfUtA7Zkj/SKqBkTKnjxXS6XCwWk3FPNKTRJVvdcj6jfFJMZnmxzIsPJ6MhqzCKNdxnAAAv + /Zsqeo3PYg3F6Iw0GKOsUawvIgARyCVEyBhtZOmPdU+kIs/o+9Y/u4AjMBjwJoIEZ2vDuUEZGDU8 + 0g6hogB76tYP/sHfmT1o8jcMcYcOmXyEKlkApTJAbDB8TMKNwbPSyN8IbBDqjuO76xoBqy7KtAXf + OXfCD5e5HNVtoG088Re8st5GUwaUkXyaITK1omcPGcCvfn/dq5WINlDTcsm0Q58Cx+NjnBgONpCT + 2YlkYukGfDofvZFWamRpXbzav1BSGdSDcziW7LSlKyK7mvnvMm9lH+e2vv6f+IFQCltGXbYBtVWv + Bx5kAdN3/pfssuO+sIj7yNiUlfU1hjbY44+q2nKl54XSq1WxFdkh+wMAAP//AwAWTTnuWgMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eadbff439d2-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:01 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=8FO1yMjc3pMQWRpWrkIe5mcs39GLeqQPmgHQq0YTT8s-1733877721-1.0.1.1-i4G06DBN08aH1F1H73U_TB9OLK3jLsV1jXydB1cQ4Hqx7I.r8xDn.7hFRZe2hy3D_nABTG1nDcdDoXL_wYiqug; + path=/; expires=Wed, 11-Dec-24 01:12:01 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=jxwySgtriPkUP8L2os1nb_gRq_SSUo3yWFUyJmHPmGY-1733877721989-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '652' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_1c86d4fda2ad715edfd41bcd2f4bdd89 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml new file mode 100644 index 00000000000..f0cdca54158 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLBjtMwEL3nKwZfuDQoTXc3VS8IBOIACCQOHHZR5NrTxDTxWJ4J2rDq + v6Ok2SYrFomLD+/Ne3pvxg8JgHJW7UCZWotpQ5O+sdfXaL9+/PDp+L6gG3ffyudv735v+y82eLUa + FLT/iUYeVa8MtaFBcTTRJqIWHFzXxWazLYoiz0eiJYvNIKuCpFeU5ll+lWbbNLuZhDU5g6x2cJsA + ADyM7xDRW7xXO8hWj0iLzLpCtbsMAahIzYAozexYtBe1mklDXtCPqb/XPVjyLwXYOPTiWBgkdiyg + hVp+fefv/Fs0umMEqbGHVh8RugD4C2MvtfPVi6V3xEPHeqjmu6aZ8NMlbENViLTnib/gB+cd12VE + zeSHYCwU1MieEoAf41K6Jz1ViNQGKYWO6AfD9fpsp+YrLMh8IoVENzOeb1bPuJUWRbuGF0tVRpsa + 7aycL6A762hBJIvOf4d5zvvc2/nqf+xnwhgMgrYMEa0zTwvPYxGHP/qvscuOx8CKexZsy4PzFcYQ + 3fmbHEJpCqMz3GujVXJK/gAAAP//AwAyhdwOLwMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eb36bb3a240-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:02 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=Q_ATX8JU4jFqXJPdwlneOua9wmNmAaASyAfcbPyPqng-1733877722-1.0.1.1-eTMEvBW7oqQa2i3l.Or2I3LF_cCESxfseq.S9DBr8dAJWsVoFfPxKtr5vMaO6yj4hRW8XOSOHcgIcwwqbHrLbg; + path=/; expires=Wed, 11-Dec-24 01:12:02 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=2ak.tRpn6uEHbM8GrWy_ALtrN34jVSNIJI1mFG2etvM-1733877722703-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '476' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_52e061e1cc55cdd8847a7ba9342f1a14 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml new file mode 100644 index 00000000000..646c57c6123 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLLbtswELzrK7a89GIVsmLXsS9Fr0UvBQIERVMINLkS2VBcglwVcQL/ + e0H5IQVNgV54mNkZzCz3pQAQVosdCGUkqz648rNer7G2z3fL8ITq+X6lvn7x/SHhN/d9LxZZQftf + qPii+qCoDw7Zkj/RKqJkzK7Lzc3N7WazqeuR6Emjy7IucLmisq7qVVndltXHs9CQVZjEDn4UAAAv + 45sjeo1PYgfV4oL0mJLsUOyuQwAiksuIkCnZxNKzWEykIs/ox9T35gCa/HuG9IgOmXyC1naGAaUy + QGwwfnrwD/7O4GXSyN8IbBC6gdO7uXHEdkgy9/KDc2f8eE3qqAuR9unMX/HWeptME1Em8jlVYgpi + ZI8FwM9xI8OrkiJE6gM3TI/os+FyebIT0xfMyNWZZGLpJrxeL95wazSytC7NNiqUVAb1pJzWLwdt + aUYUs85/h3nL+9Tb+u5/7CdCKQyMugkRtVWvC09jEfOB/mvsuuMxsEiHxNg3rfUdxhDt6Uba0Gz1 + ulJ6u632ojgWfwAAAP//AwCOwDMjLAMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eb76b71ac9a-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:03 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=r.xSSsYQNFPvMiizFSvjQiecNA6Q1wQa0VR1YElfXi4-1733877723-1.0.1.1-GVW0i7wrpHCQSY5eXu7sIQgxYWl6jfeSordQ7JFxV3lO6UfFhwxRT92bBP4DfnrSYpBpRw3k4aONAURyvKctiQ; + path=/; expires=Wed, 11-Dec-24 01:12:03 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=CQJVOdASzL9ency5_q6SDaInTsvpjA240cIxf.AUwXM-1733877723385-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '523' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_c9c57cfa6f37a99aaf0abac013237ed6 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/test_autopatch.py b/tests/integrations/openai/test_autopatch.py new file mode 100644 index 00000000000..2c2f5201d3f --- /dev/null +++ b/tests/integrations/openai/test_autopatch.py @@ -0,0 +1,116 @@ +# This is included here for convenience. Instead of creating a dummy API, we can test +# autopatching against the actual OpenAI API. + +from typing import Any + +import pytest +from openai import OpenAI + +from weave.integrations.openai import openai_sdk +from weave.trace.autopatch import AutopatchSettings, IntegrationSettings, OpSettings + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_disabled_integration_doesnt_patch(client_creator): + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings(enabled=False), + ) + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 0 + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_enabled_integration_patches(client_creator): + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings(enabled=True), + ) + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_passthrough_op_kwargs(client_creator): + def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return dict.fromkeys(inputs, "REDACTED") + + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings( + op_settings=OpSettings( + postprocess_inputs=redact_inputs, + ) + ) + ) + + # Explicitly reset the patcher here to pretend like we're starting fresh. We need + # to do this because `_openai_patcher` is a global variable that is shared across + # tests. If we don't reset it, it will retain the state from the previous test, + # which can cause this test to fail. + openai_sdk._openai_patcher = None + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + call = calls[0] + assert all(v == "REDACTED" for v in call.inputs.values()) + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_configuration_with_dicts(client_creator): + def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return dict.fromkeys(inputs, "REDACTED") + + autopatch_settings = { + "openai": { + "op_settings": {"postprocess_inputs": redact_inputs}, + } + } + + openai_sdk._openai_patcher = None + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + call = calls[0] + assert all(v == "REDACTED" for v in call.inputs.values()) diff --git a/tests/trace/test_evaluations.py b/tests/trace/test_evaluations.py index e5c38ef0140..ab74d4c0c0b 100644 --- a/tests/trace/test_evaluations.py +++ b/tests/trace/test_evaluations.py @@ -1021,13 +1021,19 @@ def my_second_scorer(text, output, model_output): ds = [{"text": "hello"}] - with pytest.raises(ValueError, match="Both 'output' and 'model_output'"): + with pytest.raises( + ValueError, match="cannot include both `output` and `model_output`" + ): scorer = MyScorer() - with pytest.raises(ValueError, match="Both 'output' and 'model_output'"): + with pytest.raises( + ValueError, match="cannot include both `output` and `model_output`" + ): evaluation = weave.Evaluation(dataset=ds, scorers=[MyScorer()]) - with pytest.raises(ValueError, match="Both 'output' and 'model_output'"): + with pytest.raises( + ValueError, match="cannot include both `output` and `model_output`" + ): evaluation = weave.Evaluation(dataset=ds, scorers=[my_second_scorer]) diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py new file mode 100644 index 00000000000..3a6a92f2479 --- /dev/null +++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py @@ -0,0 +1,230 @@ +import types +from unittest.mock import Mock, call, patch + +import pytest + +from weave.trace_server import clickhouse_trace_server_migrator as trace_server_migrator +from weave.trace_server.clickhouse_trace_server_migrator import MigrationError + + +@pytest.fixture +def mock_costs(): + with patch( + "weave.trace_server.costs.insert_costs.should_insert_costs", return_value=False + ) as mock_should_insert: + with patch( + "weave.trace_server.costs.insert_costs.get_current_costs", return_value=[] + ) as mock_get_costs: + yield + + +@pytest.fixture +def migrator(): + ch_client = Mock() + migrator = trace_server_migrator.ClickHouseTraceServerMigrator(ch_client) + migrator._get_migration_status = Mock() + migrator._get_migrations = Mock() + migrator._determine_migrations_to_apply = Mock() + migrator._update_migration_status = Mock() + ch_client.command.reset_mock() + return migrator + + +def test_apply_migrations_with_target_version(mock_costs, migrator, tmp_path): + # Setup + migrator._get_migration_status.return_value = { + "curr_version": 1, + "partially_applied_version": None, + } + migrator._get_migrations.return_value = { + "1": {"up": "1.up.sql", "down": "1.down.sql"}, + "2": {"up": "2.up.sql", "down": "2.down.sql"}, + } + migrator._determine_migrations_to_apply.return_value = [(2, "2.up.sql")] + + # Create a temporary migration file + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + migration_file = migration_dir / "2.up.sql" + migration_file.write_text( + "CREATE TABLE test1 (id Int32);\nCREATE TABLE test2 (id Int32);" + ) + + # Mock the migration directory path + with patch("os.path.dirname") as mock_dirname: + mock_dirname.return_value = str(tmp_path) + + # Execute + migrator.apply_migrations("test_db", target_version=2) + + # Verify + migrator._get_migration_status.assert_called_once_with("test_db") + migrator._get_migrations.assert_called_once() + migrator._determine_migrations_to_apply.assert_called_once_with( + 1, migrator._get_migrations.return_value, 2 + ) + + # Verify migration execution + assert migrator._update_migration_status.call_count == 2 + migrator._update_migration_status.assert_has_calls( + [call("test_db", 2, is_start=True), call("test_db", 2, is_start=False)] + ) + + # Verify the actual SQL commands were executed + ch_client = migrator.ch_client + assert ch_client.command.call_count == 2 + ch_client.command.assert_has_calls( + [call("CREATE TABLE test1 (id Int32)"), call("CREATE TABLE test2 (id Int32)")] + ) + + +def test_execute_migration_command(mock_costs, migrator): + # Setup + ch_client = migrator.ch_client + ch_client.database = "original_db" + + # Execute + migrator._execute_migration_command("test_db", "CREATE TABLE test (id Int32)") + + # Verify + assert ch_client.database == "original_db" # Should restore original database + ch_client.command.assert_called_once_with("CREATE TABLE test (id Int32)") + + +def test_migration_replicated(mock_costs, migrator): + ch_client = migrator.ch_client + orig = "CREATE TABLE test (id String, project_id String) ENGINE = MergeTree ORDER BY (project_id, id);" + migrator._execute_migration_command("test_db", orig) + ch_client.command.assert_called_once_with(orig) + + +def test_update_migration_status(mock_costs, migrator): + # Don't mock _update_migration_status for this test + migrator._update_migration_status = types.MethodType( + trace_server_migrator.ClickHouseTraceServerMigrator._update_migration_status, + migrator, + ) + + # Test start of migration + migrator._update_migration_status("test_db", 2, is_start=True) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE partially_applied_version = 2 WHERE db_name = 'test_db'" + ) + + # Test end of migration + migrator._update_migration_status("test_db", 2, is_start=False) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE curr_version = 2, partially_applied_version = NULL WHERE db_name = 'test_db'" + ) + + +def test_is_safe_identifier(mock_costs, migrator): + # Valid identifiers + assert migrator._is_safe_identifier("test_db") + assert migrator._is_safe_identifier("my_db123") + assert migrator._is_safe_identifier("db.table") + + # Invalid identifiers + assert not migrator._is_safe_identifier("test-db") + assert not migrator._is_safe_identifier("db;") + assert not migrator._is_safe_identifier("db'name") + assert not migrator._is_safe_identifier("db/*") + + +def test_create_db_sql_validation(mock_costs, migrator): + # Test invalid database name + with pytest.raises(MigrationError, match="Invalid database name"): + migrator._create_db_sql("test;db") + + # Test replicated mode with invalid values + migrator.replicated = True + migrator.replicated_cluster = "test;cluster" + with pytest.raises(MigrationError, match="Invalid cluster name"): + migrator._create_db_sql("test_db") + + migrator.replicated_cluster = "test_cluster" + migrator.replicated_path = "/clickhouse/bad;path/{db}" + with pytest.raises(MigrationError, match="Invalid replicated path"): + migrator._create_db_sql("test_db") + + +def test_create_db_sql_non_replicated(mock_costs, migrator): + # Test non-replicated mode + migrator.replicated = False + sql = migrator._create_db_sql("test_db") + assert sql.strip() == "CREATE DATABASE IF NOT EXISTS test_db" + + +def test_create_db_sql_replicated(mock_costs, migrator): + # Test replicated mode + migrator.replicated = True + migrator.replicated_path = "/clickhouse/tables/{db}" + migrator.replicated_cluster = "test_cluster" + + sql = migrator._create_db_sql("test_db") + expected = """ + CREATE DATABASE IF NOT EXISTS test_db ON CLUSTER test_cluster ENGINE=Replicated('/clickhouse/tables/test_db', '{shard}', '{replica}') + """.strip() + assert sql.strip() == expected + + +def test_format_replicated_sql_non_replicated(mock_costs, migrator): + # Test that SQL is unchanged when replicated=False + migrator.replicated = False + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql + + +def test_format_replicated_sql_replicated(mock_costs, migrator): + # Test that MergeTree engines are converted to Replicated variants + migrator.replicated = True + + test_cases = [ + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedSummingMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedReplacingMergeTree", + ), + # Test with extra whitespace + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + # Test with parameters + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree()", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree()", + ), + ] + + for input_sql, expected_sql in test_cases: + assert migrator._format_replicated_sql(input_sql) == expected_sql + + +def test_format_replicated_sql_non_mergetree(mock_costs, migrator): + # Test that non-MergeTree engines are left unchanged + migrator.replicated = True + + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = Memory", + "CREATE TABLE test (id Int32) ENGINE = Log", + "CREATE TABLE test (id Int32) ENGINE = TinyLog", + # This should not be changed as it's not a complete word match + "CREATE TABLE test (id Int32) ENGINE = MyMergeTreeCustom", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql diff --git a/weave-js/package.json b/weave-js/package.json index cb57125143a..db96be60691 100644 --- a/weave-js/package.json +++ b/weave-js/package.json @@ -141,6 +141,7 @@ "unified": "^10.1.0", "unist-util-visit": "3.1.0", "universal-perf-hooks": "^1.0.1", + "uuid": "^11.0.3", "vega": "^5.24.0", "vega-lite": "5.6.0", "vega-tooltip": "^0.28.0", @@ -192,7 +193,6 @@ "@types/react-virtualized-auto-sizer": "^1.0.0", "@types/safe-json-stringify": "^1.1.2", "@types/styled-components": "^5.1.26", - "@types/uuid": "^9.0.1", "@types/wavesurfer.js": "^2.0.0", "@types/zen-observable": "^0.8.3", "@typescript-eslint/eslint-plugin": "5.35.1", @@ -237,7 +237,6 @@ "tslint-config-prettier": "^1.18.0", "tslint-plugin-prettier": "^2.3.0", "typescript": "4.7.4", - "uuid": "^9.0.0", "vite": "5.2.9", "vitest": "^1.6.0" }, diff --git a/weave-js/src/assets/icons/icon-spiral.svg b/weave-js/src/assets/icons/icon-spiral.svg new file mode 100644 index 00000000000..ce5c147b43a --- /dev/null +++ b/weave-js/src/assets/icons/icon-spiral.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/weave-js/src/assets/icons/icon-visible.svg b/weave-js/src/assets/icons/icon-visible.svg new file mode 100644 index 00000000000..823c36bb17f --- /dev/null +++ b/weave-js/src/assets/icons/icon-visible.svg @@ -0,0 +1,4 @@ + + + + diff --git a/weave-js/src/components/ErrorBoundary.tsx b/weave-js/src/components/ErrorBoundary.tsx index 48b4fcf1a1e..053af08aa97 100644 --- a/weave-js/src/components/ErrorBoundary.tsx +++ b/weave-js/src/components/ErrorBoundary.tsx @@ -1,6 +1,7 @@ import {datadogRum} from '@datadog/browser-rum'; import * as Sentry from '@sentry/react'; import React, {Component, ErrorInfo, ReactNode} from 'react'; +import {v7 as uuidv7} from 'uuid'; import {weaveErrorToDDPayload} from '../errors'; import {ErrorPanel} from './ErrorPanel'; @@ -10,24 +11,31 @@ type Props = { }; type State = { - hasError: boolean; + uuid: string | undefined; + timestamp: Date | undefined; + error: Error | undefined; }; export class ErrorBoundary extends Component { - public static getDerivedStateFromError(_: Error): State { - return {hasError: true}; + public static getDerivedStateFromError(error: Error): State { + return {uuid: uuidv7(), timestamp: new Date(), error}; } public state: State = { - hasError: false, + uuid: undefined, + timestamp: undefined, + error: undefined, }; public componentDidCatch(error: Error, errorInfo: ErrorInfo) { + const {uuid} = this.state; datadogRum.addAction( 'weave_panel_error_boundary', - weaveErrorToDDPayload(error) + weaveErrorToDDPayload(error, undefined, uuid) ); - Sentry.captureException(error, { + extra: { + uuid, + }, tags: { weaveErrorBoundary: 'true', }, @@ -35,8 +43,9 @@ export class ErrorBoundary extends Component { } public render() { - if (this.state.hasError) { - return ; + const {uuid, timestamp, error} = this.state; + if (error != null) { + return ; } return this.props.children; diff --git a/weave-js/src/components/ErrorPanel.tsx b/weave-js/src/components/ErrorPanel.tsx index cbcfd244ac9..86945922150 100644 --- a/weave-js/src/components/ErrorPanel.tsx +++ b/weave-js/src/components/ErrorPanel.tsx @@ -1,7 +1,19 @@ -import React, {forwardRef, useEffect, useRef, useState} from 'react'; +import copyToClipboard from 'copy-to-clipboard'; +import _ from 'lodash'; +import React, { + forwardRef, + useCallback, + useEffect, + useRef, + useState, +} from 'react'; import styled from 'styled-components'; +import {toast} from '../common/components/elements/Toast'; import {hexToRGB, MOON_300, MOON_600} from '../common/css/globals.styles'; +import {useViewerInfo} from '../common/hooks/useViewerInfo'; +import {getCookieBool, getFirebaseCookie} from '../common/util/cookie'; +import {Button} from './Button'; import {Icon} from './Icon'; import {Tooltip} from './Tooltip'; @@ -14,6 +26,11 @@ type ErrorPanelProps = { title?: string; subtitle?: string; subtitle2?: string; + + // These props are for error details object + uuid?: string; + timestamp?: Date; + error?: Error; }; export const Centered = styled.div` @@ -85,11 +102,87 @@ export const ErrorPanelSmall = ({ ); }; +const getDateObject = (timestamp?: Date): Record | null => { + if (!timestamp) { + return null; + } + return { + // e.g. "2024-12-12T06:10:19.475Z", + iso: timestamp.toISOString(), + // e.g. "Thursday, December 12, 2024 at 6:10:19 AM Coordinated Universal Time" + long: timestamp.toLocaleString('en-US', { + weekday: 'long', + year: 'numeric', // Full year + month: 'long', // Full month name + day: 'numeric', // Day of the month + hour: 'numeric', // Hour (12-hour or 24-hour depending on locale) + minute: 'numeric', + second: 'numeric', + timeZone: 'UTC', // Ensures it's in UTC + timeZoneName: 'long', // Full time zone name + }), + user: timestamp.toLocaleString('en-US', { + dateStyle: 'full', + timeStyle: 'full', + }), + }; +}; + +const getErrorObject = (error?: Error): Record | null => { + if (!error) { + return null; + } + + // Error object properties are not enumerable so we have to copy them manually + const stack = (error.stack ?? '').split('\n'); + return { + message: error.message, + stack, + }; +}; + export const ErrorPanelLarge = forwardRef( - ({title, subtitle, subtitle2}, ref) => { + ({title, subtitle, subtitle2, uuid, timestamp, error}, ref) => { const titleStr = title ?? DEFAULT_TITLE; const subtitleStr = subtitle ?? DEFAULT_SUBTITLE; const subtitle2Str = subtitle2 ?? DEFAULT_SUBTITLE2; + + const {userInfo} = useViewerInfo(); + + const onClick = useCallback(() => { + const betaVersion = getFirebaseCookie('betaVersion'); + const isUsingAdminPrivileges = getCookieBool('use_admin_privileges'); + const {location, navigator, screen} = window; + const {userAgent, language} = navigator; + const details = { + uuid, + url: location.href, + error: getErrorObject(error), + timestamp_err: getDateObject(timestamp), + timestamp_copied: getDateObject(new Date()), + user: _.pick(userInfo, ['id', 'username']), // Skipping teams and admin + cookies: { + ...(betaVersion && {betaVersion}), + ...(isUsingAdminPrivileges && {use_admin_privileges: true}), + }, + browser: { + userAgent, + language, + screenSize: { + width: screen.width, + height: screen.height, + }, + viewportSize: { + width: window.innerWidth, + height: window.innerHeight, + }, + }, + }; + const detailsText = JSON.stringify(details, null, 2); + copyToClipboard(detailsText); + toast('Copied to clipboard'); + }, [uuid, timestamp, error, userInfo]); + return ( @@ -98,6 +191,14 @@ export const ErrorPanelLarge = forwardRef( {titleStr} {subtitleStr} {subtitle2Str} + ); } diff --git a/weave-js/src/components/Form/TextField.tsx b/weave-js/src/components/Form/TextField.tsx index c40f697ac85..8f5dd1171ff 100644 --- a/weave-js/src/components/Form/TextField.tsx +++ b/weave-js/src/components/Form/TextField.tsx @@ -37,6 +37,7 @@ type TextFieldProps = { dataTest?: string; step?: number; variant?: 'default' | 'ghost'; + isContainerNightAware?: boolean; }; export const TextField = ({ @@ -59,6 +60,7 @@ export const TextField = ({ autoComplete, dataTest, step, + isContainerNightAware, }: TextFieldProps) => { const textFieldSize = size ?? 'medium'; const leftPaddingForIcon = textFieldSize === 'medium' ? 'pl-34' : 'pl-36'; @@ -83,7 +85,6 @@ export const TextField = ({
( export const IconSortDescending = (props: SVGIconProps) => ( ); +export const IconSpiral = (props: SVGIconProps) => ( + +); export const IconSplit = (props: SVGIconProps) => ( ); @@ -1048,6 +1053,9 @@ export const IconVideoPlay = (props: SVGIconProps) => ( export const IconViewGlasses = (props: SVGIconProps) => ( ); +export const IconVisible = (props: SVGIconProps) => ( + +); export const IconWandb = (props: SVGIconProps) => ( ); @@ -1291,6 +1299,7 @@ const ICON_NAME_TO_ICON: Record = { sort: IconSort, 'sort-ascending': IconSortAscending, 'sort-descending': IconSortDescending, + spiral: IconSpiral, split: IconSplit, square: IconSquare, star: IconStar, @@ -1336,6 +1345,7 @@ const ICON_NAME_TO_ICON: Record = { 'vertex-gcp': IconVertexGCP, 'video-play': IconVideoPlay, 'view-glasses': IconViewGlasses, + visible: IconVisible, wandb: IconWandb, warning: IconWarning, 'warning-alt': IconWarningAlt, diff --git a/weave-js/src/components/Icon/index.ts b/weave-js/src/components/Icon/index.ts index 85ea5332649..39c6eed3170 100644 --- a/weave-js/src/components/Icon/index.ts +++ b/weave-js/src/components/Icon/index.ts @@ -211,6 +211,7 @@ export { IconSort, IconSortAscending, IconSortDescending, + IconSpiral, IconSplit, IconSquare, IconStar, @@ -256,6 +257,7 @@ export { IconVertexGCP, IconVideoPlay, IconViewGlasses, + IconVisible, IconWandb, IconWarning, IconWarningAlt, diff --git a/weave-js/src/components/Icon/types.ts b/weave-js/src/components/Icon/types.ts index 87a1207bc85..47f5f357adc 100644 --- a/weave-js/src/components/Icon/types.ts +++ b/weave-js/src/components/Icon/types.ts @@ -210,6 +210,7 @@ export const IconNames = { Sort: 'sort', SortAscending: 'sort-ascending', SortDescending: 'sort-descending', + Spiral: 'spiral', Split: 'split', Square: 'square', Star: 'star', @@ -255,6 +256,7 @@ export const IconNames = { VertexGCP: 'vertex-gcp', VideoPlay: 'video-play', ViewGlasses: 'view-glasses', + Visible: 'visible', Wandb: 'wandb', Warning: 'warning', WarningAlt: 'warning-alt', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index 3517f4d3b9c..761fd536930 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -1,15 +1,5 @@ import {ApolloProvider} from '@apollo/client'; -import {Home} from '@mui/icons-material'; -import { - AppBar, - Box, - Breadcrumbs, - Drawer, - IconButton, - Link as MaterialLink, - Toolbar, - Typography, -} from '@mui/material'; +import {Box, Drawer} from '@mui/material'; import { GridColumnVisibilityModel, GridFilterModel, @@ -21,9 +11,7 @@ import {LicenseInfo} from '@mui/x-license'; import {makeGorillaApolloClient} from '@wandb/weave/apollo'; import {EVALUATE_OP_NAME_POST_PYDANTIC} from '@wandb/weave/components/PagePanelComponents/Home/Browse3/pages/common/heuristics'; import {opVersionKeyToRefUri} from '@wandb/weave/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/utilities'; -import _ from 'lodash'; import React, { - ComponentProps, FC, useCallback, useEffect, @@ -33,7 +21,6 @@ import React, { } from 'react'; import useMousetrap from 'react-hook-mousetrap'; import { - Link as RouterLink, Redirect, Route, Switch, @@ -199,7 +186,6 @@ export const Browse3: FC<{ `/${URL_BROWSE3}`, ]}> @@ -211,7 +197,6 @@ export const Browse3: FC<{ }; const Browse3Mounted: FC<{ - hideHeader?: boolean; headerOffset?: number; navigateAwayFromProject?: () => void; }> = props => { @@ -225,37 +210,6 @@ const Browse3Mounted: FC<{ overflow: 'auto', flexDirection: 'column', }}> - {!props.hideHeader && ( - theme.zIndex.drawer + 1, - height: '60px', - flex: '0 0 auto', - position: 'static', - }}> - - - theme.palette.getContrastText(theme.palette.primary.main), - '&:hover': { - color: theme => - theme.palette.getContrastText(theme.palette.primary.dark), - }, - marginRight: theme => theme.spacing(2), - }}> - - - - - - )} @@ -1050,20 +1004,6 @@ const ComparePageBinding = () => { return ; }; -const AppBarLink = (props: ComponentProps) => ( - theme.palette.getContrastText(theme.palette.primary.main), - '&:hover': { - color: theme => - theme.palette.getContrastText(theme.palette.primary.dark), - }, - }} - {...props} - component={RouterLink} - /> -); - const PlaygroundPageBinding = () => { const params = useParamsDecoded(); return ( @@ -1074,79 +1014,3 @@ const PlaygroundPageBinding = () => { /> ); }; - -const Browse3Breadcrumbs: FC = props => { - const params = useParamsDecoded(); - const query = useURLSearchParamsDict(); - const filePathParts = query.path?.split('/') ?? []; - const refFields = query.extra?.split('/') ?? []; - - return ( - - {params.entity && ( - - {params.entity} - - )} - {params.project && ( - - {params.project} - - )} - {params.tab && ( - - {params.tab} - - )} - {params.itemName && ( - - {params.itemName} - - )} - {params.version && ( - - {params.version} - - )} - {filePathParts.map((part, idx) => ( - - {part} - - ))} - {_.range(0, refFields.length, 2).map(idx => ( - - - theme.palette.getContrastText(theme.palette.primary.main), - }}> - {refFields[idx]} - - - {refFields[idx + 1]} - - - ))} - - ); -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx index 7facffe9556..21dedf10ea0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx @@ -53,17 +53,11 @@ export const HumanAnnotationCell: React.FC = props => { const feedbackSpecRef = props.hfSpec.ref; useEffect(() => { - if (!props.readOnly) { - // We don't need to listen for feedback changes if the cell is editable - // it is being controlled by local state - return; - } return getTsClient().registerOnFeedbackListener( props.callRef, query.refetch ); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [props.callRef]); + }, [props.callRef, query.refetch, getTsClient]); useEffect(() => { if (foundFeedbackCallRef && props.callRef !== foundFeedbackCallRef) { @@ -183,13 +177,22 @@ const FeedbackComponentSelector: React.FC<{ }) => { const wrappedOnAddFeedback = useCallback( async (value: any) => { + if (value == null || value === foundValue || value === '') { + // Remove from unsaved changes if value is invalid + setUnsavedFeedbackChanges(curr => { + const rest = {...curr}; + delete rest[feedbackSpecRef]; + return rest; + }); + return true; + } setUnsavedFeedbackChanges(curr => ({ ...curr, [feedbackSpecRef]: () => onAddFeedback(value), })); return true; }, - [onAddFeedback, setUnsavedFeedbackChanges, feedbackSpecRef] + [onAddFeedback, setUnsavedFeedbackChanges, feedbackSpecRef, foundValue] ); switch (type) { @@ -346,21 +349,10 @@ export const NumericalFeedbackColumn = ({ focused?: boolean; isInteger?: boolean; }) => { - const debouncedFn = useMemo( - () => - _.debounce((val: number | null) => onAddFeedback?.(val), DEBOUNCE_VAL), - [onAddFeedback] - ); - useEffect(() => { - return () => { - debouncedFn.cancel(); - }; - }, [debouncedFn]); - return ( onAddFeedback?.(value)} min={min} max={max} isInteger={isInteger} @@ -415,7 +407,7 @@ export const TextFeedbackColumn = ({ placeholder="" /> {maxLength && ( -
+
{`Maximum characters: ${maxLength}`}
)} @@ -446,7 +438,7 @@ export const EnumFeedbackColumn = ({ })); return opts; }, [options]); - const [value, setValue] = useState
+ + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx index c22df7c63d7..135d297539d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx @@ -1,9 +1,10 @@ import React, {useState} from 'react'; +import {usePlaygroundContext} from '../PlaygroundPage/PlaygroundContext'; +import {ChoicesDrawer} from './ChoicesDrawer'; import {ChoicesViewCarousel} from './ChoicesViewCarousel'; -import {ChoicesViewLinear} from './ChoicesViewLinear'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewProps = { choices: Choice[]; @@ -14,32 +15,45 @@ export const ChoicesView = ({ choices, isStructuredOutput, }: ChoicesViewProps) => { - const [mode, setMode] = useState('linear'); + const {setSelectedChoiceIndex: setGlobalSelectedChoiceIndex} = + usePlaygroundContext(); + + const [isDrawerOpen, setIsDrawerOpen] = useState(false); + const [localSelectedChoiceIndex, setLocalSelectedChoiceIndex] = useState(0); + + const handleSetSelectedChoiceIndex = (choiceIndex: number) => { + setLocalSelectedChoiceIndex(choiceIndex); + setGlobalSelectedChoiceIndex(choiceIndex); + }; if (choices.length === 0) { return null; } if (choices.length === 1) { return ( - + ); } return ( <> - {mode === 'linear' && ( - - )} - {mode === 'carousel' && ( - - )} + + ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx index f4a52fc6801..b7dc6eb427d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx @@ -1,62 +1,67 @@ -import React, {useState} from 'react'; +import React from 'react'; import {Button} from '../../../../../Button'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewCarouselProps = { choices: Choice[]; isStructuredOutput?: boolean; - setMode: React.Dispatch>; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; }; export const ChoicesViewCarousel = ({ choices, isStructuredOutput, - setMode, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, }: ChoicesViewCarouselProps) => { - const [step, setStep] = useState(0); - const onNext = () => { - setStep((step + 1) % choices.length); + setSelectedChoiceIndex((selectedChoiceIndex + 1) % choices.length); }; const onBack = () => { - const newStep = step === 0 ? choices.length - 1 : step - 1; - setStep(newStep); + const newStep = + selectedChoiceIndex === 0 ? choices.length - 1 : selectedChoiceIndex - 1; + setSelectedChoiceIndex(newStep); }; return ( - <> - -
-
+ +
+
-
-
-
- + } + /> ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewLinear.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewLinear.tsx deleted file mode 100644 index 92668c6504e..00000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewLinear.tsx +++ /dev/null @@ -1,73 +0,0 @@ -import React from 'react'; - -import {Button} from '../../../../../Button'; -import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; - -type ChoicesViewLinearProps = { - choices: Choice[]; - isStructuredOutput?: boolean; - setMode: React.Dispatch>; -}; - -export const ChoicesViewLinear = ({ - choices, - isStructuredOutput, - setMode, -}: ChoicesViewLinearProps) => { - return ( -
- {choices.map(c => ( -
-
-
- {c.index !== 0 && ( -
-
- -
- ))} -
- ); -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx index cc1911b60d4..8cec95707fa 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx @@ -15,21 +15,23 @@ type MessagePanelProps = { index: number; message: Message; isStructuredOutput?: boolean; - isChoice?: boolean; + choiceIndex?: number; isNested?: boolean; pendingToolResponseId?: string; + messageHeader?: React.ReactNode; }; export const MessagePanel = ({ index, message, isStructuredOutput, - isChoice, + choiceIndex, isNested, // The id of the tool call response that is pending // If the tool call response is pending, the editor will be shown automatically // and on save the tool call response will be updated and sent to the LLM pendingToolResponseId, + messageHeader, }: MessagePanelProps) => { const [isShowingMore, setIsShowingMore] = useState(false); const [isOverflowing, setIsOverflowing] = useState(false); @@ -45,6 +47,13 @@ export const MessagePanel = ({ } }, [message.content, contentRef?.current?.scrollHeight]); + // Set isShowingMore to true when editor is opened + useEffect(() => { + if (editorHeight !== null) { + setIsShowingMore(true); + } + }, [editorHeight]); + const isUser = message.role === 'user'; const isSystemPrompt = message.role === 'system'; const isTool = message.role === 'tool'; @@ -116,11 +125,12 @@ export const MessagePanel = ({ 'max-h-[400px]': !isShowingMore, 'max-h-full': isShowingMore, })}> + {messageHeader} {isPlayground && editorHeight ? ( ; @@ -18,7 +18,7 @@ export const PlaygroundMessagePanelButtons: React.FC< PlaygroundMessagePanelButtonsProps > = ({ index, - isChoice, + choiceIndex, isTool, hasContent, contentRef, @@ -53,7 +53,7 @@ export const PlaygroundMessagePanelButtons: React.FC< variant="quiet" size="small" startIcon="randomize-reset-reload" - onClick={() => retry?.(index, isChoice)} + onClick={() => retry?.(index, choiceIndex)} tooltip={ !hasContent ? 'We currently do not support retrying functions' @@ -85,8 +85,8 @@ export const PlaygroundMessagePanelButtons: React.FC< size="small" startIcon="delete" onClick={() => { - if (isChoice) { - deleteChoice?.(index); + if (choiceIndex !== undefined) { + deleteChoice?.(index, choiceIndex); } else { deleteMessage?.(index, responseIndexes); } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx index 746b033579a..d643f103481 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelEditor.tsx @@ -13,7 +13,7 @@ type PlaygroundMessagePanelEditorProps = { pendingToolResponseId?: string; message: Message; index: number; - isChoice: boolean; + choiceIndex?: number; setEditorHeight: (height: number | null) => void; }; @@ -21,7 +21,7 @@ export const PlaygroundMessagePanelEditor: React.FC< PlaygroundMessagePanelEditorProps > = ({ index, - isChoice, + choiceIndex, setEditorHeight, editorHeight, isNested, @@ -45,10 +45,10 @@ export const PlaygroundMessagePanelEditor: React.FC< }, [initialContent]); const handleSave = () => { - if (isChoice) { - editChoice?.(index, { + if (choiceIndex !== undefined) { + editChoice?.(choiceIndex, { + ...message, content: editedContent, - role: message.role, }); } else { editMessage?.(index, { @@ -68,13 +68,12 @@ export const PlaygroundMessagePanelEditor: React.FC<
setEditedContent(e.target.value)} - autoGrow - maxHeight={160} + startHeight={320} /> {/* 6px vs. 8px to make up for extra padding from textarea field */}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts index b5696055712..3bbe65a5baf 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts @@ -113,5 +113,3 @@ export type Chat = { request: ChatRequest | null; result: ChatCompletion | null; }; - -export type ChoicesMode = 'linear' | 'carousel'; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx index 94cd3c17644..b6b6e7c420d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx @@ -161,8 +161,8 @@ export const PlaygroundChat = ({ addMessage: newMessage => addMessage(idx, newMessage), editChoice: (choiceIndex, newChoice) => editChoice(idx, choiceIndex, newChoice), - retry: (messageIndex: number, isChoice?: boolean) => - handleRetry(idx, messageIndex, isChoice), + retry: (messageIndex: number, choiceIndex?: number) => + handleRetry(idx, messageIndex, choiceIndex), sendMessage: ( role: PlaygroundMessageRole, content: string, @@ -170,6 +170,12 @@ export const PlaygroundChat = ({ ) => { handleSend(role, idx, content, toolCallId); }, + setSelectedChoiceIndex: (choiceIndex: number) => + setPlaygroundStateField( + idx, + 'selectedChoiceIndex', + choiceIndex + ), }}> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx index 10c76fc82d1..c73e5d42919 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatCompletionFunctions.tsx @@ -70,7 +70,7 @@ export const useChatCompletionFunctions = ( if (callIndex !== undefined && callIndex !== index) { return state; } - const updatedState = appendChoicesToMessages(state); + const updatedState = appendChoiceToMessages(state); if (updatedState.traceCall?.inputs?.messages) { updatedState.traceCall.inputs.messages.push(newMessage); } @@ -99,14 +99,14 @@ export const useChatCompletionFunctions = ( const handleRetry = async ( callIndex: number, messageIndex: number, - isChoice?: boolean + choiceIndex?: number ) => { try { setIsLoading(true); const updatedStates = playgroundStates.map((state, index) => { if (index === callIndex) { - if (isChoice) { - return appendChoicesToMessages(state); + if (choiceIndex !== undefined) { + return appendChoiceToMessages(state, choiceIndex); } const updatedState = JSON.parse(JSON.stringify(state)); if (updatedState.traceCall?.inputs?.messages) { @@ -203,17 +203,25 @@ const handleUpdateCallWithResponse = ( }; }; -const appendChoicesToMessages = (state: PlaygroundState): PlaygroundState => { +const appendChoiceToMessages = ( + state: PlaygroundState, + choiceIndex?: number +): PlaygroundState => { const updatedState = JSON.parse(JSON.stringify(state)); if ( updatedState.traceCall?.inputs?.messages && updatedState.traceCall.output?.choices ) { - updatedState.traceCall.output.choices.forEach((choice: any) => { - if (choice.message) { - updatedState.traceCall.inputs.messages.push(choice.message); - } - }); + if (choiceIndex !== undefined) { + updatedState.traceCall.inputs.messages.push( + updatedState.traceCall.output.choices[choiceIndex].message + ); + } else { + updatedState.traceCall.inputs.messages.push( + updatedState.traceCall.output.choices[updatedState.selectedChoiceIndex] + .message + ); + } updatedState.traceCall.output.choices = undefined; } return updatedState; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx index e84e2f75d4b..804670a1dc3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx @@ -114,16 +114,9 @@ export const useChatFunctions = ( newTraceCall?.output && Array.isArray((newTraceCall.output as TraceCallOutput).choices) ) { - // Delete the old choice - (newTraceCall.output as TraceCallOutput).choices!.splice( - choiceIndex, - 1 - ); - - // Add the new choice as a message - newTraceCall.inputs = newTraceCall.inputs ?? {}; - newTraceCall.inputs.messages = newTraceCall.inputs.messages ?? []; - newTraceCall.inputs.messages.push(newChoice); + // Replace the choice + (newTraceCall.output as TraceCallOutput).choices![choiceIndex].message = + newChoice; } return newTraceCall; }); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx index a8176292d1a..31369602560 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx @@ -10,14 +10,16 @@ export type PlaygroundContextType = { deleteMessage: (messageIndex: number, responseIndexes?: number[]) => void; editChoice: (choiceIndex: number, newChoice: Message) => void; - deleteChoice: (choiceIndex: number) => void; + deleteChoice: (messageIndex: number, choiceIndex: number) => void; - retry: (messageIndex: number, isChoice?: boolean) => void; + retry: (messageIndex: number, choiceIndex?: number) => void; sendMessage: ( role: PlaygroundMessageRole, content: string, toolCallId?: string ) => void; + + setSelectedChoiceIndex: (choiceIndex: number) => void; }; const DEFAULT_CONTEXT: PlaygroundContextType = { @@ -31,6 +33,7 @@ const DEFAULT_CONTEXT: PlaygroundContextType = { retry: () => {}, sendMessage: () => {}, + setSelectedChoiceIndex: () => {}, }; // Create context that can be undefined diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx index c6232631e4e..76d1c6d9e31 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx @@ -7,7 +7,11 @@ import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; import {useWFHooks} from '../wfReactInterface/context'; import {PlaygroundChat} from './PlaygroundChat/PlaygroundChat'; import {PlaygroundSettings} from './PlaygroundSettings/PlaygroundSettings'; -import {DEFAULT_SYSTEM_MESSAGE, usePlaygroundState} from './usePlaygroundState'; +import { + DEFAULT_SYSTEM_MESSAGE, + parseTraceCall, + usePlaygroundState, +} from './usePlaygroundState'; export type PlaygroundPageProps = { entity: string; @@ -89,7 +93,10 @@ export const PlaygroundPageInner = (props: PlaygroundPageProps) => { for (const [idx, state] of newStates.entries()) { for (const c of calls || []) { if (state.traceCall.id === c.callId) { - newStates[idx] = {...state, traceCall: c.traceCall || {}}; + newStates[idx] = { + ...state, + traceCall: parseTraceCall(c.traceCall || {}), + }; break; } } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx index f0106505518..96e8ef02627 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx @@ -62,6 +62,12 @@ export const PlaygroundSettings: React.FC = ({ gap: '4px', mt: 2, }}> + + setPlaygroundStateField(idx, 'responseFormat', value) + } + /> = ({ } /> - - setPlaygroundStateField(idx, 'responseFormat', value) + + setPlaygroundStateField(idx, 'stopSequences', value) } /> + {/* TODO: N times to run is not supported for all models */} + {/* TODO: rerun if this is not supported in the backend */} - setPlaygroundStateField(idx, 'temperature', value) + setPlaygroundStateField(idx, 'nTimes', value) } - label="Temperature" - value={playgroundState.temperature} + label="Completion iterations" + value={playgroundState.nTimes} /> = ({ value={playgroundState.maxTokens} /> - - setPlaygroundStateField(idx, 'stopSequences', value) + + setPlaygroundStateField(idx, 'temperature', value) } + label="Temperature" + value={playgroundState.temperature} /> = ({ label="Presence penalty" value={playgroundState.presencePenalty} /> + & { autoGrow?: boolean; maxHeight?: string | number; + startHeight?: string | number; reset?: boolean; }; export const StyledTextArea = forwardRef( - ({className, autoGrow, maxHeight, reset, ...props}, ref) => { + ({className, autoGrow, maxHeight, startHeight, reset, ...props}, ref) => { const textareaRef = React.useRef(null); React.useEffect(() => { @@ -26,11 +27,22 @@ export const StyledTextArea = forwardRef( return; } - // Disable resize when autoGrow is true - textareaElement.style.resize = 'none'; + // Only disable resize when autoGrow is true + textareaElement.style.resize = autoGrow ? 'none' : 'vertical'; + + // Set initial height if provided + if (startHeight && textareaElement.value === '') { + textareaElement.style.height = + typeof startHeight === 'number' ? `${startHeight}px` : startHeight; + return; + } if (reset || textareaElement.value === '') { - textareaElement.style.height = 'auto'; + textareaElement.style.height = startHeight + ? typeof startHeight === 'number' + ? `${startHeight}px` + : startHeight + : 'auto'; return; } @@ -63,7 +75,7 @@ export const StyledTextArea = forwardRef( return () => textareaRefElement.removeEventListener('input', adjustHeight); - }, [autoGrow, maxHeight, reset]); + }, [autoGrow, maxHeight, reset, startHeight]); return ( @@ -86,6 +98,7 @@ export const StyledTextArea = forwardRef( 'focus:outline-none', 'relative bottom-0 top-0 items-center rounded-sm', 'outline outline-1 outline-moon-250', + !autoGrow && 'resize-y', props.disabled ? 'opacity-50' : 'hover:outline hover:outline-2 hover:outline-teal-500/40 focus:outline-2', @@ -94,6 +107,14 @@ export const StyledTextArea = forwardRef( 'placeholder-moon-500 dark:placeholder-moon-600', className )} + style={{ + height: startHeight + ? typeof startHeight === 'number' + ? `${startHeight}px` + : startHeight + : undefined, + ...props.style, + }} {...props} /> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts index fa73e87bf45..dc0e0a370e0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts @@ -20,9 +20,10 @@ export type PlaygroundState = { topP: number; frequencyPenalty: number; presencePenalty: number; - // nTimes: number; + nTimes: number; maxTokensLimit: number; model: LLMMaxTokensKey; + selectedChoiceIndex: number; }; export type PlaygroundStateKey = keyof PlaygroundState; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts index 8c556edaef2..cbcc7c52fb8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts @@ -1,5 +1,11 @@ +import {cloneDeep} from 'lodash'; import {SetStateAction, useCallback, useState} from 'react'; +import { + anthropicContentBlocksToChoices, + hasStringProp, + isAnthropicCompletionFormat, +} from '../ChatView/hooks'; import {LLM_MAX_TOKENS_KEYS, LLMMaxTokensKey} from './llmMaxTokens'; import { OptionalTraceCallSchema, @@ -34,9 +40,10 @@ const DEFAULT_PLAYGROUND_STATE = { topP: 1, frequencyPenalty: 0, presencePenalty: 0, - // nTimes: 1, + nTimes: 1, maxTokensLimit: 16384, model: DEFAULT_MODEL, + selectedChoiceIndex: 0, }; export const usePlaygroundState = () => { @@ -76,7 +83,7 @@ export const usePlaygroundState = () => { setPlaygroundStates(prevState => { const newState = {...prevState[0]}; - newState.traceCall = traceCall; + newState.traceCall = parseTraceCall(traceCall); if (!inputs) { return [newState]; @@ -90,9 +97,9 @@ export const usePlaygroundState = () => { } } } - // if (inputs.n) { - // newState.nTimes = parseInt(inputs.n, 10); - // } + if (inputs.n) { + newState.nTimes = parseInt(inputs.n, 10); + } if (inputs.temperature) { newState.temperature = parseFloat(inputs.temperature); } @@ -147,10 +154,42 @@ export const getInputFromPlaygroundState = (state: PlaygroundState) => { top_p: state.topP, frequency_penalty: state.frequencyPenalty, presence_penalty: state.presencePenalty, - // n: state.nTimes, + n: state.nTimes, response_format: { type: state.responseFormat, }, tools: tools.length > 0 ? tools : undefined, }; }; + +// This is a helper function to parse the trace call output for anthropic +// so that the playground can display the choices +export const parseTraceCall = (traceCall: OptionalTraceCallSchema) => { + const parsedTraceCall = cloneDeep(traceCall); + + // Handles anthropic outputs + // Anthropic has content and stop_reason as top-level fields + if (isAnthropicCompletionFormat(parsedTraceCall.output)) { + const {content, stop_reason, ...outputs} = parsedTraceCall.output as any; + parsedTraceCall.output = { + ...outputs, + choices: anthropicContentBlocksToChoices(content, stop_reason), + }; + } + // Handles anthropic inputs + // Anthropic has system message as a top-level request field + if (hasStringProp(parsedTraceCall.inputs, 'system')) { + const {messages, system, ...inputs} = parsedTraceCall.inputs as any; + parsedTraceCall.inputs = { + ...inputs, + messages: [ + { + role: 'system', + content: system, + }, + ...messages, + ], + }; + } + return parsedTraceCall; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx index 9acbdfe6c2f..a478437facb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx @@ -1,5 +1,5 @@ import {Box} from '@material-ui/core'; -import React, {FC, useCallback, useState} from 'react'; +import React, {FC, useCallback, useEffect, useState} from 'react'; import {z} from 'zod'; import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; @@ -28,7 +28,7 @@ const AnnotationScorerFormSchema = z.object({ }), z.object({ type: z.literal('String'), - 'Max length': z.number().optional(), + 'Maximum length': z.number().optional(), }), z.object({ type: z.literal('Select'), @@ -45,6 +45,9 @@ export const AnnotationScorerForm: FC< ScorerFormProps> > = ({data, onDataChange}) => { const [config, setConfig] = useState(data ?? DEFAULT_STATE); + useEffect(() => { + setConfig(data ?? DEFAULT_STATE); + }, [data]); const [isValid, setIsValid] = useState(false); const handleConfigChange = useCallback( @@ -113,7 +116,7 @@ function convertTypeExtrasToJsonSchema( const typeSchema = obj.Type; const typeExtras: Record = {}; if (typeSchema.type === 'String') { - typeExtras.maxLength = typeSchema['Max length']; + typeExtras.maxLength = typeSchema['Maximum length']; } else if (typeSchema.type === 'Integer' || typeSchema.type === 'Number') { typeExtras.minimum = typeSchema.Minimum; typeExtras.maximum = typeSchema.Maximum; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx index 2716bfbfa81..250c896cfea 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx @@ -3,7 +3,7 @@ import {Select} from '@wandb/weave/components/Form/Select'; import {TextField} from '@wandb/weave/components/Form/TextField'; import React from 'react'; -export const GAP_BETWEEN_ITEMS_PX = 10; +export const GAP_BETWEEN_ITEMS_PX = 16; export const GAP_BETWEEN_LABEL_AND_FIELD_PX = 10; type AutocompleteWithLabelType - ))} - + ); }; @@ -758,6 +779,7 @@ const LiteralField: React.FC<{ setConfig, }) => { const literalValue = unwrappedSchema.value; + const isOptional = fieldSchema instanceof z.ZodOptional; useEffect(() => { if (value !== literalValue) { @@ -765,7 +787,14 @@ const LiteralField: React.FC<{ } }, [value, literalValue, targetPath, config, setConfig]); - return ; + return ( + + ); }; const BooleanField: React.FC<{ diff --git a/weave-js/src/errors.ts b/weave-js/src/errors.ts index daa3e0c97a0..f2582baade8 100644 --- a/weave-js/src/errors.ts +++ b/weave-js/src/errors.ts @@ -37,7 +37,8 @@ type DDErrorPayload = { export const weaveErrorToDDPayload = ( error: Error, - weave?: WeaveApp + weave?: WeaveApp, + uuid?: string ): DDErrorPayload => { try { return { @@ -49,6 +50,7 @@ export const weaveErrorToDDPayload = ( windowLocationURL: trimString(window.location.href), weaveContext: weave?.client.debugMeta(), isServerError: error instanceof UseNodeValueServerExecutionError, + ...(uuid != null && {uuid}), }; } catch (e) { // If we fail to serialize the error, just return an empty object. diff --git a/weave-js/yarn.lock b/weave-js/yarn.lock index c7f9379e32a..6a5ec14e872 100644 --- a/weave-js/yarn.lock +++ b/weave-js/yarn.lock @@ -4776,11 +4776,6 @@ resolved "https://registry.yarnpkg.com/@types/unist/-/unist-2.0.7.tgz#5b06ad6894b236a1d2bd6b2f07850ca5c59cf4d6" integrity sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g== -"@types/uuid@^9.0.1": - version "9.0.2" - resolved "https://registry.yarnpkg.com/@types/uuid/-/uuid-9.0.2.tgz#ede1d1b1e451548d44919dc226253e32a6952c4b" - integrity sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ== - "@types/wavesurfer.js@^2.0.0": version "2.0.2" resolved "https://registry.yarnpkg.com/@types/wavesurfer.js/-/wavesurfer.js-2.0.2.tgz#b98a4d57ca24ee2028ae6dd5c2208b568bb73842" @@ -15032,6 +15027,11 @@ util-deprecate@^1.0.1, util-deprecate@^1.0.2, util-deprecate@~1.0.1: resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== +uuid@^11.0.3: + version "11.0.3" + resolved "https://registry.yarnpkg.com/uuid/-/uuid-11.0.3.tgz#248451cac9d1a4a4128033e765d137e2b2c49a3d" + integrity sha512-d0z310fCWv5dJwnX1Y/MncBAqGMKEzlBb1AOf7z9K8ALnd0utBX/msg/fA0+sbyN1ihbMsLhrBlnl1ak7Wa0rg== + uuid@^2.0.2: version "2.0.3" resolved "https://registry.yarnpkg.com/uuid/-/uuid-2.0.3.tgz#67e2e863797215530dff318e5bf9dcebfd47b21a" @@ -15042,11 +15042,6 @@ uuid@^3.0.0, uuid@^3.4.0: resolved "https://registry.yarnpkg.com/uuid/-/uuid-3.4.0.tgz#b23e4358afa8a202fe7a100af1f5f883f02007ee" integrity sha512-HjSDRw6gZE5JMggctHBcjVak08+KEVhSIiDzFnT9S9aegmp85S/bReBVTb4QTFaRNptJ9kuYaNhnbNEOkbKb/A== -uuid@^9.0.0: - version "9.0.0" - resolved "https://registry.yarnpkg.com/uuid/-/uuid-9.0.0.tgz#592f550650024a38ceb0c562f2f6aa435761efb5" - integrity sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg== - uvu@^0.5.0: version "0.5.6" resolved "https://registry.yarnpkg.com/uvu/-/uvu-0.5.6.tgz#2754ca20bcb0bb59b64e9985e84d2e81058502df" diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py index 7814700d4d3..a1e3a9b5831 100644 --- a/weave/integrations/openai/openai_sdk.py +++ b/weave/integrations/openai/openai_sdk.py @@ -1,15 +1,20 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op import Op, ProcessedInputs from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from openai.types.chat import ChatCompletionChunk +_openai_patcher: MultiPatcher | None = None + def maybe_unwrap_api_response(value: Any) -> Any: """If the caller requests a raw response, we unwrap the APIResponse object. @@ -43,9 +48,7 @@ def maybe_unwrap_api_response(value: Any) -> Any: return value -def openai_on_finish_post_processor( - value: Optional["ChatCompletionChunk"], -) -> Optional[dict]: +def openai_on_finish_post_processor(value: ChatCompletionChunk | None) -> dict | None: from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaFunctionCall, @@ -60,8 +63,8 @@ def openai_on_finish_post_processor( value = maybe_unwrap_api_response(value) def _get_function_call( - function_call: Optional[ChoiceDeltaFunctionCall], - ) -> Optional[FunctionCall]: + function_call: ChoiceDeltaFunctionCall | None, + ) -> FunctionCall | None: if function_call is None: return function_call if isinstance(function_call, ChoiceDeltaFunctionCall): @@ -73,8 +76,8 @@ def _get_function_call( return None def _get_tool_calls( - tool_calls: Optional[list[ChoiceDeltaToolCall]], - ) -> Optional[list[ChatCompletionMessageToolCall]]: + tool_calls: list[ChoiceDeltaToolCall] | None, + ) -> list[ChatCompletionMessageToolCall] | None: if tool_calls is None: return tool_calls @@ -128,10 +131,10 @@ def _get_tool_calls( def openai_accumulator( - acc: Optional["ChatCompletionChunk"], - value: "ChatCompletionChunk", + acc: ChatCompletionChunk | None, + value: ChatCompletionChunk, skip_last: bool = False, -) -> "ChatCompletionChunk": +) -> ChatCompletionChunk: from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaFunctionCall, @@ -285,7 +288,7 @@ def should_use_accumulator(inputs: dict) -> bool: def openai_on_input_handler( func: Op, args: tuple, kwargs: dict -) -> Optional[ProcessedInputs]: +) -> ProcessedInputs | None: if len(args) == 2 and isinstance(args[1], weave.EasyPrompt): original_args = args original_kwargs = kwargs @@ -305,20 +308,16 @@ def openai_on_input_handler( return None -def create_wrapper_sync( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" def _add_stream_options(fn: Callable) -> Callable: @wraps(fn) def _wrapper(*args: Any, **kwargs: Any) -> Any: - if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None: + if kwargs.get("stream") and kwargs.get("stream_options") is None: kwargs["stream_options"] = {"include_usage": True} - return fn( - *args, **kwargs - ) # This is where the final execution of fn is happening. + return fn(*args, **kwargs) return _wrapper @@ -327,8 +326,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return True return False - op = weave.op()(_add_stream_options(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_add_stream_options(fn), **op_kwargs) op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore @@ -345,16 +344,14 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: # Surprisingly, the async `client.chat.completions.create` does not pass # `inspect.iscoroutinefunction`, so we can't dispatch on it and must write # it manually here... -def create_wrapper_async( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" def _add_stream_options(fn: Callable) -> Callable: @wraps(fn) async def _wrapper(*args: Any, **kwargs: Any) -> Any: - if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None: + if kwargs.get("stream") and kwargs.get("stream_options") is None: kwargs["stream_options"] = {"include_usage": True} return await fn(*args, **kwargs) @@ -365,8 +362,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return True return False - op = weave.op()(_add_stream_options(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_add_stream_options(fn), **op_kwargs) op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore @@ -380,28 +377,61 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return wrapper -symbol_patchers = [ - # Patch the Completions.create method - SymbolPatcher( - lambda: importlib.import_module("openai.resources.chat.completions"), - "Completions.create", - create_wrapper_sync(name="openai.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.chat.completions"), - "AsyncCompletions.create", - create_wrapper_async(name="openai.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.beta.chat.completions"), - "Completions.parse", - create_wrapper_sync(name="openai.beta.chat.completions.parse"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.beta.chat.completions"), - "AsyncCompletions.parse", - create_wrapper_async(name="openai.beta.chat.completions.parse"), - ), -] - -openai_patcher = MultiPatcher(symbol_patchers) # type: ignore +def get_openai_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _openai_patcher + if _openai_patcher is not None: + return _openai_patcher + + base = settings.op_settings + + completions_create_settings = base.model_copy( + update={"name": base.name or "openai.chat.completions.create"} + ) + async_completions_create_settings = base.model_copy( + update={"name": base.name or "openai.chat.completions.create"} + ) + completions_parse_settings = base.model_copy( + update={"name": base.name or "openai.beta.chat.completions.parse"} + ) + async_completions_parse_settings = base.model_copy( + update={"name": base.name or "openai.beta.chat.completions.parse"} + ) + + _openai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("openai.resources.chat.completions"), + "Completions.create", + create_wrapper_sync(settings=completions_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("openai.resources.chat.completions"), + "AsyncCompletions.create", + create_wrapper_async(settings=async_completions_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "openai.resources.beta.chat.completions" + ), + "Completions.parse", + create_wrapper_sync(settings=completions_parse_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "openai.resources.beta.chat.completions" + ), + "AsyncCompletions.parse", + create_wrapper_async(settings=async_completions_parse_settings), + ), + ] + ) + + return _openai_patcher diff --git a/weave/scorers/base_scorer.py b/weave/scorers/base_scorer.py index 5a19adcd04f..4ac27f1a76b 100644 --- a/weave/scorers/base_scorer.py +++ b/weave/scorers/base_scorer.py @@ -1,4 +1,5 @@ import inspect +import textwrap from collections.abc import Sequence from numbers import Number from typing import Any, Callable, Optional, Union @@ -45,7 +46,13 @@ def _validate_scorer_signature(scorer: Union[Callable, Op, Scorer]) -> bool: params = inspect.signature(scorer).parameters if "output" in params and "model_output" in params: raise ValueError( - "Both 'output' and 'model_output' cannot be in the scorer signature; prefer just using `output`." + textwrap.dedent( + """ + The scorer signature cannot include both `output` and `model_output` at the same time. + + To resolve, rename one of the arguments to avoid conflict. Prefer using `output` as the model's output. + """ + ) ) return True diff --git a/weave/scorers/llm_utils.py b/weave/scorers/llm_utils.py index 68ae2ccb366..eef6f018b0f 100644 --- a/weave/scorers/llm_utils.py +++ b/weave/scorers/llm_utils.py @@ -2,10 +2,6 @@ from typing import TYPE_CHECKING, Any, Union -from weave.trace.autopatch import autopatch - -autopatch() # ensure both weave patching and instructor patching are applied - OPENAI_DEFAULT_MODEL = "gpt-4o" OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" OPENAI_DEFAULT_MODERATION_MODEL = "text-moderation-latest" diff --git a/weave/trace/api.py b/weave/trace/api.py index ee8131b0875..294308cbb67 100644 --- a/weave/trace/api.py +++ b/weave/trace/api.py @@ -13,6 +13,7 @@ # There is probably a better place for this, but including here for now to get the fix in. from weave import type_handlers # noqa: F401 from weave.trace import urls, util, weave_client, weave_init +from weave.trace.autopatch import AutopatchSettings from weave.trace.constants import TRACE_OBJECT_EMOJI from weave.trace.context import call_context from weave.trace.context import weave_client_context as weave_client_context @@ -32,6 +33,7 @@ def init( project_name: str, *, settings: UserSettings | dict[str, Any] | None = None, + autopatch_settings: AutopatchSettings | None = None, ) -> weave_client.WeaveClient: """Initialize weave tracking, logging to a wandb project. @@ -52,7 +54,12 @@ def init( if should_disable_weave(): return weave_init.init_weave_disabled().client - return weave_init.init_weave(project_name).client + initialized_client = weave_init.init_weave( + project_name, + autopatch_settings=autopatch_settings, + ) + + return initialized_client.client @contextlib.contextmanager diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index 3a5dca14556..0619194a224 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -4,8 +4,54 @@ check if libraries are installed and imported and patch in the case that they are. """ +from typing import Any, Callable, Optional, Union -def autopatch() -> None: +from pydantic import BaseModel, Field, validate_call + +from weave.trace.weave_client import Call + + +class OpSettings(BaseModel): + """Op settings for a specific integration. + These currently subset the `op` decorator args to provide a consistent interface + when working with auto-patched functions. See the `op` decorator for more details.""" + + name: Optional[str] = None + call_display_name: Optional[Union[str, Callable[[Call], str]]] = None + postprocess_inputs: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None + postprocess_output: Optional[Callable[[Any], Any]] = None + + +class IntegrationSettings(BaseModel): + """Configuration for a specific integration.""" + + enabled: bool = True + op_settings: OpSettings = Field(default_factory=OpSettings) + + +class AutopatchSettings(BaseModel): + """Settings for auto-patching integrations.""" + + # These will be uncommented as we add support for more integrations. Note that + + # anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) + # cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) + # cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) + # dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) + # google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) + # groq: IntegrationSettings = Field(default_factory=IntegrationSettings) + # instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) + # langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) + # litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) + # llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) + # mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) + # notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) + openai: IntegrationSettings = Field(default_factory=IntegrationSettings) + # vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) + + +@validate_call +def autopatch(settings: Optional[AutopatchSettings] = None) -> None: from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher from weave.integrations.cohere.cohere_sdk import cohere_patcher @@ -20,10 +66,13 @@ def autopatch() -> None: from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher - from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher - openai_patcher.attempt_patch() + if settings is None: + settings = AutopatchSettings() + + get_openai_patcher(settings.openai).attempt_patch() mistral_patcher.attempt_patch() litellm_patcher.attempt_patch() llamaindex_patcher.attempt_patch() @@ -54,10 +103,10 @@ def reset_autopatch() -> None: from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher - from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher - openai_patcher.undo_patch() + get_openai_patcher().undo_patch() mistral_patcher.undo_patch() litellm_patcher.undo_patch() llamaindex_patcher.undo_patch() diff --git a/weave/trace/patcher.py b/weave/trace/patcher.py index 1567c4e2bb9..c1d0d653ffa 100644 --- a/weave/trace/patcher.py +++ b/weave/trace/patcher.py @@ -17,6 +17,14 @@ def undo_patch(self) -> bool: raise NotImplementedError() +class NoOpPatcher(Patcher): + def attempt_patch(self) -> bool: + return True + + def undo_patch(self) -> bool: + return True + + class MultiPatcher(Patcher): def __init__(self, patchers: Sequence[Patcher]) -> None: self.patchers = patchers diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 0eca3fcbedb..1d5d54b9b23 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -10,7 +10,7 @@ from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache -from typing import Any, Callable, cast +from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload import pydantic from requests import HTTPError @@ -90,6 +90,128 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") +R = TypeVar("R", covariant=True) + + +class FetchFunc(Protocol[T]): + def __call__(self, offset: int, limit: int) -> list[T]: ... + + +TransformFunc = Callable[[T], R] + + +class PaginatedIterator(Generic[T, R]): + """An iterator that fetches pages of items from a server and optionally transforms them + into a more user-friendly type.""" + + def __init__( + self, + fetch_func: FetchFunc[T], + page_size: int = 1000, + transform_func: TransformFunc[T, R] | None = None, + ) -> None: + self.fetch_func = fetch_func + self.page_size = page_size + self.transform_func = transform_func + + if page_size <= 0: + raise ValueError("page_size must be greater than 0") + + @lru_cache + def _fetch_page(self, index: int) -> list[T]: + return self.fetch_func(index * self.page_size, self.page_size) + + @overload + def _get_one(self: PaginatedIterator[T, T], index: int) -> T: ... + @overload + def _get_one(self: PaginatedIterator[T, R], index: int) -> R: ... + def _get_one(self, index: int) -> T | R: + if index < 0: + raise IndexError("Negative indexing not supported") + + page_index = index // self.page_size + page_offset = index % self.page_size + + page = self._fetch_page(page_index) + if page_offset >= len(page): + raise IndexError(f"Index {index} out of range") + + res = page[page_offset] + if transform := self.transform_func: + return transform(res) + return res + + @overload + def _get_slice(self: PaginatedIterator[T, T], key: slice) -> Iterator[T]: ... + @overload + def _get_slice(self: PaginatedIterator[T, R], key: slice) -> Iterator[R]: ... + def _get_slice(self, key: slice) -> Iterator[T] | Iterator[R]: + if (start := key.start or 0) < 0: + raise ValueError("Negative start not supported") + if (stop := key.stop) is not None and stop < 0: + raise ValueError("Negative stop not supported") + if (step := key.step or 1) < 0: + raise ValueError("Negative step not supported") + + i = start + while stop is None or i < stop: + try: + yield self._get_one(i) + except IndexError: + break + i += step + + @overload + def __getitem__(self: PaginatedIterator[T, T], key: int) -> T: ... + @overload + def __getitem__(self: PaginatedIterator[T, R], key: int) -> R: ... + @overload + def __getitem__(self: PaginatedIterator[T, T], key: slice) -> list[T]: ... + @overload + def __getitem__(self: PaginatedIterator[T, R], key: slice) -> list[R]: ... + def __getitem__(self, key: slice | int) -> T | R | list[T] | list[R]: + if isinstance(key, slice): + return list(self._get_slice(key)) + return self._get_one(key) + + @overload + def __iter__(self: PaginatedIterator[T, T]) -> Iterator[T]: ... + @overload + def __iter__(self: PaginatedIterator[T, R]) -> Iterator[R]: ... + def __iter__(self) -> Iterator[T] | Iterator[R]: + return self._get_slice(slice(0, None, 1)) + + +# TODO: should be Call, not WeaveObject +CallsIter = PaginatedIterator[CallSchema, WeaveObject] + + +def _make_calls_iterator( + server: TraceServerInterface, + project_id: str, + filter: CallsFilter, + include_costs: bool = False, +) -> CallsIter: + def fetch_func(offset: int, limit: int) -> list[CallSchema]: + response = server.calls_query( + CallsQueryReq( + project_id=project_id, + filter=filter, + offset=offset, + limit=limit, + include_costs=include_costs, + ) + ) + return response.calls + + # TODO: Should be Call, not WeaveObject + def transform_func(call: CallSchema) -> WeaveObject: + entity, project = project_id.split("/") + return make_client_call(entity, project, call, server) + + return PaginatedIterator(fetch_func, transform_func=transform_func) + class OpNameError(ValueError): """Raised when an op name is invalid.""" @@ -284,7 +406,7 @@ def children(self) -> CallsIter: ) client = weave_client_context.require_weave_client() - return CallsIter( + return _make_calls_iterator( client.server, self.project_id, CallsFilter(parent_ids=[self.id]), @@ -362,80 +484,6 @@ def _apply_scorer(self, scorer_op: Op) -> None: ) -class CallsIter: - server: TraceServerInterface - filter: CallsFilter - include_costs: bool - - def __init__( - self, - server: TraceServerInterface, - project_id: str, - filter: CallsFilter, - include_costs: bool = False, - ) -> None: - self.server = server - self.project_id = project_id - self.filter = filter - self._page_size = 1000 - self.include_costs = include_costs - - # seems like this caching should be on the server, but it's here for now... - @lru_cache - def _fetch_page(self, index: int) -> list[CallSchema]: - # caching in here means that any other CallsIter objects would also - # benefit from the cache - response = self.server.calls_query( - CallsQueryReq( - project_id=self.project_id, - filter=self.filter, - offset=index * self._page_size, - limit=self._page_size, - include_costs=self.include_costs, - ) - ) - return response.calls - - def _get_one(self, index: int) -> WeaveObject: - if index < 0: - raise IndexError("Negative indexing not supported") - - page_index = index // self._page_size - page_offset = index % self._page_size - - calls = self._fetch_page(page_index) - if page_offset >= len(calls): - raise IndexError(f"Index {index} out of range") - - call = calls[page_offset] - entity, project = self.project_id.split("/") - return make_client_call(entity, project, call, self.server) - - def _get_slice(self, key: slice) -> Iterator[WeaveObject]: - if (start := key.start or 0) < 0: - raise ValueError("Negative start not supported") - if (stop := key.stop) is not None and stop < 0: - raise ValueError("Negative stop not supported") - if (step := key.step or 1) < 0: - raise ValueError("Negative step not supported") - - i = start - while stop is None or i < stop: - try: - yield self._get_one(i) - except IndexError: - break - i += step - - def __getitem__(self, key: slice | int) -> WeaveObject | list[WeaveObject]: - if isinstance(key, slice): - return list(self._get_slice(key)) - return self._get_one(key) - - def __iter__(self) -> Iterator[WeaveObject]: - return self._get_slice(slice(0, None, 1)) - - def make_client_call( entity: str, project: str, server_call: CallSchema, server: TraceServerInterface ) -> WeaveObject: @@ -642,8 +690,11 @@ def get_calls( if filter is None: filter = CallsFilter() - return CallsIter( - self.server, self._project_id(), filter, include_costs or False + return _make_calls_iterator( + self.server, + self._project_id(), + filter, + include_costs, ) @deprecated(new_name="get_calls") diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index 563dcbdaed4..f51d42d5018 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -63,7 +63,9 @@ def get_entity_project_from_project_name(project_name: str) -> tuple[str, str]: def init_weave( - project_name: str, ensure_project_exists: bool = True + project_name: str, + ensure_project_exists: bool = True, + autopatch_settings: autopatch.AutopatchSettings | None = None, ) -> InitializedClient: global _current_inited_client if _current_inited_client is not None: @@ -120,7 +122,7 @@ def init_weave( # autopatching is only supported for the wandb client, because OpenAI calls are not # logged in local mode currently. When that's fixed, this autopatch call can be # moved to InitializedClient.__init__ - autopatch.autopatch() + autopatch.autopatch(autopatch_settings) username = get_username() try: diff --git a/weave/trace_server/clickhouse_trace_server_migrator.py b/weave/trace_server/clickhouse_trace_server_migrator.py index 30dffe89365..4336630bf50 100644 --- a/weave/trace_server/clickhouse_trace_server_migrator.py +++ b/weave/trace_server/clickhouse_trace_server_migrator.py @@ -1,6 +1,7 @@ # Clickhouse Trace Server Manager import logging import os +import re from typing import Optional from clickhouse_connect.driver.client import Client as CHClient @@ -9,6 +10,11 @@ logger = logging.getLogger(__name__) +# These settings are only used when `replicated` mode is enabled for +# self managed clickhouse instances. +DEFAULT_REPLICATED_PATH = "/clickhouse/tables/{db}" +DEFAULT_REPLICATED_CLUSTER = "weave_cluster" + class MigrationError(RuntimeError): """Raised when a migration error occurs.""" @@ -16,15 +22,77 @@ class MigrationError(RuntimeError): class ClickHouseTraceServerMigrator: ch_client: CHClient + replicated: bool + replicated_path: str + replicated_cluster: str def __init__( self, ch_client: CHClient, + replicated: Optional[bool] = None, + replicated_path: Optional[str] = None, + replicated_cluster: Optional[str] = None, ): super().__init__() self.ch_client = ch_client + self.replicated = False if replicated is None else replicated + self.replicated_path = ( + DEFAULT_REPLICATED_PATH if replicated_path is None else replicated_path + ) + self.replicated_cluster = ( + DEFAULT_REPLICATED_CLUSTER + if replicated_cluster is None + else replicated_cluster + ) self._initialize_migration_db() + def _is_safe_identifier(self, value: str) -> bool: + """Check if a string is safe to use as an identifier in SQL.""" + return bool(re.match(r"^[a-zA-Z0-9_\.]+$", value)) + + def _format_replicated_sql(self, sql_query: str) -> str: + """Format SQL query to use replicated engines if replicated mode is enabled.""" + if not self.replicated: + return sql_query + + # Match "ENGINE = MergeTree" followed by word boundary + pattern = r"ENGINE\s*=\s*(\w+)?MergeTree\b" + + def replace_engine(match: re.Match[str]) -> str: + engine_prefix = match.group(1) or "" + return f"ENGINE = Replicated{engine_prefix}MergeTree" + + return re.sub(pattern, replace_engine, sql_query, flags=re.IGNORECASE) + + def _create_db_sql(self, db_name: str) -> str: + """Geneate SQL database create string for normal and replicated databases.""" + if not self._is_safe_identifier(db_name): + raise MigrationError(f"Invalid database name: {db_name}") + + replicated_engine = "" + replicated_cluster = "" + if self.replicated: + if not self._is_safe_identifier(self.replicated_cluster): + raise MigrationError(f"Invalid cluster name: {self.replicated_cluster}") + + replicated_path = self.replicated_path.replace("{db}", db_name) + if not all( + self._is_safe_identifier(part) + for part in replicated_path.split("/") + if part + ): + raise MigrationError(f"Invalid replicated path: {replicated_path}") + + replicated_cluster = f" ON CLUSTER {self.replicated_cluster}" + replicated_engine = ( + f" ENGINE=Replicated('{replicated_path}', '{{shard}}', '{{replica}}')" + ) + + create_db_sql = f""" + CREATE DATABASE IF NOT EXISTS {db_name}{replicated_cluster}{replicated_engine} + """ + return create_db_sql + def apply_migrations( self, target_db: str, target_version: Optional[int] = None ) -> None: @@ -46,20 +114,15 @@ def apply_migrations( return logger.info(f"Migrations to apply: {migrations_to_apply}") if status["curr_version"] == 0: - self.ch_client.command(f"CREATE DATABASE IF NOT EXISTS {target_db}") + self.ch_client.command(self._create_db_sql(target_db)) for target_version, migration_file in migrations_to_apply: self._apply_migration(target_db, target_version, migration_file) if should_insert_costs(status["curr_version"], target_version): insert_costs(self.ch_client, target_db) def _initialize_migration_db(self) -> None: - self.ch_client.command( - """ - CREATE DATABASE IF NOT EXISTS db_management - """ - ) - self.ch_client.command( - """ + self.ch_client.command(self._create_db_sql("db_management")) + create_table_sql = """ CREATE TABLE IF NOT EXISTS db_management.migrations ( db_name String, @@ -69,7 +132,7 @@ def _initialize_migration_db(self) -> None: ENGINE = MergeTree() ORDER BY (db_name) """ - ) + self.ch_client.command(self._format_replicated_sql(create_table_sql)) def _get_migration_status(self, db_name: str) -> dict: column_names = ["db_name", "curr_version", "partially_applied_version"] @@ -184,31 +247,48 @@ def _determine_migrations_to_apply( return [] + def _execute_migration_command(self, target_db: str, command: str) -> None: + """Execute a single migration command in the context of the target database.""" + command = command.strip() + if len(command) == 0: + return + curr_db = self.ch_client.database + self.ch_client.database = target_db + self.ch_client.command(self._format_replicated_sql(command)) + self.ch_client.database = curr_db + + def _update_migration_status( + self, target_db: str, target_version: int, is_start: bool = True + ) -> None: + """Update the migration status in db_management.migrations table.""" + if is_start: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}'" + ) + else: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}'" + ) + def _apply_migration( self, target_db: str, target_version: int, migration_file: str ) -> None: logger.info(f"Applying migration {migration_file} to `{target_db}`") migration_dir = os.path.join(os.path.dirname(__file__), "migrations") migration_file_path = os.path.join(migration_dir, migration_file) + with open(migration_file_path) as f: migration_sql = f.read() - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}' - """ - ) + + # Mark migration as partially applied + self._update_migration_status(target_db, target_version, is_start=True) + + # Execute each command in the migration migration_sub_commands = migration_sql.split(";") for command in migration_sub_commands: - command = command.strip() - if len(command) == 0: - continue - curr_db = self.ch_client.database - self.ch_client.database = target_db - self.ch_client.command(command) - self.ch_client.database = curr_db - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}' - """ - ) + self._execute_migration_command(target_db, command) + + # Mark migration as fully applied + self._update_migration_status(target_db, target_version, is_start=False) + logger.info(f"Migration {migration_file} applied to `{target_db}`") diff --git a/weave/version.py b/weave/version.py index 5212f6aee7d..70f670abc21 100644 --- a/weave/version.py +++ b/weave/version.py @@ -44,4 +44,4 @@ """ -VERSION = "0.51.24-dev0" +VERSION = "0.51.25-dev0"