Skip to content

Commit

Permalink
fix: Improve TableQuery performance by adding a streaming interface (#…
Browse files Browse the repository at this point in the history
…2619)

* init

* lint

* init

* implemented

* small test

* implemented

* addressed comments
  • Loading branch information
tssweeney authored Oct 10, 2024
1 parent 8d5d38b commit cd1ac69
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 12 deletions.
23 changes: 23 additions & 0 deletions tests/trace/test_table_query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from typing import Iterator

from weave.trace.weave_client import WeaveClient
from weave.trace_server import trace_server_interface as tsi
Expand Down Expand Up @@ -55,6 +56,28 @@ def test_table_query(client: WeaveClient):
assert result_digests == row_digests


def test_table_query_stream(client: WeaveClient):
digest, row_digests, data = generate_table_data(client, 10, 10)

res = client.server.table_query_stream(
tsi.TableQueryReq(
project_id=client._project_id(),
digest=digest,
)
)

assert isinstance(res, Iterator)
rows = []
for r in res:
rows.append(r)

result_vals = [r.val for r in rows]
result_digests = [r.digest for r in rows]

assert result_vals == data
assert result_digests == row_digests


def test_table_query_invalid_digest(client: WeaveClient):
res = client.server.table_query(
tsi.TableQueryReq(
Expand Down
31 changes: 20 additions & 11 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def __init__(

@classmethod
def from_env(cls, use_async_insert: bool = False) -> "ClickHouseTraceServer":
return cls(
# Explicitly calling `RemoteHTTPTraceServer` constructor here to ensure
# that type checking is applied to the constructor.
return ClickHouseTraceServer(
host=wf_env.wf_clickhouse_host(),
port=wf_env.wf_clickhouse_port(),
user=wf_env.wf_clickhouse_user(),
Expand Down Expand Up @@ -770,6 +772,12 @@ def add_new_row_needed_to_insert(row_data: Any) -> str:
return tsi.TableUpdateRes(digest=digest, updated_row_digests=updated_digests)

def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes:
rows = list(self.table_query_stream(req))
return tsi.TableQueryRes(rows=rows)

def table_query_stream(
self, req: tsi.TableQueryReq
) -> Iterator[tsi.TableRowSchema]:
conds = []
pb = ParamBuilder()
if req.filter:
Expand All @@ -790,7 +798,8 @@ def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes:
direction="ASC" if sort.direction.lower() == "asc" else "DESC",
)
sort_fields.append(field)
rows = self._table_query(

rows = self._table_query_stream(
req.project_id,
req.digest,
pb,
Expand All @@ -799,9 +808,10 @@ def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes:
limit=req.limit,
offset=req.offset,
)
return tsi.TableQueryRes(rows=rows)
for row in rows:
yield row

def _table_query(
def _table_query_stream(
self,
project_id: str,
digest: str,
Expand All @@ -813,7 +823,7 @@ def _table_query(
sort_fields: Optional[list[OrderField]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> list[tsi.TableRowSchema]:
) -> Iterator[tsi.TableRowSchema]:
if not sort_fields:
sort_fields = [
OrderField(
Expand Down Expand Up @@ -850,12 +860,10 @@ def _table_query(
offset=offset,
)

query_result = self.ch_client.query(query, parameters=pb.get_params())
res = self._query_stream(query, parameters=pb.get_params())

return [
tsi.TableRowSchema(digest=r[0], val=json.loads(r[1]))
for r in query_result.result_rows
]
for row in res:
yield tsi.TableRowSchema(digest=row[0], val=json.loads(row[1]))

def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes:
parameters: Dict[str, Any] = {
Expand Down Expand Up @@ -1124,14 +1132,15 @@ def resolve_extra(extra: list[str], val: Any) -> PartialRefResult:
raise ValueError("Will not resolve cross-project refs.")
pb = ParamBuilder()
row_digests_name = pb.add_param(row_digests)
rows = self._table_query(
rows_stream = self._table_query_stream(
project_id=project_id_scope,
digest=digest,
pb=pb,
sql_safe_conditions=[
f"digest IN {{{row_digests_name}: Array(String)}}"
],
)
rows = list(rows_stream)
# Unpack the results into the target rows
row_digest_vals = {r.digest: r.val for r in rows}
for index, row_digest in index_digests:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes:
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
return self._ref_apply(self._internal_trace_server.table_query, req)

def table_query_stream(
self, req: tsi.TableQueryReq
) -> Iterator[tsi.TableRowSchema]:
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
return self._stream_ref_apply(
self._internal_trace_server.table_query_stream, req
)

def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes:
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
return self._ref_apply(self._internal_trace_server.table_query_stats, req)
Expand Down
7 changes: 7 additions & 0 deletions weave/trace_server/sqlite_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,13 @@ def _select_objs_query(
)
return result

def table_query_stream(
self, req: tsi.TableQueryReq
) -> Iterator[tsi.TableRowSchema]:
results = self.table_query(req)
for row in results.rows:
yield row


def get_type(val: Any) -> str:
if val == None:
Expand Down
1 change: 1 addition & 0 deletions weave/trace_server/trace_server_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ def objs_query(self, req: ObjQueryReq) -> ObjQueryRes: ...
def table_create(self, req: TableCreateReq) -> TableCreateRes: ...
def table_update(self, req: TableUpdateReq) -> TableUpdateRes: ...
def table_query(self, req: TableQueryReq) -> TableQueryRes: ...
def table_query_stream(self, req: TableQueryReq) -> Iterator[TableRowSchema]: ...
def table_query_stats(self, req: TableQueryStatsReq) -> TableQueryStatsRes: ...
def refs_read_batch(self, req: RefsReadBatchReq) -> RefsReadBatchRes: ...
def file_create(self, req: FileCreateReq) -> FileCreateRes: ...
Expand Down
12 changes: 11 additions & 1 deletion weave/trace_server_bindings/remote_http_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def ensure_project_exists(

@classmethod
def from_env(cls, should_batch: bool = False) -> "RemoteHTTPTraceServer":
return cls(weave_trace_server_url(), should_batch)
# Explicitly calling `RemoteHTTPTraceServer` constructor here to ensure
# that type checking is applied to the constructor.
return RemoteHTTPTraceServer(weave_trace_server_url(), should_batch)

def set_auth(self, auth: Tuple[str, str]) -> None:
self._auth = auth
Expand Down Expand Up @@ -439,6 +441,14 @@ def table_query(
"/table/query", req, tsi.TableQueryReq, tsi.TableQueryRes
)

def table_query_stream(
self, req: tsi.TableQueryReq
) -> Iterator[tsi.TableRowSchema]:
# Need to manually iterate over this until the stram endpoint is built and shipped.
res = self.table_query(req)
for row in res.rows:
yield row

def table_query_stats(
self, req: Union[tsi.TableQueryStatsReq, dict[str, Any]]
) -> tsi.TableQueryStatsRes:
Expand Down

0 comments on commit cd1ac69

Please sign in to comment.