Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Aug 21, 2024
1 parent fd5e042 commit 56d34af
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
14 changes: 8 additions & 6 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,10 @@ def calls_query_stream(
)

else:
ref_cache = LRUCache(max_size=1000)

batch_size = 10
batch = []

ref_cache = LRUCache(max_size=1000)
for row in raw_res:
call_dict = _ch_call_dict_to_call_schema_dict(
dict(zip(select_columns, row))
Expand All @@ -322,8 +322,6 @@ def calls_query_stream(
for call in hydrated_batch:
yield tsi.CallSchema.model_validate(call)

batch = []

# *** Dynamic Batch Size ***
# count the number of columns at each depth
depths = Counter(col.count(".") for col in req.expand_columns)
Expand All @@ -333,6 +331,7 @@ def calls_query_stream(
max_size = 1000 // max_count_at_ref_depth
# double batch size up to what refs_read_batch can handle
batch_size = min(max_size, batch_size * 2)
batch = []

hydrated_batch = self._hydrate_calls(batch, req.expand_columns, ref_cache)
for call in hydrated_batch:
Expand All @@ -342,8 +341,11 @@ def _hydrate_calls(
self,
calls: list[dict[str, typing.Any]],
expand_columns: typing.List[str],
ref_cache: typing.Dict[str, typing.Any],
ref_cache: LRUCache,
) -> list[dict[str, typing.Any]]:
if len(calls) == 0:
return calls

# TODO: Implement feedback hydration here

calls = self._expand_call_refs(calls, expand_columns, ref_cache)
Expand Down Expand Up @@ -372,7 +374,7 @@ def _expand_call_refs(
self,
calls: list[dict[str, typing.Any]],
expand_columns: typing.List[str],
ref_cache: typing.Dict[str, typing.Any],
ref_cache: LRUCache,
) -> list[dict[str, typing.Any]]:
# format expand columns by depth, iterate through each batch in order
expand_column_by_depth = defaultdict(list)
Expand Down
20 changes: 11 additions & 9 deletions weave/trace_server/trace_server_common.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import copy
from collections import OrderedDict
from typing import Any, Dict
from typing import Any, Dict, Optional


def get_nested_key(d: Dict[str, Any], col: str) -> Any:
def get_nested_key(d: Dict[str, Any], col: str) -> Optional[Any]:
"""
Get a nested key from a dict.
Get a nested key from a dict. None if not found.
Example:
get_nested_key({"a": {"b": {"c": "d"}}}, "a.b.c") -> "d"
get_nested_key({"a": {"b": {"c": "d"}}}, "a.b") -> {"c": "d"}
get_nested_key({"a": {"b": {"c": "d"}}}, "a.b.c.e") -> None
get_nested_key({"a": {"b": {"c": "d"}}}, "foobar") -> None
"""

def _get(dictionary: Dict[str, Any], key: str) -> Any:
if isinstance(dictionary, dict):
return dictionary.get(key, {})
return None
def _get(data: Optional[Any], key: str) -> Optional[Any]:
if not data or not isinstance(data, dict):
return None
return data.get(key)

keys = col.split(".")
curr = d
curr: Optional[Any] = d
for key in keys[:-1]:
curr = _get(curr, key)
return _get(curr, keys[-1])
Expand All @@ -44,6 +46,6 @@ def __init__(self, max_size: int = 1000, *args: Any, **kwargs: Dict[str, Any]):
super().__init__(*args, **kwargs)

def __setitem__(self, key: str, value: Any) -> None:
if len(self) >= self.max_size:
if key not in self and len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)

0 comments on commit 56d34af

Please sign in to comment.