Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave_query): add ability to view traces on query panels #3147

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions weave-js/src/core/ll.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ describe('ll', () => {
'project-artifactTypes',
'project-artifact',
'project-artifactVersion',
'project-traces',
]);
});

Expand Down
6 changes: 5 additions & 1 deletion weave-js/src/core/model/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -897,10 +897,14 @@ export function isObjectTypeLike(t: any): t is ObjectType | Union {
);
}

export function typedDict(propertyTypes: {[key: string]: Type}): TypedDictType {
export function typedDict(
propertyTypes: {[key: string]: Type},
notRequiredKeys?: string[]
): TypedDictType {
return {
type: 'typedDict',
propertyTypes,
notRequiredKeys,
};
}

Expand Down
79 changes: 77 additions & 2 deletions weave-js/src/core/ops/domain/project.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import * as Urls from '../../_external/util/urls';
import {hash, list} from '../../model';
import {hash, list, typedDict, union} from '../../model';
import {docType} from '../../util/docs';
import * as OpKinds from '../opKinds';
import {connectionToNodes} from './util';
import {
connectionToNodes,
traceFilterType,
traceLimitType,
traceOffsetType,
traceQueryType,
traceSortByType,
} from './util';

const makeProjectOp = OpKinds.makeTaggingStandardOp;

Expand Down Expand Up @@ -297,3 +304,71 @@ export const opProjectRunQueues = makeProjectOp({
returnType: inputTypes => list('runQueue'),
resolver: ({project}) => project.runQueues,
});

const projectTracesArgTypes = {
...projectArgTypes,
payload: union([
'none',
typedDict(
{
filter: traceFilterType,
limit: traceLimitType,
offset: traceOffsetType,
sort_by: traceSortByType,
query: traceQueryType,
},
['filter', 'limit', 'offset', 'sort_by', 'query']
),
]),
};

const projectTracesArgTypesDescription = {
project: projectArgDescription,
payload: 'The payload object to the trace api',
'payload.filter': `The filter object used when querying traces`,
'payload.limit': `A number representing the limit for number of trace calls`,
'payload.offset': `A number representing the offset for the number of trace calls`,
'payload.sort_by': `An array with a dictionary with keys \`field\`(<string>) and \`direction\` ("asc"|"desc")`,
'payload.query': `A dictionary to query data inspired by mongodb aggregation operators`,
};

export const opProjectTracesType = makeProjectOp({
name: 'project-tracesType',
argTypes: projectTracesArgTypes,
description: `Returns the ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
argDescriptions: projectTracesArgTypesDescription,
returnValueDescription: `The ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
returnType: inputTypes => 'type',
resolver: ({project}) => project.traces,
hidden: true,
});

export const opProjectTraces = makeProjectOp({
name: 'project-traces',
argTypes: projectTracesArgTypes,
description: `Returns the ${docType('list', {
plural: true,
})} of traces for a ${docType('project')}`,
argDescriptions: projectTracesArgTypesDescription,
returnValueDescription: `The ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
returnType: inputTypes => list(typedDict({})),
resolver: ({project}) => project.traces,
resolveOutputType: async (
inputTypes,
node,
executableNode,
client,
stack
) => {
const res = await client.query(
opProjectTracesType(executableNode.fromOp.inputs as any)
);
return res;
},
});
26 changes: 26 additions & 0 deletions weave-js/src/core/ops/domain/util.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import {dict, list, typedDict, union} from '../../model';

/**
* `connectionToNodes` is a helper function that converts a `connection` type
* returned from gql to its list of nodes. In many, many location in our
Expand Down Expand Up @@ -41,3 +43,27 @@ export const connectionToNodes = <T>(connection: MaybeConnection<T>): T[] =>
(connection?.edges ?? [])
.map(edge => edge?.node)
.filter(node => node != null) as T[];

const traceFilterPropertyTypes = {
trace_roots_only: union(['none', 'boolean']),
op_names: union(['none', list('string')]),
input_refs: union(['none', list('string')]),
output_refs: union(['none', list('string')]),
parent_ids: union(['none', list('string')]),
trace_ids: union(['none', list('string')]),
call_ids: union(['none', list('string')]),
wb_user_ids: union(['none', list('string')]),
wb_run_ids: union(['none', list('string')]),
};

export const traceFilterType = union([
'none',
typedDict(traceFilterPropertyTypes, Object.keys(traceFilterPropertyTypes)),
]);
export const traceLimitType = union(['none', 'number']);
export const traceOffsetType = union(['none', 'number']);
export const traceSortByType = union([
'none',
list(typedDict({field: 'string', direction: 'string'})),
]);
export const traceQueryType = union(['none', dict('any')]);
69 changes: 65 additions & 4 deletions weave_query/weave_query/io_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
uris,
wandb_api,
wandb_file_manager,
wandb_trace_server_api
)

tracer = engine_trace.tracer() # type: ignore
Expand Down Expand Up @@ -217,6 +218,7 @@ def __init__(
self.register_handler_fn("ensure_file", self.handle_ensure_file)
self.register_handler_fn("direct_url", self.handle_direct_url)
self.register_handler_fn("sleep", self.handle_sleep)
self.register_handler_fn("query_traces", self.handle_query_traces)

if process:
self.request_handler = aioprocessing.AioProcess(
Expand Down Expand Up @@ -343,6 +345,7 @@ async def _request_handler_fn_main(self) -> None:
self.wandb_file_manager = wandb_file_manager.WandbFileManagerAsync(
fs, net, await wandb_api.get_wandb_api()
)
self.wandb_trace_server_api = wandb_trace_server_api.WandbTraceApiAsync(net)

self._request_handler_ready_event.set()
while True:
Expand Down Expand Up @@ -417,6 +420,28 @@ async def handle_ensure_file(self, artifact_uri: str) -> typing.Optional[str]:
):
raise errors.WeaveInternalError("invalid scheme ", uri)
return await self.wandb_file_manager.ensure_file(uri)

async def handle_query_traces(
self,
project_id: str,
filter: typing.Optional[dict] = None,
limit: typing.Optional[int] = None,
offset: typing.Optional[int] = None,
sort_by: typing.Optional[list] = None,
query: typing.Optional[dict] = None,
include_costs: bool = False,
include_feedback: bool = False,
) -> list[dict]:
return await self.wandb_trace_server_api.query_calls_stream(
project_id,
filter,
limit,
offset,
sort_by,
query,
include_costs,
include_feedback
)

async def handle_ensure_file_downloaded(
self, download_url: str
Expand Down Expand Up @@ -481,12 +506,24 @@ def response_task_ended(self, task: asyncio.Task) -> None:

async def close(self) -> None:
self.connected = False
sentinel = ServerResponse(http_error_code=200, error=False, client_id=self.client_id, id=-1, value=None)
await self.response_queue.async_put(sentinel)
await self.response_task

async def handle_responses(self) -> None:
while self.connected:
resp = await self.response_queue.async_get()
self.response_queue.task_done()
self.requests[resp.id].set_result(resp)
try:
while self.connected:
resp = await self.response_queue.async_get()

if resp.id == -1:
break

self.response_queue.task_done()
self.requests[resp.id].set_result(resp)
finally:
for future in self.requests.values():
if not future.done():
future.cancel()

async def request(self, name: str, *args: typing.Any) -> typing.Any:
# Caller must check ServerResponse.error!
Expand Down Expand Up @@ -569,6 +606,30 @@ async def direct_url(

async def sleep(self, seconds: float) -> float:
return await self.request("sleep", seconds)

async def query_traces(
self,
project_id: str,
filter: typing.Optional[dict] = None,
limit: typing.Optional[int] = None,
offset: typing.Optional[int] = None,
sort_by: typing.Optional[list] = None,
query: typing.Optional[dict] = None,
include_costs: bool = False,
include_feedback: bool = False
):
res = await self.request(
"query_traces",
project_id,
filter,
limit,
offset,
sort_by,
query,
include_costs,
include_feedback
)
return res


class AsyncClient:
Expand Down
95 changes: 95 additions & 0 deletions weave_query/weave_query/ops_domain/project_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
import json
import typing

from weave_query import errors
from weave_query import weave_types as types
from weave_query import ops_arrow
from weave_query import io_service
from weave_query.api import op
from weave_query import input_provider
from weave_query.gql_op_plugin import wb_gql_op_plugin
Expand Down Expand Up @@ -259,3 +262,95 @@ def artifacts(
for typeEdge in project["artifactTypes_100"]["edges"]
for edge in typeEdge["node"]["artifactCollections_100"]["edges"]
]

async def _get_project_traces(project, payload):
client = io_service.get_async_client()
project_id = f'{project["entity"]["name"]}/{project["name"]}'
filter = None
limit = None
offset = None
sort_by = None
query = None
if payload is not None:
filter = payload.get("filter")
limit = payload.get("limit")
offset = payload.get("offset")
sort_by = payload.get("sort_by")
query = payload.get("query")

loop = asyncio.get_running_loop()
tasks = set()
async with client.connect() as conn:
task = loop.create_task(conn.query_traces(
project_id,
filter=filter,
limit=limit,
offset=offset,
sort_by=sort_by,
query=query))
tasks.add(task)
await asyncio.wait(tasks)
return task.result()

traces_filter_property_types = {
"op_names": types.optional(types.List(types.String())),
"input_refs": types.optional(types.List(types.String())),
"output_refs": types.optional(types.List(types.String())),
"parent_ids": types.optional(types.List(types.String())),
"trace_ids": types.optional(types.List(types.String())),
"call_ids": types.optional(types.List(types.String())),
"trace_roots_only": types.optional(types.Boolean()),
"wb_user_ids": types.optional(types.List(types.String())),
"wb_run_ids": types.optional(types.List(types.String())),
}

traces_input_types = {
"project": wdt.ProjectType,
"payload": types.optional(types.TypedDict(property_types={
"filter": types.optional(types.TypedDict(property_types=traces_filter_property_types, not_required_keys=set(traces_filter_property_types.keys()))),
"limit": types.optional(types.Number()),
"offset": types.optional(types.Number()),
"sort_by": types.optional(types.List(types.TypedDict(property_types={"field": types.String(), "direction": types.String()}))),
"query": types.optional(types.Dict())
}, not_required_keys=set(['filter', 'limit', 'offset', 'sort_by', 'query'])))
}

@op(
name="project-tracesType",
input_type=traces_input_types,
output_type=types.TypeType(),
plugins=wb_gql_op_plugin(
lambda inputs, inner: """
entity {
id
name
}
"""
),
hidden=True
)
def traces_type(project, payload):
res = asyncio.run(_get_project_traces(project, payload))
ttype = types.TypeRegistry.type_of(res)
return ttype

@op(
name="project-traces",
input_type=traces_input_types,
output_type=ops_arrow.ArrowWeaveListType(types.TypedDict({})),
plugins=wb_gql_op_plugin(
lambda inputs, inner: """
entity {
id
name
}
"""
),
refine_output_type=traces_type
)
def traces(project, payload):
res = asyncio.run(_get_project_traces(project, payload))
if res:
return ops_arrow.to_arrow(res)
else:
return ops_arrow.to_arrow([])
Loading
Loading