diff --git a/weave-js/src/core/ll.test.ts b/weave-js/src/core/ll.test.ts index 6e7bb2666a7..95fa418dcc9 100644 --- a/weave-js/src/core/ll.test.ts +++ b/weave-js/src/core/ll.test.ts @@ -340,6 +340,7 @@ describe('ll', () => { 'project-artifactTypes', 'project-artifact', 'project-artifactVersion', + 'project-traces', ]); }); diff --git a/weave-js/src/core/model/helpers.ts b/weave-js/src/core/model/helpers.ts index 32b560ad4a6..e0500d3b365 100644 --- a/weave-js/src/core/model/helpers.ts +++ b/weave-js/src/core/model/helpers.ts @@ -897,10 +897,14 @@ export function isObjectTypeLike(t: any): t is ObjectType | Union { ); } -export function typedDict(propertyTypes: {[key: string]: Type}): TypedDictType { +export function typedDict( + propertyTypes: {[key: string]: Type}, + notRequiredKeys?: string[] +): TypedDictType { return { type: 'typedDict', propertyTypes, + notRequiredKeys, }; } diff --git a/weave-js/src/core/ops/domain/project.ts b/weave-js/src/core/ops/domain/project.ts index d8bbb1cdf64..fc2532fb26d 100644 --- a/weave-js/src/core/ops/domain/project.ts +++ b/weave-js/src/core/ops/domain/project.ts @@ -1,8 +1,15 @@ import * as Urls from '../../_external/util/urls'; -import {hash, list} from '../../model'; +import {hash, list, typedDict, union} from '../../model'; import {docType} from '../../util/docs'; import * as OpKinds from '../opKinds'; -import {connectionToNodes} from './util'; +import { + connectionToNodes, + traceFilterType, + traceLimitType, + traceOffsetType, + traceQueryType, + traceSortByType, +} from './util'; const makeProjectOp = OpKinds.makeTaggingStandardOp; @@ -297,3 +304,71 @@ export const opProjectRunQueues = makeProjectOp({ returnType: inputTypes => list('runQueue'), resolver: ({project}) => project.runQueues, }); + +const projectTracesArgTypes = { + ...projectArgTypes, + payload: union([ + 'none', + typedDict( + { + filter: traceFilterType, + limit: traceLimitType, + offset: traceOffsetType, + sort_by: traceSortByType, + query: traceQueryType, + }, + ['filter', 'limit', 'offset', 'sort_by', 'query'] + ), + ]), +}; + +const projectTracesArgTypesDescription = { + project: projectArgDescription, + payload: 'The payload object to the trace api', + 'payload.filter': `The filter object used when querying traces`, + 'payload.limit': `A number representing the limit for number of trace calls`, + 'payload.offset': `A number representing the offset for the number of trace calls`, + 'payload.sort_by': `An array with a dictionary with keys \`field\`() and \`direction\` ("asc"|"desc")`, + 'payload.query': `A dictionary to query data inspired by mongodb aggregation operators`, +}; + +export const opProjectTracesType = makeProjectOp({ + name: 'project-tracesType', + argTypes: projectTracesArgTypes, + description: `Returns the ${docType('list', { + plural: true, + })} for a ${docType('project')}`, + argDescriptions: projectTracesArgTypesDescription, + returnValueDescription: `The ${docType('list', { + plural: true, + })} for a ${docType('project')}`, + returnType: inputTypes => 'type', + resolver: ({project}) => project.traces, + hidden: true, +}); + +export const opProjectTraces = makeProjectOp({ + name: 'project-traces', + argTypes: projectTracesArgTypes, + description: `Returns the ${docType('list', { + plural: true, + })} of traces for a ${docType('project')}`, + argDescriptions: projectTracesArgTypesDescription, + returnValueDescription: `The ${docType('list', { + plural: true, + })} for a ${docType('project')}`, + returnType: inputTypes => list(typedDict({})), + resolver: ({project}) => project.traces, + resolveOutputType: async ( + inputTypes, + node, + executableNode, + client, + stack + ) => { + const res = await client.query( + opProjectTracesType(executableNode.fromOp.inputs as any) + ); + return res; + }, +}); diff --git a/weave-js/src/core/ops/domain/util.ts b/weave-js/src/core/ops/domain/util.ts index df05ccfe3af..094cbde2e6c 100644 --- a/weave-js/src/core/ops/domain/util.ts +++ b/weave-js/src/core/ops/domain/util.ts @@ -1,3 +1,5 @@ +import {dict, list, typedDict, union} from '../../model'; + /** * `connectionToNodes` is a helper function that converts a `connection` type * returned from gql to its list of nodes. In many, many location in our @@ -41,3 +43,27 @@ export const connectionToNodes = (connection: MaybeConnection): T[] => (connection?.edges ?? []) .map(edge => edge?.node) .filter(node => node != null) as T[]; + +const traceFilterPropertyTypes = { + trace_roots_only: union(['none', 'boolean']), + op_names: union(['none', list('string')]), + input_refs: union(['none', list('string')]), + output_refs: union(['none', list('string')]), + parent_ids: union(['none', list('string')]), + trace_ids: union(['none', list('string')]), + call_ids: union(['none', list('string')]), + wb_user_ids: union(['none', list('string')]), + wb_run_ids: union(['none', list('string')]), +}; + +export const traceFilterType = union([ + 'none', + typedDict(traceFilterPropertyTypes, Object.keys(traceFilterPropertyTypes)), +]); +export const traceLimitType = union(['none', 'number']); +export const traceOffsetType = union(['none', 'number']); +export const traceSortByType = union([ + 'none', + list(typedDict({field: 'string', direction: 'string'})), +]); +export const traceQueryType = union(['none', dict('any')]); diff --git a/weave_query/weave_query/io_service.py b/weave_query/weave_query/io_service.py index 79b917d3634..acf67cbdf19 100644 --- a/weave_query/weave_query/io_service.py +++ b/weave_query/weave_query/io_service.py @@ -35,6 +35,7 @@ uris, wandb_api, wandb_file_manager, + wandb_trace_server_api ) tracer = engine_trace.tracer() # type: ignore @@ -217,6 +218,7 @@ def __init__( self.register_handler_fn("ensure_file", self.handle_ensure_file) self.register_handler_fn("direct_url", self.handle_direct_url) self.register_handler_fn("sleep", self.handle_sleep) + self.register_handler_fn("query_traces", self.handle_query_traces) if process: self.request_handler = aioprocessing.AioProcess( @@ -343,6 +345,7 @@ async def _request_handler_fn_main(self) -> None: self.wandb_file_manager = wandb_file_manager.WandbFileManagerAsync( fs, net, await wandb_api.get_wandb_api() ) + self.wandb_trace_server_api = wandb_trace_server_api.WandbTraceApiAsync(net) self._request_handler_ready_event.set() while True: @@ -417,6 +420,24 @@ async def handle_ensure_file(self, artifact_uri: str) -> typing.Optional[str]: ): raise errors.WeaveInternalError("invalid scheme ", uri) return await self.wandb_file_manager.ensure_file(uri) + + async def handle_query_traces( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + ) -> list[dict]: + return await self.wandb_trace_server_api.query_calls_stream( + project_id, + filter=filter, + limit=limit, + offset=offset, + sort_by=sort_by, + query=query, + ) async def handle_ensure_file_downloaded( self, download_url: str @@ -481,12 +502,24 @@ def response_task_ended(self, task: asyncio.Task) -> None: async def close(self) -> None: self.connected = False + sentinel = ServerResponse(http_error_code=200, error=False, client_id=self.client_id, id=-1, value=None) + await self.response_queue.async_put(sentinel) + await self.response_task async def handle_responses(self) -> None: - while self.connected: - resp = await self.response_queue.async_get() - self.response_queue.task_done() - self.requests[resp.id].set_result(resp) + try: + while self.connected: + resp = await self.response_queue.async_get() + + if resp.id == -1: + break + + self.response_queue.task_done() + self.requests[resp.id].set_result(resp) + finally: + for future in self.requests.values(): + if not future.done(): + future.cancel() async def request(self, name: str, *args: typing.Any) -> typing.Any: # Caller must check ServerResponse.error! @@ -569,6 +602,26 @@ async def direct_url( async def sleep(self, seconds: float) -> float: return await self.request("sleep", seconds) + + async def query_traces( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + ) -> typing.Optional[list[dict]]: + res = await self.request( + "query_traces", + project_id, + filter, + limit, + offset, + sort_by, + query, + ) + return res class AsyncClient: @@ -665,6 +718,24 @@ def direct_url( def sleep(self, seconds: float) -> None: return self.request("sleep", seconds) + def query_traces( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + ) -> typing.Optional[list[dict]]: + return self.request( + "query_traces", + project_id, + filter, + limit, + offset, + sort_by, + query, + ) class ServerlessClient: def __init__(self, fs: filesystem.Filesystem) -> None: @@ -674,6 +745,7 @@ def __init__(self, fs: filesystem.Filesystem) -> None: self.wandb_file_manager = wandb_file_manager.WandbFileManager( self.fs, self.http, self.wandb_api ) + self.wandb_trace_server_api = wandb_trace_server_api.WandbTraceApiSync(self.http) def manifest( self, @@ -705,6 +777,25 @@ def direct_url( def sleep(self, seconds: float) -> None: time.sleep(seconds) + def query_traces( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + ) -> typing.Optional[list[dict]]: + res = self.wandb_trace_server_api.query_calls_stream( + project_id, + filter=filter, + limit=limit, + offset=offset, + sort_by=sort_by, + query=query, + ) + return res + def get_sync_client() -> typing.Union[SyncClient, ServerlessClient]: if context_state.serverless_io_service(): diff --git a/weave_query/weave_query/ops_domain/project_ops.py b/weave_query/weave_query/ops_domain/project_ops.py index 0cfbaf05117..63c5a3ef29f 100644 --- a/weave_query/weave_query/ops_domain/project_ops.py +++ b/weave_query/weave_query/ops_domain/project_ops.py @@ -1,8 +1,11 @@ +import asyncio import json import typing from weave_query import errors from weave_query import weave_types as types +from weave_query import ops_arrow +from weave_query import io_service from weave_query.api import op from weave_query import input_provider from weave_query.gql_op_plugin import wb_gql_op_plugin @@ -259,3 +262,97 @@ def artifacts( for typeEdge in project["artifactTypes_100"]["edges"] for edge in typeEdge["node"]["artifactCollections_100"]["edges"] ] + +async def _get_project_traces(project, payload): + client = io_service.get_async_client() + project_id = f'{project["entity"]["name"]}/{project["name"]}' + filter = None + limit = None + offset = None + sort_by = None + query = None + if payload is not None: + filter = payload.get("filter") + limit = payload.get("limit") + offset = payload.get("offset") + sort_by = payload.get("sort_by") + query = payload.get("query") + + loop = asyncio.get_running_loop() + tasks = set() + async with client.connect() as conn: + task = loop.create_task(conn.query_traces( + project_id, + filter=filter, + limit=limit, + offset=offset, + sort_by=sort_by, + query=query)) + tasks.add(task) + await asyncio.wait(tasks) + return task.result() + +traces_filter_property_types = { + "op_names": types.optional(types.List(types.String())), + "input_refs": types.optional(types.List(types.String())), + "output_refs": types.optional(types.List(types.String())), + "parent_ids": types.optional(types.List(types.String())), + "trace_ids": types.optional(types.List(types.String())), + "call_ids": types.optional(types.List(types.String())), + "trace_roots_only": types.optional(types.Boolean()), + "wb_user_ids": types.optional(types.List(types.String())), + "wb_run_ids": types.optional(types.List(types.String())), +} + +traces_input_types = { + "project": wdt.ProjectType, + "payload": types.optional(types.TypedDict(property_types={ + "filter": types.optional(types.TypedDict(property_types=traces_filter_property_types, not_required_keys=set(traces_filter_property_types.keys()))), + "limit": types.optional(types.Number()), + "offset": types.optional(types.Number()), + "sort_by": types.optional(types.List(types.TypedDict(property_types={"field": types.String(), "direction": types.String()}))), + "query": types.optional(types.Dict()) + }, not_required_keys=set(['filter', 'limit', 'offset', 'sort_by', 'query']))) +} + +@op( + name="project-tracesType", + input_type=traces_input_types, + output_type=types.TypeType(), + plugins=wb_gql_op_plugin( + lambda inputs, inner: """ + entity { + id + name + } + """ + ), + hidden=True +) +def traces_type(project, payload): + res = asyncio.run(_get_project_traces(project, payload)) + if res: + return types.TypeRegistry.type_of(res) + else: + return types.TypeRegistry.type_of([]) + +@op( + name="project-traces", + input_type=traces_input_types, + output_type=ops_arrow.ArrowWeaveListType(types.TypedDict({})), + plugins=wb_gql_op_plugin( + lambda inputs, inner: """ + entity { + id + name + } + """ + ), + refine_output_type=traces_type +) +def traces(project, payload): + res = asyncio.run(_get_project_traces(project, payload)) + if res: + return ops_arrow.to_arrow(res) + else: + return ops_arrow.to_arrow([]) \ No newline at end of file diff --git a/weave_query/weave_query/wandb_trace_server_api.py b/weave_query/weave_query/wandb_trace_server_api.py new file mode 100644 index 00000000000..0f6187353ab --- /dev/null +++ b/weave_query/weave_query/wandb_trace_server_api.py @@ -0,0 +1,103 @@ +# This is an experimental client api used to make requests to the +# Weave Trace Server +import typing + +import aiohttp + +from weave_query import environment as weave_env +from weave_query import engine_trace, weave_http + +from weave_query.context_state import WandbApiContext, _wandb_api_context + +tracer = engine_trace.tracer() # type: ignore + +def get_wandb_api_context() -> typing.Optional[WandbApiContext]: + return _wandb_api_context.get() + +class WandbTraceApiAsync: + def __init__(self, http: weave_http.HttpAsync) -> None: + self.http = http + + async def query_calls_stream( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + ) -> typing.List[dict]: + wandb_api_context = get_wandb_api_context() + headers = {'content-type': 'application/json'} + auth = None + cookies = None + if wandb_api_context is not None: + headers = wandb_api_context.headers + cookies = wandb_api_context.cookies + if wandb_api_context.api_key is not None: + auth = aiohttp.BasicAuth("api", wandb_api_context.api_key) + if cookies: + headers["authorization"] = "Basic Og==" # base64 encoding of ":" + + url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" + + payload = { + "project_id": project_id, + } + + if filter: + payload["filter"] = filter + if limit: + payload["limit"] = limit + if offset: + payload["offset"] = offset + if sort_by: + payload["sort_by"] = sort_by + if query: + payload["query"] = query + + return await self.http.query_traces(url, payload, headers=headers, cookies=cookies, auth=auth) + +class WandbTraceApiSync: + def __init__(self, http: weave_http.Http) -> None: + self.http = http + def query_calls_stream( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + **kwargs: typing.Any + ) -> typing.Any: + wandb_api_context = get_wandb_api_context() + headers = {'content-type': 'application/json'} + auth = None + cookies = None + if wandb_api_context is not None: + headers = wandb_api_context.headers + cookies = wandb_api_context.cookies + if wandb_api_context.api_key is not None: + auth = aiohttp.BasicAuth("api", wandb_api_context.api_key) + if cookies: + headers["authorization"] = "Basic Og==" # base64 encoding of ":" + + url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" + + payload = { + "project_id": project_id, + } + + if filter: + payload["filter"] = filter + if limit: + payload["limit"] = limit + if offset: + payload["offset"] = offset + if sort_by: + payload["sort_by"] = sort_by + if query: + payload["query"] = query + + return self.http.query_traces(url, payload, headers=headers, cookies=cookies, auth=auth) diff --git a/weave_query/weave_query/weave_http.py b/weave_query/weave_query/weave_http.py index 39bd91fbb9c..16d9c8f0a40 100644 --- a/weave_query/weave_query/weave_http.py +++ b/weave_query/weave_query/weave_http.py @@ -12,6 +12,7 @@ import requests import requests.auth import yarl +import json from weave_query import engine_trace, filesystem, server_error_handling @@ -128,6 +129,29 @@ async def download_file( r.status, "Download failed" ) + async def query_traces( + self, + url: str, + payload: typing.Optional[dict], + headers: typing.Optional[dict[str, str]] = None, + cookies: typing.Optional[dict[str, str]] = None, + auth: typing.Optional[aiohttp.BasicAuth] = None, + ) -> list[dict]: + with tracer.trace("query_traces"): + results = [] + async with self.session.post(url, json=payload, headers=headers, cookies=cookies, auth=auth) as response: + if response.status == 200: + async for line in response.content: + if line: + decoded_line = line.decode('utf-8').strip() + if decoded_line: + results.append(json.loads(decoded_line)) + return results + else: + raise server_error_handling.WeaveInternalHttpException.from_code( + response.status, + "Traces query failed", + ) class Http: def __init__(self, fs: filesystem.Filesystem) -> None: @@ -169,3 +193,25 @@ def download_file( r.status_code, "Download failed", # type: ignore ) + + def query_traces( + self, + url: str, + payload: typing.Optional[dict], + headers: typing.Optional[dict[str, str]] = None, + cookies: typing.Optional[dict[str, str]] = None, + auth: typing.Optional[aiohttp.BasicAuth] = None, + ) -> list[dict]: + with tracer.trace("query_traces"): + results = [] + with self.session.post(url, json=payload, headers=headers, cookies=cookies, auth=auth) as response: + if response.status_code == 200: + for line in response.iter_lines(): + if line: + results.append(json.loads(line)) + else: + raise server_error_handling.WeaveInternalHttpException.from_code( + response.status_code, + "Traces query failed", + ) + return results