diff --git a/docs/docs/guides/tracking/ops.md b/docs/docs/guides/tracking/ops.md index b83ff69a72a..48ac1e5ff2a 100644 --- a/docs/docs/guides/tracking/ops.md +++ b/docs/docs/guides/tracking/ops.md @@ -42,22 +42,28 @@ If you want to change the data that is logged to weave without modifying the ori `postprocess_output` takes in any value which would normally be returned by the function and returns the transformed output. ```py -def postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: - return {k:v for k,v in inputs.items() if k != "hide_me"} - -def postprocess_output(output: CustomObject) -> CustomObject: - return CustomObject(x=output.x, secret_password="REDACTED") - +from dataclasses import dataclass +from typing import Any +import weave @dataclass class CustomObject: x: int secret_password: str +def postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]: + return {k:v for k,v in inputs.items() if k != "hide_me"} + +def postprocess_output(output: CustomObject) -> CustomObject: + return CustomObject(x=output.x, secret_password="REDACTED") + @weave.op( postprocess_inputs=postprocess_inputs, postprocess_output=postprocess_output, ) def func(a: int, hide_me: str) -> CustomObject: return CustomObject(x=a, secret_password=hide_me) + +weave.init('hide-data-example') # 🐝 +func(a=1, hide_me="password123") ``` diff --git a/pyproject.toml b/pyproject.toml index e5e714455fd..9c0cb00d680 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,13 +2,17 @@ name = "weave" description = "A toolkit for building composable interactive data driven applications." readme = "README.md" -license = { text = "Apache-2.0" } +license = { file = "LICENSE" } maintainers = [{ name = "W&B", email = "support@wandb.com" }] authors = [ { name = "Shawn Lewis", email = "shawn@wandb.com" }, - { name = "Danny Goldstein", email = "danny@wandb.com" }, { name = "Tim Sweeney", email = "tim@wandb.com" }, { name = "Nick Peneranda", email = "nick.penaranda@wandb.com" }, + { name = "Jeff Raubitschek", email = "jeff@wandb.com" }, + { name = "Jamie Rasmussen", email = "jamie.rasmussen@wandb.com" }, + { name = "Griffin Tarpenning", email = "griffin.tarpenning@wandb.com" }, + { name = "Josiah Lee", email = "josiah.lee@wandb.com" }, + { name = "Andrew Truong", email = "andrew@wandb.com" }, ] classifiers = [ "Development Status :: 4 - Beta", @@ -38,8 +42,8 @@ dependencies = [ "tenacity>=8.3.0,!=8.4.0", # Excluding 8.4.0 because it had a bug on import of AsyncRetrying "emoji>=2.12.1", # For emoji shortcode support in Feedback "uuid-utils>=0.9.0", # Used for ID generation - remove once python's built-in uuid supports UUIDv7 - "numpy>1.21.0", - "rich", + "numpy>1.21.0", # Used in box.py (should be made optional) + "rich", # Used for special formatting of tables (should be made optional) # dependencies for remaining legacy code. Remove when possible "httpx", diff --git a/tests/trace/test_client_trace.py b/tests/trace/test_client_trace.py index 4e902fd5673..28b29e490ee 100644 --- a/tests/trace/test_client_trace.py +++ b/tests/trace/test_client_trace.py @@ -2663,6 +2663,31 @@ def return_nested_object(nested_obj: NestedObject): assert call_result.output == nested_ref.uri() +# Batch size is dynamically increased from 10 to MAX_CALLS_STREAM_BATCH_SIZE (500) +# in clickhouse_trace_server_batched.py, this test verifies that the dynamic +# increase works as expected +@pytest.mark.parametrize("batch_size", [1, 10, 100, 110]) +def test_calls_stream_column_expansion_dynamic_batch_size(client, batch_size): + @weave.op + def test_op(x): + return x + + for i in range(batch_size): + test_op(i) + + res = client.server.calls_query_stream( + tsi.CallsQueryReq( + project_id=client._project_id(), + columns=["output"], + expand_columns=["output"], + ) + ) + calls = list(res) + assert len(calls) == batch_size + for i in range(batch_size): + assert calls[i].output == i + + class Custom(weave.Object): val: dict @@ -2792,16 +2817,18 @@ def test(obj: Custom): def test_calls_stream_feedback(client): + BATCH_SIZE = 10 + num_calls = BATCH_SIZE + 1 + @weave.op def test_call(x): return "ello chap" - test_call(1) - test_call(2) - test_call(3) + for i in range(num_calls): + test_call(i) calls = list(test_call.calls()) - assert len(calls) == 3 + assert len(calls) == num_calls # add feedback to the first call calls[0].feedback.add("note", {"note": "this is a note on call1"}) @@ -2820,7 +2847,7 @@ def test_call(x): ) calls = list(res) - assert len(calls) == 3 + assert len(calls) == num_calls assert len(calls[0].summary["weave"]["feedback"]) == 4 assert len(calls[1].summary["weave"]["feedback"]) == 1 assert not calls[2].summary.get("weave", {}).get("feedback") diff --git a/tests/trace/test_weave_client.py b/tests/trace/test_weave_client.py index c482a14d454..5d081f8ce63 100644 --- a/tests/trace/test_weave_client.py +++ b/tests/trace/test_weave_client.py @@ -27,7 +27,13 @@ TABLE_ROW_ID_EDGE_NAME, ) from weave.trace.serializer import get_serializer_for_obj, register_serializer -from weave.trace_server.sqlite_trace_server import SqliteTraceServer +from weave.trace_server.clickhouse_trace_server_batched import NotFoundError +from weave.trace_server.sqlite_trace_server import ( + NotFoundError as sqliteNotFoundError, +) +from weave.trace_server.sqlite_trace_server import ( + SqliteTraceServer, +) from weave.trace_server.trace_server_interface import ( FileContentReadReq, FileCreateReq, @@ -1436,7 +1442,31 @@ def test_object_version_read(client): assert obj_res.obj.val == {"a": 9} assert obj_res.obj.version_index == 9 - # now grab version 5 + # now grab each by their digests + for i, digest in enumerate([obj.digest for obj in objs]): + obj_res = client.server.obj_read( + tsi.ObjReadReq( + project_id=client._project_id(), + object_id=refs[0].name, + digest=digest, + ) + ) + assert obj_res.obj.val == {"a": i} + assert obj_res.obj.version_index == i + + # publish another, check that latest is updated + client._save_object({"a": 10}, refs[0].name) + obj_res = client.server.obj_read( + tsi.ObjReadReq( + project_id=client._project_id(), + object_id=refs[0].name, + digest="latest", + ) + ) + assert obj_res.obj.val == {"a": 10} + assert obj_res.obj.version_index == 10 + + # check that v5 is still correct obj_res = client.server.obj_read( tsi.ObjReadReq( project_id=client._project_id(), @@ -1446,3 +1476,26 @@ def test_object_version_read(client): ) assert obj_res.obj.val == {"a": 5} assert obj_res.obj.version_index == 5 + + # check badly formatted digests + digests = ["v1111", "1", ""] + for digest in digests: + with pytest.raises((NotFoundError, sqliteNotFoundError)): + # grab non-existant version + obj_res = client.server.obj_read( + tsi.ObjReadReq( + project_id=client._project_id(), + object_id=refs[0].name, + digest=digest, + ) + ) + + # check non-existant object_id + with pytest.raises((NotFoundError, sqliteNotFoundError)): + obj_res = client.server.obj_read( + tsi.ObjReadReq( + project_id=client._project_id(), + object_id="refs[0].name", + digest="v1", + ) + ) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 08a244604f7..75fe1f7b8e7 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -28,7 +28,7 @@ import json import logging import threading -from collections import Counter, defaultdict +from collections import defaultdict from contextlib import contextmanager from typing import ( Any, @@ -114,6 +114,7 @@ FILE_CHUNK_SIZE = 100000 MAX_DELETE_CALLS_COUNT = 100 +MAX_CALLS_STREAM_BATCH_SIZE = 500 class NotFoundError(Exception): @@ -356,15 +357,8 @@ def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema] for call in hydrated_batch: yield tsi.CallSchema.model_validate(call) - # *** Dynamic Batch Size *** - # count the number of columns at each depth - depths = Counter(col.count(".") for col in expand_columns) - # take the max number of columns at any depth - max_count_at_ref_depth = max(depths.values()) - # divide max refs that we can resolve 1000 refs at any depth - max_size = 1000 // max_count_at_ref_depth - # double batch size up to what refs_read_batch can handle - batch_size = min(max_size, batch_size * 2) + # *** Dynamic increase from 10 to 500 *** + batch_size = min(MAX_CALLS_STREAM_BATCH_SIZE, batch_size * 10) batch = [] hydrated_batch = self._hydrate_calls( @@ -1440,9 +1434,8 @@ def _select_objs_query( parameters to be passed to the query. Must include all parameters for both conditions and object_id_conditions. metadata_only: - if metadata_only is True, then we exclude the val_dump field in the select query. - generally, "queries" should not include the val_dump, but "reads" should, as - the val_dump is the most expensive part of the query. + if metadata_only is True, then we return early and dont grab the value. + Otherwise, make a second query to grab the val_dump from the db """ if not conditions: conditions = ["1 = 1"] @@ -1475,11 +1468,7 @@ def _select_objs_query( if parameters is None: parameters = {} - # When metadata_only is false, dont actually read from the field - val_dump_field = "'{}' AS val_dump" if metadata_only else "val_dump" - - # The subquery is for deduplication of object versions by digest - select_query = f""" + select_without_val_dump_query = f""" SELECT project_id, object_id, @@ -1487,7 +1476,6 @@ def _select_objs_query( kind, base_object_class, refs, - val_dump, digest, is_op, version_index, @@ -1500,7 +1488,6 @@ def _select_objs_query( kind, base_object_class, refs, - val_dump, digest, is_op, row_number() OVER ( @@ -1518,7 +1505,6 @@ def _select_objs_query( kind, base_object_class, refs, - {val_dump_field}, digest, if (kind = 'op', 1, 0) AS is_op, row_number() OVER ( @@ -1540,7 +1526,7 @@ def _select_objs_query( {offset_part} """ query_result = self._query_stream( - select_query, + select_without_val_dump_query, {"project_id": project_id, **parameters}, ) result: list[SelectableCHObjSchema] = [] @@ -1556,19 +1542,50 @@ def _select_objs_query( "kind", "base_object_class", "refs", - "val_dump", "digest", "is_op", "version_index", "version_count", "is_latest", + "val_dump", ], - row, + # Add an empty val_dump to the end of the row + list(row) + ["{}"], ) ) ) ) + # -- Don't make second query for object values if metadata_only -- + if metadata_only: + return result + + # now get the val_dump for each object + object_ids = list(set([row.object_id for row in result])) + digests = list(set([row.digest for row in result])) + query = """ + SELECT object_id, digest, any(val_dump) + FROM object_versions + WHERE project_id = {project_id: String} AND + object_id IN {object_ids: Array(String)} AND + digest IN {digests: Array(String)} + GROUP BY object_id, digest + """ + parameters = { + "project_id": project_id, + "object_ids": object_ids, + "digests": digests, + } + query_result = self._query_stream(query, parameters) + # Map (object_id, digest) to val_dump + object_values: Dict[tuple[str, str], Any] = {} + for row in query_result: + (object_id, digest, val_dump) = row + object_values[(object_id, digest)] = val_dump + + # update the val_dump for each object + for obj in result: + obj.val_dump = object_values.get((obj.object_id, obj.digest), "{}") return result def _run_migrations(self) -> None: @@ -1581,7 +1598,7 @@ def _query_stream( query: str, parameters: Dict[str, Any], column_formats: Optional[Dict[str, Any]] = None, - ) -> Iterator[QueryResult]: + ) -> Iterator[tuple]: """Streams the results of a query from the database.""" summary = None parameters = _process_parameters(parameters) diff --git a/weave/trace_server/table_query_builder.py b/weave/trace_server/table_query_builder.py index 0a38a8245ec..c5c204c388e 100644 --- a/weave/trace_server/table_query_builder.py +++ b/weave/trace_server/table_query_builder.py @@ -23,30 +23,33 @@ def make_natural_sort_table_query( """ project_id_name = pb.add_param(project_id) digest_name = pb.add_param(digest) - sql_safe_dir = "ASC" if natural_direction == "ASC" else "DESC" - sql_safe_limit = ( - f"LIMIT {{{pb.add_param(limit)}: Int64}}" if limit is not None else "" - ) - sql_safe_offset = ( - f"OFFSET {{{pb.add_param(offset)}: Int64}}" if offset is not None else "" - ) + row_digests_selection = "row_digests" + if natural_direction.lower() == "desc": + row_digests_selection = f"reverse({row_digests_selection})" + if limit is not None and offset is None: + offset = 0 + if offset is not None: + if limit is None: + row_digests_selection = f"arraySlice({row_digests_selection}, 1 + {{{pb.add_param(offset)}: Int64}})" + else: + row_digests_selection = f"arraySlice({row_digests_selection}, 1 + {{{pb.add_param(offset)}: Int64}}, {{{pb.add_param(limit)}: Int64}})" query = f""" SELECT DISTINCT tr.digest, tr.val_dump, t.row_order FROM table_rows tr - RIGHT JOIN ( + INNER JOIN ( SELECT row_digest, row_number() OVER () AS row_order - FROM tables + FROM ( + SELECT {row_digests_selection} as row_digests + FROM tables + WHERE project_id = {{{project_id_name}: String}} + AND digest = {{{digest_name}: String}} + ) ARRAY JOIN row_digests AS row_digest - WHERE project_id = {{{project_id_name}: String}} - AND digest = {{{digest_name}: String}} - ORDER BY row_order {sql_safe_dir} - {sql_safe_limit} - {sql_safe_offset} ) AS t ON tr.digest = t.row_digest WHERE tr.project_id = {{{project_id_name}: String}} - ORDER BY row_order {sql_safe_dir} + ORDER BY row_order ASC """ return query @@ -88,12 +91,15 @@ def make_standard_table_query( ( SELECT DISTINCT tr.digest, tr.val_dump, t.row_order FROM table_rows tr - RIGHT JOIN ( + INNER JOIN ( SELECT row_digest, row_number() OVER () AS row_order - FROM tables + FROM ( + SELECT row_digests + FROM tables + WHERE project_id = {{{project_id_name}: String}} + AND digest = {{{digest_name}: String}} + ) ARRAY JOIN row_digests AS row_digest - WHERE project_id = {{{project_id_name}: String}} - AND digest = {{{digest_name}: String}} ) AS t ON tr.digest = t.row_digest WHERE tr.project_id = {{{project_id_name}: String}} {sql_safe_filter_clause} diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 78f163749cf..fd8fd5d3fb6 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -57,7 +57,7 @@ def _is_retryable_exception(e: Exception) -> bool: # Unknown server error # TODO(np): We need to fix the server to return proper status codes # for downstream 401, 403, 404, etc... Those should propagate back to - # the clien + # the client. if e.response.status_code == 500: return False @@ -263,13 +263,13 @@ def call_start( req_as_obj = tsi.CallStartReq.model_validate(req) else: req_as_obj = req - if req_as_obj.starid == None or req_as_obj.startrace_id == None: + if req_as_obj.start.id == None or req_as_obj.start.trace_id == None: raise ValueError( "CallStartReq must have id and trace_id when batching." ) self.call_processor.enqueue([StartBatchItem(req=req_as_obj)]) return tsi.CallStartRes( - id=req_as_obj.starid, trace_id=req_as_obj.startrace_id + id=req_as_obj.start.id, trace_id=req_as_obj.start.trace_id ) return self._generic_request( "/call/start", req, tsi.CallStartReq, tsi.CallStartRes @@ -362,7 +362,7 @@ def table_create( """Similar to `calls/batch_upsert`, we can dynamically adjust the payload size due to the property that table creation can be decomposed into a series of updates. This is useful when the table creation size is too big to be sent in - a single reques We can create an empty table first, then update the table + a single request. We can create an empty table first, then update the table with the rows. """ if isinstance(req, dict):