diff --git a/weave-js/src/core/model/helpers.ts b/weave-js/src/core/model/helpers.ts index 32b560ad4a6..e0500d3b365 100644 --- a/weave-js/src/core/model/helpers.ts +++ b/weave-js/src/core/model/helpers.ts @@ -897,10 +897,14 @@ export function isObjectTypeLike(t: any): t is ObjectType | Union { ); } -export function typedDict(propertyTypes: {[key: string]: Type}): TypedDictType { +export function typedDict( + propertyTypes: {[key: string]: Type}, + notRequiredKeys?: string[] +): TypedDictType { return { type: 'typedDict', propertyTypes, + notRequiredKeys, }; } diff --git a/weave-js/src/core/ops/domain/project.ts b/weave-js/src/core/ops/domain/project.ts index ba369ac3e31..45b4351dd31 100644 --- a/weave-js/src/core/ops/domain/project.ts +++ b/weave-js/src/core/ops/domain/project.ts @@ -1,8 +1,14 @@ import * as Urls from '../../_external/util/urls'; -import {hash, list, typedDict} from '../../model'; +import {hash, list, maybe, typedDict, union} from '../../model'; import {docType} from '../../util/docs'; import * as OpKinds from '../opKinds'; -import {connectionToNodes} from './util'; +import { + connectionToNodes, + traceFilterType, + traceLimitType, + traceOffsetType, + traceSortByType, +} from './util'; const makeProjectOp = OpKinds.makeTaggingStandardOp; @@ -298,13 +304,38 @@ export const opProjectRunQueues = makeProjectOp({ resolver: ({project}) => project.runQueues, }); +const projectTracesArgTypes = { + ...projectArgTypes, + payload: union([ + 'none', + typedDict( + { + filter: traceFilterType, + limit: traceLimitType, + offset: traceOffsetType, + sort_by: traceSortByType, + }, + ['filter', 'limit', 'offset', 'sort_by'] + ), + ]), +}; + +const projectTracesArgTypesDescription = { + project: projectArgDescription, + payload: 'The payload object to the trace api', + 'payload.filter': `The filter object used when querying traces`, + 'payload.limit': `A number representing the limit for number of trace calls`, + 'payload.offset': `A number representing the offset for the number of trace calls`, + 'payload.sort_by': `An array with a dictionary with keys \`field\`() and \`direction\` ("asc"|"desc")`, +}; + export const opProjectTracesType = makeProjectOp({ name: 'project-tracesType', - argTypes: projectArgTypes, + argTypes: projectTracesArgTypes, description: `Returns the ${docType('list', { plural: true, })} for a ${docType('project')}`, - argDescriptions: {project: projectArgDescription}, + argDescriptions: projectTracesArgTypesDescription, returnValueDescription: `The ${docType('list', { plural: true, })} for a ${docType('project')}`, @@ -315,27 +346,26 @@ export const opProjectTracesType = makeProjectOp({ export const opProjectTraces = makeProjectOp({ name: 'project-traces', - argTypes: projectArgTypes, + argTypes: projectTracesArgTypes, description: `Returns the ${docType('list', { plural: true, - })} for a ${docType('project')}`, - argDescriptions: {project: projectArgDescription}, + })} of traces for a ${docType('project')}`, + argDescriptions: projectTracesArgTypesDescription, returnValueDescription: `The ${docType('list', { plural: true, })} for a ${docType('project')}`, returnType: inputTypes => list(typedDict({})), resolver: ({project}) => project.traces, - // resolveOutputType: async ( - // inputTypes, - // node, - // executableNode, - // client, - // stack - // ) => { - // console.log(executableNode); - // const res = await client.query( - // opProjectTracesType(executableNode.fromOp.inputs as any) - // ); - // return res; - // }, + resolveOutputType: async ( + inputTypes, + node, + executableNode, + client, + stack + ) => { + const res = await client.query( + opProjectTracesType(executableNode.fromOp.inputs as any) + ); + return res; + }, }); diff --git a/weave-js/src/core/ops/domain/util.ts b/weave-js/src/core/ops/domain/util.ts index df05ccfe3af..a77cbb6dd17 100644 --- a/weave-js/src/core/ops/domain/util.ts +++ b/weave-js/src/core/ops/domain/util.ts @@ -1,3 +1,5 @@ +import {list, typedDict, TypedDictType, union} from '../../model'; + /** * `connectionToNodes` is a helper function that converts a `connection` type * returned from gql to its list of nodes. In many, many location in our @@ -41,3 +43,26 @@ export const connectionToNodes = (connection: MaybeConnection): T[] => (connection?.edges ?? []) .map(edge => edge?.node) .filter(node => node != null) as T[]; + +const traceFilterPropertyTypes = { + trace_roots_only: union(['none', 'boolean']), + op_names: union(['none', list('string')]), + input_refs: union(['none', list('string')]), + output_refs: union(['none', list('string')]), + parent_ids: union(['none', list('string')]), + trace_ids: union(['none', list('string')]), + call_ids: union(['none', list('string')]), + wb_user_ids: union(['none', list('string')]), + wb_run_ids: union(['none', list('string')]), +}; + +export const traceFilterType = union([ + 'none', + typedDict(traceFilterPropertyTypes, Object.keys(traceFilterPropertyTypes)), +]); +export const traceLimitType = union(['none', 'number']); +export const traceOffsetType = union(['none', 'number']); +export const traceSortByType = union([ + 'none', + list(typedDict({field: 'string', direction: 'string'})), +]); diff --git a/weave_query/weave_query/ops_domain/project_ops.py b/weave_query/weave_query/ops_domain/project_ops.py index 20d5184b53b..c7f532a3057 100644 --- a/weave_query/weave_query/ops_domain/project_ops.py +++ b/weave_query/weave_query/ops_domain/project_ops.py @@ -1,6 +1,6 @@ import json import typing -import asyncio + from weave_query import errors from weave_query import weave_types as types from weave_query import ops_arrow @@ -18,7 +18,7 @@ gql_root_op, make_root_op_gql_op_output_type, ) -from weave_query.wandb_trace_server_api import get_wandb_api +from weave_query.wandb_trace_server_api import get_wandb_trace_api_sync # Section 1/6: Tag Getters get_project_tag = make_tag_getter_op("project", wdt.ProjectType, op_name="tag-project") @@ -262,14 +262,48 @@ def artifacts( for edge in typeEdge["node"]["artifactCollections_100"]["edges"] ] -def _get_project_traces(project): - api = get_wandb_api() - res = asyncio.run(api.query_calls_stream(project.id)) - print('#############', res) +def _get_project_traces(project, payload): + api = get_wandb_trace_api_sync() + project_id = f'{project["entity"]["name"]}/{project["name"]}' + filter = None + limit = None + offset = None + sort_by = None + if payload is not None: + filter = payload.get("filter") + limit = payload.get("limit") + offset = payload.get("offset") + sort_by = payload.get("sort_by") + res = [] + for call in api.query_calls_stream(project_id, filter=filter, limit=limit, offset=offset, sort_by=sort_by): + res.append(call) return res +traces_filter_property_types = { + "op_names": types.optional(types.List(types.String())), + "input_refs": types.optional(types.List(types.String())), + "output_refs": types.optional(types.List(types.String())), + "parent_ids": types.optional(types.List(types.String())), + "trace_ids": types.optional(types.List(types.String())), + "call_ids": types.optional(types.List(types.String())), + "trace_roots_only": types.optional(types.Boolean()), + "wb_user_ids": types.optional(types.List(types.String())), + "wb_run_ids": types.optional(types.List(types.String())), +} + +traces_input_types = { + "project": wdt.ProjectType, + "payload": types.optional(types.TypedDict(property_types={ + "filter": types.optional(types.TypedDict(property_types=traces_filter_property_types, not_required_keys=set(traces_filter_property_types.keys()))), + "limit": types.optional(types.Number()), + "offset": types.optional(types.Number()), + "sort_by": types.optional(types.List(types.TypedDict(property_types={"field": types.String(), "direction": types.String()}))) + }, not_required_keys=set(['filter', 'limit', 'offset', 'sort_by']))) +} + @op( name="project-tracesType", + input_type=traces_input_types, output_type=types.TypeType(), plugins=wb_gql_op_plugin( lambda inputs, inner: """ @@ -281,13 +315,14 @@ def _get_project_traces(project): ), hidden=True ) -def traces_type(project: wdt.Project): - ttype = types.TypeRegistry.type_of([{"test": 1}, {"test": 2}]) - print('ttype:', ttype) +def traces_type(project, payload): + res = _get_project_traces(project, payload) + ttype = types.TypeRegistry.type_of(res) return ttype @op( name="project-traces", + input_type=traces_input_types, output_type=ops_arrow.ArrowWeaveListType(types.TypedDict({})), plugins=wb_gql_op_plugin( lambda inputs, inner: """ @@ -299,12 +334,9 @@ def traces_type(project: wdt.Project): ), refine_output_type=traces_type ) -def traces(project: wdt.Project): - # api = get_wandb_api() - res = _get_project_traces(project) - print('#############', res) - return ops_arrow.to_arrow([{"test": 1}, {"test": 2}]) - if "calls" in res[0]: - return ops_arrow.to_arrow(res[0]["calls"]) +def traces(project, payload): + res = _get_project_traces(project, payload) + if res: + return ops_arrow.to_arrow(res) else: return ops_arrow.to_arrow([]) \ No newline at end of file diff --git a/weave_query/weave_query/wandb_trace_server_api.py b/weave_query/weave_query/wandb_trace_server_api.py index 3b32e9eff26..2549915bdaa 100644 --- a/weave_query/weave_query/wandb_trace_server_api.py +++ b/weave_query/weave_query/wandb_trace_server_api.py @@ -3,8 +3,10 @@ import contextlib import contextvars import typing +import json import aiohttp +import requests from weave_query import errors from weave_query import environment as weave_env @@ -17,6 +19,7 @@ def get_wandb_api_context() -> typing.Optional[WandbApiContext]: return _wandb_api_context.get() +# todo(dom): Figure out how to use this async client API within ops. class WandbTraceApiAsync: def __init__(self) -> None: self.connector = aiohttp.TCPConnector(limit=50) @@ -27,15 +30,13 @@ async def query_calls_stream( filter: typing.Optional[dict] = None, limit: typing.Optional[int] = None, offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, include_costs: bool = False, include_feedback: bool = False, **kwargs: typing.Any - ) -> typing.Any: + ) -> typing.AsyncIterator[dict]: wandb_context = get_wandb_api_context() - headers = { - "Accept": "application/jsonl", - "Content-Type": "application/json" - } + headers = {'content-type: application/json'} auth = None if wandb_context is not None: @@ -48,7 +49,9 @@ async def query_calls_stream( if api_key_override: auth = aiohttp.BasicAuth("api", api_key_override) - url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" + # todo(dom): Add env var support instead of hardcoding + # url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" + url = "http://127.0.0.1:6345/calls/stream_query" payload = { "project_id": project_id, @@ -62,17 +65,22 @@ async def query_calls_stream( payload["limit"] = limit if offset: payload["offset"] = offset + if sort_by: + payload["sort_by"] = sort_by payload.update(kwargs) async with aiohttp.ClientSession( connector=self.connector, headers=headers, - auth=auth ) as session: async with session.post(url, json=payload) as response: response.raise_for_status() - return await response.json() + async for line in response.content: + if line: + decoded_line = line.decode('utf-8').strip() + if decoded_line: + yield json.loads(decoded_line) class WandbTraceApiSync: def query_calls_stream( @@ -81,28 +89,28 @@ def query_calls_stream( filter: typing.Optional[dict] = None, limit: typing.Optional[int] = None, offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, include_costs: bool = False, include_feedback: bool = False, **kwargs: typing.Any ) -> typing.Any: wandb_context = get_wandb_api_context() - headers = { - "Accept": "application/jsonl", - "Content-Type": "application/json" - } - auth = None + headers = {} if wandb_context is not None: if wandb_context.headers: headers.update(wandb_context.headers) + headers["authorization"] = "Basic Og==" if wandb_context.api_key is not None: - auth = ( "api", wandb_context.api_key) + auth = ("api", wandb_context.api_key) api_key_override = kwargs.pop("api_key", None) if api_key_override: auth = ("api", api_key_override) - url = f"{weave_env.wandb_base_url()}/calls/stream_query" + # todo(dom): Add env var support instead of hardcoding + # url = f"https://trace_server.wandb.test/calls/stream_query" + url = "http://127.0.0.1:6345/calls/stream_query" payload = { "project_id": project_id, @@ -116,21 +124,27 @@ def query_calls_stream( payload["limit"] = limit if offset: payload["offset"] = offset + if sort_by: + payload["sort_by"] = sort_by payload.update(kwargs) - with aiohttp.ClientSession( - connector=self.connector, + # todo(dom): Figure out a way to specify the auth kwarg with it + # causing a 403 error when it is None (when using the authorization header) + response = requests.post( + url, + json=payload, headers=headers, - auth=auth - ) as session: - with session.post(url, json=payload) as response: - response.raise_for_status() - return response.json() - -async def get_wandb_api() -> WandbTraceApiAsync: - return WandbTraceApiAsync() + stream=True + ) + response.raise_for_status() + + for line in response.iter_lines(): + if line: + yield json.loads(line.decode('utf-8')) +async def get_wandb_trace_api() -> WandbTraceApiAsync: + return WandbTraceApiAsync() -def get_wandb_api_sync() -> WandbTraceApiSync: +def get_wandb_trace_api_sync() -> WandbTraceApiSync: return WandbTraceApiSync()