Skip to content

Commit

Permalink
chore(weave): Add generic iterator for trace server API objects (#3177)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong authored Dec 11, 2024
1 parent 2c93fa6 commit a0f1263
Showing 1 changed file with 129 additions and 78 deletions.
207 changes: 129 additions & 78 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Iterator, Sequence
from concurrent.futures import Future
from functools import lru_cache
from typing import Any, Callable, cast
from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload

import pydantic
from requests import HTTPError
Expand Down Expand Up @@ -90,6 +90,128 @@

logger = logging.getLogger(__name__)

T = TypeVar("T")
R = TypeVar("R", covariant=True)


class FetchFunc(Protocol[T]):
def __call__(self, offset: int, limit: int) -> list[T]: ...


TransformFunc = Callable[[T], R]


class PaginatedIterator(Generic[T, R]):
"""An iterator that fetches pages of items from a server and optionally transforms them
into a more user-friendly type."""

def __init__(
self,
fetch_func: FetchFunc[T],
page_size: int = 1000,
transform_func: TransformFunc[T, R] | None = None,
) -> None:
self.fetch_func = fetch_func
self.page_size = page_size
self.transform_func = transform_func

if page_size <= 0:
raise ValueError("page_size must be greater than 0")

@lru_cache
def _fetch_page(self, index: int) -> list[T]:
return self.fetch_func(index * self.page_size, self.page_size)

@overload
def _get_one(self: PaginatedIterator[T, T], index: int) -> T: ...
@overload
def _get_one(self: PaginatedIterator[T, R], index: int) -> R: ...
def _get_one(self, index: int) -> T | R:
if index < 0:
raise IndexError("Negative indexing not supported")

page_index = index // self.page_size
page_offset = index % self.page_size

page = self._fetch_page(page_index)
if page_offset >= len(page):
raise IndexError(f"Index {index} out of range")

res = page[page_offset]
if transform := self.transform_func:
return transform(res)
return res

@overload
def _get_slice(self: PaginatedIterator[T, T], key: slice) -> Iterator[T]: ...
@overload
def _get_slice(self: PaginatedIterator[T, R], key: slice) -> Iterator[R]: ...
def _get_slice(self, key: slice) -> Iterator[T] | Iterator[R]:
if (start := key.start or 0) < 0:
raise ValueError("Negative start not supported")
if (stop := key.stop) is not None and stop < 0:
raise ValueError("Negative stop not supported")
if (step := key.step or 1) < 0:
raise ValueError("Negative step not supported")

i = start
while stop is None or i < stop:
try:
yield self._get_one(i)
except IndexError:
break
i += step

@overload
def __getitem__(self: PaginatedIterator[T, T], key: int) -> T: ...
@overload
def __getitem__(self: PaginatedIterator[T, R], key: int) -> R: ...
@overload
def __getitem__(self: PaginatedIterator[T, T], key: slice) -> list[T]: ...
@overload
def __getitem__(self: PaginatedIterator[T, R], key: slice) -> list[R]: ...
def __getitem__(self, key: slice | int) -> T | R | list[T] | list[R]:
if isinstance(key, slice):
return list(self._get_slice(key))
return self._get_one(key)

@overload
def __iter__(self: PaginatedIterator[T, T]) -> Iterator[T]: ...
@overload
def __iter__(self: PaginatedIterator[T, R]) -> Iterator[R]: ...
def __iter__(self) -> Iterator[T] | Iterator[R]:
return self._get_slice(slice(0, None, 1))


# TODO: should be Call, not WeaveObject
CallsIter = PaginatedIterator[CallSchema, WeaveObject]


def _make_calls_iterator(
server: TraceServerInterface,
project_id: str,
filter: CallsFilter,
include_costs: bool = False,
) -> CallsIter:
def fetch_func(offset: int, limit: int) -> list[CallSchema]:
response = server.calls_query(
CallsQueryReq(
project_id=project_id,
filter=filter,
offset=offset,
limit=limit,
include_costs=include_costs,
)
)
return response.calls

# TODO: Should be Call, not WeaveObject
def transform_func(call: CallSchema) -> WeaveObject:
entity, project = project_id.split("/")
return make_client_call(entity, project, call, server)

return PaginatedIterator(fetch_func, transform_func=transform_func)


class OpNameError(ValueError):
"""Raised when an op name is invalid."""
Expand Down Expand Up @@ -284,7 +406,7 @@ def children(self) -> CallsIter:
)

client = weave_client_context.require_weave_client()
return CallsIter(
return _make_calls_iterator(
client.server,
self.project_id,
CallsFilter(parent_ids=[self.id]),
Expand Down Expand Up @@ -362,80 +484,6 @@ def _apply_scorer(self, scorer_op: Op) -> None:
)


class CallsIter:
server: TraceServerInterface
filter: CallsFilter
include_costs: bool

def __init__(
self,
server: TraceServerInterface,
project_id: str,
filter: CallsFilter,
include_costs: bool = False,
) -> None:
self.server = server
self.project_id = project_id
self.filter = filter
self._page_size = 1000
self.include_costs = include_costs

# seems like this caching should be on the server, but it's here for now...
@lru_cache
def _fetch_page(self, index: int) -> list[CallSchema]:
# caching in here means that any other CallsIter objects would also
# benefit from the cache
response = self.server.calls_query(
CallsQueryReq(
project_id=self.project_id,
filter=self.filter,
offset=index * self._page_size,
limit=self._page_size,
include_costs=self.include_costs,
)
)
return response.calls

def _get_one(self, index: int) -> WeaveObject:
if index < 0:
raise IndexError("Negative indexing not supported")

page_index = index // self._page_size
page_offset = index % self._page_size

calls = self._fetch_page(page_index)
if page_offset >= len(calls):
raise IndexError(f"Index {index} out of range")

call = calls[page_offset]
entity, project = self.project_id.split("/")
return make_client_call(entity, project, call, self.server)

def _get_slice(self, key: slice) -> Iterator[WeaveObject]:
if (start := key.start or 0) < 0:
raise ValueError("Negative start not supported")
if (stop := key.stop) is not None and stop < 0:
raise ValueError("Negative stop not supported")
if (step := key.step or 1) < 0:
raise ValueError("Negative step not supported")

i = start
while stop is None or i < stop:
try:
yield self._get_one(i)
except IndexError:
break
i += step

def __getitem__(self, key: slice | int) -> WeaveObject | list[WeaveObject]:
if isinstance(key, slice):
return list(self._get_slice(key))
return self._get_one(key)

def __iter__(self) -> Iterator[WeaveObject]:
return self._get_slice(slice(0, None, 1))


def make_client_call(
entity: str, project: str, server_call: CallSchema, server: TraceServerInterface
) -> WeaveObject:
Expand Down Expand Up @@ -642,8 +690,11 @@ def get_calls(
if filter is None:
filter = CallsFilter()

return CallsIter(
self.server, self._project_id(), filter, include_costs or False
return _make_calls_iterator(
self.server,
self._project_id(),
filter,
include_costs,
)

@deprecated(new_name="get_calls")
Expand Down

0 comments on commit a0f1263

Please sign in to comment.