Skip to content

Commit

Permalink
Merge branch 'andrew/tests3' into andrew/tests4
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong authored Oct 2, 2024
2 parents 07f1fe5 + 6b6b17a commit 9487a4e
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 65 deletions.
18 changes: 12 additions & 6 deletions docs/docs/guides/tracking/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,28 @@ If you want to change the data that is logged to weave without modifying the ori
`postprocess_output` takes in any value which would normally be returned by the function and returns the transformed output.

```py
def postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
return {k:v for k,v in inputs.items() if k != "hide_me"}

def postprocess_output(output: CustomObject) -> CustomObject:
return CustomObject(x=output.x, secret_password="REDACTED")

from dataclasses import dataclass
from typing import Any
import weave

@dataclass
class CustomObject:
x: int
secret_password: str

def postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
return {k:v for k,v in inputs.items() if k != "hide_me"}

def postprocess_output(output: CustomObject) -> CustomObject:
return CustomObject(x=output.x, secret_password="REDACTED")

@weave.op(
postprocess_inputs=postprocess_inputs,
postprocess_output=postprocess_output,
)
def func(a: int, hide_me: str) -> CustomObject:
return CustomObject(x=a, secret_password=hide_me)

weave.init('hide-data-example') # 🐝
func(a=1, hide_me="password123")
```
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
name = "weave"
description = "A toolkit for building composable interactive data driven applications."
readme = "README.md"
license = { text = "Apache-2.0" }
license = { file = "LICENSE" }
maintainers = [{ name = "W&B", email = "[email protected]" }]
authors = [
{ name = "Shawn Lewis", email = "[email protected]" },
{ name = "Danny Goldstein", email = "[email protected]" },
{ name = "Tim Sweeney", email = "[email protected]" },
{ name = "Nick Peneranda", email = "[email protected]" },
{ name = "Jeff Raubitschek", email = "[email protected]" },
{ name = "Jamie Rasmussen", email = "[email protected]" },
{ name = "Griffin Tarpenning", email = "[email protected]" },
{ name = "Josiah Lee", email = "[email protected]" },
{ name = "Andrew Truong", email = "[email protected]" },
]
classifiers = [
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -38,8 +42,8 @@ dependencies = [
"tenacity>=8.3.0,!=8.4.0", # Excluding 8.4.0 because it had a bug on import of AsyncRetrying
"emoji>=2.12.1", # For emoji shortcode support in Feedback
"uuid-utils>=0.9.0", # Used for ID generation - remove once python's built-in uuid supports UUIDv7
"numpy>1.21.0",
"rich",
"numpy>1.21.0", # Used in box.py (should be made optional)
"rich", # Used for special formatting of tables (should be made optional)

# dependencies for remaining legacy code. Remove when possible
"httpx",
Expand Down
37 changes: 32 additions & 5 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2663,6 +2663,31 @@ def return_nested_object(nested_obj: NestedObject):
assert call_result.output == nested_ref.uri()


# Batch size is dynamically increased from 10 to MAX_CALLS_STREAM_BATCH_SIZE (500)
# in clickhouse_trace_server_batched.py, this test verifies that the dynamic
# increase works as expected
@pytest.mark.parametrize("batch_size", [1, 10, 100, 110])
def test_calls_stream_column_expansion_dynamic_batch_size(client, batch_size):
@weave.op
def test_op(x):
return x

for i in range(batch_size):
test_op(i)

res = client.server.calls_query_stream(
tsi.CallsQueryReq(
project_id=client._project_id(),
columns=["output"],
expand_columns=["output"],
)
)
calls = list(res)
assert len(calls) == batch_size
for i in range(batch_size):
assert calls[i].output == i


class Custom(weave.Object):
val: dict

Expand Down Expand Up @@ -2792,16 +2817,18 @@ def test(obj: Custom):


def test_calls_stream_feedback(client):
BATCH_SIZE = 10
num_calls = BATCH_SIZE + 1

@weave.op
def test_call(x):
return "ello chap"

test_call(1)
test_call(2)
test_call(3)
for i in range(num_calls):
test_call(i)

calls = list(test_call.calls())
assert len(calls) == 3
assert len(calls) == num_calls

# add feedback to the first call
calls[0].feedback.add("note", {"note": "this is a note on call1"})
Expand All @@ -2820,7 +2847,7 @@ def test_call(x):
)
calls = list(res)

assert len(calls) == 3
assert len(calls) == num_calls
assert len(calls[0].summary["weave"]["feedback"]) == 4
assert len(calls[1].summary["weave"]["feedback"]) == 1
assert not calls[2].summary.get("weave", {}).get("feedback")
Expand Down
57 changes: 55 additions & 2 deletions tests/trace/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
TABLE_ROW_ID_EDGE_NAME,
)
from weave.trace.serializer import get_serializer_for_obj, register_serializer
from weave.trace_server.sqlite_trace_server import SqliteTraceServer
from weave.trace_server.clickhouse_trace_server_batched import NotFoundError
from weave.trace_server.sqlite_trace_server import (
NotFoundError as sqliteNotFoundError,
)
from weave.trace_server.sqlite_trace_server import (
SqliteTraceServer,
)
from weave.trace_server.trace_server_interface import (
FileContentReadReq,
FileCreateReq,
Expand Down Expand Up @@ -1436,7 +1442,31 @@ def test_object_version_read(client):
assert obj_res.obj.val == {"a": 9}
assert obj_res.obj.version_index == 9

# now grab version 5
# now grab each by their digests
for i, digest in enumerate([obj.digest for obj in objs]):
obj_res = client.server.obj_read(
tsi.ObjReadReq(
project_id=client._project_id(),
object_id=refs[0].name,
digest=digest,
)
)
assert obj_res.obj.val == {"a": i}
assert obj_res.obj.version_index == i

# publish another, check that latest is updated
client._save_object({"a": 10}, refs[0].name)
obj_res = client.server.obj_read(
tsi.ObjReadReq(
project_id=client._project_id(),
object_id=refs[0].name,
digest="latest",
)
)
assert obj_res.obj.val == {"a": 10}
assert obj_res.obj.version_index == 10

# check that v5 is still correct
obj_res = client.server.obj_read(
tsi.ObjReadReq(
project_id=client._project_id(),
Expand All @@ -1446,3 +1476,26 @@ def test_object_version_read(client):
)
assert obj_res.obj.val == {"a": 5}
assert obj_res.obj.version_index == 5

# check badly formatted digests
digests = ["v1111", "1", ""]
for digest in digests:
with pytest.raises((NotFoundError, sqliteNotFoundError)):
# grab non-existant version
obj_res = client.server.obj_read(
tsi.ObjReadReq(
project_id=client._project_id(),
object_id=refs[0].name,
digest=digest,
)
)

# check non-existant object_id
with pytest.raises((NotFoundError, sqliteNotFoundError)):
obj_res = client.server.obj_read(
tsi.ObjReadReq(
project_id=client._project_id(),
object_id="refs[0].name",
digest="v1",
)
)
67 changes: 42 additions & 25 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import json
import logging
import threading
from collections import Counter, defaultdict
from collections import defaultdict
from contextlib import contextmanager
from typing import (
Any,
Expand Down Expand Up @@ -114,6 +114,7 @@
FILE_CHUNK_SIZE = 100000

MAX_DELETE_CALLS_COUNT = 100
MAX_CALLS_STREAM_BATCH_SIZE = 500


class NotFoundError(Exception):
Expand Down Expand Up @@ -356,15 +357,8 @@ def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]
for call in hydrated_batch:
yield tsi.CallSchema.model_validate(call)

# *** Dynamic Batch Size ***
# count the number of columns at each depth
depths = Counter(col.count(".") for col in expand_columns)
# take the max number of columns at any depth
max_count_at_ref_depth = max(depths.values())
# divide max refs that we can resolve 1000 refs at any depth
max_size = 1000 // max_count_at_ref_depth
# double batch size up to what refs_read_batch can handle
batch_size = min(max_size, batch_size * 2)
# *** Dynamic increase from 10 to 500 ***
batch_size = min(MAX_CALLS_STREAM_BATCH_SIZE, batch_size * 10)
batch = []

hydrated_batch = self._hydrate_calls(
Expand Down Expand Up @@ -1440,9 +1434,8 @@ def _select_objs_query(
parameters to be passed to the query. Must include all parameters for both
conditions and object_id_conditions.
metadata_only:
if metadata_only is True, then we exclude the val_dump field in the select query.
generally, "queries" should not include the val_dump, but "reads" should, as
the val_dump is the most expensive part of the query.
if metadata_only is True, then we return early and dont grab the value.
Otherwise, make a second query to grab the val_dump from the db
"""
if not conditions:
conditions = ["1 = 1"]
Expand Down Expand Up @@ -1475,19 +1468,14 @@ def _select_objs_query(
if parameters is None:
parameters = {}

# When metadata_only is false, dont actually read from the field
val_dump_field = "'{}' AS val_dump" if metadata_only else "val_dump"

# The subquery is for deduplication of object versions by digest
select_query = f"""
select_without_val_dump_query = f"""
SELECT
project_id,
object_id,
created_at,
kind,
base_object_class,
refs,
val_dump,
digest,
is_op,
version_index,
Expand All @@ -1500,7 +1488,6 @@ def _select_objs_query(
kind,
base_object_class,
refs,
val_dump,
digest,
is_op,
row_number() OVER (
Expand All @@ -1518,7 +1505,6 @@ def _select_objs_query(
kind,
base_object_class,
refs,
{val_dump_field},
digest,
if (kind = 'op', 1, 0) AS is_op,
row_number() OVER (
Expand All @@ -1540,7 +1526,7 @@ def _select_objs_query(
{offset_part}
"""
query_result = self._query_stream(
select_query,
select_without_val_dump_query,
{"project_id": project_id, **parameters},
)
result: list[SelectableCHObjSchema] = []
Expand All @@ -1556,19 +1542,50 @@ def _select_objs_query(
"kind",
"base_object_class",
"refs",
"val_dump",
"digest",
"is_op",
"version_index",
"version_count",
"is_latest",
"val_dump",
],
row,
# Add an empty val_dump to the end of the row
list(row) + ["{}"],
)
)
)
)

