diff --git a/docs/docs/guides/evaluation/scorers.md b/docs/docs/guides/evaluation/scorers.md index 5babbfb68e2..581d068402d 100644 --- a/docs/docs/guides/evaluation/scorers.md +++ b/docs/docs/guides/evaluation/scorers.md @@ -260,4 +260,3 @@ The following section describes how to output the standard final summarization f - diff --git a/docs/docs/guides/integrations/local_models.md b/docs/docs/guides/integrations/local_models.md index 090d22a3f76..2597ad19dd1 100644 --- a/docs/docs/guides/integrations/local_models.md +++ b/docs/docs/guides/integrations/local_models.md @@ -14,7 +14,6 @@ First and most important, is the `base_url` change during the `openai.OpenAI()` ```python client = openai.OpenAI( - api_key='fake', base_url="http://localhost:1234", ) ``` diff --git a/docs/docs/guides/integrations/notdiamond.md b/docs/docs/guides/integrations/notdiamond.md index 3106ef11f1b..e98a23d1334 100644 --- a/docs/docs/guides/integrations/notdiamond.md +++ b/docs/docs/guides/integrations/notdiamond.md @@ -68,7 +68,6 @@ preference_id = train_router( response_column="actual", language="en", maximize=True, - api_key=api_key, ) ``` diff --git a/docs/docs/guides/tracking/ops.md b/docs/docs/guides/tracking/ops.md index b69d5d1d91a..4c1e064b0aa 100644 --- a/docs/docs/guides/tracking/ops.md +++ b/docs/docs/guides/tracking/ops.md @@ -116,6 +116,39 @@ A Weave op is a versioned function that automatically logs all calls. +## Control sampling rate + + + + You can control how frequently an op's calls are traced by setting the `tracing_sample_rate` parameter in the `@weave.op` decorator. This is useful for high-frequency ops where you only need to trace a subset of calls. + + Note that sampling rates are only applied to root calls. If an op has a sample rate, but is called by another op first, then that sampling rate will be ignored. + + ```python + @weave.op(tracing_sample_rate=0.1) # Only trace ~10% of calls + def high_frequency_op(x: int) -> int: + return x + 1 + + @weave.op(tracing_sample_rate=1.0) # Always trace (default) + def always_traced_op(x: int) -> int: + return x + 1 + ``` + + When an op's call is not sampled: + - The function executes normally + - No trace data is sent to Weave + - Child ops are also not traced for that call + + The sampling rate must be between 0.0 and 1.0 inclusive. + + + + ```plaintext + This feature is not available in TypeScript yet. Stay tuned! + ``` + + + ### Control call link output If you want to suppress the printing of call links during logging, you can use the `WEAVE_PRINT_CALL_LINK` environment variable to `false`. This can be useful if you want to reduce output verbosity and reduce clutter in your logs. diff --git a/docs/docs/quickstart.md b/docs/docs/quickstart.md index 59ac16ef236..65bc813af03 100644 --- a/docs/docs/quickstart.md +++ b/docs/docs/quickstart.md @@ -50,7 +50,7 @@ _In this example, we're using openai so you will need to add an OpenAI [API key] import weave from openai import OpenAI - client = OpenAI(api_key="...") + client = OpenAI() # Weave will track the inputs, outputs and code of this function # highlight-next-line diff --git a/docs/docs/tutorial-tracing_2.md b/docs/docs/tutorial-tracing_2.md index 719ee99e012..d6e392d9c56 100644 --- a/docs/docs/tutorial-tracing_2.md +++ b/docs/docs/tutorial-tracing_2.md @@ -24,7 +24,7 @@ Building on our [basic tracing example](/quickstart), we will now add additional import json from openai import OpenAI - client = OpenAI(api_key="...") + client = OpenAI() # highlight-next-line @weave.op() diff --git a/pyproject.toml b/pyproject.toml index d392d527b60..f34757b315e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -226,7 +226,7 @@ module = "weave_query.*" ignore_errors = true [tool.bumpversion] -current_version = "0.51.24-dev0" +current_version = "0.51.25-dev0" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/tests/conftest.py b/tests/conftest.py index b28187a3833..85e9b53c36b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -477,7 +477,9 @@ def __getattribute__(self, name): return ServerRecorder(server) -def create_client(request) -> weave_init.InitializedClient: +def create_client( + request, autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None +) -> weave_init.InitializedClient: inited_client = None weave_server_flag = request.config.getoption("--weave-server") server: tsi.TraceServerInterface @@ -513,7 +515,7 @@ def create_client(request) -> weave_init.InitializedClient: entity, project, make_server_recorder(server) ) inited_client = weave_init.InitializedClient(client) - autopatch.autopatch() + autopatch.autopatch(autopatch_settings) return inited_client @@ -527,6 +529,7 @@ def client(request): yield inited_client.client finally: inited_client.reset() + autopatch.reset_autopatch() @pytest.fixture() @@ -534,12 +537,13 @@ def client_creator(request): """This fixture is useful for delaying the creation of the client (ex. when you want to set settings first)""" @contextlib.contextmanager - def client(): - inited_client = create_client(request) + def client(autopatch_settings: typing.Optional[autopatch.AutopatchSettings] = None): + inited_client = create_client(request, autopatch_settings) try: yield inited_client.client finally: inited_client.reset() + autopatch.reset_autopatch() yield client diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml new file mode 100644 index 00000000000..7245829a0b3 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_configuration_with_dicts.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFJNa9wwEL37V0x1yWVd7P3KspcSSKE5thvooSlGK40tJbJGSOOSNOx/ + L/Z+2KEp9KLDe/Me743mNQMQVostCGUkqza4/Eav1f1O/v4aVvvbL/Pdt7tVHUp1s/vsnhsx6xW0 + f0TFZ9VHRW1wyJb8kVYRJWPvWl4vFpvNYl2UA9GSRtfLmsD5kvJ5MV/mxSYv1iehIaswiS38yAAA + Xoe3j+g1PostFLMz0mJKskGxvQwBiEiuR4RMySaWnsVsJBV5Rj+k/m5eQJO/YkhP6JDJJ6htYxhQ + KgPEBuOnB//g7w2eJ438hcAGoek4fZgaR6y7JPtevnPuhB8uSR01IdI+nfgLXltvk6kiykS+T5WY + ghjYQwbwc9hI96akCJHawBXTE/resCyPdmL8ggm5PJFMLN2Iz1ezd9wqjSytS5ONCiWVQT0qx/XL + TluaENmk899h3vM+9ra++R/7kVAKA6OuQkRt1dvC41jE/kD/NXbZ8RBYpJfE2Fa19Q3GEO3xRupQ + qWslC9xLJUV2yP4AAAD//wMA4O+DUSwDAAA= + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f01fe3aabd037cf-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 02:20:01 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=xqe_jHZdTV5LijJQYQ3GMY5MjtVrCyxbFO4glgLvgD0-1733883601-1.0.1.1-p.DDUca_cHppJu2hXzzA0CXU1mtalxHUNfBWVgPIQj.UkU603pbNscCvSIi4_Zjlz9Zuc3.hjlvoyZxcDBJTsw; + path=/; expires=Wed, 11-Dec-24 02:50:01 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=WEjxXqkGswaEDhllTROGX_go9tgaWNJcUJ3cCd50xDI-1733883601764-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '607' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_8592a74b531c806f65c63c7471101cb6 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml new file mode 100644 index 00000000000..1895cdcd5f2 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_disabled_integration_doesnt_patch.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLJbtswEL3rK6a85GIV8lYvl6KX9NQFrYEckkKgyZHImOII5KiJEfjf + C8qLHDQFeuHhbXgzw5cMQFgt1iCUkaya1uWf9Hyuv5H29ebu89Pt80Z9//H0RS2nX/e3LEbJQdtH + VHx2vVfUtA7Zkj/SKqBkTKnjxXS6XCwWk3FPNKTRJVvdcj6jfFJMZnmxzIsPJ6MhqzCKNdxnAAAv + /Zsqeo3PYg3F6Iw0GKOsUawvIgARyCVEyBhtZOmPdU+kIs/o+9Y/u4AjMBjwJoIEZ2vDuUEZGDU8 + 0g6hogB76tYP/sHfmT1o8jcMcYcOmXyEKlkApTJAbDB8TMKNwbPSyN8IbBDqjuO76xoBqy7KtAXf + OXfCD5e5HNVtoG088Re8st5GUwaUkXyaITK1omcPGcCvfn/dq5WINlDTcsm0Q58Cx+NjnBgONpCT + 2YlkYukGfDofvZFWamRpXbzav1BSGdSDcziW7LSlKyK7mvnvMm9lH+e2vv6f+IFQCltGXbYBtVWv + Bx5kAdN3/pfssuO+sIj7yNiUlfU1hjbY44+q2nKl54XSq1WxFdkh+wMAAP//AwAWTTnuWgMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eadbff439d2-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:01 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=8FO1yMjc3pMQWRpWrkIe5mcs39GLeqQPmgHQq0YTT8s-1733877721-1.0.1.1-i4G06DBN08aH1F1H73U_TB9OLK3jLsV1jXydB1cQ4Hqx7I.r8xDn.7hFRZe2hy3D_nABTG1nDcdDoXL_wYiqug; + path=/; expires=Wed, 11-Dec-24 01:12:01 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=jxwySgtriPkUP8L2os1nb_gRq_SSUo3yWFUyJmHPmGY-1733877721989-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '652' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_1c86d4fda2ad715edfd41bcd2f4bdd89 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml new file mode 100644 index 00000000000..f0cdca54158 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_enabled_integration_patches.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLBjtMwEL3nKwZfuDQoTXc3VS8IBOIACCQOHHZR5NrTxDTxWJ4J2rDq + v6Ok2SYrFomLD+/Ne3pvxg8JgHJW7UCZWotpQ5O+sdfXaL9+/PDp+L6gG3ffyudv735v+y82eLUa + FLT/iUYeVa8MtaFBcTTRJqIWHFzXxWazLYoiz0eiJYvNIKuCpFeU5ll+lWbbNLuZhDU5g6x2cJsA + ADyM7xDRW7xXO8hWj0iLzLpCtbsMAahIzYAozexYtBe1mklDXtCPqb/XPVjyLwXYOPTiWBgkdiyg + hVp+fefv/Fs0umMEqbGHVh8RugD4C2MvtfPVi6V3xEPHeqjmu6aZ8NMlbENViLTnib/gB+cd12VE + zeSHYCwU1MieEoAf41K6Jz1ViNQGKYWO6AfD9fpsp+YrLMh8IoVENzOeb1bPuJUWRbuGF0tVRpsa + 7aycL6A762hBJIvOf4d5zvvc2/nqf+xnwhgMgrYMEa0zTwvPYxGHP/qvscuOx8CKexZsy4PzFcYQ + 3fmbHEJpCqMz3GujVXJK/gAAAP//AwAyhdwOLwMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eb36bb3a240-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:02 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=Q_ATX8JU4jFqXJPdwlneOua9wmNmAaASyAfcbPyPqng-1733877722-1.0.1.1-eTMEvBW7oqQa2i3l.Or2I3LF_cCESxfseq.S9DBr8dAJWsVoFfPxKtr5vMaO6yj4hRW8XOSOHcgIcwwqbHrLbg; + path=/; expires=Wed, 11-Dec-24 01:12:02 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=2ak.tRpn6uEHbM8GrWy_ALtrN34jVSNIJI1mFG2etvM-1733877722703-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '476' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_52e061e1cc55cdd8847a7ba9342f1a14 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml b/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml new file mode 100644 index 00000000000..646c57c6123 --- /dev/null +++ b/tests/integrations/openai/cassettes/test_autopatch/test_passthrough_op_kwargs.yaml @@ -0,0 +1,102 @@ +interactions: +- request: + body: '{"messages":[{"role":"user","content":"tell me a joke"}],"model":"gpt-4o"}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate, zstd + connection: + - keep-alive + content-length: + - '74' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.57.2 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.57.2 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.13.0rc2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFLLbtswELzrK7a89GIVsmLXsS9Fr0UvBQIERVMINLkS2VBcglwVcQL/ + e0H5IQVNgV54mNkZzCz3pQAQVosdCGUkqz648rNer7G2z3fL8ITq+X6lvn7x/SHhN/d9LxZZQftf + qPii+qCoDw7Zkj/RKqJkzK7Lzc3N7WazqeuR6Emjy7IucLmisq7qVVndltXHs9CQVZjEDn4UAAAv + 45sjeo1PYgfV4oL0mJLsUOyuQwAiksuIkCnZxNKzWEykIs/ox9T35gCa/HuG9IgOmXyC1naGAaUy + QGwwfnrwD/7O4GXSyN8IbBC6gdO7uXHEdkgy9/KDc2f8eE3qqAuR9unMX/HWeptME1Em8jlVYgpi + ZI8FwM9xI8OrkiJE6gM3TI/os+FyebIT0xfMyNWZZGLpJrxeL95wazSytC7NNiqUVAb1pJzWLwdt + aUYUs85/h3nL+9Tb+u5/7CdCKQyMugkRtVWvC09jEfOB/mvsuuMxsEiHxNg3rfUdxhDt6Uba0Gz1 + ulJ6u632ojgWfwAAAP//AwCOwDMjLAMAAA== + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8f016eb76b71ac9a-YYZ + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 11 Dec 2024 00:42:03 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=r.xSSsYQNFPvMiizFSvjQiecNA6Q1wQa0VR1YElfXi4-1733877723-1.0.1.1-GVW0i7wrpHCQSY5eXu7sIQgxYWl6jfeSordQ7JFxV3lO6UfFhwxRT92bBP4DfnrSYpBpRw3k4aONAURyvKctiQ; + path=/; expires=Wed, 11-Dec-24 01:12:03 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=CQJVOdASzL9ency5_q6SDaInTsvpjA240cIxf.AUwXM-1733877723385-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + openai-organization: + - wandb + openai-processing-ms: + - '523' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '30000000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '29999979' + x-ratelimit-reset-requests: + - 6ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_c9c57cfa6f37a99aaf0abac013237ed6 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/integrations/openai/test_autopatch.py b/tests/integrations/openai/test_autopatch.py new file mode 100644 index 00000000000..2c2f5201d3f --- /dev/null +++ b/tests/integrations/openai/test_autopatch.py @@ -0,0 +1,116 @@ +# This is included here for convenience. Instead of creating a dummy API, we can test +# autopatching against the actual OpenAI API. + +from typing import Any + +import pytest +from openai import OpenAI + +from weave.integrations.openai import openai_sdk +from weave.trace.autopatch import AutopatchSettings, IntegrationSettings, OpSettings + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_disabled_integration_doesnt_patch(client_creator): + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings(enabled=False), + ) + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 0 + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_enabled_integration_patches(client_creator): + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings(enabled=True), + ) + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_passthrough_op_kwargs(client_creator): + def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return dict.fromkeys(inputs, "REDACTED") + + autopatch_settings = AutopatchSettings( + openai=IntegrationSettings( + op_settings=OpSettings( + postprocess_inputs=redact_inputs, + ) + ) + ) + + # Explicitly reset the patcher here to pretend like we're starting fresh. We need + # to do this because `_openai_patcher` is a global variable that is shared across + # tests. If we don't reset it, it will retain the state from the previous test, + # which can cause this test to fail. + openai_sdk._openai_patcher = None + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + call = calls[0] + assert all(v == "REDACTED" for v in call.inputs.values()) + + +@pytest.mark.skip_clickhouse_client # TODO:VCR recording does not seem to allow us to make requests to the clickhouse db in non-recording mode +@pytest.mark.vcr( + filter_headers=["authorization"], allowed_hosts=["api.wandb.ai", "localhost"] +) +def test_configuration_with_dicts(client_creator): + def redact_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return dict.fromkeys(inputs, "REDACTED") + + autopatch_settings = { + "openai": { + "op_settings": {"postprocess_inputs": redact_inputs}, + } + } + + openai_sdk._openai_patcher = None + + with client_creator(autopatch_settings=autopatch_settings) as client: + oaiclient = OpenAI() + oaiclient.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "tell me a joke"}], + ) + + calls = list(client.get_calls()) + assert len(calls) == 1 + + call = calls[0] + assert all(v == "REDACTED" for v in call.inputs.values()) diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 5b45abc8432..005c79f5cb0 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -60,7 +60,7 @@ def get_client_project_id(client: weave_client.WeaveClient) -> str: def test_simple_op(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -229,7 +229,7 @@ def test_call_read_not_found(client): def test_graph_call_ordering(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -263,27 +263,27 @@ def simple_line_call_bootstrap(init_wandb: bool = False) -> OpCallSpec: class Number(weave.Object): value: int - @weave.op() + @weave.op def adder(a: Number) -> Number: return Number(value=a.value + a.value) adder_v0 = adder - @weave.op() + @weave.op # type: ignore def adder(a: Number, b) -> Number: return Number(value=a.value + b) - @weave.op() + @weave.op def subtractor(a: Number, b) -> Number: return Number(value=a.value - b) - @weave.op() + @weave.op def multiplier( a: Number, b ) -> int: # intentionally deviant in returning plain int - so that we have a different type return a.value * b - @weave.op() + @weave.op def liner(m: Number, b, x) -> Number: return adder(Number(value=multiplier(m, x)), b) @@ -691,7 +691,7 @@ def test_trace_call_query_offset(client): def test_trace_call_sort(client): - @weave.op() + @weave.op def basic_op(in_val: dict, delay) -> dict: import time @@ -727,7 +727,7 @@ def test_trace_call_sort_with_mixed_types(client): # SQLite does not support sorting over mixed types in a column, so we skip this test return - @weave.op() + @weave.op def basic_op(in_val: dict) -> dict: import time @@ -769,7 +769,7 @@ def basic_op(in_val: dict) -> dict: def test_trace_call_filter(client): is_sqlite = client_is_sqlite(client) - @weave.op() + @weave.op def basic_op(in_val: dict, delay) -> dict: return in_val @@ -1160,7 +1160,7 @@ def basic_op(in_val: dict, delay) -> dict: def test_ops_with_default_params(client): - @weave.op() + @weave.op def op_with_default(a: int, b: int = 10) -> int: return a + b @@ -1234,7 +1234,7 @@ class BaseTypeC(BaseTypeB): def test_attributes_on_ops(client): - @weave.op() + @weave.op def op_with_attrs(a: int, b: int) -> int: return a + b @@ -1277,7 +1277,7 @@ def test_dataclass_support(client): class MyDataclass: val: int - @weave.op() + @weave.op def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass: return MyDataclass(a.val + b.val) @@ -1322,7 +1322,7 @@ def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass: def test_op_retrieval(client): - @weave.op() + @weave.op def my_op(a: int) -> int: return a + 1 @@ -1336,7 +1336,7 @@ def test_bound_op_retrieval(client): class CustomType(weave.Object): a: int - @weave.op() + @weave.op def op_with_custom_type(self, v): return self.a + v @@ -1359,7 +1359,7 @@ def test_bound_op_retrieval_no_self(client): class CustomTypeWithoutSelf(weave.Object): a: int - @weave.op() + @weave.op def op_with_custom_type(me, v): return me.a + v @@ -1387,7 +1387,7 @@ def test_dataset_row_ref(client): def test_tuple_support(client): - @weave.op() + @weave.op def tuple_maker(a, b): return (a, b) @@ -1411,7 +1411,7 @@ def tuple_maker(a, b): def test_namedtuple_support(client): - @weave.op() + @weave.op def tuple_maker(a, b): return (a, b) @@ -1442,7 +1442,7 @@ def test_named_reuse(client): d_ref = weave.publish(d, "test_dataset") dataset = weave.ref(d_ref.uri()).get() - @weave.op() + @weave.op async def dummy_score(output): return 1 @@ -1489,7 +1489,7 @@ class MyUnknownClassB: def __init__(self, b_val) -> None: self.b_val = b_val - @weave.op() + @weave.op def op_with_unknown_types(a: MyUnknownClassA, b: float) -> MyUnknownClassB: return MyUnknownClassB(a.a_val + b) @@ -1564,19 +1564,19 @@ def init_weave_get_server_patched(api_key): def test_single_primitive_output(client): - @weave.op() + @weave.op def single_int_output(a: int) -> int: return a - @weave.op() + @weave.op def single_bool_output(a: int) -> bool: return a == 1 - @weave.op() + @weave.op def single_none_output(a: int) -> None: return None - @weave.op() + @weave.op def dict_output(a: int, b: bool, c: None) -> dict: return {"a": a, "b": b, "c": c} @@ -1669,14 +1669,14 @@ def test_mapped_execution(client, mapper): events = [] - @weave.op() + @weave.op def op_a(a: int) -> int: events.append("A(S):" + str(a)) time.sleep(0.3) events.append("A(E):" + str(a)) return a - @weave.op() + @weave.op def op_b(b: int) -> int: events.append("B(S):" + str(b)) time.sleep(0.2) @@ -1684,7 +1684,7 @@ def op_b(b: int) -> int: events.append("B(E):" + str(b)) return res - @weave.op() + @weave.op def op_c(c: int) -> int: events.append("C(S):" + str(c)) time.sleep(0.1) @@ -1692,7 +1692,7 @@ def op_c(c: int) -> int: events.append("C(E):" + str(c)) return res - @weave.op() + @weave.op def op_mapper(vals): return mapper(op_c, vals) @@ -2127,7 +2127,7 @@ def calculate(a: int, b: int) -> int: def test_call_query_stream_columns(client): @weave.op - def calculate(a: int, b: int) -> int: + def calculate(a: int, b: int) -> dict[str, Any]: return {"result": {"a + b": a + b}, "not result": 123} for i in range(2): @@ -2170,7 +2170,7 @@ def test_call_query_stream_columns_with_costs(client): return @weave.op - def calculate(a: int, b: int) -> int: + def calculate(a: int, b: int) -> dict[str, Any]: return { "result": {"a + b": a + b}, "not result": 123, @@ -2238,7 +2238,7 @@ def calculate(a: int, b: int) -> int: @pytest.mark.skip("Not implemented: filter / sort through refs") def test_sort_and_filter_through_refs(client): - @weave.op() + @weave.op def test_op(label, val): return val @@ -2272,7 +2272,8 @@ def test_obj(val): # Ref at A, B and C test_op( - values[7], {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})} + values[7], + {"a": test_obj({"b": test_obj({"c": test_obj({"d": values[7]})})})}, ) for first, last, sort_by in [ @@ -2355,7 +2356,7 @@ def test_obj(val): def test_in_operation(client): - @weave.op() + @weave.op def test_op(label, val): return val @@ -2500,7 +2501,7 @@ def func(x): class BasicModel(weave.Model): - @weave.op() + @weave.op def predict(self, x): return {"answer": "42"} @@ -2546,7 +2547,7 @@ class SimpleObject(weave.Object): class NestedObject(weave.Object): b: SimpleObject - @weave.op() + @weave.op def return_nested_object(nested_obj: NestedObject): return nested_obj @@ -2997,3 +2998,224 @@ def foo(): foo() assert len(list(weave_client.get_calls())) == 1 assert weave.trace.weave_init._current_inited_client is None + + +def test_op_sampling(client): + never_traced_calls = 0 + always_traced_calls = 0 + sometimes_traced_calls = 0 + + @weave.op(tracing_sample_rate=0.0) + def never_traced(x: int) -> int: + nonlocal never_traced_calls + never_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) + def always_traced(x: int) -> int: + nonlocal always_traced_calls + always_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.5) + def sometimes_traced(x: int) -> int: + nonlocal sometimes_traced_calls + sometimes_traced_calls += 1 + return x + 1 + + weave.publish(never_traced) + # Never traced should execute but not be traced + for i in range(10): + never_traced(i) + assert never_traced_calls == 10 # Function was called + assert len(list(never_traced.calls())) == 0 # Not traced + + # Always traced should execute and be traced + for i in range(10): + always_traced(i) + assert always_traced_calls == 10 # Function was called + assert len(list(always_traced.calls())) == 10 # And traced + # Sanity check that the call_start was logged, unlike in the never_traced case. + assert "call_start" in client.server.attribute_access_log + + # Sometimes traced should execute always but only be traced sometimes + num_runs = 100 + for i in range(num_runs): + sometimes_traced(i) + assert sometimes_traced_calls == num_runs # Function was called every time + num_traces = len(list(sometimes_traced.calls())) + assert 35 < num_traces < 65 # But only traced ~50% of the time + + +def test_op_sampling_async(client): + never_traced_calls = 0 + always_traced_calls = 0 + sometimes_traced_calls = 0 + + @weave.op(tracing_sample_rate=0.0) + async def never_traced(x: int) -> int: + nonlocal never_traced_calls + never_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) + async def always_traced(x: int) -> int: + nonlocal always_traced_calls + always_traced_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.5) + async def sometimes_traced(x: int) -> int: + nonlocal sometimes_traced_calls + sometimes_traced_calls += 1 + return x + 1 + + import asyncio + + weave.publish(never_traced) + # Never traced should execute but not be traced + for i in range(10): + asyncio.run(never_traced(i)) + assert never_traced_calls == 10 # Function was called + assert len(list(never_traced.calls())) == 0 # Not traced + + # Always traced should execute and be traced + for i in range(10): + asyncio.run(always_traced(i)) + assert always_traced_calls == 10 # Function was called + assert len(list(always_traced.calls())) == 10 # And traced + assert "call_start" in client.server.attribute_access_log + + # Sometimes traced should execute always but only be traced sometimes + num_runs = 100 + for i in range(num_runs): + asyncio.run(sometimes_traced(i)) + assert sometimes_traced_calls == num_runs # Function was called every time + num_traces = len(list(sometimes_traced.calls())) + assert 35 < num_traces < 65 # But only traced ~50% of the time + + +def test_op_sampling_inheritance(client): + parent_calls = 0 + child_calls = 0 + + @weave.op + def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.0) + def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return child_op(x) + + weave.publish(parent_op) + # When parent is sampled out, child should still execute but not be traced + for i in range(10): + parent_op(i) + + assert parent_calls == 10 # Parent function executed + assert child_calls == 10 # Child function executed + assert len(list(parent_op.calls())) == 0 # Parent not traced + + # Reset counters + child_calls = 0 + + # Direct calls to child should execute and be traced + for i in range(10): + child_op(i) + + assert child_calls == 10 # Child function executed + assert len(list(child_op.calls())) == 10 # And was traced + assert "call_start" in client.server.attribute_access_log # Verify tracing occurred + + +def test_op_sampling_inheritance_async(client): + parent_calls = 0 + child_calls = 0 + + @weave.op + async def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=0.0) + async def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return await child_op(x) + + import asyncio + + weave.publish(parent_op) + # When parent is sampled out, child should still execute but not be traced + for i in range(10): + asyncio.run(parent_op(i)) + + assert parent_calls == 10 # Parent function executed + assert child_calls == 10 # Child function executed + assert len(list(parent_op.calls())) == 0 # Parent not traced + + # Reset counters + child_calls = 0 + + # Direct calls to child should execute and be traced + for i in range(10): + asyncio.run(child_op(i)) + + assert child_calls == 10 # Child function executed + assert len(list(child_op.calls())) == 10 # And was traced + assert "call_start" in client.server.attribute_access_log # Verify tracing occurred + + +def test_op_sampling_invalid_rates(client): + with pytest.raises(ValueError): + + @weave.op(tracing_sample_rate=-0.5) + def negative_rate(): + pass + + with pytest.raises(ValueError): + + @weave.op(tracing_sample_rate=1.5) + def too_high_rate(): + pass + + with pytest.raises(TypeError): + + @weave.op(tracing_sample_rate="invalid") # type: ignore + def invalid_type(): + pass + + +def test_op_sampling_child_follows_parent(client): + parent_calls = 0 + child_calls = 0 + + @weave.op(tracing_sample_rate=0.0) # Never traced + def child_op(x: int) -> int: + nonlocal child_calls + child_calls += 1 + return x + 1 + + @weave.op(tracing_sample_rate=1.0) # Always traced + def parent_op(x: int) -> int: + nonlocal parent_calls + parent_calls += 1 + return child_op(x) + + num_runs = 100 + for i in range(num_runs): + parent_op(i) + + assert parent_calls == num_runs # Parent was always executed + assert child_calls == num_runs # Child was always executed + + parent_traces = len(list(parent_op.calls())) + child_traces = len(list(child_op.calls())) + + assert parent_traces == num_runs # Parent was always traced + assert child_traces == num_runs # Child was traced whenever parent was diff --git a/tests/trace/test_evaluations.py b/tests/trace/test_evaluations.py index f4993e9227d..ab74d4c0c0b 100644 --- a/tests/trace/test_evaluations.py +++ b/tests/trace/test_evaluations.py @@ -7,7 +7,7 @@ from PIL import Image import weave -from tests.trace.util import AnyIntMatcher +from tests.trace.util import AnyIntMatcher, AnyStrMatcher from weave import Evaluation, Model from weave.scorers import Scorer from weave.trace.refs import CallRef @@ -504,8 +504,8 @@ async def test_evaluation_data_topology(client): } }, "weave": { + "display_name": AnyStrMatcher(), "latency_ms": AnyIntMatcher(), - "trace_name": "Evaluation.evaluate", "status": "success", }, } @@ -1021,11 +1021,35 @@ def my_second_scorer(text, output, model_output): ds = [{"text": "hello"}] - with pytest.raises(ValueError, match="Both 'output' and 'model_output'"): + with pytest.raises( + ValueError, match="cannot include both `output` and `model_output`" + ): scorer = MyScorer() - with pytest.raises(ValueError, match="Both 'output' and 'model_output'"): + with pytest.raises( + ValueError, match="cannot include both `output` and `model_output`" + ): evaluation = weave.Evaluation(dataset=ds, scorers=[MyScorer()]) - with pytest.raises(ValueError, match="Both 'output' and 'model_output'"): + with pytest.raises( + ValueError, match="cannot include both `output` and `model_output`" + ): evaluation = weave.Evaluation(dataset=ds, scorers=[my_second_scorer]) + + +@pytest.mark.asyncio +async def test_evaluation_with_custom_name(client): + dataset = weave.Dataset(rows=[{"input": "hi", "output": "hello"}]) + evaluation = weave.Evaluation(dataset=dataset, evaluation_name="wow-custom!") + + @weave.op() + def model(input: str) -> str: + return "hmmm" + + await evaluation.evaluate(model) + + calls = list(client.get_calls(filter=tsi.CallsFilter(trace_roots_only=True))) + assert len(calls) == 1 + + call = calls[0] + assert call.display_name == "wow-custom!" diff --git a/tests/trace/test_trace_server_common.py b/tests/trace/test_trace_server_common.py index 9bc7495481f..d9170f83eee 100644 --- a/tests/trace/test_trace_server_common.py +++ b/tests/trace/test_trace_server_common.py @@ -1,4 +1,5 @@ from weave.trace_server.trace_server_common import ( + DynamicBatchProcessor, LRUCache, get_nested_key, set_nested_key, @@ -54,3 +55,26 @@ def test_lru_cache(): cache["c"] = 10 assert cache["c"] == 10 assert cache["d"] == 4 + + +def test_dynamic_batch_processor(): + # Initialize processor with: + # - initial batch size of 2 + # - max size of 8 + # - growth factor of 2 + processor = DynamicBatchProcessor(initial_size=2, max_size=8, growth_factor=2) + + test_data = range(15) + + batches = list(processor.make_batches(iter(test_data))) + + # Expected batch sizes: 2, 4, 8, 1 + assert batches[0] == [0, 1] + assert batches[1] == [2, 3, 4, 5] + assert batches[2] == [6, 7, 8, 9, 10, 11, 12, 13] + assert batches[3] == [14] + assert len(batches) == 4 + + # Verify all items were processed + flattened = [item for batch in batches for item in batch] + assert flattened == list(range(15)) diff --git a/tests/trace/util.py b/tests/trace/util.py index eb4c6002beb..beb651722e3 100644 --- a/tests/trace/util.py +++ b/tests/trace/util.py @@ -8,6 +8,13 @@ def client_is_sqlite(client): return isinstance(client.server._internal_trace_server, SqliteTraceServer) +class AnyStrMatcher: + """Matches any string.""" + + def __eq__(self, other): + return isinstance(other, str) + + class AnyIntMatcher: """Matches any integer.""" diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py new file mode 100644 index 00000000000..3a6a92f2479 --- /dev/null +++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py @@ -0,0 +1,230 @@ +import types +from unittest.mock import Mock, call, patch + +import pytest + +from weave.trace_server import clickhouse_trace_server_migrator as trace_server_migrator +from weave.trace_server.clickhouse_trace_server_migrator import MigrationError + + +@pytest.fixture +def mock_costs(): + with patch( + "weave.trace_server.costs.insert_costs.should_insert_costs", return_value=False + ) as mock_should_insert: + with patch( + "weave.trace_server.costs.insert_costs.get_current_costs", return_value=[] + ) as mock_get_costs: + yield + + +@pytest.fixture +def migrator(): + ch_client = Mock() + migrator = trace_server_migrator.ClickHouseTraceServerMigrator(ch_client) + migrator._get_migration_status = Mock() + migrator._get_migrations = Mock() + migrator._determine_migrations_to_apply = Mock() + migrator._update_migration_status = Mock() + ch_client.command.reset_mock() + return migrator + + +def test_apply_migrations_with_target_version(mock_costs, migrator, tmp_path): + # Setup + migrator._get_migration_status.return_value = { + "curr_version": 1, + "partially_applied_version": None, + } + migrator._get_migrations.return_value = { + "1": {"up": "1.up.sql", "down": "1.down.sql"}, + "2": {"up": "2.up.sql", "down": "2.down.sql"}, + } + migrator._determine_migrations_to_apply.return_value = [(2, "2.up.sql")] + + # Create a temporary migration file + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + migration_file = migration_dir / "2.up.sql" + migration_file.write_text( + "CREATE TABLE test1 (id Int32);\nCREATE TABLE test2 (id Int32);" + ) + + # Mock the migration directory path + with patch("os.path.dirname") as mock_dirname: + mock_dirname.return_value = str(tmp_path) + + # Execute + migrator.apply_migrations("test_db", target_version=2) + + # Verify + migrator._get_migration_status.assert_called_once_with("test_db") + migrator._get_migrations.assert_called_once() + migrator._determine_migrations_to_apply.assert_called_once_with( + 1, migrator._get_migrations.return_value, 2 + ) + + # Verify migration execution + assert migrator._update_migration_status.call_count == 2 + migrator._update_migration_status.assert_has_calls( + [call("test_db", 2, is_start=True), call("test_db", 2, is_start=False)] + ) + + # Verify the actual SQL commands were executed + ch_client = migrator.ch_client + assert ch_client.command.call_count == 2 + ch_client.command.assert_has_calls( + [call("CREATE TABLE test1 (id Int32)"), call("CREATE TABLE test2 (id Int32)")] + ) + + +def test_execute_migration_command(mock_costs, migrator): + # Setup + ch_client = migrator.ch_client + ch_client.database = "original_db" + + # Execute + migrator._execute_migration_command("test_db", "CREATE TABLE test (id Int32)") + + # Verify + assert ch_client.database == "original_db" # Should restore original database + ch_client.command.assert_called_once_with("CREATE TABLE test (id Int32)") + + +def test_migration_replicated(mock_costs, migrator): + ch_client = migrator.ch_client + orig = "CREATE TABLE test (id String, project_id String) ENGINE = MergeTree ORDER BY (project_id, id);" + migrator._execute_migration_command("test_db", orig) + ch_client.command.assert_called_once_with(orig) + + +def test_update_migration_status(mock_costs, migrator): + # Don't mock _update_migration_status for this test + migrator._update_migration_status = types.MethodType( + trace_server_migrator.ClickHouseTraceServerMigrator._update_migration_status, + migrator, + ) + + # Test start of migration + migrator._update_migration_status("test_db", 2, is_start=True) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE partially_applied_version = 2 WHERE db_name = 'test_db'" + ) + + # Test end of migration + migrator._update_migration_status("test_db", 2, is_start=False) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE curr_version = 2, partially_applied_version = NULL WHERE db_name = 'test_db'" + ) + + +def test_is_safe_identifier(mock_costs, migrator): + # Valid identifiers + assert migrator._is_safe_identifier("test_db") + assert migrator._is_safe_identifier("my_db123") + assert migrator._is_safe_identifier("db.table") + + # Invalid identifiers + assert not migrator._is_safe_identifier("test-db") + assert not migrator._is_safe_identifier("db;") + assert not migrator._is_safe_identifier("db'name") + assert not migrator._is_safe_identifier("db/*") + + +def test_create_db_sql_validation(mock_costs, migrator): + # Test invalid database name + with pytest.raises(MigrationError, match="Invalid database name"): + migrator._create_db_sql("test;db") + + # Test replicated mode with invalid values + migrator.replicated = True + migrator.replicated_cluster = "test;cluster" + with pytest.raises(MigrationError, match="Invalid cluster name"): + migrator._create_db_sql("test_db") + + migrator.replicated_cluster = "test_cluster" + migrator.replicated_path = "/clickhouse/bad;path/{db}" + with pytest.raises(MigrationError, match="Invalid replicated path"): + migrator._create_db_sql("test_db") + + +def test_create_db_sql_non_replicated(mock_costs, migrator): + # Test non-replicated mode + migrator.replicated = False + sql = migrator._create_db_sql("test_db") + assert sql.strip() == "CREATE DATABASE IF NOT EXISTS test_db" + + +def test_create_db_sql_replicated(mock_costs, migrator): + # Test replicated mode + migrator.replicated = True + migrator.replicated_path = "/clickhouse/tables/{db}" + migrator.replicated_cluster = "test_cluster" + + sql = migrator._create_db_sql("test_db") + expected = """ + CREATE DATABASE IF NOT EXISTS test_db ON CLUSTER test_cluster ENGINE=Replicated('/clickhouse/tables/test_db', '{shard}', '{replica}') + """.strip() + assert sql.strip() == expected + + +def test_format_replicated_sql_non_replicated(mock_costs, migrator): + # Test that SQL is unchanged when replicated=False + migrator.replicated = False + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql + + +def test_format_replicated_sql_replicated(mock_costs, migrator): + # Test that MergeTree engines are converted to Replicated variants + migrator.replicated = True + + test_cases = [ + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedSummingMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedReplacingMergeTree", + ), + # Test with extra whitespace + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + # Test with parameters + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree()", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree()", + ), + ] + + for input_sql, expected_sql in test_cases: + assert migrator._format_replicated_sql(input_sql) == expected_sql + + +def test_format_replicated_sql_non_mergetree(mock_costs, migrator): + # Test that non-MergeTree engines are left unchanged + migrator.replicated = True + + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = Memory", + "CREATE TABLE test (id Int32) ENGINE = Log", + "CREATE TABLE test (id Int32) ENGINE = TinyLog", + # This should not be changed as it's not a complete word match + "CREATE TABLE test (id Int32) ENGINE = MyMergeTreeCustom", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql diff --git a/weave-js/package.json b/weave-js/package.json index cb57125143a..db96be60691 100644 --- a/weave-js/package.json +++ b/weave-js/package.json @@ -141,6 +141,7 @@ "unified": "^10.1.0", "unist-util-visit": "3.1.0", "universal-perf-hooks": "^1.0.1", + "uuid": "^11.0.3", "vega": "^5.24.0", "vega-lite": "5.6.0", "vega-tooltip": "^0.28.0", @@ -192,7 +193,6 @@ "@types/react-virtualized-auto-sizer": "^1.0.0", "@types/safe-json-stringify": "^1.1.2", "@types/styled-components": "^5.1.26", - "@types/uuid": "^9.0.1", "@types/wavesurfer.js": "^2.0.0", "@types/zen-observable": "^0.8.3", "@typescript-eslint/eslint-plugin": "5.35.1", @@ -237,7 +237,6 @@ "tslint-config-prettier": "^1.18.0", "tslint-plugin-prettier": "^2.3.0", "typescript": "4.7.4", - "uuid": "^9.0.0", "vite": "5.2.9", "vitest": "^1.6.0" }, diff --git a/weave-js/src/assets/icons/icon-moon.svg b/weave-js/src/assets/icons/icon-moon.svg new file mode 100644 index 00000000000..e448eab96c3 --- /dev/null +++ b/weave-js/src/assets/icons/icon-moon.svg @@ -0,0 +1,3 @@ + + + diff --git a/weave-js/src/assets/icons/icon-not-visible.svg b/weave-js/src/assets/icons/icon-not-visible.svg index 766810c7811..b2782d987b9 100644 --- a/weave-js/src/assets/icons/icon-not-visible.svg +++ b/weave-js/src/assets/icons/icon-not-visible.svg @@ -2,4 +2,4 @@ - \ No newline at end of file + diff --git a/weave-js/src/assets/icons/icon-pin-to-right.svg b/weave-js/src/assets/icons/icon-pin-to-right.svg index 1ae05ea52ae..46a9c0bf114 100644 --- a/weave-js/src/assets/icons/icon-pin-to-right.svg +++ b/weave-js/src/assets/icons/icon-pin-to-right.svg @@ -2,4 +2,4 @@ - \ No newline at end of file + diff --git a/weave-js/src/assets/icons/icon-spiral.svg b/weave-js/src/assets/icons/icon-spiral.svg new file mode 100644 index 00000000000..ce5c147b43a --- /dev/null +++ b/weave-js/src/assets/icons/icon-spiral.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/weave-js/src/assets/icons/icon-sun.svg b/weave-js/src/assets/icons/icon-sun.svg new file mode 100644 index 00000000000..bb0c57891b0 --- /dev/null +++ b/weave-js/src/assets/icons/icon-sun.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/weave-js/src/assets/icons/icon-visible.svg b/weave-js/src/assets/icons/icon-visible.svg new file mode 100644 index 00000000000..823c36bb17f --- /dev/null +++ b/weave-js/src/assets/icons/icon-visible.svg @@ -0,0 +1,4 @@ + + + + diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts index 95c7639dc4a..df5eca57b46 100644 --- a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts +++ b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts @@ -4,6 +4,7 @@ import { DEFAULT_POINT_COLOR, getFilteringOptionsForPointCloud, getVertexCompatiblePositionsAndColors, + loadPointCloud, MAX_BOUNDING_BOX_LABELS_FOR_DISPLAY, MaxAlphaValue, } from './SdkPointCloudToBabylon'; @@ -174,3 +175,16 @@ describe('getFilteringOptionsForPointCloud', () => { expect(newClassIdToLabel.get(49)).toEqual('label49'); }); }); +describe('loadPointCloud', () => { + it('appropriate defaults set when loading point cloud from file', () => { + const fileContents = JSON.stringify({ + boxes: [], + points: [[]], + type: 'lidar/beta', + vectors: [], + }); + const babylonPointCloud = loadPointCloud(fileContents); + expect(babylonPointCloud.points).toHaveLength(1); + expect(babylonPointCloud.points[0].position).toEqual([0, 0, 0]); + }); +}); diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.ts index 274e1676be4..d52682743ee 100644 --- a/weave-js/src/common/util/SdkPointCloudToBabylon.ts +++ b/weave-js/src/common/util/SdkPointCloudToBabylon.ts @@ -160,7 +160,7 @@ export const handlePoints = (object3D: Object3DScene): ScenePoint[] => { // Draw Points return truncatedPoints.map(point => { const [x, y, z, r, g, b] = point; - const position: Position = [x, y, z]; + const position: Position = [x ?? 0, y ?? 0, z ?? 0]; const category = r; if (r !== undefined && g !== undefined && b !== undefined) { diff --git a/weave-js/src/common/util/render_babylon.ts b/weave-js/src/common/util/render_babylon.ts index 10aee3f6c51..ebd213c2677 100644 --- a/weave-js/src/common/util/render_babylon.ts +++ b/weave-js/src/common/util/render_babylon.ts @@ -394,6 +394,15 @@ const pointCloudScene = ( // Apply vertexData to custom mesh vertexData.applyToMesh(pcMesh); + // A file without any points defined still includes a single, empty "point". + // In order to play nice with Babylon, we position this empty point at 0,0,0. + // Hence, a pointCloud with a single point at 0,0,0 is likely empty. + const isEmpty = + pointCloud.points.length === 1 && + pointCloud.points[0].position[0] === 0 && + pointCloud.points[0].position[1] === 0 && + pointCloud.points[0].position[2] === 0; + camera.parent = pcMesh; const pcMaterial = new Babylon.StandardMaterial('mat', scene); @@ -472,8 +481,8 @@ const pointCloudScene = ( new Vector3(edges.length * 2, edges.length * 2, edges.length * 2) ); - // If we are iterating over camera, target a box - if (index === meta?.cameraIndex) { + // If we are iterating over camera or the cloud is empty, target a box + if (index === meta?.cameraIndex || (index === 0 && isEmpty)) { camera.position = center.add(new Vector3(0, 0, 1000)); camera.target = center; camera.zoomOn([lines]); diff --git a/weave-js/src/components/ErrorBoundary.tsx b/weave-js/src/components/ErrorBoundary.tsx index 48b4fcf1a1e..053af08aa97 100644 --- a/weave-js/src/components/ErrorBoundary.tsx +++ b/weave-js/src/components/ErrorBoundary.tsx @@ -1,6 +1,7 @@ import {datadogRum} from '@datadog/browser-rum'; import * as Sentry from '@sentry/react'; import React, {Component, ErrorInfo, ReactNode} from 'react'; +import {v7 as uuidv7} from 'uuid'; import {weaveErrorToDDPayload} from '../errors'; import {ErrorPanel} from './ErrorPanel'; @@ -10,24 +11,31 @@ type Props = { }; type State = { - hasError: boolean; + uuid: string | undefined; + timestamp: Date | undefined; + error: Error | undefined; }; export class ErrorBoundary extends Component { - public static getDerivedStateFromError(_: Error): State { - return {hasError: true}; + public static getDerivedStateFromError(error: Error): State { + return {uuid: uuidv7(), timestamp: new Date(), error}; } public state: State = { - hasError: false, + uuid: undefined, + timestamp: undefined, + error: undefined, }; public componentDidCatch(error: Error, errorInfo: ErrorInfo) { + const {uuid} = this.state; datadogRum.addAction( 'weave_panel_error_boundary', - weaveErrorToDDPayload(error) + weaveErrorToDDPayload(error, undefined, uuid) ); - Sentry.captureException(error, { + extra: { + uuid, + }, tags: { weaveErrorBoundary: 'true', }, @@ -35,8 +43,9 @@ export class ErrorBoundary extends Component { } public render() { - if (this.state.hasError) { - return ; + const {uuid, timestamp, error} = this.state; + if (error != null) { + return ; } return this.props.children; diff --git a/weave-js/src/components/ErrorPanel.tsx b/weave-js/src/components/ErrorPanel.tsx index cbcfd244ac9..86945922150 100644 --- a/weave-js/src/components/ErrorPanel.tsx +++ b/weave-js/src/components/ErrorPanel.tsx @@ -1,7 +1,19 @@ -import React, {forwardRef, useEffect, useRef, useState} from 'react'; +import copyToClipboard from 'copy-to-clipboard'; +import _ from 'lodash'; +import React, { + forwardRef, + useCallback, + useEffect, + useRef, + useState, +} from 'react'; import styled from 'styled-components'; +import {toast} from '../common/components/elements/Toast'; import {hexToRGB, MOON_300, MOON_600} from '../common/css/globals.styles'; +import {useViewerInfo} from '../common/hooks/useViewerInfo'; +import {getCookieBool, getFirebaseCookie} from '../common/util/cookie'; +import {Button} from './Button'; import {Icon} from './Icon'; import {Tooltip} from './Tooltip'; @@ -14,6 +26,11 @@ type ErrorPanelProps = { title?: string; subtitle?: string; subtitle2?: string; + + // These props are for error details object + uuid?: string; + timestamp?: Date; + error?: Error; }; export const Centered = styled.div` @@ -85,11 +102,87 @@ export const ErrorPanelSmall = ({ ); }; +const getDateObject = (timestamp?: Date): Record | null => { + if (!timestamp) { + return null; + } + return { + // e.g. "2024-12-12T06:10:19.475Z", + iso: timestamp.toISOString(), + // e.g. "Thursday, December 12, 2024 at 6:10:19 AM Coordinated Universal Time" + long: timestamp.toLocaleString('en-US', { + weekday: 'long', + year: 'numeric', // Full year + month: 'long', // Full month name + day: 'numeric', // Day of the month + hour: 'numeric', // Hour (12-hour or 24-hour depending on locale) + minute: 'numeric', + second: 'numeric', + timeZone: 'UTC', // Ensures it's in UTC + timeZoneName: 'long', // Full time zone name + }), + user: timestamp.toLocaleString('en-US', { + dateStyle: 'full', + timeStyle: 'full', + }), + }; +}; + +const getErrorObject = (error?: Error): Record | null => { + if (!error) { + return null; + } + + // Error object properties are not enumerable so we have to copy them manually + const stack = (error.stack ?? '').split('\n'); + return { + message: error.message, + stack, + }; +}; + export const ErrorPanelLarge = forwardRef( - ({title, subtitle, subtitle2}, ref) => { + ({title, subtitle, subtitle2, uuid, timestamp, error}, ref) => { const titleStr = title ?? DEFAULT_TITLE; const subtitleStr = subtitle ?? DEFAULT_SUBTITLE; const subtitle2Str = subtitle2 ?? DEFAULT_SUBTITLE2; + + const {userInfo} = useViewerInfo(); + + const onClick = useCallback(() => { + const betaVersion = getFirebaseCookie('betaVersion'); + const isUsingAdminPrivileges = getCookieBool('use_admin_privileges'); + const {location, navigator, screen} = window; + const {userAgent, language} = navigator; + const details = { + uuid, + url: location.href, + error: getErrorObject(error), + timestamp_err: getDateObject(timestamp), + timestamp_copied: getDateObject(new Date()), + user: _.pick(userInfo, ['id', 'username']), // Skipping teams and admin + cookies: { + ...(betaVersion && {betaVersion}), + ...(isUsingAdminPrivileges && {use_admin_privileges: true}), + }, + browser: { + userAgent, + language, + screenSize: { + width: screen.width, + height: screen.height, + }, + viewportSize: { + width: window.innerWidth, + height: window.innerHeight, + }, + }, + }; + const detailsText = JSON.stringify(details, null, 2); + copyToClipboard(detailsText); + toast('Copied to clipboard'); + }, [uuid, timestamp, error, userInfo]); + return ( @@ -98,6 +191,14 @@ export const ErrorPanelLarge = forwardRef( {titleStr} {subtitleStr} {subtitle2Str} + ); } diff --git a/weave-js/src/components/FancyPage/FancyPageMenu.tsx b/weave-js/src/components/FancyPage/FancyPageMenu.tsx index c5971756016..6829b70fc06 100644 --- a/weave-js/src/components/FancyPage/FancyPageMenu.tsx +++ b/weave-js/src/components/FancyPage/FancyPageMenu.tsx @@ -60,7 +60,6 @@ export const FancyPageMenu = ({ return null; } const linkProps = { - key: menuItem.slug, to: menuItem.isDisabled ? {} : { @@ -76,7 +75,7 @@ export const FancyPageMenu = ({ }, }; return ( - + {menuItem.nameTooltip || menuItem.name} diff --git a/weave-js/src/components/Form/TextField.tsx b/weave-js/src/components/Form/TextField.tsx index c40f697ac85..8f5dd1171ff 100644 --- a/weave-js/src/components/Form/TextField.tsx +++ b/weave-js/src/components/Form/TextField.tsx @@ -37,6 +37,7 @@ type TextFieldProps = { dataTest?: string; step?: number; variant?: 'default' | 'ghost'; + isContainerNightAware?: boolean; }; export const TextField = ({ @@ -59,6 +60,7 @@ export const TextField = ({ autoComplete, dataTest, step, + isContainerNightAware, }: TextFieldProps) => { const textFieldSize = size ?? 'medium'; const leftPaddingForIcon = textFieldSize === 'medium' ? 'pl-34' : 'pl-36'; @@ -83,7 +85,6 @@ export const TextField = ({
( export const IconMolecule = (props: SVGIconProps) => ( ); +export const IconMoon = (props: SVGIconProps) => ( + +); export const IconMusicAudio = (props: SVGIconProps) => ( ); @@ -908,6 +915,9 @@ export const IconSortAscending = (props: SVGIconProps) => ( export const IconSortDescending = (props: SVGIconProps) => ( ); +export const IconSpiral = (props: SVGIconProps) => ( + +); export const IconSplit = (props: SVGIconProps) => ( ); @@ -926,6 +936,9 @@ export const IconStop = (props: SVGIconProps) => ( export const IconStopped = (props: SVGIconProps) => ( ); +export const IconSun = (props: SVGIconProps) => ( + +); export const IconSwap = (props: SVGIconProps) => ( ); @@ -1040,6 +1053,9 @@ export const IconVideoPlay = (props: SVGIconProps) => ( export const IconViewGlasses = (props: SVGIconProps) => ( ); +export const IconVisible = (props: SVGIconProps) => ( + +); export const IconWandb = (props: SVGIconProps) => ( ); @@ -1211,6 +1227,7 @@ const ICON_NAME_TO_ICON: Record = { model: IconModel, 'model-on-dark': IconModelOnDark, molecule: IconMolecule, + moon: IconMoon, 'music-audio': IconMusicAudio, 'new-section-above': IconNewSectionAbove, 'new-section-below': IconNewSectionBelow, @@ -1282,12 +1299,14 @@ const ICON_NAME_TO_ICON: Record = { sort: IconSort, 'sort-ascending': IconSortAscending, 'sort-descending': IconSortDescending, + spiral: IconSpiral, split: IconSplit, square: IconSquare, star: IconStar, 'star-filled': IconStarFilled, stop: IconStop, stopped: IconStopped, + sun: IconSun, swap: IconSwap, 'sweep-bayes': IconSweepBayes, 'sweep-grid': IconSweepGrid, @@ -1326,6 +1345,7 @@ const ICON_NAME_TO_ICON: Record = { 'vertex-gcp': IconVertexGCP, 'video-play': IconVideoPlay, 'view-glasses': IconViewGlasses, + visible: IconVisible, wandb: IconWandb, warning: IconWarning, 'warning-alt': IconWarningAlt, diff --git a/weave-js/src/components/Icon/index.ts b/weave-js/src/components/Icon/index.ts index 81839a71019..39c6eed3170 100644 --- a/weave-js/src/components/Icon/index.ts +++ b/weave-js/src/components/Icon/index.ts @@ -139,6 +139,7 @@ export { IconModel, IconModelOnDark, IconMolecule, + IconMoon, IconMusicAudio, IconNewSectionAbove, IconNewSectionBelow, @@ -210,12 +211,14 @@ export { IconSort, IconSortAscending, IconSortDescending, + IconSpiral, IconSplit, IconSquare, IconStar, IconStarFilled, IconStop, IconStopped, + IconSun, IconSwap, IconSweepBayes, IconSweepGrid, @@ -254,6 +257,7 @@ export { IconVertexGCP, IconVideoPlay, IconViewGlasses, + IconVisible, IconWandb, IconWarning, IconWarningAlt, diff --git a/weave-js/src/components/Icon/types.ts b/weave-js/src/components/Icon/types.ts index 55f46c52833..47f5f357adc 100644 --- a/weave-js/src/components/Icon/types.ts +++ b/weave-js/src/components/Icon/types.ts @@ -138,6 +138,7 @@ export const IconNames = { Model: 'model', ModelOnDark: 'model-on-dark', Molecule: 'molecule', + Moon: 'moon', MusicAudio: 'music-audio', NewSectionAbove: 'new-section-above', NewSectionBelow: 'new-section-below', @@ -209,12 +210,14 @@ export const IconNames = { Sort: 'sort', SortAscending: 'sort-ascending', SortDescending: 'sort-descending', + Spiral: 'spiral', Split: 'split', Square: 'square', Star: 'star', StarFilled: 'star-filled', Stop: 'stop', Stopped: 'stopped', + Sun: 'sun', Swap: 'swap', SweepBayes: 'sweep-bayes', SweepGrid: 'sweep-grid', @@ -253,6 +256,7 @@ export const IconNames = { VertexGCP: 'vertex-gcp', VideoPlay: 'video-play', ViewGlasses: 'view-glasses', + Visible: 'visible', Wandb: 'wandb', Warning: 'warning', WarningAlt: 'warning-alt', diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx index 3517f4d3b9c..761fd536930 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3.tsx @@ -1,15 +1,5 @@ import {ApolloProvider} from '@apollo/client'; -import {Home} from '@mui/icons-material'; -import { - AppBar, - Box, - Breadcrumbs, - Drawer, - IconButton, - Link as MaterialLink, - Toolbar, - Typography, -} from '@mui/material'; +import {Box, Drawer} from '@mui/material'; import { GridColumnVisibilityModel, GridFilterModel, @@ -21,9 +11,7 @@ import {LicenseInfo} from '@mui/x-license'; import {makeGorillaApolloClient} from '@wandb/weave/apollo'; import {EVALUATE_OP_NAME_POST_PYDANTIC} from '@wandb/weave/components/PagePanelComponents/Home/Browse3/pages/common/heuristics'; import {opVersionKeyToRefUri} from '@wandb/weave/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/utilities'; -import _ from 'lodash'; import React, { - ComponentProps, FC, useCallback, useEffect, @@ -33,7 +21,6 @@ import React, { } from 'react'; import useMousetrap from 'react-hook-mousetrap'; import { - Link as RouterLink, Redirect, Route, Switch, @@ -199,7 +186,6 @@ export const Browse3: FC<{ `/${URL_BROWSE3}`, ]}> @@ -211,7 +197,6 @@ export const Browse3: FC<{ }; const Browse3Mounted: FC<{ - hideHeader?: boolean; headerOffset?: number; navigateAwayFromProject?: () => void; }> = props => { @@ -225,37 +210,6 @@ const Browse3Mounted: FC<{ overflow: 'auto', flexDirection: 'column', }}> - {!props.hideHeader && ( - theme.zIndex.drawer + 1, - height: '60px', - flex: '0 0 auto', - position: 'static', - }}> - - - theme.palette.getContrastText(theme.palette.primary.main), - '&:hover': { - color: theme => - theme.palette.getContrastText(theme.palette.primary.dark), - }, - marginRight: theme => theme.spacing(2), - }}> - - - - - - )} @@ -1050,20 +1004,6 @@ const ComparePageBinding = () => { return ; }; -const AppBarLink = (props: ComponentProps) => ( - theme.palette.getContrastText(theme.palette.primary.main), - '&:hover': { - color: theme => - theme.palette.getContrastText(theme.palette.primary.dark), - }, - }} - {...props} - component={RouterLink} - /> -); - const PlaygroundPageBinding = () => { const params = useParamsDecoded(); return ( @@ -1074,79 +1014,3 @@ const PlaygroundPageBinding = () => { /> ); }; - -const Browse3Breadcrumbs: FC = props => { - const params = useParamsDecoded(); - const query = useURLSearchParamsDict(); - const filePathParts = query.path?.split('/') ?? []; - const refFields = query.extra?.split('/') ?? []; - - return ( - - {params.entity && ( - - {params.entity} - - )} - {params.project && ( - - {params.project} - - )} - {params.tab && ( - - {params.tab} - - )} - {params.itemName && ( - - {params.itemName} - - )} - {params.version && ( - - {params.version} - - )} - {filePathParts.map((part, idx) => ( - - {part} - - ))} - {_.range(0, refFields.length, 2).map(idx => ( - - - theme.palette.getContrastText(theme.palette.primary.main), - }}> - {refFields[idx]} - - - {refFields[idx + 1]} - - - ))} - - ); -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx index ef6bcbd69ff..0b3c9603fef 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx @@ -96,7 +96,7 @@ export const FeedbackSidebar = ({
Feedback
-
+
{humanAnnotationSpecs.length > 0 ? ( <>
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx index 7facffe9556..21dedf10ea0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/HumanAnnotation.tsx @@ -53,17 +53,11 @@ export const HumanAnnotationCell: React.FC = props => { const feedbackSpecRef = props.hfSpec.ref; useEffect(() => { - if (!props.readOnly) { - // We don't need to listen for feedback changes if the cell is editable - // it is being controlled by local state - return; - } return getTsClient().registerOnFeedbackListener( props.callRef, query.refetch ); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [props.callRef]); + }, [props.callRef, query.refetch, getTsClient]); useEffect(() => { if (foundFeedbackCallRef && props.callRef !== foundFeedbackCallRef) { @@ -183,13 +177,22 @@ const FeedbackComponentSelector: React.FC<{ }) => { const wrappedOnAddFeedback = useCallback( async (value: any) => { + if (value == null || value === foundValue || value === '') { + // Remove from unsaved changes if value is invalid + setUnsavedFeedbackChanges(curr => { + const rest = {...curr}; + delete rest[feedbackSpecRef]; + return rest; + }); + return true; + } setUnsavedFeedbackChanges(curr => ({ ...curr, [feedbackSpecRef]: () => onAddFeedback(value), })); return true; }, - [onAddFeedback, setUnsavedFeedbackChanges, feedbackSpecRef] + [onAddFeedback, setUnsavedFeedbackChanges, feedbackSpecRef, foundValue] ); switch (type) { @@ -346,21 +349,10 @@ export const NumericalFeedbackColumn = ({ focused?: boolean; isInteger?: boolean; }) => { - const debouncedFn = useMemo( - () => - _.debounce((val: number | null) => onAddFeedback?.(val), DEBOUNCE_VAL), - [onAddFeedback] - ); - useEffect(() => { - return () => { - debouncedFn.cancel(); - }; - }, [debouncedFn]); - return ( onAddFeedback?.(value)} min={min} max={max} isInteger={isInteger} @@ -415,7 +407,7 @@ export const TextFeedbackColumn = ({ placeholder="" /> {maxLength && ( -
+
{`Maximum characters: ${maxLength}`}
)} @@ -446,7 +438,7 @@ export const EnumFeedbackColumn = ({ })); return opts; }, [options]); - const [value, setValue] = useState
+ + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx index c22df7c63d7..135d297539d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx @@ -1,9 +1,10 @@ import React, {useState} from 'react'; +import {usePlaygroundContext} from '../PlaygroundPage/PlaygroundContext'; +import {ChoicesDrawer} from './ChoicesDrawer'; import {ChoicesViewCarousel} from './ChoicesViewCarousel'; -import {ChoicesViewLinear} from './ChoicesViewLinear'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewProps = { choices: Choice[]; @@ -14,32 +15,45 @@ export const ChoicesView = ({ choices, isStructuredOutput, }: ChoicesViewProps) => { - const [mode, setMode] = useState('linear'); + const {setSelectedChoiceIndex: setGlobalSelectedChoiceIndex} = + usePlaygroundContext(); + + const [isDrawerOpen, setIsDrawerOpen] = useState(false); + const [localSelectedChoiceIndex, setLocalSelectedChoiceIndex] = useState(0); + + const handleSetSelectedChoiceIndex = (choiceIndex: number) => { + setLocalSelectedChoiceIndex(choiceIndex); + setGlobalSelectedChoiceIndex(choiceIndex); + }; if (choices.length === 0) { return null; } if (choices.length === 1) { return ( - + ); } return ( <> - {mode === 'linear' && ( - - )} - {mode === 'carousel' && ( - - )} + + ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx index f4a52fc6801..b7dc6eb427d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx @@ -1,62 +1,67 @@ -import React, {useState} from 'react'; +import React from 'react'; import {Button} from '../../../../../Button'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewCarouselProps = { choices: Choice[]; isStructuredOutput?: boolean; - setMode: React.Dispatch>; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; }; export const ChoicesViewCarousel = ({ choices, isStructuredOutput, - setMode, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, }: ChoicesViewCarouselProps) => { - const [step, setStep] = useState(0); - const onNext = () => { - setStep((step + 1) % choices.length); + setSelectedChoiceIndex((selectedChoiceIndex + 1) % choices.length); }; const onBack = () => { - const newStep = step === 0 ? choices.length - 1 : step - 1; - setStep(newStep); + const newStep = + selectedChoiceIndex === 0 ? choices.length - 1 : selectedChoiceIndex - 1; + setSelectedChoiceIndex(newStep); }; return ( - <> - -
-
+ +
+
-
-
-
- + } + /> ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewLinear.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewLinear.tsx deleted file mode 100644 index 92668c6504e..00000000000 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewLinear.tsx +++ /dev/null @@ -1,73 +0,0 @@ -import React from 'react'; - -import {Button} from '../../../../../Button'; -import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; - -type ChoicesViewLinearProps = { - choices: Choice[]; - isStructuredOutput?: boolean; - setMode: React.Dispatch>; -}; - -export const ChoicesViewLinear = ({ - choices, - isStructuredOutput, - setMode, -}: ChoicesViewLinearProps) => { - return ( -
- {choices.map(c => ( -
-
-
- {c.index !== 0 && ( -
-
- -
- ))} -
- ); -}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx index c9ed9313f59..8cec95707fa 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/MessagePanel.tsx @@ -15,25 +15,26 @@ type MessagePanelProps = { index: number; message: Message; isStructuredOutput?: boolean; - isChoice?: boolean; + choiceIndex?: number; isNested?: boolean; pendingToolResponseId?: string; + messageHeader?: React.ReactNode; }; export const MessagePanel = ({ index, message, isStructuredOutput, - isChoice, + choiceIndex, isNested, // The id of the tool call response that is pending // If the tool call response is pending, the editor will be shown automatically // and on save the tool call response will be updated and sent to the LLM pendingToolResponseId, + messageHeader, }: MessagePanelProps) => { const [isShowingMore, setIsShowingMore] = useState(false); const [isOverflowing, setIsOverflowing] = useState(false); - const [isHovering, setIsHovering] = useState(false); const [editorHeight, setEditorHeight] = useState( pendingToolResponseId ? 100 : null ); @@ -46,6 +47,13 @@ export const MessagePanel = ({ } }, [message.content, contentRef?.current?.scrollHeight]); + // Set isShowingMore to true when editor is opened + useEffect(() => { + if (editorHeight !== null) { + setIsShowingMore(true); + } + }, [editorHeight]); + const isUser = message.role === 'user'; const isSystemPrompt = message.role === 'system'; const isTool = message.role === 'tool'; @@ -62,126 +70,128 @@ export const MessagePanel = ({ : undefined; return ( -
- {!isNested && !isSystemPrompt && ( -
- {!isUser && !isTool && ( - - )} -
- )} +
+
+ {!isNested && !isSystemPrompt && ( +
+ {!isUser && !isTool && ( + + )} +
+ )} -
setIsHovering(true)} - onMouseLeave={() => setIsHovering(false)}> -
- {isSystemPrompt && ( -
-
- {message.role.charAt(0).toUpperCase() + message.role.slice(1)} +
+
+ {isSystemPrompt && ( +
+
+ {message.role.charAt(0).toUpperCase() + message.role.slice(1)} +
-
- )} + )} - {isTool && ( -
-
- Response + {isTool && ( +
+
+ Response +
+ )} + +
+ {messageHeader} + {isPlayground && editorHeight ? ( + + ) : ( + <> + {hasContent && ( +
+ {_.isString(message.content) ? ( + + ) : ( + message.content!.map((p, i) => ( + + )) + )} +
+ )} + {hasToolCalls && ( +
+ +
+ )} + + )}
- )} -
- {isPlayground && editorHeight ? ( - - ) : ( - <> - {hasContent && ( -
- {_.isString(message.content) ? ( - - ) : ( - message.content!.map((p, i) => ( - - )) - )} -
- )} - {hasToolCalls && ( -
- -
- )} - )}
- - {isOverflowing && !editorHeight && ( - - )} - - {/* Playground buttons (retry, edit, delete) */} - {isPlayground && isHovering && !editorHeight && ( -
- -
- )}
+ + {/* Playground buttons (retry, edit, delete) - using group and group-hover to control opacity. */} + {isPlayground && !editorHeight ? ( +
+ +
+ ) : null}
); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelButtons.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelButtons.tsx index e491c38ede7..923f70762de 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelButtons.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/PlaygroundMessagePanelButtons.tsx @@ -5,7 +5,7 @@ import {usePlaygroundContext} from '../PlaygroundPage/PlaygroundContext'; type PlaygroundMessagePanelButtonsProps = { index: number; - isChoice: boolean; + choiceIndex?: number; isTool: boolean; hasContent: boolean; contentRef: React.RefObject; @@ -17,7 +17,7 @@ export const PlaygroundMessagePanelButtons: React.FC< PlaygroundMessagePanelButtonsProps > = ({ index, - isChoice, + choiceIndex, isTool, hasContent, contentRef, @@ -27,12 +27,12 @@ export const PlaygroundMessagePanelButtons: React.FC< const {deleteMessage, deleteChoice, retry} = usePlaygroundContext(); return ( -
+
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts index 38b5c820195..33ced58ec49 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/hooks.ts @@ -246,6 +246,63 @@ export const isTraceCallChatFormatGemini = (call: TraceCallSchema): boolean => { ); }; +export const isAnthropicContentBlock = (block: any): boolean => { + if (!_.isPlainObject(block)) { + return false; + } + // TODO: Are there other types? + if (block.type !== 'text') { + return false; + } + if (!hasStringProp(block, 'text')) { + return false; + } + return true; +}; + +export const isAnthropicCompletionFormat = (output: any): boolean => { + if (output !== null) { + // TODO: Could have additional checks here on things like usage + if ( + _.isPlainObject(output) && + output.type === 'message' && + output.role === 'assistant' && + hasStringProp(output, 'model') && + _.isArray(output.content) && + output.content.every((c: any) => isAnthropicContentBlock(c)) + ) { + return true; + } + return false; + } + return true; +}; + +type AnthropicContentBlock = { + type: 'text'; + text: string; +}; + +export const anthropicContentBlocksToChoices = ( + blocks: AnthropicContentBlock[], + stopReason: string +): Choice[] => { + const choices: Choice[] = []; + for (let i = 0; i < blocks.length; i++) { + const block = blocks[i]; + choices.push({ + index: i, + message: { + role: 'assistant', + content: block.text, + }, + // TODO: What is correct way to map this? + finish_reason: stopReason, + }); + } + return choices; +}; + export const isTraceCallChatFormatOpenAI = (call: TraceCallSchema): boolean => { if (!('messages' in call.inputs)) { return false; @@ -336,6 +393,19 @@ export const normalizeChatRequest = (request: any): ChatRequest => { ], }; } + // Anthropic has system message as a top-level request field + if (hasStringProp(request, 'system')) { + return { + ...request, + messages: [ + { + role: 'system', + content: request.system, + }, + ...request.messages, + ], + }; + } return request as ChatRequest; }; @@ -360,6 +430,24 @@ export const normalizeChatCompletion = ( }, }; } + if (isAnthropicCompletionFormat(completion)) { + return { + id: completion.id, + choices: anthropicContentBlocksToChoices( + completion.content, + completion.stop_reason + ), + created: 0, + model: completion.model, + system_fingerprint: '', + usage: { + prompt_tokens: completion.usage.input_tokens, + completion_tokens: completion.usage.output_tokens, + total_tokens: + completion.usage.input_tokens + completion.usage.output_tokens, + }, + }; + } return completion as ChatCompletion; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts index b5696055712..3bbe65a5baf 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/types.ts @@ -113,5 +113,3 @@ export type Chat = { request: ChatRequest | null; result: ChatCompletion | null; }; - -export type ChoicesMode = 'linear' | 'carousel'; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx index b302b0262ae..478c4887546 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/CompareEvaluationsPage.tsx @@ -77,8 +77,6 @@ export const CompareEvaluationsPage: React.FC< export const CompareEvaluationsPageContent: React.FC< CompareEvaluationsPageProps > = props => { - const [baselineEvaluationCallId, setBaselineEvaluationCallId] = - React.useState(null); const [comparisonDimensions, setComparisonDimensions] = React.useState(null); @@ -104,14 +102,6 @@ export const CompareEvaluationsPageContent: React.FC< [comparisonDimensions] ); - React.useEffect(() => { - // Only update the baseline if we are switching evaluations, if there - // is more than 1, we are in the compare view and baseline is auto set - if (props.evaluationCallIds.length === 1) { - setBaselineEvaluationCallId(props.evaluationCallIds[0]); - } - }, [props.evaluationCallIds]); - if (props.evaluationCallIds.length === 0) { return
No evaluations to compare
; } @@ -120,13 +110,11 @@ export const CompareEvaluationsPageContent: React.FC< diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx index b65658a890a..638565e8ad6 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/compareEvaluationsContext.tsx @@ -10,9 +10,6 @@ import {ComparisonDimensionsType} from './ecpState'; const CompareEvaluationsContext = React.createContext<{ state: EvaluationComparisonState; - setBaselineEvaluationCallId: React.Dispatch< - React.SetStateAction - >; setComparisonDimensions: React.Dispatch< React.SetStateAction >; @@ -20,6 +17,7 @@ const CompareEvaluationsContext = React.createContext<{ setSelectedMetrics: (newModel: Record) => void; addEvaluationCall: (newCallId: string) => void; removeEvaluationCall: (callId: string) => void; + setEvaluationCallOrder: (newCallIdOrder: string[]) => void; } | null>(null); export const useCompareEvaluationsState = () => { @@ -33,34 +31,26 @@ export const useCompareEvaluationsState = () => { export const CompareEvaluationsProvider: React.FC<{ entity: string; project: string; + initialEvaluationCallIds: string[]; selectedMetrics: Record | null; setSelectedMetrics: (newModel: Record) => void; - initialEvaluationCallIds: string[]; onEvaluationCallIdsUpdate: (newEvaluationCallIds: string[]) => void; - setBaselineEvaluationCallId: React.Dispatch< - React.SetStateAction - >; setComparisonDimensions: React.Dispatch< React.SetStateAction >; setSelectedInputDigest: React.Dispatch>; - baselineEvaluationCallId?: string; comparisonDimensions?: ComparisonDimensionsType; selectedInputDigest?: string; }> = ({ entity, project, + initialEvaluationCallIds, selectedMetrics, setSelectedMetrics, - - initialEvaluationCallIds, onEvaluationCallIdsUpdate, - setBaselineEvaluationCallId, setComparisonDimensions, - setSelectedInputDigest, - baselineEvaluationCallId, comparisonDimensions, selectedInputDigest, children, @@ -77,7 +67,6 @@ export const CompareEvaluationsProvider: React.FC<{ entity, project, evaluationCallIds, - baselineEvaluationCallId, comparisonDimensions, selectedInputDigest, selectedMetrics ?? undefined @@ -89,7 +78,6 @@ export const CompareEvaluationsProvider: React.FC<{ } return { state: initialState.result, - setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, setSelectedMetrics, @@ -105,14 +93,17 @@ export const CompareEvaluationsProvider: React.FC<{ setEvaluationCallIds(newEvaluationCallIds); onEvaluationCallIdsUpdate(newEvaluationCallIds); }, + setEvaluationCallOrder: (newCallIdOrder: string[]) => { + setEvaluationCallIds(newCallIdOrder); + onEvaluationCallIdsUpdate(newCallIdOrder); + }, }; }, [ initialState.loading, initialState.result, + setEvaluationCallIds, evaluationCallIds, onEvaluationCallIdsUpdate, - setEvaluationCallIds, - setBaselineEvaluationCallId, setComparisonDimensions, setSelectedInputDigest, setSelectedMetrics, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts index 35f95dbf14f..e5c1b03d60a 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpState.ts @@ -18,14 +18,14 @@ import {getMetricIds} from './ecpUtil'; export type EvaluationComparisonState = { // The normalized data for the evaluations data: EvaluationComparisonData; - // The evaluation call id of the baseline model - baselineEvaluationCallId: string; // The dimensions to compare & filter results comparisonDimensions?: ComparisonDimensionsType; // The current digest which is in view selectedInputDigest?: string; // The selected metrics to display selectedMetrics?: Record; + // Ordered call Ids + evaluationCallIdsOrdered: string[]; }; export type ComparisonDimensionsType = Array<{ @@ -43,12 +43,14 @@ export const useEvaluationComparisonState = ( entity: string, project: string, evaluationCallIds: string[], - baselineEvaluationCallId?: string, comparisonDimensions?: ComparisonDimensionsType, selectedInputDigest?: string, selectedMetrics?: Record ): Loadable => { - const data = useEvaluationComparisonData(entity, project, evaluationCallIds); + const orderedCallIds = useMemo(() => { + return getCallIdsOrderedForQuery(evaluationCallIds); + }, [evaluationCallIds]); + const data = useEvaluationComparisonData(entity, project, orderedCallIds); const value = useMemo(() => { if (data.result == null || data.loading) { @@ -92,42 +94,45 @@ export const useEvaluationComparisonState = ( loading: false, result: { data: data.result, - baselineEvaluationCallId: - baselineEvaluationCallId ?? evaluationCallIds[0], comparisonDimensions: newComparisonDimensions, selectedInputDigest, selectedMetrics, + evaluationCallIdsOrdered: evaluationCallIds, }, }; }, [ data.result, data.loading, - baselineEvaluationCallId, - evaluationCallIds, comparisonDimensions, selectedInputDigest, selectedMetrics, + evaluationCallIds, ]); return value; }; +export const getOrderedCallIds = (state: EvaluationComparisonState) => { + return Array.from(state.evaluationCallIdsOrdered); +}; + +export const getBaselineCallId = (state: EvaluationComparisonState) => { + return getOrderedCallIds(state)[0]; +}; + /** - * Should use this over keys of `state.data.evaluationCalls` because it ensures the baseline - * evaluation call is first. + * Sort call IDs to ensure consistent order for memoized query params */ -export const getOrderedCallIds = (state: EvaluationComparisonState) => { - const initial = Object.keys(state.data.evaluationCalls); - moveItemToFront(initial, state.baselineEvaluationCallId); - return initial; +const getCallIdsOrderedForQuery = (callIds: string[]) => { + return Array.from(callIds).sort(); }; /** * Should use this over keys of `state.data.models` because it ensures the baseline model is first. */ export const getOrderedModelRefs = (state: EvaluationComparisonState) => { - const baselineRef = - state.data.evaluationCalls[state.baselineEvaluationCallId].modelRef; + const baselineCallId = getBaselineCallId(state); + const baselineRef = state.data.evaluationCalls[baselineCallId].modelRef; const refs = Object.keys(state.data.models); // Make sure the baseline model is first moveItemToFront(refs, baselineRef); @@ -145,3 +150,29 @@ const moveItemToFront = (arr: T[], item: T): T[] => { } return arr; }; + +export const getOrderedEvalsWithNewBaseline = ( + evaluationCallIds: string[], + newBaselineCallId: string +) => { + return moveItemToFront(evaluationCallIds, newBaselineCallId); +}; + +export const swapEvaluationCalls = ( + evaluationCallIds: string[], + ndx1: number, + ndx2: number +): string[] => { + return swapArrayItems(evaluationCallIds, ndx1, ndx2); +}; + +const swapArrayItems = (arr: T[], ndx1: number, ndx2: number): T[] => { + if (ndx1 < 0 || ndx2 < 0 || ndx1 >= arr.length || ndx2 >= arr.length) { + throw new Error('Index out of bounds'); + } + const newArr = [...arr]; + const from = newArr[ndx1]; + newArr[ndx1] = newArr[ndx2]; + newArr[ndx2] = from; + return newArr; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts index 7454e0707b4..b4642fae240 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/ecpTypes.ts @@ -18,6 +18,7 @@ export type EvaluationComparisonData = { }; // EvaluationCalls are the specific calls of an evaluation. + // The visual order of the evaluation calls is determined by the order of the keys. evaluationCalls: { [callId: string]: EvaluationCall; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx index b5c1a4bf96c..2704a66cbea 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/ComparisonDefinitionSection.tsx @@ -12,44 +12,81 @@ import { import {useCallsForQuery} from '../../../CallsPage/callsTableQuery'; import {useEvaluationsFilter} from '../../../CallsPage/evaluationsFilter'; import {Id} from '../../../common/Id'; +import {opNiceName} from '../../../common/Links'; import {useWFHooks} from '../../../wfReactInterface/context'; import { CallSchema, ObjectVersionKey, } from '../../../wfReactInterface/wfDataModelHooksInterface'; import {useCompareEvaluationsState} from '../../compareEvaluationsContext'; -import {STANDARD_PADDING} from '../../ecpConstants'; -import {getOrderedCallIds} from '../../ecpState'; -import {EvaluationComparisonState} from '../../ecpState'; +import { + EvaluationComparisonState, + getOrderedCallIds, + getOrderedEvalsWithNewBaseline, + swapEvaluationCalls, +} from '../../ecpState'; import {HorizontalBox} from '../../Layout'; -import {EvaluationDefinition, VerticalBar} from './EvaluationDefinition'; +import {ItemDef} from '../DraggableSection/DraggableItem'; +import {DraggableSection} from '../DraggableSection/DraggableSection'; +import {VerticalBar} from './EvaluationDefinition'; export const ComparisonDefinitionSection: React.FC<{ state: EvaluationComparisonState; }> = props => { - const evalCallIds = useMemo( - () => getOrderedCallIds(props.state), - [props.state] - ); + const {setEvaluationCallOrder, removeEvaluationCall} = + useCompareEvaluationsState(); + + const callIds = useMemo(() => { + return getOrderedCallIds(props.state); + }, [props.state]); + + const items: ItemDef[] = useMemo(() => { + return callIds.map(callId => ({ + key: 'evaluations', + value: callId, + label: props.state.data.evaluationCalls[callId]?.name ?? callId, + })); + }, [callIds, props.state.data.evaluationCalls]); + + const onSetBaseline = (value: string | null) => { + if (!value) { + return; + } + const newSortOrder = getOrderedEvalsWithNewBaseline(callIds, value); + setEvaluationCallOrder(newSortOrder); + }; + const onRemoveItem = (value: string) => removeEvaluationCall(value); + const onSortEnd = ({ + oldIndex, + newIndex, + }: { + oldIndex: number; + newIndex: number; + }) => { + const newSortOrder = swapEvaluationCalls(callIds, oldIndex, newIndex); + setEvaluationCallOrder(newSortOrder); + }; return ( - - {evalCallIds.map((key, ndx) => { - return ( - - - - ); - })} - - + +
+ + + + + + +
+
); }; @@ -81,7 +118,7 @@ const ModelRefLabel: React.FC<{modelRef: string}> = props => { const objectVersion = useObjectVersion(objVersionKey); return ( - {objectVersion.result?.objectId}:{objectVersion.result?.versionIndex} + {objectVersion.result?.objectId}:v{objectVersion.result?.versionIndex} ); }; @@ -105,7 +142,7 @@ const AddEvaluationButton: React.FC<{ ); const expandedRefCols = useMemo(() => new Set(), []); // Don't query for output here, re-queried in tsDataModelHooksEvaluationComparison.ts - const columns = useMemo(() => ['inputs'], []); + const columns = useMemo(() => ['inputs', 'display_name'], []); const calls = useCallsForQuery( props.state.data.entity, props.state.data.project, @@ -119,10 +156,9 @@ const AddEvaluationButton: React.FC<{ const evalsNotComparing = useMemo(() => { return calls.result.filter( - call => - !Object.keys(props.state.data.evaluationCalls).includes(call.callId) + call => !getOrderedCallIds(props.state).includes(call.callId) ); - }, [calls.result, props.state.data.evaluationCalls]); + }, [calls.result, props.state]); const [menuOptions, setMenuOptions] = useState(evalsNotComparing); @@ -137,7 +173,7 @@ const AddEvaluationButton: React.FC<{ return; } - const filteredOptions = calls.result.filter(call => { + const filteredOptions = evalsNotComparing.filter(call => { if ( (call.displayName ?? call.spanName) .toLowerCase() @@ -222,12 +258,18 @@ const AddEvaluationButton: React.FC<{ variant="ghost" size="small" className="pb-8 pt-8 font-['Source_Sans_Pro'] text-base font-normal text-moon-800" - onClick={() => { - addEvaluationCall(call.callId); - }}> + onClick={() => addEvaluationCall(call.callId)}> <> - {call.displayName ?? call.spanName} - + + {call.displayName ?? opNiceName(call.spanName)} + + + + diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx index adc80d044f6..5dcf835e378 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ComparisonDefinitionSection/EvaluationDefinition.tsx @@ -1,8 +1,5 @@ import {Box} from '@material-ui/core'; import {Circle} from '@mui/icons-material'; -import {PopupDropdown} from '@wandb/weave/common/components/PopupDropdown'; -import {Button} from '@wandb/weave/components/Button'; -import {Pill} from '@wandb/weave/components/Tag'; import React, {useMemo} from 'react'; import { @@ -17,87 +14,17 @@ import {SmallRef} from '../../../../../Browse2/SmallRef'; import {CallLink, ObjectVersionLink} from '../../../common/Links'; import {useWFHooks} from '../../../wfReactInterface/context'; import {ObjectVersionKey} from '../../../wfReactInterface/wfDataModelHooksInterface'; -import {useCompareEvaluationsState} from '../../compareEvaluationsContext'; -import { - BOX_RADIUS, - CIRCLE_SIZE, - EVAL_DEF_HEIGHT, - STANDARD_BORDER, -} from '../../ecpConstants'; +import {CIRCLE_SIZE} from '../../ecpConstants'; import {EvaluationComparisonState} from '../../ecpState'; -import {HorizontalBox} from '../../Layout'; - -export const EvaluationDefinition: React.FC<{ - state: EvaluationComparisonState; - callId: string; -}> = props => { - const {removeEvaluationCall, setBaselineEvaluationCallId} = - useCompareEvaluationsState(); - - const menuOptions = useMemo(() => { - return [ - { - key: 'add-to-baseline', - content: 'Set as baseline', - onClick: () => { - setBaselineEvaluationCallId(props.callId); - }, - disabled: props.callId === props.state.baselineEvaluationCallId, - }, - { - key: 'remove', - content: 'Remove', - onClick: () => { - removeEvaluationCall(props.callId); - }, - disabled: Object.keys(props.state.data.evaluationCalls).length === 1, - }, - ]; - }, [ - props.callId, - props.state.baselineEvaluationCallId, - props.state.data.evaluationCalls, - removeEvaluationCall, - setBaselineEvaluationCallId, - ]); - - return ( - - - {props.callId === props.state.baselineEvaluationCallId && ( - - )} -
- - } - /> -
-
- ); -}; export const EvaluationCallLink: React.FC<{ callId: string; state: EvaluationComparisonState; }> = props => { - const evaluationCall = props.state.data.evaluationCalls[props.callId]; + const evaluationCall = props.state.data.evaluationCalls?.[props.callId]; + if (!evaluationCall) { + return null; + } const {entity, project} = props.state.data; return ( diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx new file mode 100644 index 00000000000..1510c502b99 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableItem.tsx @@ -0,0 +1,103 @@ +import {Button} from '@wandb/weave/components/Button'; +import * as DropdownMenu from '@wandb/weave/components/DropdownMenu'; +import {Icon} from '@wandb/weave/components/Icon'; +import {Pill} from '@wandb/weave/components/Tag/Pill'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import classNames from 'classnames'; +import React, {useState} from 'react'; +import {SortableElement, SortableHandle} from 'react-sortable-hoc'; + +import {EvaluationComparisonState} from '../../ecpState'; +import {EvaluationCallLink} from '../ComparisonDefinitionSection/EvaluationDefinition'; + +export type ItemDef = { + key: string; + value: string; + label?: string; +}; + +type DraggableItemProps = { + state: EvaluationComparisonState; + item: ItemDef; + numItems: number; + idx: number; + onRemoveItem: (value: string) => void; + onSetBaseline: (value: string | null) => void; +}; + +export const DraggableItem = SortableElement( + ({ + state, + item, + numItems, + idx, + onRemoveItem, + onSetBaseline, + }: DraggableItemProps) => { + const isDeletable = numItems > 1; + const isBaseline = idx === 0; + const [isOpen, setIsOpen] = useState(false); + + const onMakeBaselinePropagated = (e: React.MouseEvent) => { + e.stopPropagation(); + onSetBaseline(item.value); + }; + + const onRemoveItemPropagated = (e: React.MouseEvent) => { + e.stopPropagation(); + onRemoveItem(item.value); + }; + + return ( + +
+ +
+ + {isBaseline && ( + + )} +
+ + +
+
+ ); + } +); + +const DragHandle = SortableHandle(() => ( +
+ +
+)); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx new file mode 100644 index 00000000000..23a03ceb5b3 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/DraggableSection/DraggableSection.tsx @@ -0,0 +1,34 @@ +import React from 'react'; +import {SortableContainer} from 'react-sortable-hoc'; + +import {EvaluationComparisonState} from '../../ecpState'; +import {DraggableItem} from './DraggableItem'; +import {ItemDef} from './DraggableItem'; + +type DraggableSectionProps = { + state: EvaluationComparisonState; + items: ItemDef[]; + onSetBaseline: (value: string | null) => void; + onRemoveItem: (value: string) => void; +}; + +export const DraggableSection = SortableContainer( + ({state, items, onSetBaseline, onRemoveItem}: DraggableSectionProps) => { + return ( +
+ {items.map((item, index) => ( + + ))} +
+ ); + } +); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx index 6041492b5c5..398f65ecd45 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleCompareSection/ExampleCompareSection.tsx @@ -149,7 +149,7 @@ const stickySidebarHeaderMixin: React.CSSProperties = { /** * This component will occupy the entire space provided by the parent container. - * It is intended to be used in teh CompareEvaluations page, as it depends on + * It is intended to be used in the CompareEvaluations page, as it depends on * the EvaluationComparisonState. However, in principle, it is a general purpose * model-output comparison tool. It allows the user to view inputs, then compare * model outputs and evaluation metrics across multiple trials. diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx index b2c8773a7d1..1146c8ea960 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ExampleFilterSection/ExampleFilterSection.tsx @@ -13,7 +13,7 @@ import { } from '../../compositeMetricsUtil'; import {PLOT_HEIGHT, STANDARD_PADDING} from '../../ecpConstants'; import {MAX_PLOT_DOT_SIZE, MIN_PLOT_DOT_SIZE} from '../../ecpConstants'; -import {EvaluationComparisonState} from '../../ecpState'; +import {EvaluationComparisonState, getBaselineCallId} from '../../ecpState'; import {metricDefinitionId} from '../../ecpUtil'; import { flattenedDimensionPath, @@ -103,7 +103,7 @@ const SingleDimensionFilter: React.FC<{ }, [props.state.data]); const {setComparisonDimensions} = useCompareEvaluationsState(); - const baselineCallId = props.state.baselineEvaluationCallId; + const baselineCallId = getBaselineCallId(props.state); const compareCallId = Object.keys(props.state.data.evaluationCalls).find( callId => callId !== baselineCallId )!; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx index 4b319150fa8..303f640f47e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/ScorecardSection/ScorecardSection.tsx @@ -32,6 +32,7 @@ import { } from '../../ecpConstants'; import { EvaluationComparisonState, + getBaselineCallId, getOrderedCallIds, getOrderedModelRefs, } from '../../ecpState'; @@ -414,7 +415,7 @@ export const ScorecardSection: React.FC<{ {evalCallIds.map((evalCallId, mNdx) => { const baseline = resolveSummaryMetricResult( - props.state.baselineEvaluationCallId, + getBaselineCallId(props.state), groupName, metricKey, compositeSummaryMetrics, diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/MetricsSelector.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/MetricsSelector.tsx new file mode 100644 index 00000000000..3e6dfdd30a1 --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/MetricsSelector.tsx @@ -0,0 +1,173 @@ +import {Popover} from '@mui/material'; +import {Switch} from '@wandb/weave/components'; +import {Button} from '@wandb/weave/components/Button'; +import { + DraggableGrow, + DraggableHandle, +} from '@wandb/weave/components/DraggablePopups'; +import {TextField} from '@wandb/weave/components/Form/TextField'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import {maybePluralize} from '@wandb/weave/core/util/string'; +import classNames from 'classnames'; +import React, {useRef, useState} from 'react'; + +export const MetricsSelector: React.FC<{ + setSelectedMetrics: (newModel: Record) => void; + selectedMetrics: Record | undefined; + allMetrics: string[]; +}> = ({setSelectedMetrics, selectedMetrics, allMetrics}) => { + const [search, setSearch] = useState(''); + + const ref = useRef(null); + const [anchorEl, setAnchorEl] = useState(null); + const onClick = (event: React.MouseEvent) => { + setAnchorEl(anchorEl ? null : ref.current); + setSearch(''); + }; + const open = Boolean(anchorEl); + const id = open ? 'simple-popper' : undefined; + + const filteredCols = search + ? allMetrics.filter(col => col.toLowerCase().includes(search.toLowerCase())) + : allMetrics; + + const shownMetrics = Object.values(selectedMetrics ?? {}).filter(Boolean); + + const numHidden = allMetrics.length - shownMetrics.length; + const buttonSuffix = search ? `(${filteredCols.length})` : 'all'; + + return ( + <> + + +
+ +
+
+ + + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx index 9706ac09567..7942aea195e 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyBarPlot.tsx @@ -2,50 +2,54 @@ import * as Plotly from 'plotly.js'; import React, {useEffect, useMemo, useRef} from 'react'; import {PLOT_GRID_COLOR} from '../../ecpConstants'; -import {RadarPlotData} from './PlotlyRadarPlot'; export const PlotlyBarPlot: React.FC<{ height: number; - data: RadarPlotData; + yRange: [number, number]; + plotlyData: Plotly.Data; }> = props => { const divRef = useRef(null); - const plotlyData: Plotly.Data[] = useMemo(() => { - return Object.keys(props.data).map((key, i) => { - const {metrics, name, color} = props.data[key]; - return { - type: 'bar', - y: Object.values(metrics), - x: Object.keys(metrics), - name, - marker: {color}, - }; - }); - }, [props.data]); - const plotlyLayout: Partial = useMemo(() => { return { - height: props.height - 40, + height: props.height - 30, showlegend: false, margin: { - l: 0, + l: 20, r: 0, b: 20, - t: 0, - pad: 0, + t: 26, }, + bargap: 0.1, xaxis: { automargin: true, fixedrange: true, gridcolor: PLOT_GRID_COLOR, linecolor: PLOT_GRID_COLOR, + showticklabels: false, }, yaxis: { fixedrange: true, + range: props.yRange, gridcolor: PLOT_GRID_COLOR, linecolor: PLOT_GRID_COLOR, + showticklabels: true, + tickfont: { + size: 10, + }, + }, + title: { + multiline: true, + text: props.plotlyData.name ?? '', + font: {size: 12}, + xref: 'paper', + x: 0.5, + y: 1, + yanchor: 'top', + pad: {t: 2}, }, }; - }, [props.height]); + }, [props.height, props.plotlyData, props.yRange]); + const plotlyConfig = useMemo(() => { return { displayModeBar: false, @@ -57,11 +61,11 @@ export const PlotlyBarPlot: React.FC<{ useEffect(() => { Plotly.newPlot( divRef.current as any, - plotlyData, + [props.plotlyData], plotlyLayout, plotlyConfig ); - }, [plotlyConfig, plotlyData, plotlyLayout]); + }, [plotlyConfig, props.plotlyData, plotlyLayout]); return
; }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx index d459d1354f1..47d0fa3f10c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/PlotlyRadarPlot.tsx @@ -31,13 +31,13 @@ export const PlotlyRadarPlot: React.FC<{ }, [props.data]); const plotlyLayout: Partial = useMemo(() => { return { - height: props.height, + height: props.height - 40, showlegend: false, margin: { - l: 60, - r: 0, + l: 20, + r: 20, b: 30, - t: 30, + t: 20, pad: 0, }, polar: { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx index 5bfaa8fcb04..02c456df850 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CompareEvaluationsPage/sections/SummaryPlotsSection/SummaryPlotsSection.tsx @@ -1,15 +1,6 @@ import {Box} from '@material-ui/core'; -import {Popover} from '@mui/material'; -import {Switch} from '@wandb/weave/components'; import {Button} from '@wandb/weave/components/Button'; -import { - DraggableGrow, - DraggableHandle, -} from '@wandb/weave/components/DraggablePopups'; -import {TextField} from '@wandb/weave/components/Form/TextField'; import {Tailwind} from '@wandb/weave/components/Tailwind'; -import {maybePluralize} from '@wandb/weave/core/util/string'; -import classNames from 'classnames'; import React, {useEffect, useMemo, useRef, useState} from 'react'; import {buildCompositeMetricsMap} from '../../compositeMetricsUtil'; @@ -27,6 +18,7 @@ import { resolveSummaryMetricValueForEvaluateCall, } from '../../ecpUtil'; import {HorizontalBox, VerticalBox} from '../../Layout'; +import {MetricsSelector} from './MetricsSelector'; import {PlotlyBarPlot} from './PlotlyBarPlot'; import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot'; @@ -36,15 +28,12 @@ import {PlotlyRadarPlot, RadarPlotData} from './PlotlyRadarPlot'; export const SummaryPlots: React.FC<{ state: EvaluationComparisonState; setSelectedMetrics: (newModel: Record) => void; -}> = props => { - const {radarData, allMetricNames} = useNormalizedPlotDataFromMetrics( - props.state - ); - const {selectedMetrics} = props.state; - const setSelectedMetrics = props.setSelectedMetrics; +}> = ({state, setSelectedMetrics}) => { + const {radarData, allMetricNames} = usePlotDataFromMetrics(state); + const {selectedMetrics} = state; + // Initialize selectedMetrics if null useEffect(() => { - // If selectedMetrics is null, we should show all metrics if (selectedMetrics == null) { setSelectedMetrics( Object.fromEntries(Array.from(allMetricNames).map(m => [m, true])) @@ -52,10 +41,184 @@ export const SummaryPlots: React.FC<{ } }, [selectedMetrics, setSelectedMetrics, allMetricNames]); - // filter down the plotlyRadarData to only include the selected metrics, after - // computation, to allow quick addition/removal of metrics - const filteredPlotlyRadarData = useMemo(() => { - const filteredData: RadarPlotData = {}; + const filteredData = useFilteredData(radarData, selectedMetrics); + const normalizedRadarData = normalizeDataForRadarPlot(filteredData); + const barPlotData = useBarPlotData(filteredData); + + const { + containerRef, + isInitialRender, + plotsPerPage, + currentPage, + setCurrentPage, + } = useContainerDimensions(); + + const {plotsToShow, totalPlots, startIndex, endIndex, totalPages} = + usePaginatedPlots( + normalizedRadarData, + barPlotData, + plotsPerPage, + currentPage + ); + + // Render placeholder during initial render + if (isInitialRender) { + return
; + } + + return ( + + +
+ {plotsToShow} +
+ setCurrentPage(prev => Math.max(prev - 1, 0))} + onNextPage={() => + setCurrentPage(prev => Math.min(prev + 1, totalPages - 1)) + } + /> +
+ ); +}; + +const SectionHeader: React.FC<{ + selectedMetrics: Record | undefined; + setSelectedMetrics: (newModel: Record) => void; + allMetrics: string[]; +}> = ({selectedMetrics, setSelectedMetrics, allMetrics}) => ( + + + Summary Metrics + + +
+
Configure displayed metrics
+ +
+
+
+); + +const RadarPlotBox: React.FC<{data: RadarPlotData}> = ({data}) => ( + + + +); + +const BarPlotBox: React.FC<{ + plot: {plotlyData: Plotly.Data; yRange: [number, number]}; +}> = ({plot}) => ( + + + +); + +const PaginationControls: React.FC<{ + currentPage: number; + totalPages: number; + startIndex: number; + endIndex: number; + totalPlots: number; + onPrevPage: () => void; + onNextPage: () => void; +}> = ({ + currentPage, + totalPages, + startIndex, + endIndex, + totalPlots, + onPrevPage, + onNextPage, +}) => ( + + + +
+
+
+
+
+); + +const useFilteredData = ( + radarData: RadarPlotData, + selectedMetrics: Record | undefined +) => + useMemo(() => { + const data: RadarPlotData = {}; for (const [callId, metricBin] of Object.entries(radarData)) { const metrics: {[metric: string]: number} = {}; for (const [metric, value] of Object.entries(metricBin.metrics)) { @@ -64,255 +227,206 @@ export const SummaryPlots: React.FC<{ } } if (Object.keys(metrics).length > 0) { - filteredData[callId] = { + data[callId] = { metrics, name: metricBin.name, color: metricBin.color, }; } } - return filteredData; + return data; }, [radarData, selectedMetrics]); - return ( - - - - Summary Metrics - - -
-
Configure displayed metrics
- -
-
-
- - - - - - - - -
- ); -}; +function getMetricValuesMap(radarData: RadarPlotData): { + [metric: string]: number[]; +} { + const metricValues: {[metric: string]: number[]} = {}; + Object.values(radarData).forEach(callData => { + Object.entries(callData.metrics).forEach(([metric, value]) => { + if (!metricValues[metric]) { + metricValues[metric] = []; + } + metricValues[metric].push(value); + }); + }); + return metricValues; +} -const MetricsSelector: React.FC<{ - setSelectedMetrics: (newModel: Record) => void; - selectedMetrics: Record | undefined; - allMetrics: string[]; -}> = ({setSelectedMetrics, selectedMetrics, allMetrics}) => { - const [search, setSearch] = useState(''); - - const ref = useRef(null); - const [anchorEl, setAnchorEl] = useState(null); - const onClick = (event: React.MouseEvent) => { - setAnchorEl(anchorEl ? null : ref.current); - setSearch(''); - }; - const open = Boolean(anchorEl); - const id = open ? 'simple-popper' : undefined; +function normalizeMetricValues(values: number[]): { + normalizedValues: number[]; + normalizer: number; +} { + const min = Math.min(...values); + const max = Math.max(...values); - const filteredCols = search - ? allMetrics.filter(col => col.toLowerCase().includes(search.toLowerCase())) - : allMetrics; + if (min === max) { + return { + normalizedValues: values.map(() => 0.5), + normalizer: 1, + }; + } - const shownMetrics = Object.values(selectedMetrics ?? {}).filter(Boolean); + // Handle negative values by shifting + const shiftedValues = min < 0 ? values.map(v => v - min) : values; + const maxValue = min < 0 ? max - min : max; - const numHidden = allMetrics.length - shownMetrics.length; - const buttonSuffix = search ? `(${filteredCols.length})` : 'all'; + const maxPower = Math.ceil(Math.log2(maxValue)); + const normalizer = Math.pow(2, maxPower); - return ( - <> - - -
- -
-
- - - + return { + normalizedValues: shiftedValues.map(v => v / normalizer), + normalizer, + }; +} + +function normalizeDataForRadarPlot( + radarDataOriginal: RadarPlotData +): RadarPlotData { + const radarData = Object.fromEntries( + Object.entries(radarDataOriginal).map(([callId, callData]) => [ + callId, + {...callData, metrics: {...callData.metrics}}, + ]) ); + + const metricValues = getMetricValuesMap(radarData); + + // Normalize each metric independently + Object.entries(metricValues).forEach(([metric, values]) => { + const {normalizedValues} = normalizeMetricValues(values); + Object.values(radarData).forEach((callData, index) => { + callData.metrics[metric] = normalizedValues[index]; + }); + }); + + return radarData; +} + +const useBarPlotData = (filteredData: RadarPlotData) => + useMemo(() => { + const metrics: { + [metric: string]: { + callIds: string[]; + values: number[]; + name: string; + colors: string[]; + }; + } = {}; + + // Reorganize data by metric instead of by call + for (const [callId, metricBin] of Object.entries(filteredData)) { + for (const [metric, value] of Object.entries(metricBin.metrics)) { + if (!metrics[metric]) { + metrics[metric] = {callIds: [], values: [], name: metric, colors: []}; + } + metrics[metric].callIds.push(callId); + metrics[metric].values.push(value); + metrics[metric].colors.push(metricBin.color); + } + } + + // Convert metrics object to Plotly data format + return Object.entries(metrics).map(([metric, metricBin]) => { + const maxY = Math.max(...metricBin.values) * 1.1; + const minY = Math.min(...metricBin.values, 0); + const plotlyData: Plotly.Data = { + type: 'bar', + y: metricBin.values, + x: metricBin.callIds, + text: metricBin.values.map(value => + Number.isInteger(value) ? value.toString() : value.toFixed(3) + ), + textposition: 'outside', + textfont: {size: 14, color: 'black'}, + name: metric, + marker: {color: metricBin.colors}, + }; + return {plotlyData, yRange: [minY, maxY] as [number, number]}; + }); + }, [filteredData]); + +const useContainerDimensions = () => { + const containerRef = useRef(null); + const [containerWidth, setContainerWidth] = useState(0); + const [isInitialRender, setIsInitialRender] = useState(true); + const [currentPage, setCurrentPage] = useState(0); + + useEffect(() => { + const updateWidth = () => { + if (containerRef.current) { + setContainerWidth(containerRef.current.offsetWidth); + } + }; + + updateWidth(); + setIsInitialRender(false); + + window.addEventListener('resize', updateWidth); + return () => window.removeEventListener('resize', updateWidth); + }, []); + + const plotsPerPage = useMemo(() => { + return Math.max(1, Math.floor(containerWidth / PLOT_HEIGHT)); + }, [containerWidth]); + + return { + containerRef, + isInitialRender, + plotsPerPage, + currentPage, + setCurrentPage, + }; }; -const normalizeValues = (values: Array): number[] => { - // find the max value - // find the power of 2 that is greater than the max value - // divide all values by that power of 2 - const maxVal = Math.max(...(values.filter(v => v !== undefined) as number[])); - const maxPower = Math.ceil(Math.log2(maxVal)); - return values.map(val => (val ? val / 2 ** maxPower : 0)); +const usePaginatedPlots = ( + filteredData: RadarPlotData, + barPlotData: Array<{plotlyData: Plotly.Data; yRange: [number, number]}>, + plotsPerPage: number, + currentPage: number +) => { + const radarPlotWidth = 2; + const totalBarPlots = barPlotData.length; + const totalPlotWidth = radarPlotWidth + totalBarPlots; + const totalPages = Math.ceil(totalPlotWidth / plotsPerPage); + + const plotsToShow = useMemo(() => { + // First page always shows radar plot + if (currentPage === 0) { + const availableSpace = plotsPerPage - radarPlotWidth; + return [ + , + ...barPlotData + .slice(0, availableSpace) + .map((plot, index) => ( + + )), + ]; + } else { + // Subsequent pages show only bar plots + const startIdx = + (currentPage - 1) * plotsPerPage + (plotsPerPage - radarPlotWidth); + const endIdx = startIdx + plotsPerPage; + return barPlotData + .slice(startIdx, endIdx) + .map((plot, index) => ( + + )); + } + }, [currentPage, plotsPerPage, filteredData, barPlotData]); + + // Calculate pagination details + const totalPlots = barPlotData.length + 1; // +1 for the radar plot + const startIndex = + currentPage === 0 ? 1 : Math.min(plotsPerPage + 1, totalPlots); + const endIndex = + currentPage === 0 + ? Math.min(plotsToShow.length, totalPlots) + : Math.min(startIndex + plotsToShow.length - 1, totalPlots); + + return {plotsToShow, totalPlots, startIndex, endIndex, totalPages}; }; -const useNormalizedPlotDataFromMetrics = ( +const usePlotDataFromMetrics = ( state: EvaluationComparisonState ): {radarData: RadarPlotData; allMetricNames: Set} => { const compositeMetrics = useMemo(() => { @@ -323,7 +437,7 @@ const useNormalizedPlotDataFromMetrics = ( }, [state]); return useMemo(() => { - const normalizedMetrics = Object.values(compositeMetrics) + const metrics = Object.values(compositeMetrics) .map(scoreGroup => Object.values(scoreGroup.metrics)) .flat() .map(metric => { @@ -344,11 +458,8 @@ const useNormalizedPlotDataFromMetrics = ( return val; } }); - const normalizedValues = normalizeValues(values); const evalScores: {[evalCallId: string]: number | undefined} = - Object.fromEntries( - callIds.map((key, i) => [key, normalizedValues[i]]) - ); + Object.fromEntries(callIds.map((key, i) => [key, values[i]])); const metricLabel = flattenedDimensionPath( Object.values(metric.scorerRefs)[0].metric @@ -367,7 +478,7 @@ const useNormalizedPlotDataFromMetrics = ( name: evalCall.name, color: evalCall.color, metrics: Object.fromEntries( - normalizedMetrics.map(metric => { + metrics.map(metric => { return [ metric.metricLabel, metric.evalScores[evalCall.callId] ?? 0, @@ -378,7 +489,7 @@ const useNormalizedPlotDataFromMetrics = ( ]; }) ); - const allMetricNames = new Set(normalizedMetrics.map(m => m.metricLabel)); + const allMetricNames = new Set(metrics.map(m => m.metricLabel)); return {radarData, allMetricNames}; }, [callIds, compositeMetrics, state.data.evaluationCalls]); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx index 6108ed5407e..085587a64d8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ObjectVersionPage.tsx @@ -211,47 +211,50 @@ const ObjectVersionPageInner: React.FC<{ } headerContent={ - - {objectName}{' '} - {objectVersions.loading ? ( - - ) : ( - <> - [ - +
+
+

Name

+
+ +
+ {objectName} + {objectVersions.loading ? ( + + ) : ( + + ({objectVersionCount} version + {objectVersionCount !== 1 ? 's' : ''}) + + )} + - ] - - )} - - ), - Version: <>{objectVersionIndex}, - ...(refExtra - ? { - Subpath: refExtra, - } - : {}), - // 'Type Version': ( - // - // ), - }} - /> +
+
+
+
+
+

Version

+

{objectVersionIndex}

+
+ {refExtra && ( +
+

Subpath

+

{refExtra}

+
+ )} +
+ } // menuItems={[ // { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx index 1a6e4afc577..36f4e44afc5 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/OpVersionPage.tsx @@ -1,5 +1,6 @@ import React, {useMemo} from 'react'; +import {Icon} from '../../../../Icon'; import {LoadingDots} from '../../../../LoadingDots'; import {Tailwind} from '../../../../Tailwind'; import {NotFoundPanel} from '../NotFoundPanel'; @@ -13,7 +14,6 @@ import { import {CenteredAnimatedLoader} from './common/Loader'; import { ScrollableTabContent, - SimpleKeyValueTable, SimplePageLayoutWithHeader, } from './common/SimplePageLayout'; import {TabUseOp} from './TabUseOp'; @@ -75,49 +75,71 @@ const OpVersionPageInner: React.FC<{ - {opId}{' '} - {opVersions.loading ? ( - - ) : ( - <> - [ - - ] - - )} - - ), - Version: <>{versionIndex}, - Calls: - !callsStats.loading || opVersionCallCount > 0 ? ( - +
+
+

Name

+
+ + variant="secondary"> +
+ {opId} + {opVersions.loading ? ( + + ) : ( + + ({opVersionCount} version + {opVersionCount !== 1 ? 's' : ''}) + + )} + +
+
+
+
+
+

Version

+

{versionIndex}

+
+
+

Calls:

+ {!callsStats.loading || opVersionCallCount > 0 ? ( +
+ + +
) : ( - <> - ), - }} - /> +

-

+ )} +
+
+ } tabs={[ { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx index b5f675633fe..b6b6e7c420d 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChat.tsx @@ -115,11 +115,15 @@ export const PlaygroundChat = ({ }}> -
+
{state.traceCall && ( addMessage(idx, newMessage), editChoice: (choiceIndex, newChoice) => editChoice(idx, choiceIndex, newChoice), - retry: (messageIndex: number, isChoice?: boolean) => - handleRetry(idx, messageIndex, isChoice), + retry: (messageIndex: number, choiceIndex?: number) => + handleRetry(idx, messageIndex, choiceIndex), sendMessage: ( role: PlaygroundMessageRole, content: string, @@ -166,6 +170,12 @@ export const PlaygroundChat = ({ ) => { handleSend(role, idx, content, toolCallId); }, + setSelectedChoiceIndex: (choiceIndex: number) => + setPlaygroundStateField( + idx, + 'selectedChoiceIndex', + choiceIndex + ), }}> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx index 8b49483a4ed..e00f1a02b2c 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/PlaygroundChatTopBar.tsx @@ -110,7 +110,6 @@ export const PlaygroundChatTopBar: React.FC = ({ width: '100%', display: 'flex', justifyContent: 'space-between', - paddingBottom: '8px', }}> { try { setIsLoading(true); const updatedStates = playgroundStates.map((state, index) => { if (index === callIndex) { - if (isChoice) { - return appendChoicesToMessages(state); + if (choiceIndex !== undefined) { + return appendChoiceToMessages(state, choiceIndex); } const updatedState = JSON.parse(JSON.stringify(state)); if (updatedState.traceCall?.inputs?.messages) { @@ -203,17 +203,25 @@ const handleUpdateCallWithResponse = ( }; }; -const appendChoicesToMessages = (state: PlaygroundState): PlaygroundState => { +const appendChoiceToMessages = ( + state: PlaygroundState, + choiceIndex?: number +): PlaygroundState => { const updatedState = JSON.parse(JSON.stringify(state)); if ( updatedState.traceCall?.inputs?.messages && updatedState.traceCall.output?.choices ) { - updatedState.traceCall.output.choices.forEach((choice: any) => { - if (choice.message) { - updatedState.traceCall.inputs.messages.push(choice.message); - } - }); + if (choiceIndex !== undefined) { + updatedState.traceCall.inputs.messages.push( + updatedState.traceCall.output.choices[choiceIndex].message + ); + } else { + updatedState.traceCall.inputs.messages.push( + updatedState.traceCall.output.choices[updatedState.selectedChoiceIndex] + .message + ); + } updatedState.traceCall.output.choices = undefined; } return updatedState; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx index e84e2f75d4b..804670a1dc3 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundChat/useChatFunctions.tsx @@ -114,16 +114,9 @@ export const useChatFunctions = ( newTraceCall?.output && Array.isArray((newTraceCall.output as TraceCallOutput).choices) ) { - // Delete the old choice - (newTraceCall.output as TraceCallOutput).choices!.splice( - choiceIndex, - 1 - ); - - // Add the new choice as a message - newTraceCall.inputs = newTraceCall.inputs ?? {}; - newTraceCall.inputs.messages = newTraceCall.inputs.messages ?? []; - newTraceCall.inputs.messages.push(newChoice); + // Replace the choice + (newTraceCall.output as TraceCallOutput).choices![choiceIndex].message = + newChoice; } return newTraceCall; }); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx index a8176292d1a..31369602560 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundContext.tsx @@ -10,14 +10,16 @@ export type PlaygroundContextType = { deleteMessage: (messageIndex: number, responseIndexes?: number[]) => void; editChoice: (choiceIndex: number, newChoice: Message) => void; - deleteChoice: (choiceIndex: number) => void; + deleteChoice: (messageIndex: number, choiceIndex: number) => void; - retry: (messageIndex: number, isChoice?: boolean) => void; + retry: (messageIndex: number, choiceIndex?: number) => void; sendMessage: ( role: PlaygroundMessageRole, content: string, toolCallId?: string ) => void; + + setSelectedChoiceIndex: (choiceIndex: number) => void; }; const DEFAULT_CONTEXT: PlaygroundContextType = { @@ -31,6 +33,7 @@ const DEFAULT_CONTEXT: PlaygroundContextType = { retry: () => {}, sendMessage: () => {}, + setSelectedChoiceIndex: () => {}, }; // Create context that can be undefined diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx index c6232631e4e..76d1c6d9e31 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundPage.tsx @@ -7,7 +7,11 @@ import {SimplePageLayoutWithHeader} from '../common/SimplePageLayout'; import {useWFHooks} from '../wfReactInterface/context'; import {PlaygroundChat} from './PlaygroundChat/PlaygroundChat'; import {PlaygroundSettings} from './PlaygroundSettings/PlaygroundSettings'; -import {DEFAULT_SYSTEM_MESSAGE, usePlaygroundState} from './usePlaygroundState'; +import { + DEFAULT_SYSTEM_MESSAGE, + parseTraceCall, + usePlaygroundState, +} from './usePlaygroundState'; export type PlaygroundPageProps = { entity: string; @@ -89,7 +93,10 @@ export const PlaygroundPageInner = (props: PlaygroundPageProps) => { for (const [idx, state] of newStates.entries()) { for (const c of calls || []) { if (state.traceCall.id === c.callId) { - newStates[idx] = {...state, traceCall: c.traceCall || {}}; + newStates[idx] = { + ...state, + traceCall: parseTraceCall(c.traceCall || {}), + }; break; } } diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx index 5a2e8fae32c..e0971d35bfb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/PlaygroundSettings/PlaygroundSettings.tsx @@ -59,6 +59,12 @@ export const PlaygroundSettings: React.FC = ({ gap: '4px', mt: 2, }}> + + setPlaygroundStateField(idx, 'responseFormat', value) + } + /> = ({ } /> - - setPlaygroundStateField(idx, 'responseFormat', value) + + setPlaygroundStateField(idx, 'stopSequences', value) } /> + {/* TODO: N times to run is not supported for all models */} + {/* TODO: rerun if this is not supported in the backend */} - setPlaygroundStateField(idx, 'temperature', value) + setPlaygroundStateField(idx, 'nTimes', value) } - label="Temperature" - value={playgroundState.temperature} + label="Completion iterations" + value={playgroundState.nTimes} /> = ({ value={playgroundState.maxTokens} /> - - setPlaygroundStateField(idx, 'stopSequences', value) + + setPlaygroundStateField(idx, 'temperature', value) } + label="Temperature" + value={playgroundState.temperature} /> = ({ label="Presence penalty" value={playgroundState.presencePenalty} /> + & { autoGrow?: boolean; maxHeight?: string | number; + startHeight?: string | number; reset?: boolean; }; export const StyledTextArea = forwardRef( - ({className, autoGrow, maxHeight, reset, ...props}, ref) => { + ({className, autoGrow, maxHeight, startHeight, reset, ...props}, ref) => { const textareaRef = React.useRef(null); React.useEffect(() => { @@ -26,11 +27,22 @@ export const StyledTextArea = forwardRef( return; } - // Disable resize when autoGrow is true - textareaElement.style.resize = 'none'; + // Only disable resize when autoGrow is true + textareaElement.style.resize = autoGrow ? 'none' : 'vertical'; + + // Set initial height if provided + if (startHeight && textareaElement.value === '') { + textareaElement.style.height = + typeof startHeight === 'number' ? `${startHeight}px` : startHeight; + return; + } if (reset || textareaElement.value === '') { - textareaElement.style.height = 'auto'; + textareaElement.style.height = startHeight + ? typeof startHeight === 'number' + ? `${startHeight}px` + : startHeight + : 'auto'; return; } @@ -63,7 +75,7 @@ export const StyledTextArea = forwardRef( return () => textareaRefElement.removeEventListener('input', adjustHeight); - }, [autoGrow, maxHeight, reset]); + }, [autoGrow, maxHeight, reset, startHeight]); return ( @@ -86,6 +98,7 @@ export const StyledTextArea = forwardRef( 'focus:outline-none', 'relative bottom-0 top-0 items-center rounded-sm', 'outline outline-1 outline-moon-250', + !autoGrow && 'resize-y', props.disabled ? 'opacity-50' : 'hover:outline hover:outline-2 hover:outline-teal-500/40 focus:outline-2', @@ -94,6 +107,14 @@ export const StyledTextArea = forwardRef( 'placeholder-moon-500 dark:placeholder-moon-600', className )} + style={{ + height: startHeight + ? typeof startHeight === 'number' + ? `${startHeight}px` + : startHeight + : undefined, + ...props.style, + }} {...props} /> diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts index fa73e87bf45..dc0e0a370e0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/types.ts @@ -20,9 +20,10 @@ export type PlaygroundState = { topP: number; frequencyPenalty: number; presencePenalty: number; - // nTimes: number; + nTimes: number; maxTokensLimit: number; model: LLMMaxTokensKey; + selectedChoiceIndex: number; }; export type PlaygroundStateKey = keyof PlaygroundState; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts index 8c556edaef2..cbcc7c52fb8 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/PlaygroundPage/usePlaygroundState.ts @@ -1,5 +1,11 @@ +import {cloneDeep} from 'lodash'; import {SetStateAction, useCallback, useState} from 'react'; +import { + anthropicContentBlocksToChoices, + hasStringProp, + isAnthropicCompletionFormat, +} from '../ChatView/hooks'; import {LLM_MAX_TOKENS_KEYS, LLMMaxTokensKey} from './llmMaxTokens'; import { OptionalTraceCallSchema, @@ -34,9 +40,10 @@ const DEFAULT_PLAYGROUND_STATE = { topP: 1, frequencyPenalty: 0, presencePenalty: 0, - // nTimes: 1, + nTimes: 1, maxTokensLimit: 16384, model: DEFAULT_MODEL, + selectedChoiceIndex: 0, }; export const usePlaygroundState = () => { @@ -76,7 +83,7 @@ export const usePlaygroundState = () => { setPlaygroundStates(prevState => { const newState = {...prevState[0]}; - newState.traceCall = traceCall; + newState.traceCall = parseTraceCall(traceCall); if (!inputs) { return [newState]; @@ -90,9 +97,9 @@ export const usePlaygroundState = () => { } } } - // if (inputs.n) { - // newState.nTimes = parseInt(inputs.n, 10); - // } + if (inputs.n) { + newState.nTimes = parseInt(inputs.n, 10); + } if (inputs.temperature) { newState.temperature = parseFloat(inputs.temperature); } @@ -147,10 +154,42 @@ export const getInputFromPlaygroundState = (state: PlaygroundState) => { top_p: state.topP, frequency_penalty: state.frequencyPenalty, presence_penalty: state.presencePenalty, - // n: state.nTimes, + n: state.nTimes, response_format: { type: state.responseFormat, }, tools: tools.length > 0 ? tools : undefined, }; }; + +// This is a helper function to parse the trace call output for anthropic +// so that the playground can display the choices +export const parseTraceCall = (traceCall: OptionalTraceCallSchema) => { + const parsedTraceCall = cloneDeep(traceCall); + + // Handles anthropic outputs + // Anthropic has content and stop_reason as top-level fields + if (isAnthropicCompletionFormat(parsedTraceCall.output)) { + const {content, stop_reason, ...outputs} = parsedTraceCall.output as any; + parsedTraceCall.output = { + ...outputs, + choices: anthropicContentBlocksToChoices(content, stop_reason), + }; + } + // Handles anthropic inputs + // Anthropic has system message as a top-level request field + if (hasStringProp(parsedTraceCall.inputs, 'system')) { + const {messages, system, ...inputs} = parsedTraceCall.inputs as any; + parsedTraceCall.inputs = { + ...inputs, + messages: [ + { + role: 'system', + content: system, + }, + ...messages, + ], + }; + } + return parsedTraceCall; +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx index 9acbdfe6c2f..a478437facb 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/AnnotationScorerForm.tsx @@ -1,5 +1,5 @@ import {Box} from '@material-ui/core'; -import React, {FC, useCallback, useState} from 'react'; +import React, {FC, useCallback, useEffect, useState} from 'react'; import {z} from 'zod'; import {createBaseObjectInstance} from '../wfReactInterface/baseObjectClassQuery'; @@ -28,7 +28,7 @@ const AnnotationScorerFormSchema = z.object({ }), z.object({ type: z.literal('String'), - 'Max length': z.number().optional(), + 'Maximum length': z.number().optional(), }), z.object({ type: z.literal('Select'), @@ -45,6 +45,9 @@ export const AnnotationScorerForm: FC< ScorerFormProps> > = ({data, onDataChange}) => { const [config, setConfig] = useState(data ?? DEFAULT_STATE); + useEffect(() => { + setConfig(data ?? DEFAULT_STATE); + }, [data]); const [isValid, setIsValid] = useState(false); const handleConfigChange = useCallback( @@ -113,7 +116,7 @@ function convertTypeExtrasToJsonSchema( const typeSchema = obj.Type; const typeExtras: Record = {}; if (typeSchema.type === 'String') { - typeExtras.maxLength = typeSchema['Max length']; + typeExtras.maxLength = typeSchema['Maximum length']; } else if (typeSchema.type === 'Integer' || typeSchema.type === 'Number') { typeExtras.minimum = typeSchema.Minimum; typeExtras.maximum = typeSchema.Maximum; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx index 2716bfbfa81..250c896cfea 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ScorersPage/FormComponents.tsx @@ -3,7 +3,7 @@ import {Select} from '@wandb/weave/components/Form/Select'; import {TextField} from '@wandb/weave/components/Form/TextField'; import React from 'react'; -export const GAP_BETWEEN_ITEMS_PX = 10; +export const GAP_BETWEEN_ITEMS_PX = 16; export const GAP_BETWEEN_LABEL_AND_FIELD_PX = 10; type AutocompleteWithLabelType - ))} - + ); }; @@ -758,6 +779,7 @@ const LiteralField: React.FC<{ setConfig, }) => { const literalValue = unwrappedSchema.value; + const isOptional = fieldSchema instanceof z.ZodOptional; useEffect(() => { if (value !== literalValue) { @@ -765,7 +787,14 @@ const LiteralField: React.FC<{ } }, [value, literalValue, targetPath, config, setConfig]); - return ; + return ( + + ); }; const BooleanField: React.FC<{ diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx index 8a11252bc33..4060735cc67 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/Links.tsx @@ -9,6 +9,7 @@ import {Link as LinkComp, useHistory} from 'react-router-dom'; import styled, {css} from 'styled-components'; import {TargetBlank} from '../../../../../../common/util/links'; +import {maybePluralizeWord} from '../../../../../../core/util/string'; import { FEEDBACK_EXPAND_PARAM, PATH_PARAM, @@ -414,7 +415,8 @@ export const CallsLink: React.FC<{ $variant={props.variant} to={router.callsUIUrl(props.entity, props.project, props.filter)}> {props.callCount} - {props.countIsLimited ? '+' : ''} calls + {props.countIsLimited ? '+' : ''}{' '} + {maybePluralizeWord(props.callCount, 'call')} ); }; @@ -427,6 +429,7 @@ export const ObjectVersionsLink: React.FC<{ filter?: WFHighLevelObjectVersionFilter; neverPeek?: boolean; variant?: LinkVariant; + children?: React.ReactNode; }> = props => { const {peekingRouter, baseRouter} = useWeaveflowRouteContext(); const router = props.neverPeek ? baseRouter : peekingRouter; @@ -438,9 +441,13 @@ export const ObjectVersionsLink: React.FC<{ props.project, props.filter )}> - {props.versionCount} - {props.countIsLimited ? '+' : ''} version - {props.versionCount !== 1 ? 's' : ''} + {props.children ?? ( + <> + {props.versionCount} + {props.countIsLimited ? '+' : ''} version + {props.versionCount !== 1 ? 's' : ''} + + )} ); }; @@ -453,6 +460,7 @@ export const OpVersionsLink: React.FC<{ filter?: WFHighLevelOpVersionFilter; neverPeek?: boolean; variant?: LinkVariant; + children?: React.ReactNode; }> = props => { const {peekingRouter, baseRouter} = useWeaveflowRouteContext(); const router = props.neverPeek ? baseRouter : peekingRouter; @@ -460,9 +468,13 @@ export const OpVersionsLink: React.FC<{ - {props.versionCount} - {props.countIsLimited ? '+' : ''} version - {props.versionCount !== 1 ? 's' : ''} + {props.children ?? ( + <> + {props.versionCount} + {props.countIsLimited ? '+' : ''} version + {props.versionCount !== 1 ? 's' : ''} + + )} ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx index e8965d3f4a2..9d2c6ab9718 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/common/SimplePageLayout.tsx @@ -313,18 +313,20 @@ const SimpleTabView: FC<{ height: '100%', overflow: 'hidden', }}> - - {props.headerContent} - + {props.headerContent && ( + + {props.headerContent} + + )} {(!props.hideTabsIfSingle || props.tabs.length > 1) && ( - traceCall.op_name.includes(`.${name}:`) - ) && - modelRefs.includes(traceCall.inputs.self)) || + modelRefs.includes(traceCall.inputs.self) || modelRefs.includes(traceCall.op_name); const isProbablyScoreCall = scorerRefs.has(traceCall.op_name); diff --git a/weave-js/src/components/Tag/Pill.tsx b/weave-js/src/components/Tag/Pill.tsx index 6f734f9cbfc..ed958a0c543 100644 --- a/weave-js/src/components/Tag/Pill.tsx +++ b/weave-js/src/components/Tag/Pill.tsx @@ -59,3 +59,42 @@ export const IconOnlyPill: FC = ({ ); }; + +export type ExpandingPillProps = { + className?: string; + color?: TagColorName; + icon: IconName; + label: string; +}; +export const ExpandingPill = ({ + className, + color, + icon, + label, +}: ExpandingPillProps) => { + const classes = useTagClasses({color, isInteractive: true}); + return ( + +
+ + + {label} + +
+
+ ); +}; diff --git a/weave-js/src/components/ToggleButtonGroup.tsx b/weave-js/src/components/ToggleButtonGroup.tsx index 65b2c538975..eca93e95657 100644 --- a/weave-js/src/components/ToggleButtonGroup.tsx +++ b/weave-js/src/components/ToggleButtonGroup.tsx @@ -9,6 +9,7 @@ import {Tailwind} from './Tailwind'; export type ToggleOption = { value: string; icon?: IconName; + isDisabled?: boolean; }; export type ToggleButtonGroupProps = { @@ -37,7 +38,10 @@ export const ToggleButtonGroup = React.forwardRef< } const handleValueChange = (newValue: string) => { - if (newValue !== value) { + if ( + newValue !== value && + options.find(option => option.value === newValue)?.isDisabled !== true + ) { onValueChange(newValue); } }; @@ -49,34 +53,39 @@ export const ToggleButtonGroup = React.forwardRef< onValueChange={handleValueChange} className="flex gap-px" ref={ref}> - {options.map(({value: optionValue, icon}) => ( - - - - ))} + {options.map( + ({value: optionValue, icon, isDisabled: optionIsDisabled}) => ( + + + + ) + )} ); diff --git a/weave-js/src/components/TooltipDeprecated.tsx b/weave-js/src/components/TooltipDeprecated.tsx new file mode 100644 index 00000000000..1b600a6589c --- /dev/null +++ b/weave-js/src/components/TooltipDeprecated.tsx @@ -0,0 +1,36 @@ +/** + * @deprecated Don't use this in any new code, we're trying to get rid of semantic-ui. + */ +import {Popup} from 'semantic-ui-react'; +import styled from 'styled-components'; + +import { + hexToRGB, + MOON_650, + MOON_800, + OBLIVION, + WHITE, +} from '../common/css/globals.styles'; + +export const TooltipDeprecated = styled(Popup).attrs({ + basic: true, // This removes the pointing arrow. + mouseEnterDelay: 500, + popperModifiers: { + preventOverflow: { + // Prevent popper from erroneously constraining the popup. + // Without this, tooltips in single row table cells get positioned under the cursor, + // causing them to immediately close. + boundariesElement: 'viewport', + }, + }, +})` + && { + color: ${WHITE}; + background: ${MOON_800}; + border-color: ${MOON_650}; + box-shadow: 0px 4px 6px ${hexToRGB(OBLIVION, 0.2)}; + font-size: 14px; + line-height: 140%; + max-width: 300px; + } +`; diff --git a/weave-js/src/errors.ts b/weave-js/src/errors.ts index daa3e0c97a0..f2582baade8 100644 --- a/weave-js/src/errors.ts +++ b/weave-js/src/errors.ts @@ -37,7 +37,8 @@ type DDErrorPayload = { export const weaveErrorToDDPayload = ( error: Error, - weave?: WeaveApp + weave?: WeaveApp, + uuid?: string ): DDErrorPayload => { try { return { @@ -49,6 +50,7 @@ export const weaveErrorToDDPayload = ( windowLocationURL: trimString(window.location.href), weaveContext: weave?.client.debugMeta(), isServerError: error instanceof UseNodeValueServerExecutionError, + ...(uuid != null && {uuid}), }; } catch (e) { // If we fail to serialize the error, just return an empty object. diff --git a/weave-js/tailwind.config.cjs b/weave-js/tailwind.config.cjs index b234cc3cf1a..d8a580c928c 100644 --- a/weave-js/tailwind.config.cjs +++ b/weave-js/tailwind.config.cjs @@ -13,8 +13,13 @@ module.exports = { */ boxShadow: { none: 'none', - md: '0px 12px 24px 0px #15181F29', - lg: '0px 24px 48px 0px #15181F29', + flat: '0px 4px 8px 0px #0D0F120a', // oblivion 4% + medium: '0px 12px 24px 0px #0D0F1229', // oblivion 16% + deep: '0px 24px 48px 0px #0D0F123d', // oblivion 24% + + // deprecated shadow configs + md: '0px 12px 24px 0px #15181F29', // use shadow-medium instead + lg: '0px 24px 48px 0px #15181F29', // use shadow-deep instead }, spacing: { 0: '0rem', @@ -189,17 +194,17 @@ module.exports = { }, extend: { animation: { - 'wave': 'wave 3s linear infinite' + wave: 'wave 3s linear infinite', }, keyframes: { - "wave": { - "0%, 30%, 100%": { - transform: "initial" + wave: { + '0%, 30%, 100%': { + transform: 'initial', }, - "15%": { - transform: "translateY(-10px)" - } - } + '15%': { + transform: 'translateY(-10px)', + }, + }, }, opacity: { 35: '.35', @@ -221,6 +226,6 @@ module.exports = { in their parent hierarchy */ important: '.tw-style', experimental: { - optimizeUniversalDefaults: true + optimizeUniversalDefaults: true, }, }; diff --git a/weave-js/yarn.lock b/weave-js/yarn.lock index c7f9379e32a..6a5ec14e872 100644 --- a/weave-js/yarn.lock +++ b/weave-js/yarn.lock @@ -4776,11 +4776,6 @@ resolved "https://registry.yarnpkg.com/@types/unist/-/unist-2.0.7.tgz#5b06ad6894b236a1d2bd6b2f07850ca5c59cf4d6" integrity sha512-cputDpIbFgLUaGQn6Vqg3/YsJwxUwHLO13v3i5ouxT4lat0khip9AEWxtERujXV9wxIB1EyF97BSJFt6vpdI8g== -"@types/uuid@^9.0.1": - version "9.0.2" - resolved "https://registry.yarnpkg.com/@types/uuid/-/uuid-9.0.2.tgz#ede1d1b1e451548d44919dc226253e32a6952c4b" - integrity sha512-kNnC1GFBLuhImSnV7w4njQkUiJi0ZXUycu1rUaouPqiKlXkh77JKgdRnTAp1x5eBwcIwbtI+3otwzuIDEuDoxQ== - "@types/wavesurfer.js@^2.0.0": version "2.0.2" resolved "https://registry.yarnpkg.com/@types/wavesurfer.js/-/wavesurfer.js-2.0.2.tgz#b98a4d57ca24ee2028ae6dd5c2208b568bb73842" @@ -15032,6 +15027,11 @@ util-deprecate@^1.0.1, util-deprecate@^1.0.2, util-deprecate@~1.0.1: resolved "https://registry.yarnpkg.com/util-deprecate/-/util-deprecate-1.0.2.tgz#450d4dc9fa70de732762fbd2d4a28981419a0ccf" integrity sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw== +uuid@^11.0.3: + version "11.0.3" + resolved "https://registry.yarnpkg.com/uuid/-/uuid-11.0.3.tgz#248451cac9d1a4a4128033e765d137e2b2c49a3d" + integrity sha512-d0z310fCWv5dJwnX1Y/MncBAqGMKEzlBb1AOf7z9K8ALnd0utBX/msg/fA0+sbyN1ihbMsLhrBlnl1ak7Wa0rg== + uuid@^2.0.2: version "2.0.3" resolved "https://registry.yarnpkg.com/uuid/-/uuid-2.0.3.tgz#67e2e863797215530dff318e5bf9dcebfd47b21a" @@ -15042,11 +15042,6 @@ uuid@^3.0.0, uuid@^3.4.0: resolved "https://registry.yarnpkg.com/uuid/-/uuid-3.4.0.tgz#b23e4358afa8a202fe7a100af1f5f883f02007ee" integrity sha512-HjSDRw6gZE5JMggctHBcjVak08+KEVhSIiDzFnT9S9aegmp85S/bReBVTb4QTFaRNptJ9kuYaNhnbNEOkbKb/A== -uuid@^9.0.0: - version "9.0.0" - resolved "https://registry.yarnpkg.com/uuid/-/uuid-9.0.0.tgz#592f550650024a38ceb0c562f2f6aa435761efb5" - integrity sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg== - uvu@^0.5.0: version "0.5.6" resolved "https://registry.yarnpkg.com/uvu/-/uvu-0.5.6.tgz#2754ca20bcb0bb59b64e9985e84d2e81058502df" diff --git a/weave/flow/eval.py b/weave/flow/eval.py index 5f4a961f904..bf78dc06d85 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -5,9 +5,10 @@ import time import traceback from collections.abc import Coroutine +from datetime import datetime from typing import Any, Callable, Literal, Optional, Union, cast -from pydantic import PrivateAttr +from pydantic import PrivateAttr, model_validator from rich import print from rich.console import Console @@ -16,6 +17,7 @@ from weave.flow.dataset import Dataset from weave.flow.model import Model, get_infer_method from weave.flow.obj import Object +from weave.flow.util import make_memorable_name from weave.scorers import ( Scorer, _has_oldstyle_scorers, @@ -28,7 +30,7 @@ from weave.trace.env import get_weave_parallelism from weave.trace.errors import OpCallError from weave.trace.isinstance import weave_isinstance -from weave.trace.op import Op, as_op, is_op +from weave.trace.op import CallDisplayNameFunc, Op, as_op, is_op from weave.trace.vals import WeaveObject from weave.trace.weave_client import Call, get_ref @@ -41,6 +43,12 @@ ) +def default_evaluation_display_name(call: Call) -> str: + date = datetime.now().strftime("%Y-%m-%d") + unique_name = make_memorable_name() + return f"eval-{date}-{unique_name}" + + def async_call(func: Union[Callable, Op], *args: Any, **kwargs: Any) -> Coroutine: is_async = False if is_op(func): @@ -116,9 +124,21 @@ def function_to_evaluate(question: str): preprocess_model_input: Optional[Callable] = None trials: int = 1 + # Custom evaluation name for display in the UI. This is the same API as passing a + # custom `call_display_name` to `weave.op` (see that for more details). + evaluation_name: Optional[Union[str, CallDisplayNameFunc]] = None + # internal attr to track whether to use the new `output` or old `model_output` key for outputs _output_key: Literal["output", "model_output"] = PrivateAttr("output") + @model_validator(mode="after") + def _update_display_name(self) -> "Evaluation": + if self.evaluation_name: + # Treat user-specified `evaluation_name` as the name for `Evaluation.evaluate` + eval_op = as_op(self.evaluate) + eval_op.call_display_name = self.evaluation_name + return self + def model_post_init(self, __context: Any) -> None: scorers: list[Union[Callable, Scorer, Op]] = [] for scorer in self.scorers or []: @@ -486,7 +506,7 @@ async def eval_example(example: dict) -> dict: eval_rows.append(eval_row) return EvaluationResults(rows=weave.Table(eval_rows)) - @weave.op() + @weave.op(call_display_name=default_evaluation_display_name) async def evaluate(self, model: Union[Callable, Model]) -> dict: # The need for this pattern is quite unfortunate and highlights a gap in our # data model. As a user, I just want to pass a list of data `eval_rows` to diff --git a/weave/flow/util.py b/weave/flow/util.py index 4d89e777d88..ba35d5ebe4a 100644 --- a/weave/flow/util.py +++ b/weave/flow/util.py @@ -1,6 +1,7 @@ import asyncio import logging import multiprocessing +import random from collections.abc import AsyncIterator, Awaitable, Iterable from typing import Any, Callable, TypeVar @@ -81,3 +82,89 @@ def warn_once(logger: logging.Logger, message: str) -> None: if message not in _shown_warnings: logger.warning(message) _shown_warnings.add(message) + + +def make_memorable_name() -> str: + adjectives = [ + "brave", + "bright", + "calm", + "charming", + "clever", + "daring", + "dazzling", + "eager", + "elegant", + "eloquent", + "fierce", + "friendly", + "gentle", + "graceful", + "happy", + "honest", + "imaginative", + "innocent", + "joyful", + "jubilant", + "keen", + "kind", + "lively", + "loyal", + "merry", + "nice", + "noble", + "optimistic", + "proud", + "quiet", + "rich", + "sweet", + "tender", + "unique", + "wise", + "zealous", + ] + + nouns = [ + "bear", + "bird", + "breeze", + "cedar", + "cloud", + "daisy", + "dawn", + "dolphin", + "dusk", + "eagle", + "fish", + "flower", + "forest", + "hill", + "horizon", + "island", + "lake", + "lion", + "maple", + "meadow", + "moon", + "mountain", + "oak", + "ocean", + "pine", + "plateau", + "rain", + "river", + "rose", + "star", + "stream", + "sun", + "tiger", + "tree", + "valley", + "whale", + "wind", + "wolf", + ] + + adj = random.choice(adjectives) + noun = random.choice(nouns) + return f"{adj}-{noun}" diff --git a/weave/integrations/openai/openai_sdk.py b/weave/integrations/openai/openai_sdk.py index 7814700d4d3..a1e3a9b5831 100644 --- a/weave/integrations/openai/openai_sdk.py +++ b/weave/integrations/openai/openai_sdk.py @@ -1,15 +1,20 @@ +from __future__ import annotations + import importlib from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable import weave +from weave.trace.autopatch import IntegrationSettings, OpSettings from weave.trace.op import Op, ProcessedInputs from weave.trace.op_extensions.accumulator import add_accumulator -from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.patcher import MultiPatcher, NoOpPatcher, SymbolPatcher if TYPE_CHECKING: from openai.types.chat import ChatCompletionChunk +_openai_patcher: MultiPatcher | None = None + def maybe_unwrap_api_response(value: Any) -> Any: """If the caller requests a raw response, we unwrap the APIResponse object. @@ -43,9 +48,7 @@ def maybe_unwrap_api_response(value: Any) -> Any: return value -def openai_on_finish_post_processor( - value: Optional["ChatCompletionChunk"], -) -> Optional[dict]: +def openai_on_finish_post_processor(value: ChatCompletionChunk | None) -> dict | None: from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaFunctionCall, @@ -60,8 +63,8 @@ def openai_on_finish_post_processor( value = maybe_unwrap_api_response(value) def _get_function_call( - function_call: Optional[ChoiceDeltaFunctionCall], - ) -> Optional[FunctionCall]: + function_call: ChoiceDeltaFunctionCall | None, + ) -> FunctionCall | None: if function_call is None: return function_call if isinstance(function_call, ChoiceDeltaFunctionCall): @@ -73,8 +76,8 @@ def _get_function_call( return None def _get_tool_calls( - tool_calls: Optional[list[ChoiceDeltaToolCall]], - ) -> Optional[list[ChatCompletionMessageToolCall]]: + tool_calls: list[ChoiceDeltaToolCall] | None, + ) -> list[ChatCompletionMessageToolCall] | None: if tool_calls is None: return tool_calls @@ -128,10 +131,10 @@ def _get_tool_calls( def openai_accumulator( - acc: Optional["ChatCompletionChunk"], - value: "ChatCompletionChunk", + acc: ChatCompletionChunk | None, + value: ChatCompletionChunk, skip_last: bool = False, -) -> "ChatCompletionChunk": +) -> ChatCompletionChunk: from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ( ChoiceDeltaFunctionCall, @@ -285,7 +288,7 @@ def should_use_accumulator(inputs: dict) -> bool: def openai_on_input_handler( func: Op, args: tuple, kwargs: dict -) -> Optional[ProcessedInputs]: +) -> ProcessedInputs | None: if len(args) == 2 and isinstance(args[1], weave.EasyPrompt): original_args = args original_kwargs = kwargs @@ -305,20 +308,16 @@ def openai_on_input_handler( return None -def create_wrapper_sync( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_sync(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" def _add_stream_options(fn: Callable) -> Callable: @wraps(fn) def _wrapper(*args: Any, **kwargs: Any) -> Any: - if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None: + if kwargs.get("stream") and kwargs.get("stream_options") is None: kwargs["stream_options"] = {"include_usage": True} - return fn( - *args, **kwargs - ) # This is where the final execution of fn is happening. + return fn(*args, **kwargs) return _wrapper @@ -327,8 +326,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return True return False - op = weave.op()(_add_stream_options(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_add_stream_options(fn), **op_kwargs) op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore @@ -345,16 +344,14 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: # Surprisingly, the async `client.chat.completions.create` does not pass # `inspect.iscoroutinefunction`, so we can't dispatch on it and must write # it manually here... -def create_wrapper_async( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async(settings: OpSettings) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: "We need to do this so we can check if `stream` is used" def _add_stream_options(fn: Callable) -> Callable: @wraps(fn) async def _wrapper(*args: Any, **kwargs: Any) -> Any: - if bool(kwargs.get("stream")) and kwargs.get("stream_options") is None: + if kwargs.get("stream") and kwargs.get("stream_options") is None: kwargs["stream_options"] = {"include_usage": True} return await fn(*args, **kwargs) @@ -365,8 +362,8 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return True return False - op = weave.op()(_add_stream_options(fn)) - op.name = name # type: ignore + op_kwargs = settings.model_dump() + op = weave.op(_add_stream_options(fn), **op_kwargs) op._set_on_input_handler(openai_on_input_handler) return add_accumulator( op, # type: ignore @@ -380,28 +377,61 @@ def _openai_stream_options_is_set(inputs: dict) -> bool: return wrapper -symbol_patchers = [ - # Patch the Completions.create method - SymbolPatcher( - lambda: importlib.import_module("openai.resources.chat.completions"), - "Completions.create", - create_wrapper_sync(name="openai.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.chat.completions"), - "AsyncCompletions.create", - create_wrapper_async(name="openai.chat.completions.create"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.beta.chat.completions"), - "Completions.parse", - create_wrapper_sync(name="openai.beta.chat.completions.parse"), - ), - SymbolPatcher( - lambda: importlib.import_module("openai.resources.beta.chat.completions"), - "AsyncCompletions.parse", - create_wrapper_async(name="openai.beta.chat.completions.parse"), - ), -] - -openai_patcher = MultiPatcher(symbol_patchers) # type: ignore +def get_openai_patcher( + settings: IntegrationSettings | None = None, +) -> MultiPatcher | NoOpPatcher: + if settings is None: + settings = IntegrationSettings() + + if not settings.enabled: + return NoOpPatcher() + + global _openai_patcher + if _openai_patcher is not None: + return _openai_patcher + + base = settings.op_settings + + completions_create_settings = base.model_copy( + update={"name": base.name or "openai.chat.completions.create"} + ) + async_completions_create_settings = base.model_copy( + update={"name": base.name or "openai.chat.completions.create"} + ) + completions_parse_settings = base.model_copy( + update={"name": base.name or "openai.beta.chat.completions.parse"} + ) + async_completions_parse_settings = base.model_copy( + update={"name": base.name or "openai.beta.chat.completions.parse"} + ) + + _openai_patcher = MultiPatcher( + [ + SymbolPatcher( + lambda: importlib.import_module("openai.resources.chat.completions"), + "Completions.create", + create_wrapper_sync(settings=completions_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module("openai.resources.chat.completions"), + "AsyncCompletions.create", + create_wrapper_async(settings=async_completions_create_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "openai.resources.beta.chat.completions" + ), + "Completions.parse", + create_wrapper_sync(settings=completions_parse_settings), + ), + SymbolPatcher( + lambda: importlib.import_module( + "openai.resources.beta.chat.completions" + ), + "AsyncCompletions.parse", + create_wrapper_async(settings=async_completions_parse_settings), + ), + ] + ) + + return _openai_patcher diff --git a/weave/scorers/base_scorer.py b/weave/scorers/base_scorer.py index 5a19adcd04f..4ac27f1a76b 100644 --- a/weave/scorers/base_scorer.py +++ b/weave/scorers/base_scorer.py @@ -1,4 +1,5 @@ import inspect +import textwrap from collections.abc import Sequence from numbers import Number from typing import Any, Callable, Optional, Union @@ -45,7 +46,13 @@ def _validate_scorer_signature(scorer: Union[Callable, Op, Scorer]) -> bool: params = inspect.signature(scorer).parameters if "output" in params and "model_output" in params: raise ValueError( - "Both 'output' and 'model_output' cannot be in the scorer signature; prefer just using `output`." + textwrap.dedent( + """ + The scorer signature cannot include both `output` and `model_output` at the same time. + + To resolve, rename one of the arguments to avoid conflict. Prefer using `output` as the model's output. + """ + ) ) return True diff --git a/weave/scorers/llm_utils.py b/weave/scorers/llm_utils.py index 8e2672b8ee7..eef6f018b0f 100644 --- a/weave/scorers/llm_utils.py +++ b/weave/scorers/llm_utils.py @@ -2,10 +2,6 @@ from typing import TYPE_CHECKING, Any, Union -from weave.trace.autopatch import autopatch - -autopatch() # ensure both weave patching and instructor patching are applied - OPENAI_DEFAULT_MODEL = "gpt-4o" OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" OPENAI_DEFAULT_MODERATION_MODEL = "text-moderation-latest" @@ -23,10 +19,17 @@ from google.generativeai import GenerativeModel from instructor.patch import InstructorChatCompletionCreate from mistralai import Mistral - from openai import AsyncOpenAI, OpenAI + from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI _LLM_CLIENTS = Union[ - OpenAI, AsyncOpenAI, Anthropic, AsyncAnthropic, Mistral, GenerativeModel + OpenAI, + AsyncOpenAI, + AzureOpenAI, + AsyncAzureOpenAI, + Anthropic, + AsyncAnthropic, + Mistral, + GenerativeModel, ] else: _LLM_CLIENTS = object @@ -34,6 +37,8 @@ _LLM_CLIENTS_NAMES = ( "OpenAI", "AsyncOpenAI", + "AzureOpenAI", + "AsyncAzureOpenAI", "Anthropic", "AsyncAnthropic", "Mistral", diff --git a/weave/trace/api.py b/weave/trace/api.py index ee8131b0875..294308cbb67 100644 --- a/weave/trace/api.py +++ b/weave/trace/api.py @@ -13,6 +13,7 @@ # There is probably a better place for this, but including here for now to get the fix in. from weave import type_handlers # noqa: F401 from weave.trace import urls, util, weave_client, weave_init +from weave.trace.autopatch import AutopatchSettings from weave.trace.constants import TRACE_OBJECT_EMOJI from weave.trace.context import call_context from weave.trace.context import weave_client_context as weave_client_context @@ -32,6 +33,7 @@ def init( project_name: str, *, settings: UserSettings | dict[str, Any] | None = None, + autopatch_settings: AutopatchSettings | None = None, ) -> weave_client.WeaveClient: """Initialize weave tracking, logging to a wandb project. @@ -52,7 +54,12 @@ def init( if should_disable_weave(): return weave_init.init_weave_disabled().client - return weave_init.init_weave(project_name).client + initialized_client = weave_init.init_weave( + project_name, + autopatch_settings=autopatch_settings, + ) + + return initialized_client.client @contextlib.contextmanager diff --git a/weave/trace/autopatch.py b/weave/trace/autopatch.py index 3a5dca14556..0619194a224 100644 --- a/weave/trace/autopatch.py +++ b/weave/trace/autopatch.py @@ -4,8 +4,54 @@ check if libraries are installed and imported and patch in the case that they are. """ +from typing import Any, Callable, Optional, Union -def autopatch() -> None: +from pydantic import BaseModel, Field, validate_call + +from weave.trace.weave_client import Call + + +class OpSettings(BaseModel): + """Op settings for a specific integration. + These currently subset the `op` decorator args to provide a consistent interface + when working with auto-patched functions. See the `op` decorator for more details.""" + + name: Optional[str] = None + call_display_name: Optional[Union[str, Callable[[Call], str]]] = None + postprocess_inputs: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None + postprocess_output: Optional[Callable[[Any], Any]] = None + + +class IntegrationSettings(BaseModel): + """Configuration for a specific integration.""" + + enabled: bool = True + op_settings: OpSettings = Field(default_factory=OpSettings) + + +class AutopatchSettings(BaseModel): + """Settings for auto-patching integrations.""" + + # These will be uncommented as we add support for more integrations. Note that + + # anthropic: IntegrationSettings = Field(default_factory=IntegrationSettings) + # cerebras: IntegrationSettings = Field(default_factory=IntegrationSettings) + # cohere: IntegrationSettings = Field(default_factory=IntegrationSettings) + # dspy: IntegrationSettings = Field(default_factory=IntegrationSettings) + # google_ai_studio: IntegrationSettings = Field(default_factory=IntegrationSettings) + # groq: IntegrationSettings = Field(default_factory=IntegrationSettings) + # instructor: IntegrationSettings = Field(default_factory=IntegrationSettings) + # langchain: IntegrationSettings = Field(default_factory=IntegrationSettings) + # litellm: IntegrationSettings = Field(default_factory=IntegrationSettings) + # llamaindex: IntegrationSettings = Field(default_factory=IntegrationSettings) + # mistral: IntegrationSettings = Field(default_factory=IntegrationSettings) + # notdiamond: IntegrationSettings = Field(default_factory=IntegrationSettings) + openai: IntegrationSettings = Field(default_factory=IntegrationSettings) + # vertexai: IntegrationSettings = Field(default_factory=IntegrationSettings) + + +@validate_call +def autopatch(settings: Optional[AutopatchSettings] = None) -> None: from weave.integrations.anthropic.anthropic_sdk import anthropic_patcher from weave.integrations.cerebras.cerebras_sdk import cerebras_patcher from weave.integrations.cohere.cohere_sdk import cohere_patcher @@ -20,10 +66,13 @@ def autopatch() -> None: from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher - from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher - openai_patcher.attempt_patch() + if settings is None: + settings = AutopatchSettings() + + get_openai_patcher(settings.openai).attempt_patch() mistral_patcher.attempt_patch() litellm_patcher.attempt_patch() llamaindex_patcher.attempt_patch() @@ -54,10 +103,10 @@ def reset_autopatch() -> None: from weave.integrations.llamaindex.llamaindex import llamaindex_patcher from weave.integrations.mistral import mistral_patcher from weave.integrations.notdiamond.tracing import notdiamond_patcher - from weave.integrations.openai.openai_sdk import openai_patcher + from weave.integrations.openai.openai_sdk import get_openai_patcher from weave.integrations.vertexai.vertexai_sdk import vertexai_patcher - openai_patcher.undo_patch() + get_openai_patcher().undo_patch() mistral_patcher.undo_patch() litellm_patcher.undo_patch() llamaindex_patcher.undo_patch() diff --git a/weave/trace/context/call_context.py b/weave/trace/context/call_context.py index 402e1843ade..3a03bd167c3 100644 --- a/weave/trace/context/call_context.py +++ b/weave/trace/context/call_context.py @@ -20,6 +20,8 @@ class NoCurrentCallError(Exception): ... logger = logging.getLogger(__name__) +_tracing_enabled = contextvars.ContextVar("tracing_enabled", default=True) + def push_call(call: Call) -> None: new_stack = copy.copy(_call_stack.get()) @@ -136,3 +138,22 @@ def set_call_stack(stack: list[Call]) -> Iterator[list[Call]]: call_attributes: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar( "call_attributes", default={} ) + + +def get_tracing_enabled() -> bool: + return _tracing_enabled.get() + + +@contextlib.contextmanager +def set_tracing_enabled(enabled: bool) -> Iterator[None]: + token = _tracing_enabled.set(enabled) + try: + yield + finally: + _tracing_enabled.reset(token) + + +@contextlib.contextmanager +def tracing_disabled() -> Iterator[None]: + with set_tracing_enabled(False): + yield diff --git a/weave/trace/op.py b/weave/trace/op.py index 5e33e8bdbf8..a89c7400d8b 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -4,6 +4,7 @@ import inspect import logging +import random import sys import traceback from collections.abc import Coroutine, Mapping @@ -26,7 +27,11 @@ from weave.trace.constants import TRACE_CALL_EMOJI from weave.trace.context import call_context from weave.trace.context import weave_client_context as weave_client_context -from weave.trace.context.call_context import call_attributes +from weave.trace.context.call_context import ( + call_attributes, + get_tracing_enabled, + tracing_disabled, +) from weave.trace.context.tests_context import get_raise_on_captured_errors from weave.trace.errors import OpCallError from weave.trace.refs import ObjectRef @@ -107,16 +112,15 @@ def _apply_fn_defaults_to_inputs( ) -> dict[str, Any]: inputs = {**inputs} sig = inspect.signature(fn) - for param_name, param in sig.parameters.items(): - if param_name not in inputs: - if param.default != inspect.Parameter.empty and not _value_is_sentinel( - param - ): - inputs[param_name] = param.default - if param.kind == inspect.Parameter.VAR_POSITIONAL: - inputs[param_name] = () - elif param.kind == inspect.Parameter.VAR_KEYWORD: - inputs[param_name] = {} + for name, param in sig.parameters.items(): + if name in inputs: + continue + if param.default != inspect.Parameter.empty and not _value_is_sentinel(param): + inputs[name] = param.default + if param.kind == inspect.Parameter.VAR_POSITIONAL: + inputs[name] = () + if param.kind == inspect.Parameter.VAR_KEYWORD: + inputs[name] = {} return inputs @@ -175,6 +179,8 @@ class Op(Protocol): # it disables child ops as well. _tracing_enabled: bool + tracing_sample_rate: float + def _set_on_input_handler(func: Op, on_input: OnInputHandlerType) -> None: if func._on_input_handler is not None: @@ -216,6 +222,7 @@ def _default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedI inputs = sig.bind(*args, **kwargs).arguments except TypeError as e: raise OpCallError(f"Error calling {func.name}: {e}") + inputs_with_defaults = _apply_fn_defaults_to_inputs(func, inputs) return ProcessedInputs( original_args=args, @@ -407,37 +414,54 @@ def _do_call( if not pargs: pargs = _default_on_input_handler(op, args, kwargs) + # Handle all of the possible cases where we would skip tracing. if settings.should_disable_weave(): res = func(*pargs.args, **pargs.kwargs) - elif weave_client_context.get_weave_client() is None: + return res, call + if weave_client_context.get_weave_client() is None: + res = func(*pargs.args, **pargs.kwargs) + return res, call + if not op._tracing_enabled: res = func(*pargs.args, **pargs.kwargs) - elif not op._tracing_enabled: + return res, call + if not get_tracing_enabled(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = func(*pargs.args, **pargs.kwargs) + return res, call + + # Proceed with tracing. Note that we don't check the sample rate here. + # Only root calls get sampling applied. + # If the parent was traced (sampled in), the child will be too. + try: + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + CALL_CREATE_MSG.format(traceback.format_exc()), + ) res = func(*pargs.args, **pargs.kwargs) else: - try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = func(*pargs.args, **pargs.kwargs) - else: - execute_result = _execute_op( - op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + execute_result = _execute_op( + op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs + ) + if inspect.iscoroutine(execute_result): + raise TypeError( + "Internal error: Expected `_execute_call` to return a sync result" ) - if inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a sync result" - ) - execute_result = cast(tuple[Any, "Call"], execute_result) - res, call = execute_result + execute_result = cast(tuple[Any, "Call"], execute_result) + res, call = execute_result return res, call @@ -450,39 +474,52 @@ async def _do_call_async( ) -> tuple[Any, Call]: func = op.resolve_fn call = _placeholder_call() + + # Handle all of the possible cases where we would skip tracing. if settings.should_disable_weave(): res = await func(*args, **kwargs) - elif weave_client_context.get_weave_client() is None: + return res, call + if weave_client_context.get_weave_client() is None: res = await func(*args, **kwargs) - elif not op._tracing_enabled: + return res, call + if not op._tracing_enabled: + res = await func(*args, **kwargs) + return res, call + if not get_tracing_enabled(): + res = await func(*args, **kwargs) + return res, call + + current_call = call_context.get_current_call() + if current_call is None: + # Root call: decide whether to trace based on sample rate + if random.random() > op.tracing_sample_rate: + # Disable tracing for this call and all descendants + with tracing_disabled(): + res = await func(*args, **kwargs) + return res, call + + # Proceed with tracing + try: + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), + ) res = await func(*args, **kwargs) else: - try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = await func(*args, **kwargs) - else: - execute_result = _execute_op( - op, call, *args, __should_raise=__should_raise, **kwargs - ) - if not inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a coroutine" - ) - execute_result = cast( - Coroutine[Any, Any, tuple[Any, "Call"]], execute_result + execute_result = _execute_op( + op, call, *args, __should_raise=__should_raise, **kwargs + ) + if not inspect.iscoroutine(execute_result): + raise TypeError( + "Internal error: Expected `_execute_call` to return a coroutine" ) - res, call = await execute_result + res, call = await execute_result return res, call @@ -540,6 +577,7 @@ def op( call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, + tracing_sample_rate: float = 1.0, ) -> Callable[[Callable], Op] | Op: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -565,6 +603,7 @@ def op( postprocess_output (Optional[Callable[..., Any]]): A function to process the output after it's been returned from the function but before it's logged. This does not affect the actual output of the function, only the displayed output. + tracing_sample_rate (float): The sampling rate for tracing this function. Defaults to 1.0 (always trace). Returns: Union[Callable[[Any], Op], Op]: If called without arguments, returns a decorator. @@ -591,6 +630,10 @@ async def extract(): await extract() # calls the function and tracks the call in the Weave UI ``` """ + if not isinstance(tracing_sample_rate, (int, float)): + raise TypeError("tracing_sample_rate must be a float") + if not 0 <= tracing_sample_rate <= 1: + raise ValueError("tracing_sample_rate must be between 0 and 1") def op_deco(func: Callable) -> Op: # Check function type @@ -647,6 +690,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: wrapper._on_finish_handler = None # type: ignore wrapper._tracing_enabled = True # type: ignore + wrapper.tracing_sample_rate = tracing_sample_rate # type: ignore wrapper.get_captured_code = partial(get_captured_code, wrapper) # type: ignore @@ -736,7 +780,9 @@ def as_op(fn: Callable) -> Op: if not is_op(fn): raise ValueError("fn must be a weave.op() decorated function") - return cast(Op, fn) + # The unbinding is necessary for methods because `MethodType` is applied after the + # func is decorated into an Op. + return maybe_unbind_method(cast(Op, fn)) __docspec__ = [call, calls] diff --git a/weave/trace/patcher.py b/weave/trace/patcher.py index 1567c4e2bb9..c1d0d653ffa 100644 --- a/weave/trace/patcher.py +++ b/weave/trace/patcher.py @@ -17,6 +17,14 @@ def undo_patch(self) -> bool: raise NotImplementedError() +class NoOpPatcher(Patcher): + def attempt_patch(self) -> bool: + return True + + def undo_patch(self) -> bool: + return True + + class MultiPatcher(Patcher): def __init__(self, patchers: Sequence[Patcher]) -> None: self.patchers = patchers diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 86ba7b8653e..1d5d54b9b23 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -10,7 +10,7 @@ from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache -from typing import Any, Callable, cast +from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload import pydantic from requests import HTTPError @@ -90,6 +90,128 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") +R = TypeVar("R", covariant=True) + + +class FetchFunc(Protocol[T]): + def __call__(self, offset: int, limit: int) -> list[T]: ... + + +TransformFunc = Callable[[T], R] + + +class PaginatedIterator(Generic[T, R]): + """An iterator that fetches pages of items from a server and optionally transforms them + into a more user-friendly type.""" + + def __init__( + self, + fetch_func: FetchFunc[T], + page_size: int = 1000, + transform_func: TransformFunc[T, R] | None = None, + ) -> None: + self.fetch_func = fetch_func + self.page_size = page_size + self.transform_func = transform_func + + if page_size <= 0: + raise ValueError("page_size must be greater than 0") + + @lru_cache + def _fetch_page(self, index: int) -> list[T]: + return self.fetch_func(index * self.page_size, self.page_size) + + @overload + def _get_one(self: PaginatedIterator[T, T], index: int) -> T: ... + @overload + def _get_one(self: PaginatedIterator[T, R], index: int) -> R: ... + def _get_one(self, index: int) -> T | R: + if index < 0: + raise IndexError("Negative indexing not supported") + + page_index = index // self.page_size + page_offset = index % self.page_size + + page = self._fetch_page(page_index) + if page_offset >= len(page): + raise IndexError(f"Index {index} out of range") + + res = page[page_offset] + if transform := self.transform_func: + return transform(res) + return res + + @overload + def _get_slice(self: PaginatedIterator[T, T], key: slice) -> Iterator[T]: ... + @overload + def _get_slice(self: PaginatedIterator[T, R], key: slice) -> Iterator[R]: ... + def _get_slice(self, key: slice) -> Iterator[T] | Iterator[R]: + if (start := key.start or 0) < 0: + raise ValueError("Negative start not supported") + if (stop := key.stop) is not None and stop < 0: + raise ValueError("Negative stop not supported") + if (step := key.step or 1) < 0: + raise ValueError("Negative step not supported") + + i = start + while stop is None or i < stop: + try: + yield self._get_one(i) + except IndexError: + break + i += step + + @overload + def __getitem__(self: PaginatedIterator[T, T], key: int) -> T: ... + @overload + def __getitem__(self: PaginatedIterator[T, R], key: int) -> R: ... + @overload + def __getitem__(self: PaginatedIterator[T, T], key: slice) -> list[T]: ... + @overload + def __getitem__(self: PaginatedIterator[T, R], key: slice) -> list[R]: ... + def __getitem__(self, key: slice | int) -> T | R | list[T] | list[R]: + if isinstance(key, slice): + return list(self._get_slice(key)) + return self._get_one(key) + + @overload + def __iter__(self: PaginatedIterator[T, T]) -> Iterator[T]: ... + @overload + def __iter__(self: PaginatedIterator[T, R]) -> Iterator[R]: ... + def __iter__(self) -> Iterator[T] | Iterator[R]: + return self._get_slice(slice(0, None, 1)) + + +# TODO: should be Call, not WeaveObject +CallsIter = PaginatedIterator[CallSchema, WeaveObject] + + +def _make_calls_iterator( + server: TraceServerInterface, + project_id: str, + filter: CallsFilter, + include_costs: bool = False, +) -> CallsIter: + def fetch_func(offset: int, limit: int) -> list[CallSchema]: + response = server.calls_query( + CallsQueryReq( + project_id=project_id, + filter=filter, + offset=offset, + limit=limit, + include_costs=include_costs, + ) + ) + return response.calls + + # TODO: Should be Call, not WeaveObject + def transform_func(call: CallSchema) -> WeaveObject: + entity, project = project_id.split("/") + return make_client_call(entity, project, call, server) + + return PaginatedIterator(fetch_func, transform_func=transform_func) + class OpNameError(ValueError): """Raised when an op name is invalid.""" @@ -239,7 +361,9 @@ def func_name(self) -> str: @property def feedback(self) -> RefFeedbackQuery: if not self.id: - raise ValueError("Can't get feedback for call without ID") + raise ValueError( + "Can't get feedback for call without ID, was `weave.init` called?" + ) if self._feedback is None: try: @@ -253,7 +377,9 @@ def feedback(self) -> RefFeedbackQuery: @property def ui_url(self) -> str: if not self.id: - raise ValueError("Can't get URL for call without ID") + raise ValueError( + "Can't get URL for call without ID, was `weave.init` called?" + ) try: entity, project = self.project_id.split("/") @@ -265,7 +391,9 @@ def ui_url(self) -> str: def ref(self) -> CallRef: entity, project = self.project_id.split("/") if not self.id: - raise ValueError("Can't get ref for call without ID") + raise ValueError( + "Can't get ref for call without ID, was `weave.init` called?" + ) return CallRef(entity, project, self.id) @@ -273,10 +401,12 @@ def ref(self) -> CallRef: def children(self) -> CallsIter: client = weave_client_context.require_weave_client() if not self.id: - raise ValueError("Can't get children of call without ID") + raise ValueError( + "Can't get children of call without ID, was `weave.init` called?" + ) client = weave_client_context.require_weave_client() - return CallsIter( + return _make_calls_iterator( client.server, self.project_id, CallsFilter(parent_ids=[self.id]), @@ -354,80 +484,6 @@ def _apply_scorer(self, scorer_op: Op) -> None: ) -class CallsIter: - server: TraceServerInterface - filter: CallsFilter - include_costs: bool - - def __init__( - self, - server: TraceServerInterface, - project_id: str, - filter: CallsFilter, - include_costs: bool = False, - ) -> None: - self.server = server - self.project_id = project_id - self.filter = filter - self._page_size = 1000 - self.include_costs = include_costs - - # seems like this caching should be on the server, but it's here for now... - @lru_cache - def _fetch_page(self, index: int) -> list[CallSchema]: - # caching in here means that any other CallsIter objects would also - # benefit from the cache - response = self.server.calls_query( - CallsQueryReq( - project_id=self.project_id, - filter=self.filter, - offset=index * self._page_size, - limit=self._page_size, - include_costs=self.include_costs, - ) - ) - return response.calls - - def _get_one(self, index: int) -> WeaveObject: - if index < 0: - raise IndexError("Negative indexing not supported") - - page_index = index // self._page_size - page_offset = index % self._page_size - - calls = self._fetch_page(page_index) - if page_offset >= len(calls): - raise IndexError(f"Index {index} out of range") - - call = calls[page_offset] - entity, project = self.project_id.split("/") - return make_client_call(entity, project, call, self.server) - - def _get_slice(self, key: slice) -> Iterator[WeaveObject]: - if (start := key.start or 0) < 0: - raise ValueError("Negative start not supported") - if (stop := key.stop) is not None and stop < 0: - raise ValueError("Negative stop not supported") - if (step := key.step or 1) < 0: - raise ValueError("Negative step not supported") - - i = start - while stop is None or i < stop: - try: - yield self._get_one(i) - except IndexError: - break - i += step - - def __getitem__(self, key: slice | int) -> WeaveObject | list[WeaveObject]: - if isinstance(key, slice): - return list(self._get_slice(key)) - return self._get_one(key) - - def __iter__(self) -> Iterator[WeaveObject]: - return self._get_slice(slice(0, None, 1)) - - def make_client_call( entity: str, project: str, server_call: CallSchema, server: TraceServerInterface ) -> WeaveObject: @@ -634,8 +690,11 @@ def get_calls( if filter is None: filter = CallsFilter() - return CallsIter( - self.server, self._project_id(), filter, include_costs or False + return _make_calls_iterator( + self.server, + self._project_id(), + filter, + include_costs, ) @deprecated(new_name="get_calls") diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index 563dcbdaed4..f51d42d5018 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -63,7 +63,9 @@ def get_entity_project_from_project_name(project_name: str) -> tuple[str, str]: def init_weave( - project_name: str, ensure_project_exists: bool = True + project_name: str, + ensure_project_exists: bool = True, + autopatch_settings: autopatch.AutopatchSettings | None = None, ) -> InitializedClient: global _current_inited_client if _current_inited_client is not None: @@ -120,7 +122,7 @@ def init_weave( # autopatching is only supported for the wandb client, because OpenAI calls are not # logged in local mode currently. When that's fixed, this autopatch call can be # moved to InitializedClient.__init__ - autopatch.autopatch() + autopatch.autopatch(autopatch_settings) username = get_username() try: diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 943ef31b0b8..d40d7bcc2a3 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -95,6 +95,7 @@ validate_cost_purge_req, ) from weave.trace_server.trace_server_common import ( + DynamicBatchProcessor, LRUCache, digest_is_version_like, empty_str_to_none, @@ -120,6 +121,7 @@ FILE_CHUNK_SIZE = 100000 MAX_DELETE_CALLS_COUNT = 100 +INITIAL_CALLS_STREAM_BATCH_SIZE = 100 MAX_CALLS_STREAM_BATCH_SIZE = 500 @@ -343,68 +345,47 @@ def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema] ) select_columns = [c.field for c in cq.select_fields] + expand_columns = req.expand_columns or [] + include_feedback = req.include_feedback or False - if not req.expand_columns and not req.include_feedback: - for row in raw_res: - yield tsi.CallSchema.model_validate( - _ch_call_dict_to_call_schema_dict(dict(zip(select_columns, row))) - ) - - else: - expand_columns = req.expand_columns or [] - ref_cache = LRUCache(max_size=1000) + def row_to_call_schema_dict(row: tuple[Any, ...]) -> dict[str, Any]: + return _ch_call_dict_to_call_schema_dict(dict(zip(select_columns, row))) - batch_size = 10 - batch = [] + if not expand_columns and not include_feedback: for row in raw_res: - call_dict = _ch_call_dict_to_call_schema_dict( - dict(zip(select_columns, row)) + yield tsi.CallSchema.model_validate(row_to_call_schema_dict(row)) + return + + ref_cache = LRUCache(max_size=1000) + batch_processor = DynamicBatchProcessor( + initial_size=INITIAL_CALLS_STREAM_BATCH_SIZE, + max_size=MAX_CALLS_STREAM_BATCH_SIZE, + growth_factor=10, + ) + + for batch in batch_processor.make_batches(raw_res): + call_dicts = [row_to_call_schema_dict(row) for row in batch] + if expand_columns: + self._expand_call_refs( + req.project_id, call_dicts, expand_columns, ref_cache ) - batch.append(call_dict) - if len(batch) >= batch_size: - hydrated_batch = self._hydrate_calls( - req.project_id, - batch, - expand_columns, - req.include_feedback or False, - ref_cache, - ) - for call in hydrated_batch: - yield tsi.CallSchema.model_validate(call) - - # *** Dynamic increase from 10 to 500 *** - batch_size = min(MAX_CALLS_STREAM_BATCH_SIZE, batch_size * 10) - batch = [] - - hydrated_batch = self._hydrate_calls( - req.project_id, - batch, - expand_columns, - req.include_feedback or False, - ref_cache, - ) - for call in hydrated_batch: + if include_feedback: + self._add_feedback_to_calls(req.project_id, call_dicts) + + for call in call_dicts: yield tsi.CallSchema.model_validate(call) - def _hydrate_calls( - self, - project_id: str, - calls: list[dict[str, Any]], - expand_columns: list[str], - include_feedback: bool, - ref_cache: LRUCache, - ) -> list[dict[str, Any]]: + def _add_feedback_to_calls( + self, project_id: str, calls: list[dict[str, Any]] + ) -> None: if len(calls) == 0: - return calls + return - self._expand_call_refs(project_id, calls, expand_columns, ref_cache) - if include_feedback: - feedback_query_req = make_feedback_query_req(project_id, calls) + feedback_query_req = make_feedback_query_req(project_id, calls) + with self.with_new_client(): feedback = self.feedback_query(feedback_query_req) - hydrate_calls_with_feedback(calls, feedback) - - return calls + hydrate_calls_with_feedback(calls, feedback) def _get_refs_to_resolve( self, calls: list[dict[str, Any]], expand_columns: list[str] @@ -436,6 +417,9 @@ def _expand_call_refs( expand_columns: list[str], ref_cache: LRUCache, ) -> None: + if len(calls) == 0: + return + # format expand columns by depth, iterate through each batch in order expand_column_by_depth = defaultdict(list) for col in expand_columns: @@ -448,9 +432,10 @@ def _expand_call_refs( if not refs_to_resolve: continue - vals = self._refs_read_batch_within_project( - project_id, list(refs_to_resolve.values()), ref_cache - ) + with self.with_new_client(): + vals = self._refs_read_batch_within_project( + project_id, list(refs_to_resolve.values()), ref_cache + ) for ((i, col), ref), val in zip(refs_to_resolve.items(), vals): if isinstance(val, dict) and "_ref" not in val: val["_ref"] = ref.uri() @@ -1521,6 +1506,7 @@ def completions_create( # Private Methods @property def ch_client(self) -> CHClient: + """Returns and creates (if necessary) the clickhouse client""" if not hasattr(self._thread_local, "ch_client"): self._thread_local.ch_client = self._mint_client() return self._thread_local.ch_client @@ -1538,6 +1524,26 @@ def _mint_client(self) -> CHClient: client.database = self._database return client + @contextmanager + def with_new_client(self) -> Iterator[None]: + """Context manager to use a new client for operations. + Each call gets a fresh client with its own clickhouse session ID. + + Usage: + ``` + with self.with_new_client(): + self.feedback_query(req) + ``` + """ + client = self._mint_client() + original_client = self.ch_client + self._thread_local.ch_client = client + try: + yield + finally: + self._thread_local.ch_client = original_client + client.close() + # def __del__(self) -> None: # self.ch_client.close() diff --git a/weave/trace_server/clickhouse_trace_server_migrator.py b/weave/trace_server/clickhouse_trace_server_migrator.py index 30dffe89365..4336630bf50 100644 --- a/weave/trace_server/clickhouse_trace_server_migrator.py +++ b/weave/trace_server/clickhouse_trace_server_migrator.py @@ -1,6 +1,7 @@ # Clickhouse Trace Server Manager import logging import os +import re from typing import Optional from clickhouse_connect.driver.client import Client as CHClient @@ -9,6 +10,11 @@ logger = logging.getLogger(__name__) +# These settings are only used when `replicated` mode is enabled for +# self managed clickhouse instances. +DEFAULT_REPLICATED_PATH = "/clickhouse/tables/{db}" +DEFAULT_REPLICATED_CLUSTER = "weave_cluster" + class MigrationError(RuntimeError): """Raised when a migration error occurs.""" @@ -16,15 +22,77 @@ class MigrationError(RuntimeError): class ClickHouseTraceServerMigrator: ch_client: CHClient + replicated: bool + replicated_path: str + replicated_cluster: str def __init__( self, ch_client: CHClient, + replicated: Optional[bool] = None, + replicated_path: Optional[str] = None, + replicated_cluster: Optional[str] = None, ): super().__init__() self.ch_client = ch_client + self.replicated = False if replicated is None else replicated + self.replicated_path = ( + DEFAULT_REPLICATED_PATH if replicated_path is None else replicated_path + ) + self.replicated_cluster = ( + DEFAULT_REPLICATED_CLUSTER + if replicated_cluster is None + else replicated_cluster + ) self._initialize_migration_db() + def _is_safe_identifier(self, value: str) -> bool: + """Check if a string is safe to use as an identifier in SQL.""" + return bool(re.match(r"^[a-zA-Z0-9_\.]+$", value)) + + def _format_replicated_sql(self, sql_query: str) -> str: + """Format SQL query to use replicated engines if replicated mode is enabled.""" + if not self.replicated: + return sql_query + + # Match "ENGINE = MergeTree" followed by word boundary + pattern = r"ENGINE\s*=\s*(\w+)?MergeTree\b" + + def replace_engine(match: re.Match[str]) -> str: + engine_prefix = match.group(1) or "" + return f"ENGINE = Replicated{engine_prefix}MergeTree" + + return re.sub(pattern, replace_engine, sql_query, flags=re.IGNORECASE) + + def _create_db_sql(self, db_name: str) -> str: + """Geneate SQL database create string for normal and replicated databases.""" + if not self._is_safe_identifier(db_name): + raise MigrationError(f"Invalid database name: {db_name}") + + replicated_engine = "" + replicated_cluster = "" + if self.replicated: + if not self._is_safe_identifier(self.replicated_cluster): + raise MigrationError(f"Invalid cluster name: {self.replicated_cluster}") + + replicated_path = self.replicated_path.replace("{db}", db_name) + if not all( + self._is_safe_identifier(part) + for part in replicated_path.split("/") + if part + ): + raise MigrationError(f"Invalid replicated path: {replicated_path}") + + replicated_cluster = f" ON CLUSTER {self.replicated_cluster}" + replicated_engine = ( + f" ENGINE=Replicated('{replicated_path}', '{{shard}}', '{{replica}}')" + ) + + create_db_sql = f""" + CREATE DATABASE IF NOT EXISTS {db_name}{replicated_cluster}{replicated_engine} + """ + return create_db_sql + def apply_migrations( self, target_db: str, target_version: Optional[int] = None ) -> None: @@ -46,20 +114,15 @@ def apply_migrations( return logger.info(f"Migrations to apply: {migrations_to_apply}") if status["curr_version"] == 0: - self.ch_client.command(f"CREATE DATABASE IF NOT EXISTS {target_db}") + self.ch_client.command(self._create_db_sql(target_db)) for target_version, migration_file in migrations_to_apply: self._apply_migration(target_db, target_version, migration_file) if should_insert_costs(status["curr_version"], target_version): insert_costs(self.ch_client, target_db) def _initialize_migration_db(self) -> None: - self.ch_client.command( - """ - CREATE DATABASE IF NOT EXISTS db_management - """ - ) - self.ch_client.command( - """ + self.ch_client.command(self._create_db_sql("db_management")) + create_table_sql = """ CREATE TABLE IF NOT EXISTS db_management.migrations ( db_name String, @@ -69,7 +132,7 @@ def _initialize_migration_db(self) -> None: ENGINE = MergeTree() ORDER BY (db_name) """ - ) + self.ch_client.command(self._format_replicated_sql(create_table_sql)) def _get_migration_status(self, db_name: str) -> dict: column_names = ["db_name", "curr_version", "partially_applied_version"] @@ -184,31 +247,48 @@ def _determine_migrations_to_apply( return [] + def _execute_migration_command(self, target_db: str, command: str) -> None: + """Execute a single migration command in the context of the target database.""" + command = command.strip() + if len(command) == 0: + return + curr_db = self.ch_client.database + self.ch_client.database = target_db + self.ch_client.command(self._format_replicated_sql(command)) + self.ch_client.database = curr_db + + def _update_migration_status( + self, target_db: str, target_version: int, is_start: bool = True + ) -> None: + """Update the migration status in db_management.migrations table.""" + if is_start: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}'" + ) + else: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}'" + ) + def _apply_migration( self, target_db: str, target_version: int, migration_file: str ) -> None: logger.info(f"Applying migration {migration_file} to `{target_db}`") migration_dir = os.path.join(os.path.dirname(__file__), "migrations") migration_file_path = os.path.join(migration_dir, migration_file) + with open(migration_file_path) as f: migration_sql = f.read() - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}' - """ - ) + + # Mark migration as partially applied + self._update_migration_status(target_db, target_version, is_start=True) + + # Execute each command in the migration migration_sub_commands = migration_sql.split(";") for command in migration_sub_commands: - command = command.strip() - if len(command) == 0: - continue - curr_db = self.ch_client.database - self.ch_client.database = target_db - self.ch_client.command(command) - self.ch_client.database = curr_db - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}' - """ - ) + self._execute_migration_command(target_db, command) + + # Mark migration as fully applied + self._update_migration_status(target_db, target_version, is_start=False) + logger.info(f"Migration {migration_file} applied to `{target_db}`") diff --git a/weave/trace_server/trace_server_common.py b/weave/trace_server/trace_server_common.py index 0ff14d4396b..0691927bc47 100644 --- a/weave/trace_server/trace_server_common.py +++ b/weave/trace_server/trace_server_common.py @@ -1,6 +1,7 @@ import copy import datetime from collections import OrderedDict, defaultdict +from collections.abc import Iterator from typing import Any, Optional, cast from weave.trace_server import refs_internal as ri @@ -170,6 +171,33 @@ def __setitem__(self, key: str, value: Any) -> None: super().__setitem__(key, value) +class DynamicBatchProcessor: + """Helper class to handle dynamic batch processing with growing batch sizes.""" + + def __init__(self, initial_size: int, max_size: int, growth_factor: int): + self.batch_size = initial_size + self.max_size = max_size + self.growth_factor = growth_factor + + def make_batches(self, iterator: Iterator[Any]) -> Iterator[list[Any]]: + batch = [] + + for item in iterator: + batch.append(item) + + if len(batch) >= self.batch_size: + yield batch + + batch = [] + self.batch_size = self._compute_batch_size() + + if batch: + yield batch + + def _compute_batch_size(self) -> int: + return min(self.max_size, self.batch_size * self.growth_factor) + + def digest_is_version_like(digest: str) -> tuple[bool, int]: """ Check if a digest is a version like string. diff --git a/weave/version.py b/weave/version.py index 5212f6aee7d..70f670abc21 100644 --- a/weave/version.py +++ b/weave/version.py @@ -44,4 +44,4 @@ """ -VERSION = "0.51.24-dev0" +VERSION = "0.51.25-dev0"