Skip to content

Commit

Permalink
feat(weave_query): add ops to query weave traces from a query panel (#…
Browse files Browse the repository at this point in the history
…3269)

* feat(weave_query): create api client for trace server

* feat(weave_query): Create ops to fetch traces and traces types (#3270)
  • Loading branch information
domphan-wandb authored Dec 18, 2024
1 parent c9a5c09 commit 5fd95bd
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 2 deletions.
6 changes: 5 additions & 1 deletion weave-js/src/core/model/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -897,10 +897,14 @@ export function isObjectTypeLike(t: any): t is ObjectType | Union {
);
}

export function typedDict(propertyTypes: {[key: string]: Type}): TypedDictType {
export function typedDict(
propertyTypes: {[key: string]: Type},
notRequiredKeys?: string[]
): TypedDictType {
return {
type: 'typedDict',
propertyTypes,
notRequiredKeys,
};
}

Expand Down
78 changes: 77 additions & 1 deletion weave-js/src/core/ops/domain/project.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import * as Urls from '../../_external/util/urls';
import {hash, list} from '../../model';
import {hash, list, typedDict, union} from '../../model';
import {docType} from '../../util/docs';
import * as OpKinds from '../opKinds';
import {
traceFilterType,
traceLimitType,
traceOffsetType,
traceQueryType,
traceSortByType,
} from '../traceTypes';
import {connectionToNodes} from './util';

const makeProjectOp = OpKinds.makeTaggingStandardOp;
Expand Down Expand Up @@ -297,3 +304,72 @@ export const opProjectRunQueues = makeProjectOp({
returnType: inputTypes => list('runQueue'),
resolver: ({project}) => project.runQueues,
});

const projectTracesArgTypes = {
...projectArgTypes,
payload: union([
'none',
typedDict(
{
filter: traceFilterType,
limit: traceLimitType,
offset: traceOffsetType,
sort_by: traceSortByType,
query: traceQueryType,
},
['filter', 'limit', 'offset', 'sort_by', 'query']
),
]),
};

const projectTracesArgTypesDescription = {
project: projectArgDescription,
payload: 'The payload object to the trace api',
'payload.filter': `The filter object used when querying traces`,
'payload.limit': `A number representing the limit for number of trace calls`,
'payload.offset': `A number representing the offset for the number of trace calls`,
'payload.sort_by': `An array with a dictionary with keys \`field\`(<string>) and \`direction\` ("asc"|"desc")`,
'payload.query': `A dictionary to query data inspired by mongodb aggregation operators`,
};

export const opProjectTracesType = makeProjectOp({
name: 'project-tracesType',
argTypes: projectTracesArgTypes,
description: `Returns the ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
argDescriptions: projectTracesArgTypesDescription,
returnValueDescription: `The ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
returnType: inputTypes => 'type',
resolver: ({project}) => project.traces,
hidden: true,
});

export const opProjectTraces = makeProjectOp({
name: 'project-traces',
argTypes: projectTracesArgTypes,
description: `Returns the ${docType('list', {
plural: true,
})} of traces for a ${docType('project')}`,
argDescriptions: projectTracesArgTypesDescription,
returnValueDescription: `The ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
returnType: inputTypes => list(typedDict({})),
resolver: ({project}) => project.traces,
hidden: true,
resolveOutputType: async (
inputTypes,
node,
executableNode,
client,
stack
) => {
const res = await client.query(
opProjectTracesType(executableNode.fromOp.inputs as any)
);
return res;
},
});
25 changes: 25 additions & 0 deletions weave-js/src/core/ops/traceTypes.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import {dict, list, typedDict, union} from '../model';

const traceFilterPropertyTypes = {
trace_roots_only: union(['none', 'boolean']),
op_names: union(['none', list('string')]),
input_refs: union(['none', list('string')]),
output_refs: union(['none', list('string')]),
parent_ids: union(['none', list('string')]),
trace_ids: union(['none', list('string')]),
call_ids: union(['none', list('string')]),
wb_user_ids: union(['none', list('string')]),
wb_run_ids: union(['none', list('string')]),
};

export const traceFilterType = union([
'none',
typedDict(traceFilterPropertyTypes, Object.keys(traceFilterPropertyTypes)),
]);
export const traceLimitType = union(['none', 'number']);
export const traceOffsetType = union(['none', 'number']);
export const traceSortByType = union([
'none',
list(typedDict({field: 'string', direction: 'string'})),
]);
export const traceQueryType = union(['none', dict('any')]);
103 changes: 103 additions & 0 deletions weave_query/weave_query/ops_domain/project_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from weave_query import errors
from weave_query import weave_types as types
from weave_query import ops_arrow
from weave_query.wandb_trace_server_api import get_wandb_trace_api
from weave_query.api import op
from weave_query import input_provider
from weave_query.gql_op_plugin import wb_gql_op_plugin
Expand Down Expand Up @@ -259,3 +261,104 @@ def artifacts(
for typeEdge in project["artifactTypes_100"]["edges"]
for edge in typeEdge["node"]["artifactCollections_100"]["edges"]
]

def _get_project_traces(project, payload):
project_id = f'{project["entity"]["name"]}/{project["name"]}'
filter = None
limit = None
offset = None
sort_by = None
query = None
if payload is not None:
filter = payload.get("filter")
limit = payload.get("limit")
offset = payload.get("offset")
sort_by = payload.get("sort_by")
query = payload.get("query")
trace_api = get_wandb_trace_api()
return trace_api.query_calls_stream(project_id, filter=filter, limit=limit, offset=offset, sort_by=sort_by, query=query)

traces_filter_property_types = {
"op_names": types.optional(types.List(types.String())),
"input_refs": types.optional(types.List(types.String())),
"output_refs": types.optional(types.List(types.String())),
"parent_ids": types.optional(types.List(types.String())),
"trace_ids": types.optional(types.List(types.String())),
"call_ids": types.optional(types.List(types.String())),
"trace_roots_only": types.optional(types.Boolean()),
"wb_user_ids": types.optional(types.List(types.String())),
"wb_run_ids": types.optional(types.List(types.String())),
}

traces_input_types = {
"project": wdt.ProjectType,
"payload": types.optional(types.TypedDict(property_types={
"filter": types.optional(types.TypedDict(property_types=traces_filter_property_types, not_required_keys=set(traces_filter_property_types.keys()))),
"limit": types.optional(types.Number()),
"offset": types.optional(types.Number()),
"sort_by": types.optional(types.List(types.TypedDict(property_types={"field": types.String(), "direction": types.String()}))),
"query": types.optional(types.Dict())
}, not_required_keys=set(['filter', 'limit', 'offset', 'sort_by', 'query'])))
}

traces_output_type = types.TypedDict(property_types={
"id": types.String(),
"project_id": types.String(),
"op_name": types.String(),
"display_name": types.optional(types.String()),
"trace_id": types.String(),
"parent_id": types.optional(types.String()),
"started_at": types.Timestamp(),
"attributes": types.Dict(types.String(), types.Any()),
"inputs": types.Dict(types.String(), types.Any()),
"ended_at": types.optional(types.Timestamp()),
"exception": types.optional(types.String()),
"output": types.optional(types.Any()),
"summary": types.optional(types.Any()),
"wb_user_id": types.optional(types.String()),
"wb_run_id": types.optional(types.String()),
"deleted_at": types.optional(types.Timestamp())
})

@op(
name="project-tracesType",
input_type=traces_input_types,
output_type=types.TypeType(),
plugins=wb_gql_op_plugin(
lambda inputs, inner: """
entity {
id
name
}
"""
),
hidden=True
)
def traces_type(project, payload):
res = _get_project_traces(project, payload)
if res:
return types.TypeRegistry.type_of(res)
else:
return types.TypeRegistry.type_of([])

@op(
name="project-traces",
input_type=traces_input_types,
output_type=ops_arrow.ArrowWeaveListType(traces_output_type),
plugins=wb_gql_op_plugin(
lambda inputs, inner: """
entity {
id
name
}
"""
),
refine_output_type=traces_type,
hidden=True
)
def traces(project, payload):
res = _get_project_traces(project, payload)
if res:
return ops_arrow.to_arrow(res)
else:
return ops_arrow.to_arrow([])
74 changes: 74 additions & 0 deletions weave_query/weave_query/wandb_trace_server_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# This is an experimental client api used to make requests to the
# Weave Trace Server
import typing
import requests
import json
from requests.auth import HTTPBasicAuth

from weave_query import environment as weave_env
from weave_query import engine_trace, server_error_handling

from weave_query.context_state import WandbApiContext, _wandb_api_context

tracer = engine_trace.tracer() # type: ignore

def _get_wandb_api_context() -> typing.Optional[WandbApiContext]:
return _wandb_api_context.get()

class WandbTraceApi:
def __init__(self) -> None:
self.session = requests.Session()

def query_calls_stream(
self,
project_id: str,
filter: typing.Optional[dict] = None,
limit: typing.Optional[int] = None,
offset: typing.Optional[int] = None,
sort_by: typing.Optional[list] = None,
query: typing.Optional[dict] = None,
) -> typing.Any:
with tracer.trace("query_calls_stream"):
wandb_api_context = _get_wandb_api_context()
headers = {'content-type': 'application/json'}
auth = None
cookies = None
if wandb_api_context is not None:
headers = wandb_api_context.headers
cookies = wandb_api_context.cookies
if wandb_api_context.api_key is not None:
auth = HTTPBasicAuth("api", wandb_api_context.api_key)
if cookies:
headers["authorization"] = "Basic Og==" # base64 encoding of ":"

url = f"{weave_env.weave_trace_server_url()}/calls/stream_query"

payload = {
"project_id": project_id,
}

if filter:
payload["filter"] = filter
if limit:
payload["limit"] = limit
if offset:
payload["offset"] = offset
if sort_by:
payload["sort_by"] = sort_by
if query:
payload["query"] = query

with self.session.post(
url=url, headers=headers, auth=auth, cookies=cookies, json=payload
) as r:
results = []
if r.status_code == 200:
for line in r.iter_lines():
if line:
results.append(json.loads(line))
return results
else:
raise server_error_handling.WeaveInternalHttpException.from_code(r.status_code, "Weave Traces query failed")

def get_wandb_trace_api() -> WandbTraceApi:
return WandbTraceApi()

0 comments on commit 5fd95bd

Please sign in to comment.