Skip to content

Commit

Permalink
feat(weave_query): create ops to support project level traces
Browse files Browse the repository at this point in the history
  • Loading branch information
domphan-wandb committed Dec 4, 2024
1 parent 0899fce commit 8337219
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 63 deletions.
6 changes: 5 additions & 1 deletion weave-js/src/core/model/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -897,10 +897,14 @@ export function isObjectTypeLike(t: any): t is ObjectType | Union {
);
}

export function typedDict(propertyTypes: {[key: string]: Type}): TypedDictType {
export function typedDict(
propertyTypes: {[key: string]: Type},
notRequiredKeys?: string[]
): TypedDictType {
return {
type: 'typedDict',
propertyTypes,
notRequiredKeys,
};
}

Expand Down
70 changes: 50 additions & 20 deletions weave-js/src/core/ops/domain/project.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import * as Urls from '../../_external/util/urls';
import {hash, list, typedDict} from '../../model';
import {hash, list, maybe, typedDict, union} from '../../model';

Check warning on line 2 in weave-js/src/core/ops/domain/project.ts

View workflow job for this annotation

GitHub Actions / WeaveJS Lint and Compile

'maybe' is defined but never used
import {docType} from '../../util/docs';
import * as OpKinds from '../opKinds';
import {connectionToNodes} from './util';
import {
connectionToNodes,
traceFilterType,
traceLimitType,
traceOffsetType,
traceSortByType,
} from './util';

const makeProjectOp = OpKinds.makeTaggingStandardOp;

Expand Down Expand Up @@ -298,13 +304,38 @@ export const opProjectRunQueues = makeProjectOp({
resolver: ({project}) => project.runQueues,
});

const projectTracesArgTypes = {
...projectArgTypes,
payload: union([
'none',
typedDict(
{
filter: traceFilterType,
limit: traceLimitType,
offset: traceOffsetType,
sort_by: traceSortByType,
},
['filter', 'limit', 'offset', 'sort_by']
),
]),
};

const projectTracesArgTypesDescription = {
project: projectArgDescription,
payload: 'The payload object to the trace api',
'payload.filter': `The filter object used when querying traces`,
'payload.limit': `A number representing the limit for number of trace calls`,
'payload.offset': `A number representing the offset for the number of trace calls`,
'payload.sort_by': `An array with a dictionary with keys \`field\`(<string>) and \`direction\` ("asc"|"desc")`,
};

export const opProjectTracesType = makeProjectOp({
name: 'project-tracesType',
argTypes: projectArgTypes,
argTypes: projectTracesArgTypes,
description: `Returns the ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
argDescriptions: {project: projectArgDescription},
argDescriptions: projectTracesArgTypesDescription,
returnValueDescription: `The ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
Expand All @@ -315,27 +346,26 @@ export const opProjectTracesType = makeProjectOp({

export const opProjectTraces = makeProjectOp({
name: 'project-traces',
argTypes: projectArgTypes,
argTypes: projectTracesArgTypes,
description: `Returns the ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
argDescriptions: {project: projectArgDescription},
})} of traces for a ${docType('project')}`,
argDescriptions: projectTracesArgTypesDescription,
returnValueDescription: `The ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
returnType: inputTypes => list(typedDict({})),
resolver: ({project}) => project.traces,
// resolveOutputType: async (
// inputTypes,
// node,
// executableNode,
// client,
// stack
// ) => {
// console.log(executableNode);
// const res = await client.query(
// opProjectTracesType(executableNode.fromOp.inputs as any)
// );
// return res;
// },
resolveOutputType: async (
inputTypes,
node,
executableNode,
client,
stack
) => {
const res = await client.query(
opProjectTracesType(executableNode.fromOp.inputs as any)
);
return res;
},
});
25 changes: 25 additions & 0 deletions weave-js/src/core/ops/domain/util.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import {list, typedDict, TypedDictType, union} from '../../model';

Check warning on line 1 in weave-js/src/core/ops/domain/util.ts

View workflow job for this annotation

GitHub Actions / WeaveJS Lint and Compile

'TypedDictType' is defined but never used

/**
* `connectionToNodes` is a helper function that converts a `connection` type
* returned from gql to its list of nodes. In many, many location in our
Expand Down Expand Up @@ -41,3 +43,26 @@ export const connectionToNodes = <T>(connection: MaybeConnection<T>): T[] =>
(connection?.edges ?? [])
.map(edge => edge?.node)
.filter(node => node != null) as T[];

const traceFilterPropertyTypes = {
trace_roots_only: union(['none', 'boolean']),
op_names: union(['none', list('string')]),
input_refs: union(['none', list('string')]),
output_refs: union(['none', list('string')]),
parent_ids: union(['none', list('string')]),
trace_ids: union(['none', list('string')]),
call_ids: union(['none', list('string')]),
wb_user_ids: union(['none', list('string')]),
wb_run_ids: union(['none', list('string')]),
};

export const traceFilterType = union([
'none',
typedDict(traceFilterPropertyTypes, Object.keys(traceFilterPropertyTypes)),
]);
export const traceLimitType = union(['none', 'number']);
export const traceOffsetType = union(['none', 'number']);
export const traceSortByType = union([
'none',
list(typedDict({field: 'string', direction: 'string'})),
]);
64 changes: 48 additions & 16 deletions weave_query/weave_query/ops_domain/project_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import typing
import asyncio

from weave_query import errors
from weave_query import weave_types as types
from weave_query import ops_arrow
Expand All @@ -18,7 +18,7 @@
gql_root_op,
make_root_op_gql_op_output_type,
)
from weave_query.wandb_trace_server_api import get_wandb_api
from weave_query.wandb_trace_server_api import get_wandb_trace_api_sync

# Section 1/6: Tag Getters
get_project_tag = make_tag_getter_op("project", wdt.ProjectType, op_name="tag-project")
Expand Down Expand Up @@ -262,14 +262,48 @@ def artifacts(
for edge in typeEdge["node"]["artifactCollections_100"]["edges"]
]

def _get_project_traces(project):
api = get_wandb_api()
res = asyncio.run(api.query_calls_stream(project.id))
print('#############', res)
def _get_project_traces(project, payload):
api = get_wandb_trace_api_sync()
project_id = f'{project["entity"]["name"]}/{project["name"]}'
filter = None
limit = None
offset = None
sort_by = None
if payload is not None:
filter = payload.get("filter")
limit = payload.get("limit")
offset = payload.get("offset")
sort_by = payload.get("sort_by")
res = []
for call in api.query_calls_stream(project_id, filter=filter, limit=limit, offset=offset, sort_by=sort_by):
res.append(call)
return res

traces_filter_property_types = {
"op_names": types.optional(types.List(types.String())),
"input_refs": types.optional(types.List(types.String())),
"output_refs": types.optional(types.List(types.String())),
"parent_ids": types.optional(types.List(types.String())),
"trace_ids": types.optional(types.List(types.String())),
"call_ids": types.optional(types.List(types.String())),
"trace_roots_only": types.optional(types.Boolean()),
"wb_user_ids": types.optional(types.List(types.String())),
"wb_run_ids": types.optional(types.List(types.String())),
}

traces_input_types = {
"project": wdt.ProjectType,
"payload": types.optional(types.TypedDict(property_types={
"filter": types.optional(types.TypedDict(property_types=traces_filter_property_types, not_required_keys=set(traces_filter_property_types.keys()))),
"limit": types.optional(types.Number()),
"offset": types.optional(types.Number()),
"sort_by": types.optional(types.List(types.TypedDict(property_types={"field": types.String(), "direction": types.String()})))
}, not_required_keys=set(['filter', 'limit', 'offset', 'sort_by'])))
}

@op(
name="project-tracesType",
input_type=traces_input_types,
output_type=types.TypeType(),
plugins=wb_gql_op_plugin(
lambda inputs, inner: """
Expand All @@ -281,13 +315,14 @@ def _get_project_traces(project):
),
hidden=True
)
def traces_type(project: wdt.Project):
ttype = types.TypeRegistry.type_of([{"test": 1}, {"test": 2}])
print('ttype:', ttype)
def traces_type(project, payload):
res = _get_project_traces(project, payload)
ttype = types.TypeRegistry.type_of(res)
return ttype

@op(
name="project-traces",
input_type=traces_input_types,
output_type=ops_arrow.ArrowWeaveListType(types.TypedDict({})),
plugins=wb_gql_op_plugin(
lambda inputs, inner: """
Expand All @@ -299,12 +334,9 @@ def traces_type(project: wdt.Project):
),
refine_output_type=traces_type
)
def traces(project: wdt.Project):
# api = get_wandb_api()
res = _get_project_traces(project)
print('#############', res)
return ops_arrow.to_arrow([{"test": 1}, {"test": 2}])
if "calls" in res[0]:
return ops_arrow.to_arrow(res[0]["calls"])
def traces(project, payload):
res = _get_project_traces(project, payload)
if res:
return ops_arrow.to_arrow(res)
else:
return ops_arrow.to_arrow([])
Loading

0 comments on commit 8337219

Please sign in to comment.