diff --git a/tests/trace/test_trace_server_common.py b/tests/trace/test_trace_server_common.py
index 9bc7495481f..d9170f83eee 100644
--- a/tests/trace/test_trace_server_common.py
+++ b/tests/trace/test_trace_server_common.py
@@ -1,4 +1,5 @@
from weave.trace_server.trace_server_common import (
+ DynamicBatchProcessor,
LRUCache,
get_nested_key,
set_nested_key,
@@ -54,3 +55,26 @@ def test_lru_cache():
cache["c"] = 10
assert cache["c"] == 10
assert cache["d"] == 4
+
+
+def test_dynamic_batch_processor():
+ # Initialize processor with:
+ # - initial batch size of 2
+ # - max size of 8
+ # - growth factor of 2
+ processor = DynamicBatchProcessor(initial_size=2, max_size=8, growth_factor=2)
+
+ test_data = range(15)
+
+ batches = list(processor.make_batches(iter(test_data)))
+
+ # Expected batch sizes: 2, 4, 8, 1
+ assert batches[0] == [0, 1]
+ assert batches[1] == [2, 3, 4, 5]
+ assert batches[2] == [6, 7, 8, 9, 10, 11, 12, 13]
+ assert batches[3] == [14]
+ assert len(batches) == 4
+
+ # Verify all items were processed
+ flattened = [item for batch in batches for item in batch]
+ assert flattened == list(range(15))
diff --git a/weave-js/src/components/FancyPage/FancyPageMenu.tsx b/weave-js/src/components/FancyPage/FancyPageMenu.tsx
index c5971756016..6829b70fc06 100644
--- a/weave-js/src/components/FancyPage/FancyPageMenu.tsx
+++ b/weave-js/src/components/FancyPage/FancyPageMenu.tsx
@@ -60,7 +60,6 @@ export const FancyPageMenu = ({
return null;
}
const linkProps = {
- key: menuItem.slug,
to: menuItem.isDisabled
? {}
: {
@@ -76,7 +75,7 @@ export const FancyPageMenu = ({
},
};
return (
-
+
{menuItem.nameTooltip || menuItem.name}
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx
index 9f845ad0cd1..18130d64341 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallsPage/CallsTable.tsx
@@ -85,7 +85,10 @@ import {TraceCallSchema} from '../wfReactInterface/traceServerClientTypes';
import {traceCallToUICallSchema} from '../wfReactInterface/tsDataModelHooks';
import {EXPANDED_REF_REF_KEY} from '../wfReactInterface/tsDataModelHooksCallRefExpansion';
import {objectVersionNiceString} from '../wfReactInterface/utilities';
-import {CallSchema} from '../wfReactInterface/wfDataModelHooksInterface';
+import {
+ CallSchema,
+ OpVersionSchema,
+} from '../wfReactInterface/wfDataModelHooksInterface';
import {CallsCharts} from './CallsCharts';
import {CallsCustomColumnMenu} from './CallsCustomColumnMenu';
import {
@@ -743,71 +746,13 @@ export const CallsTable: FC<{
calls.refetch()} />
{!hideOpSelector && (
-
-
-
- (
-
- )}
- sx={{
- '& .MuiOutlinedInput-root': {
- height: '32px',
- '& fieldset': {
- borderColor: MOON_200,
- },
- '&:hover fieldset': {
- borderColor: `rgba(${TEAL_300}, 0.48)`,
- },
- },
- '& .MuiOutlinedInput-input': {
- height: '32px',
- padding: '0 14px',
- boxSizing: 'border-box',
- },
- }}
- size="small"
- // Temp disable multiple for simplicity - may want to re-enable
- // multiple
- limitTags={1}
- disabled={Object.keys(frozenFilter ?? {}).includes(
- 'opVersions'
- )}
- value={selectedOpVersionOption}
- onChange={(event, newValue) => {
- if (newValue === ALL_TRACES_OR_CALLS_REF_KEY) {
- setFilter({
- ...filter,
- opVersionRefs: [],
- });
- } else {
- setFilter({
- ...filter,
- opVersionRefs: newValue ? [newValue] : [],
- });
- }
- }}
- renderInput={renderParams => (
-
- )}
- getOptionLabel={option => {
- return opVersionOptions[option]?.title ?? 'loading...';
- }}
- disableClearable={
- selectedOpVersionOption === ALL_TRACES_OR_CALLS_REF_KEY
- }
- groupBy={option => opVersionOptions[option]?.group}
- options={Object.keys(opVersionOptions)}
- popupIcon={}
- clearIcon={}
- />
-
-
-
+
)}
{filterModel && setFilterModel && (
void;
+ selectedOpVersionOption: string;
+ opVersionOptions: Record<
+ string,
+ {
+ title: string;
+ ref: string;
+ group: string;
+ objectVersion?: OpVersionSchema;
+ }
+ >;
+}) => {
+ const frozenOpFilter = Object.keys(frozenFilter ?? {}).includes('opVersions');
+ const handleChange = useCallback(
+ (event: any, newValue: string | null) => {
+ if (newValue === ALL_TRACES_OR_CALLS_REF_KEY) {
+ setFilter({
+ ...filter,
+ opVersionRefs: [],
+ });
+ } else {
+ setFilter({
+ ...filter,
+ opVersionRefs: newValue ? [newValue] : [],
+ });
+ }
+ },
+ [filter, setFilter]
+ );
+
+ return (
+
+
+
+ }
+ sx={{
+ '& .MuiOutlinedInput-root': {
+ height: '32px',
+ '& fieldset': {
+ borderColor: MOON_200,
+ },
+ '&:hover fieldset': {
+ borderColor: `rgba(${TEAL_300}, 0.48)`,
+ },
+ },
+ '& .MuiOutlinedInput-input': {
+ height: '32px',
+ padding: '0 14px',
+ boxSizing: 'border-box',
+ },
+ }}
+ size="small"
+ limitTags={1}
+ disabled={frozenOpFilter}
+ value={selectedOpVersionOption}
+ onChange={handleChange}
+ renderInput={renderParams => (
+
+ )}
+ getOptionLabel={option => opVersionOptions[option]?.title ?? ''}
+ disableClearable={
+ selectedOpVersionOption === ALL_TRACES_OR_CALLS_REF_KEY
+ }
+ groupBy={option => opVersionOptions[option]?.group}
+ options={Object.keys(opVersionOptions)}
+ popupIcon={}
+ clearIcon={}
+ />
+
+
+
+ );
+};
+
const ButtonDivider = () => (
);
diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py
index 943ef31b0b8..d40d7bcc2a3 100644
--- a/weave/trace_server/clickhouse_trace_server_batched.py
+++ b/weave/trace_server/clickhouse_trace_server_batched.py
@@ -95,6 +95,7 @@
validate_cost_purge_req,
)
from weave.trace_server.trace_server_common import (
+ DynamicBatchProcessor,
LRUCache,
digest_is_version_like,
empty_str_to_none,
@@ -120,6 +121,7 @@
FILE_CHUNK_SIZE = 100000
MAX_DELETE_CALLS_COUNT = 100
+INITIAL_CALLS_STREAM_BATCH_SIZE = 100
MAX_CALLS_STREAM_BATCH_SIZE = 500
@@ -343,68 +345,47 @@ def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]
)
select_columns = [c.field for c in cq.select_fields]
+ expand_columns = req.expand_columns or []
+ include_feedback = req.include_feedback or False
- if not req.expand_columns and not req.include_feedback:
- for row in raw_res:
- yield tsi.CallSchema.model_validate(
- _ch_call_dict_to_call_schema_dict(dict(zip(select_columns, row)))
- )
-
- else:
- expand_columns = req.expand_columns or []
- ref_cache = LRUCache(max_size=1000)
+ def row_to_call_schema_dict(row: tuple[Any, ...]) -> dict[str, Any]:
+ return _ch_call_dict_to_call_schema_dict(dict(zip(select_columns, row)))
- batch_size = 10
- batch = []
+ if not expand_columns and not include_feedback:
for row in raw_res:
- call_dict = _ch_call_dict_to_call_schema_dict(
- dict(zip(select_columns, row))
+ yield tsi.CallSchema.model_validate(row_to_call_schema_dict(row))
+ return
+
+ ref_cache = LRUCache(max_size=1000)
+ batch_processor = DynamicBatchProcessor(
+ initial_size=INITIAL_CALLS_STREAM_BATCH_SIZE,
+ max_size=MAX_CALLS_STREAM_BATCH_SIZE,
+ growth_factor=10,
+ )
+
+ for batch in batch_processor.make_batches(raw_res):
+ call_dicts = [row_to_call_schema_dict(row) for row in batch]
+ if expand_columns:
+ self._expand_call_refs(
+ req.project_id, call_dicts, expand_columns, ref_cache
)
- batch.append(call_dict)
- if len(batch) >= batch_size:
- hydrated_batch = self._hydrate_calls(
- req.project_id,
- batch,
- expand_columns,
- req.include_feedback or False,
- ref_cache,
- )
- for call in hydrated_batch:
- yield tsi.CallSchema.model_validate(call)
-
- # *** Dynamic increase from 10 to 500 ***
- batch_size = min(MAX_CALLS_STREAM_BATCH_SIZE, batch_size * 10)
- batch = []
-
- hydrated_batch = self._hydrate_calls(
- req.project_id,
- batch,
- expand_columns,
- req.include_feedback or False,
- ref_cache,
- )
- for call in hydrated_batch:
+ if include_feedback:
+ self._add_feedback_to_calls(req.project_id, call_dicts)
+
+ for call in call_dicts:
yield tsi.CallSchema.model_validate(call)
- def _hydrate_calls(
- self,
- project_id: str,
- calls: list[dict[str, Any]],
- expand_columns: list[str],
- include_feedback: bool,
- ref_cache: LRUCache,
- ) -> list[dict[str, Any]]:
+ def _add_feedback_to_calls(
+ self, project_id: str, calls: list[dict[str, Any]]
+ ) -> None:
if len(calls) == 0:
- return calls
+ return
- self._expand_call_refs(project_id, calls, expand_columns, ref_cache)
- if include_feedback:
- feedback_query_req = make_feedback_query_req(project_id, calls)
+ feedback_query_req = make_feedback_query_req(project_id, calls)
+ with self.with_new_client():
feedback = self.feedback_query(feedback_query_req)
- hydrate_calls_with_feedback(calls, feedback)
-
- return calls
+ hydrate_calls_with_feedback(calls, feedback)
def _get_refs_to_resolve(
self, calls: list[dict[str, Any]], expand_columns: list[str]
@@ -436,6 +417,9 @@ def _expand_call_refs(
expand_columns: list[str],
ref_cache: LRUCache,
) -> None:
+ if len(calls) == 0:
+ return
+
# format expand columns by depth, iterate through each batch in order
expand_column_by_depth = defaultdict(list)
for col in expand_columns:
@@ -448,9 +432,10 @@ def _expand_call_refs(
if not refs_to_resolve:
continue
- vals = self._refs_read_batch_within_project(
- project_id, list(refs_to_resolve.values()), ref_cache
- )
+ with self.with_new_client():
+ vals = self._refs_read_batch_within_project(
+ project_id, list(refs_to_resolve.values()), ref_cache
+ )
for ((i, col), ref), val in zip(refs_to_resolve.items(), vals):
if isinstance(val, dict) and "_ref" not in val:
val["_ref"] = ref.uri()
@@ -1521,6 +1506,7 @@ def completions_create(
# Private Methods
@property
def ch_client(self) -> CHClient:
+ """Returns and creates (if necessary) the clickhouse client"""
if not hasattr(self._thread_local, "ch_client"):
self._thread_local.ch_client = self._mint_client()
return self._thread_local.ch_client
@@ -1538,6 +1524,26 @@ def _mint_client(self) -> CHClient:
client.database = self._database
return client
+ @contextmanager
+ def with_new_client(self) -> Iterator[None]:
+ """Context manager to use a new client for operations.
+ Each call gets a fresh client with its own clickhouse session ID.
+
+ Usage:
+ ```
+ with self.with_new_client():
+ self.feedback_query(req)
+ ```
+ """
+ client = self._mint_client()
+ original_client = self.ch_client
+ self._thread_local.ch_client = client
+ try:
+ yield
+ finally:
+ self._thread_local.ch_client = original_client
+ client.close()
+
# def __del__(self) -> None:
# self.ch_client.close()
diff --git a/weave/trace_server/trace_server_common.py b/weave/trace_server/trace_server_common.py
index 0ff14d4396b..0691927bc47 100644
--- a/weave/trace_server/trace_server_common.py
+++ b/weave/trace_server/trace_server_common.py
@@ -1,6 +1,7 @@
import copy
import datetime
from collections import OrderedDict, defaultdict
+from collections.abc import Iterator
from typing import Any, Optional, cast
from weave.trace_server import refs_internal as ri
@@ -170,6 +171,33 @@ def __setitem__(self, key: str, value: Any) -> None:
super().__setitem__(key, value)
+class DynamicBatchProcessor:
+ """Helper class to handle dynamic batch processing with growing batch sizes."""
+
+ def __init__(self, initial_size: int, max_size: int, growth_factor: int):
+ self.batch_size = initial_size
+ self.max_size = max_size
+ self.growth_factor = growth_factor
+
+ def make_batches(self, iterator: Iterator[Any]) -> Iterator[list[Any]]:
+ batch = []
+
+ for item in iterator:
+ batch.append(item)
+
+ if len(batch) >= self.batch_size:
+ yield batch
+
+ batch = []
+ self.batch_size = self._compute_batch_size()
+
+ if batch:
+ yield batch
+
+ def _compute_batch_size(self) -> int:
+ return min(self.max_size, self.batch_size * self.growth_factor)
+
+
def digest_is_version_like(digest: str) -> tuple[bool, int]:
"""
Check if a digest is a version like string.