From 5fd95bdd0a6fdba5bcadfb27bdc52daf5a8ee001 Mon Sep 17 00:00:00 2001 From: domphan-wandb Date: Wed, 18 Dec 2024 10:57:04 -0800 Subject: [PATCH] feat(weave_query): add ops to query weave traces from a query panel (#3269) * feat(weave_query): create api client for trace server * feat(weave_query): Create ops to fetch traces and traces types (#3270) --- weave-js/src/core/model/helpers.ts | 6 +- weave-js/src/core/ops/domain/project.ts | 78 ++++++++++++- weave-js/src/core/ops/traceTypes.ts | 25 +++++ .../weave_query/ops_domain/project_ops.py | 103 ++++++++++++++++++ .../weave_query/wandb_trace_server_api.py | 74 +++++++++++++ 5 files changed, 284 insertions(+), 2 deletions(-) create mode 100644 weave-js/src/core/ops/traceTypes.ts create mode 100644 weave_query/weave_query/wandb_trace_server_api.py diff --git a/weave-js/src/core/model/helpers.ts b/weave-js/src/core/model/helpers.ts index 32b560ad4a6..e0500d3b365 100644 --- a/weave-js/src/core/model/helpers.ts +++ b/weave-js/src/core/model/helpers.ts @@ -897,10 +897,14 @@ export function isObjectTypeLike(t: any): t is ObjectType | Union { ); } -export function typedDict(propertyTypes: {[key: string]: Type}): TypedDictType { +export function typedDict( + propertyTypes: {[key: string]: Type}, + notRequiredKeys?: string[] +): TypedDictType { return { type: 'typedDict', propertyTypes, + notRequiredKeys, }; } diff --git a/weave-js/src/core/ops/domain/project.ts b/weave-js/src/core/ops/domain/project.ts index d8bbb1cdf64..ee3d14b655f 100644 --- a/weave-js/src/core/ops/domain/project.ts +++ b/weave-js/src/core/ops/domain/project.ts @@ -1,7 +1,14 @@ import * as Urls from '../../_external/util/urls'; -import {hash, list} from '../../model'; +import {hash, list, typedDict, union} from '../../model'; import {docType} from '../../util/docs'; import * as OpKinds from '../opKinds'; +import { + traceFilterType, + traceLimitType, + traceOffsetType, + traceQueryType, + traceSortByType, +} from '../traceTypes'; import {connectionToNodes} from './util'; const makeProjectOp = OpKinds.makeTaggingStandardOp; @@ -297,3 +304,72 @@ export const opProjectRunQueues = makeProjectOp({ returnType: inputTypes => list('runQueue'), resolver: ({project}) => project.runQueues, }); + +const projectTracesArgTypes = { + ...projectArgTypes, + payload: union([ + 'none', + typedDict( + { + filter: traceFilterType, + limit: traceLimitType, + offset: traceOffsetType, + sort_by: traceSortByType, + query: traceQueryType, + }, + ['filter', 'limit', 'offset', 'sort_by', 'query'] + ), + ]), +}; + +const projectTracesArgTypesDescription = { + project: projectArgDescription, + payload: 'The payload object to the trace api', + 'payload.filter': `The filter object used when querying traces`, + 'payload.limit': `A number representing the limit for number of trace calls`, + 'payload.offset': `A number representing the offset for the number of trace calls`, + 'payload.sort_by': `An array with a dictionary with keys \`field\`() and \`direction\` ("asc"|"desc")`, + 'payload.query': `A dictionary to query data inspired by mongodb aggregation operators`, +}; + +export const opProjectTracesType = makeProjectOp({ + name: 'project-tracesType', + argTypes: projectTracesArgTypes, + description: `Returns the ${docType('list', { + plural: true, + })} for a ${docType('project')}`, + argDescriptions: projectTracesArgTypesDescription, + returnValueDescription: `The ${docType('list', { + plural: true, + })} for a ${docType('project')}`, + returnType: inputTypes => 'type', + resolver: ({project}) => project.traces, + hidden: true, +}); + +export const opProjectTraces = makeProjectOp({ + name: 'project-traces', + argTypes: projectTracesArgTypes, + description: `Returns the ${docType('list', { + plural: true, + })} of traces for a ${docType('project')}`, + argDescriptions: projectTracesArgTypesDescription, + returnValueDescription: `The ${docType('list', { + plural: true, + })} for a ${docType('project')}`, + returnType: inputTypes => list(typedDict({})), + resolver: ({project}) => project.traces, + hidden: true, + resolveOutputType: async ( + inputTypes, + node, + executableNode, + client, + stack + ) => { + const res = await client.query( + opProjectTracesType(executableNode.fromOp.inputs as any) + ); + return res; + }, +}); diff --git a/weave-js/src/core/ops/traceTypes.ts b/weave-js/src/core/ops/traceTypes.ts new file mode 100644 index 00000000000..9cfa2e75616 --- /dev/null +++ b/weave-js/src/core/ops/traceTypes.ts @@ -0,0 +1,25 @@ +import {dict, list, typedDict, union} from '../model'; + +const traceFilterPropertyTypes = { + trace_roots_only: union(['none', 'boolean']), + op_names: union(['none', list('string')]), + input_refs: union(['none', list('string')]), + output_refs: union(['none', list('string')]), + parent_ids: union(['none', list('string')]), + trace_ids: union(['none', list('string')]), + call_ids: union(['none', list('string')]), + wb_user_ids: union(['none', list('string')]), + wb_run_ids: union(['none', list('string')]), +}; + +export const traceFilterType = union([ + 'none', + typedDict(traceFilterPropertyTypes, Object.keys(traceFilterPropertyTypes)), +]); +export const traceLimitType = union(['none', 'number']); +export const traceOffsetType = union(['none', 'number']); +export const traceSortByType = union([ + 'none', + list(typedDict({field: 'string', direction: 'string'})), +]); +export const traceQueryType = union(['none', dict('any')]); diff --git a/weave_query/weave_query/ops_domain/project_ops.py b/weave_query/weave_query/ops_domain/project_ops.py index 0cfbaf05117..de4272fa8ce 100644 --- a/weave_query/weave_query/ops_domain/project_ops.py +++ b/weave_query/weave_query/ops_domain/project_ops.py @@ -3,6 +3,8 @@ from weave_query import errors from weave_query import weave_types as types +from weave_query import ops_arrow +from weave_query.wandb_trace_server_api import get_wandb_trace_api from weave_query.api import op from weave_query import input_provider from weave_query.gql_op_plugin import wb_gql_op_plugin @@ -259,3 +261,104 @@ def artifacts( for typeEdge in project["artifactTypes_100"]["edges"] for edge in typeEdge["node"]["artifactCollections_100"]["edges"] ] + +def _get_project_traces(project, payload): + project_id = f'{project["entity"]["name"]}/{project["name"]}' + filter = None + limit = None + offset = None + sort_by = None + query = None + if payload is not None: + filter = payload.get("filter") + limit = payload.get("limit") + offset = payload.get("offset") + sort_by = payload.get("sort_by") + query = payload.get("query") + trace_api = get_wandb_trace_api() + return trace_api.query_calls_stream(project_id, filter=filter, limit=limit, offset=offset, sort_by=sort_by, query=query) + +traces_filter_property_types = { + "op_names": types.optional(types.List(types.String())), + "input_refs": types.optional(types.List(types.String())), + "output_refs": types.optional(types.List(types.String())), + "parent_ids": types.optional(types.List(types.String())), + "trace_ids": types.optional(types.List(types.String())), + "call_ids": types.optional(types.List(types.String())), + "trace_roots_only": types.optional(types.Boolean()), + "wb_user_ids": types.optional(types.List(types.String())), + "wb_run_ids": types.optional(types.List(types.String())), +} + +traces_input_types = { + "project": wdt.ProjectType, + "payload": types.optional(types.TypedDict(property_types={ + "filter": types.optional(types.TypedDict(property_types=traces_filter_property_types, not_required_keys=set(traces_filter_property_types.keys()))), + "limit": types.optional(types.Number()), + "offset": types.optional(types.Number()), + "sort_by": types.optional(types.List(types.TypedDict(property_types={"field": types.String(), "direction": types.String()}))), + "query": types.optional(types.Dict()) + }, not_required_keys=set(['filter', 'limit', 'offset', 'sort_by', 'query']))) +} + +traces_output_type = types.TypedDict(property_types={ + "id": types.String(), + "project_id": types.String(), + "op_name": types.String(), + "display_name": types.optional(types.String()), + "trace_id": types.String(), + "parent_id": types.optional(types.String()), + "started_at": types.Timestamp(), + "attributes": types.Dict(types.String(), types.Any()), + "inputs": types.Dict(types.String(), types.Any()), + "ended_at": types.optional(types.Timestamp()), + "exception": types.optional(types.String()), + "output": types.optional(types.Any()), + "summary": types.optional(types.Any()), + "wb_user_id": types.optional(types.String()), + "wb_run_id": types.optional(types.String()), + "deleted_at": types.optional(types.Timestamp()) +}) + +@op( + name="project-tracesType", + input_type=traces_input_types, + output_type=types.TypeType(), + plugins=wb_gql_op_plugin( + lambda inputs, inner: """ + entity { + id + name + } + """ + ), + hidden=True +) +def traces_type(project, payload): + res = _get_project_traces(project, payload) + if res: + return types.TypeRegistry.type_of(res) + else: + return types.TypeRegistry.type_of([]) + +@op( + name="project-traces", + input_type=traces_input_types, + output_type=ops_arrow.ArrowWeaveListType(traces_output_type), + plugins=wb_gql_op_plugin( + lambda inputs, inner: """ + entity { + id + name + } + """ + ), + refine_output_type=traces_type, + hidden=True +) +def traces(project, payload): + res = _get_project_traces(project, payload) + if res: + return ops_arrow.to_arrow(res) + else: + return ops_arrow.to_arrow([]) diff --git a/weave_query/weave_query/wandb_trace_server_api.py b/weave_query/weave_query/wandb_trace_server_api.py new file mode 100644 index 00000000000..4af605a90a4 --- /dev/null +++ b/weave_query/weave_query/wandb_trace_server_api.py @@ -0,0 +1,74 @@ +# This is an experimental client api used to make requests to the +# Weave Trace Server +import typing +import requests +import json +from requests.auth import HTTPBasicAuth + +from weave_query import environment as weave_env +from weave_query import engine_trace, server_error_handling + +from weave_query.context_state import WandbApiContext, _wandb_api_context + +tracer = engine_trace.tracer() # type: ignore + +def _get_wandb_api_context() -> typing.Optional[WandbApiContext]: + return _wandb_api_context.get() + +class WandbTraceApi: + def __init__(self) -> None: + self.session = requests.Session() + + def query_calls_stream( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + ) -> typing.Any: + with tracer.trace("query_calls_stream"): + wandb_api_context = _get_wandb_api_context() + headers = {'content-type': 'application/json'} + auth = None + cookies = None + if wandb_api_context is not None: + headers = wandb_api_context.headers + cookies = wandb_api_context.cookies + if wandb_api_context.api_key is not None: + auth = HTTPBasicAuth("api", wandb_api_context.api_key) + if cookies: + headers["authorization"] = "Basic Og==" # base64 encoding of ":" + + url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" + + payload = { + "project_id": project_id, + } + + if filter: + payload["filter"] = filter + if limit: + payload["limit"] = limit + if offset: + payload["offset"] = offset + if sort_by: + payload["sort_by"] = sort_by + if query: + payload["query"] = query + + with self.session.post( + url=url, headers=headers, auth=auth, cookies=cookies, json=payload + ) as r: + results = [] + if r.status_code == 200: + for line in r.iter_lines(): + if line: + results.append(json.loads(line)) + return results + else: + raise server_error_handling.WeaveInternalHttpException.from_code(r.status_code, "Weave Traces query failed") + +def get_wandb_trace_api() -> WandbTraceApi: + return WandbTraceApi() \ No newline at end of file