# -- Don't make second query for object values if metadata_only --
if metadata_only:
return result

# now get the val_dump for each object
object_ids = list(set([row.object_id for row in result]))
digests = list(set([row.digest for row in result]))
query = """
SELECT object_id, digest, any(val_dump)
FROM object_versions
WHERE project_id = {project_id: String} AND
object_id IN {object_ids: Array(String)} AND
digest IN {digests: Array(String)}
GROUP BY object_id, digest
"""
parameters = {
"project_id": project_id,
"object_ids": object_ids,
"digests": digests,
}
query_result = self._query_stream(query, parameters)
# Map (object_id, digest) to val_dump
object_values: Dict[tuple[str, str], Any] = {}
for row in query_result:
(object_id, digest, val_dump) = row
object_values[(object_id, digest)] = val_dump

# update the val_dump for each object
for obj in result:
obj.val_dump = object_values.get((obj.object_id, obj.digest), "{}")
return result

def _run_migrations(self) -> None:
Expand All @@ -1581,7 +1598,7 @@ def _query_stream(
query: str,
parameters: Dict[str, Any],
column_formats: Optional[Dict[str, Any]] = None,
) -> Iterator[QueryResult]:
) -> Iterator[tuple]:
"""Streams the results of a query from the database."""
summary = None
parameters = _process_parameters(parameters)
Expand Down
Loading

0 comments on commit 9487a4e

Please sign in to comment.