Skip to content

Commit

Permalink
chore: wip
Browse files Browse the repository at this point in the history
  • Loading branch information
domphan-wandb committed Dec 2, 2024
1 parent 8950d3b commit 0899fce
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 2 deletions.
44 changes: 43 additions & 1 deletion weave-js/src/core/ops/domain/project.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as Urls from '../../_external/util/urls';
import {hash, list} from '../../model';
import {hash, list, typedDict} from '../../model';
import {docType} from '../../util/docs';
import * as OpKinds from '../opKinds';
import {connectionToNodes} from './util';
Expand Down Expand Up @@ -297,3 +297,45 @@ export const opProjectRunQueues = makeProjectOp({
returnType: inputTypes => list('runQueue'),
resolver: ({project}) => project.runQueues,
});

export const opProjectTracesType = makeProjectOp({
name: 'project-tracesType',
argTypes: projectArgTypes,
description: `Returns the ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
argDescriptions: {project: projectArgDescription},
returnValueDescription: `The ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
returnType: inputTypes => 'type',
resolver: ({project}) => project.traces,
hidden: true,
});

export const opProjectTraces = makeProjectOp({
name: 'project-traces',
argTypes: projectArgTypes,
description: `Returns the ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
argDescriptions: {project: projectArgDescription},
returnValueDescription: `The ${docType('list', {
plural: true,
})} for a ${docType('project')}`,
returnType: inputTypes => list(typedDict({})),
resolver: ({project}) => project.traces,
// resolveOutputType: async (
// inputTypes,
// node,
// executableNode,
// client,
// stack
// ) => {
// console.log(executableNode);
// const res = await client.query(
// opProjectTracesType(executableNode.fromOp.inputs as any)
// );
// return res;
// },
});
51 changes: 50 additions & 1 deletion weave_query/weave_query/ops_domain/project_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import typing

import asyncio
from weave_query import errors
from weave_query import weave_types as types
from weave_query import ops_arrow
from weave_query.api import op
from weave_query import input_provider
from weave_query.gql_op_plugin import wb_gql_op_plugin
Expand All @@ -17,6 +18,7 @@
gql_root_op,
make_root_op_gql_op_output_type,
)
from weave_query.wandb_trace_server_api import get_wandb_api

# Section 1/6: Tag Getters
get_project_tag = make_tag_getter_op("project", wdt.ProjectType, op_name="tag-project")
Expand Down Expand Up @@ -259,3 +261,50 @@ def artifacts(
for typeEdge in project["artifactTypes_100"]["edges"]
for edge in typeEdge["node"]["artifactCollections_100"]["edges"]
]

def _get_project_traces(project):
api = get_wandb_api()
res = asyncio.run(api.query_calls_stream(project.id))
print('#############', res)
return res

@op(
name="project-tracesType",
output_type=types.TypeType(),
plugins=wb_gql_op_plugin(
lambda inputs, inner: """
entity {
id
name
}
"""
),
hidden=True
)
def traces_type(project: wdt.Project):
ttype = types.TypeRegistry.type_of([{"test": 1}, {"test": 2}])
print('ttype:', ttype)
return ttype

@op(
name="project-traces",
output_type=ops_arrow.ArrowWeaveListType(types.TypedDict({})),
plugins=wb_gql_op_plugin(
lambda inputs, inner: """
entity {
id
name
}
"""
),
refine_output_type=traces_type
)
def traces(project: wdt.Project):
# api = get_wandb_api()
res = _get_project_traces(project)
print('#############', res)
return ops_arrow.to_arrow([{"test": 1}, {"test": 2}])
if "calls" in res[0]:
return ops_arrow.to_arrow(res[0]["calls"])
else:
return ops_arrow.to_arrow([])
136 changes: 136 additions & 0 deletions weave_query/weave_query/wandb_trace_server_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# This is an experimental client api used to make requests to the
# Weave Trace Server
import contextlib
import contextvars
import typing

import aiohttp

from weave_query import errors
from weave_query import environment as weave_env
from weave_query import wandb_client_api, engine_trace

from weave_query.context_state import WandbApiContext, _wandb_api_context

tracer = engine_trace.tracer() # type: ignore

def get_wandb_api_context() -> typing.Optional[WandbApiContext]:
return _wandb_api_context.get()

class WandbTraceApiAsync:
def __init__(self) -> None:
self.connector = aiohttp.TCPConnector(limit=50)

async def query_calls_stream(
self,
project_id: str,
filter: typing.Optional[dict] = None,
limit: typing.Optional[int] = None,
offset: typing.Optional[int] = None,
include_costs: bool = False,
include_feedback: bool = False,
**kwargs: typing.Any
) -> typing.Any:
wandb_context = get_wandb_api_context()
headers = {
"Accept": "application/jsonl",
"Content-Type": "application/json"
}
auth = None

if wandb_context is not None:
if wandb_context.headers:
headers.update(wandb_context.headers)
if wandb_context.api_key is not None:
auth = aiohttp.BasicAuth("api", wandb_context.api_key)

api_key_override = kwargs.pop("api_key", None)
if api_key_override:
auth = aiohttp.BasicAuth("api", api_key_override)

url = f"{weave_env.weave_trace_server_url()}/calls/stream_query"

payload = {
"project_id": project_id,
"include_costs": include_costs,
"include_feedback": include_feedback,
}

if filter:
payload["filter"] = filter
if limit:
payload["limit"] = limit
if offset:
payload["offset"] = offset

payload.update(kwargs)

async with aiohttp.ClientSession(
connector=self.connector,
headers=headers,
auth=auth
) as session:
async with session.post(url, json=payload) as response:
response.raise_for_status()
return await response.json()

class WandbTraceApiSync:
def query_calls_stream(
self,
project_id: str,
filter: typing.Optional[dict] = None,
limit: typing.Optional[int] = None,
offset: typing.Optional[int] = None,
include_costs: bool = False,
include_feedback: bool = False,
**kwargs: typing.Any
) -> typing.Any:
wandb_context = get_wandb_api_context()
headers = {
"Accept": "application/jsonl",
"Content-Type": "application/json"
}
auth = None

if wandb_context is not None:
if wandb_context.headers:
headers.update(wandb_context.headers)
if wandb_context.api_key is not None:
auth = ( "api", wandb_context.api_key)

api_key_override = kwargs.pop("api_key", None)
if api_key_override:
auth = ("api", api_key_override)

url = f"{weave_env.wandb_base_url()}/calls/stream_query"

payload = {
"project_id": project_id,
"include_costs": include_costs,
"include_feedback": include_feedback,
}

if filter:
payload["filter"] = filter
if limit:
payload["limit"] = limit
if offset:
payload["offset"] = offset

payload.update(kwargs)

with aiohttp.ClientSession(
connector=self.connector,
headers=headers,
auth=auth
) as session:
with session.post(url, json=payload) as response:
response.raise_for_status()
return response.json()

async def get_wandb_api() -> WandbTraceApiAsync:
return WandbTraceApiAsync()


def get_wandb_api_sync() -> WandbTraceApiSync:
return WandbTraceApiSync()

0 comments on commit 0899fce

Please sign in to comment.