From 3eece12edde11bc361fd820f2008f6271e47d921 Mon Sep 17 00:00:00 2001 From: dom phan Date: Mon, 2 Dec 2024 08:13:43 -0800 Subject: [PATCH 1/7] chore: wip --- weave-js/src/core/ops/domain/project.ts | 44 +++++- .../weave_query/ops_domain/project_ops.py | 51 ++++++- .../weave_query/wandb_trace_server_api.py | 136 ++++++++++++++++++ 3 files changed, 229 insertions(+), 2 deletions(-) create mode 100644 weave_query/weave_query/wandb_trace_server_api.py diff --git a/weave-js/src/core/ops/domain/project.ts b/weave-js/src/core/ops/domain/project.ts index d8bbb1cdf64..ba369ac3e31 100644 --- a/weave-js/src/core/ops/domain/project.ts +++ b/weave-js/src/core/ops/domain/project.ts @@ -1,5 +1,5 @@ import * as Urls from '../../_external/util/urls'; -import {hash, list} from '../../model'; +import {hash, list, typedDict} from '../../model'; import {docType} from '../../util/docs'; import * as OpKinds from '../opKinds'; import {connectionToNodes} from './util'; @@ -297,3 +297,45 @@ export const opProjectRunQueues = makeProjectOp({ returnType: inputTypes => list('runQueue'), resolver: ({project}) => project.runQueues, }); + +export const opProjectTracesType = makeProjectOp({ + name: 'project-tracesType', + argTypes: projectArgTypes, + description: `Returns the ${docType('list', { + plural: true, + })} for a ${docType('project')}`, + argDescriptions: {project: projectArgDescription}, + returnValueDescription: `The ${docType('list', { + plural: true, + })} for a ${docType('project')}`, + returnType: inputTypes => 'type', + resolver: ({project}) => project.traces, + hidden: true, +}); + +export const opProjectTraces = makeProjectOp({ + name: 'project-traces', + argTypes: projectArgTypes, + description: `Returns the ${docType('list', { + plural: true, + })} for a ${docType('project')}`, + argDescriptions: {project: projectArgDescription}, + returnValueDescription: `The ${docType('list', { + plural: true, + })} for a ${docType('project')}`, + returnType: inputTypes => list(typedDict({})), + resolver: ({project}) => project.traces, + // resolveOutputType: async ( + // inputTypes, + // node, + // executableNode, + // client, + // stack + // ) => { + // console.log(executableNode); + // const res = await client.query( + // opProjectTracesType(executableNode.fromOp.inputs as any) + // ); + // return res; + // }, +}); diff --git a/weave_query/weave_query/ops_domain/project_ops.py b/weave_query/weave_query/ops_domain/project_ops.py index 0cfbaf05117..20d5184b53b 100644 --- a/weave_query/weave_query/ops_domain/project_ops.py +++ b/weave_query/weave_query/ops_domain/project_ops.py @@ -1,8 +1,9 @@ import json import typing - +import asyncio from weave_query import errors from weave_query import weave_types as types +from weave_query import ops_arrow from weave_query.api import op from weave_query import input_provider from weave_query.gql_op_plugin import wb_gql_op_plugin @@ -17,6 +18,7 @@ gql_root_op, make_root_op_gql_op_output_type, ) +from weave_query.wandb_trace_server_api import get_wandb_api # Section 1/6: Tag Getters get_project_tag = make_tag_getter_op("project", wdt.ProjectType, op_name="tag-project") @@ -259,3 +261,50 @@ def artifacts( for typeEdge in project["artifactTypes_100"]["edges"] for edge in typeEdge["node"]["artifactCollections_100"]["edges"] ] + +def _get_project_traces(project): + api = get_wandb_api() + res = asyncio.run(api.query_calls_stream(project.id)) + print('#############', res) + return res + +@op( + name="project-tracesType", + output_type=types.TypeType(), + plugins=wb_gql_op_plugin( + lambda inputs, inner: """ + entity { + id + name + } + """ + ), + hidden=True +) +def traces_type(project: wdt.Project): + ttype = types.TypeRegistry.type_of([{"test": 1}, {"test": 2}]) + print('ttype:', ttype) + return ttype + +@op( + name="project-traces", + output_type=ops_arrow.ArrowWeaveListType(types.TypedDict({})), + plugins=wb_gql_op_plugin( + lambda inputs, inner: """ + entity { + id + name + } + """ + ), + refine_output_type=traces_type +) +def traces(project: wdt.Project): + # api = get_wandb_api() + res = _get_project_traces(project) + print('#############', res) + return ops_arrow.to_arrow([{"test": 1}, {"test": 2}]) + if "calls" in res[0]: + return ops_arrow.to_arrow(res[0]["calls"]) + else: + return ops_arrow.to_arrow([]) \ No newline at end of file diff --git a/weave_query/weave_query/wandb_trace_server_api.py b/weave_query/weave_query/wandb_trace_server_api.py new file mode 100644 index 00000000000..3b32e9eff26 --- /dev/null +++ b/weave_query/weave_query/wandb_trace_server_api.py @@ -0,0 +1,136 @@ +# This is an experimental client api used to make requests to the +# Weave Trace Server +import contextlib +import contextvars +import typing + +import aiohttp + +from weave_query import errors +from weave_query import environment as weave_env +from weave_query import wandb_client_api, engine_trace + +from weave_query.context_state import WandbApiContext, _wandb_api_context + +tracer = engine_trace.tracer() # type: ignore + +def get_wandb_api_context() -> typing.Optional[WandbApiContext]: + return _wandb_api_context.get() + +class WandbTraceApiAsync: + def __init__(self) -> None: + self.connector = aiohttp.TCPConnector(limit=50) + + async def query_calls_stream( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + include_costs: bool = False, + include_feedback: bool = False, + **kwargs: typing.Any + ) -> typing.Any: + wandb_context = get_wandb_api_context() + headers = { + "Accept": "application/jsonl", + "Content-Type": "application/json" + } + auth = None + + if wandb_context is not None: + if wandb_context.headers: + headers.update(wandb_context.headers) + if wandb_context.api_key is not None: + auth = aiohttp.BasicAuth("api", wandb_context.api_key) + + api_key_override = kwargs.pop("api_key", None) + if api_key_override: + auth = aiohttp.BasicAuth("api", api_key_override) + + url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" + + payload = { + "project_id": project_id, + "include_costs": include_costs, + "include_feedback": include_feedback, + } + + if filter: + payload["filter"] = filter + if limit: + payload["limit"] = limit + if offset: + payload["offset"] = offset + + payload.update(kwargs) + + async with aiohttp.ClientSession( + connector=self.connector, + headers=headers, + auth=auth + ) as session: + async with session.post(url, json=payload) as response: + response.raise_for_status() + return await response.json() + +class WandbTraceApiSync: + def query_calls_stream( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + include_costs: bool = False, + include_feedback: bool = False, + **kwargs: typing.Any + ) -> typing.Any: + wandb_context = get_wandb_api_context() + headers = { + "Accept": "application/jsonl", + "Content-Type": "application/json" + } + auth = None + + if wandb_context is not None: + if wandb_context.headers: + headers.update(wandb_context.headers) + if wandb_context.api_key is not None: + auth = ( "api", wandb_context.api_key) + + api_key_override = kwargs.pop("api_key", None) + if api_key_override: + auth = ("api", api_key_override) + + url = f"{weave_env.wandb_base_url()}/calls/stream_query" + + payload = { + "project_id": project_id, + "include_costs": include_costs, + "include_feedback": include_feedback, + } + + if filter: + payload["filter"] = filter + if limit: + payload["limit"] = limit + if offset: + payload["offset"] = offset + + payload.update(kwargs) + + with aiohttp.ClientSession( + connector=self.connector, + headers=headers, + auth=auth + ) as session: + with session.post(url, json=payload) as response: + response.raise_for_status() + return response.json() + +async def get_wandb_api() -> WandbTraceApiAsync: + return WandbTraceApiAsync() + + +def get_wandb_api_sync() -> WandbTraceApiSync: + return WandbTraceApiSync() From 7db8cb9f6ba2cef6cc4e4fc15ef2a58a2e802ded Mon Sep 17 00:00:00 2001 From: dom phan Date: Tue, 3 Dec 2024 13:26:40 -0800 Subject: [PATCH 2/7] feat(weave_query): create ops to support project level traces --- weave-js/src/core/model/helpers.ts | 6 +- weave-js/src/core/ops/domain/project.ts | 70 +++++++++++++------ weave-js/src/core/ops/domain/util.ts | 25 +++++++ .../weave_query/ops_domain/project_ops.py | 64 ++++++++++++----- .../weave_query/wandb_trace_server_api.py | 66 ++++++++++------- 5 files changed, 168 insertions(+), 63 deletions(-) diff --git a/weave-js/src/core/model/helpers.ts b/weave-js/src/core/model/helpers.ts index 32b560ad4a6..e0500d3b365 100644 --- a/weave-js/src/core/model/helpers.ts +++ b/weave-js/src/core/model/helpers.ts @@ -897,10 +897,14 @@ export function isObjectTypeLike(t: any): t is ObjectType | Union { ); } -export function typedDict(propertyTypes: {[key: string]: Type}): TypedDictType { +export function typedDict( + propertyTypes: {[key: string]: Type}, + notRequiredKeys?: string[] +): TypedDictType { return { type: 'typedDict', propertyTypes, + notRequiredKeys, }; } diff --git a/weave-js/src/core/ops/domain/project.ts b/weave-js/src/core/ops/domain/project.ts index ba369ac3e31..45b4351dd31 100644 --- a/weave-js/src/core/ops/domain/project.ts +++ b/weave-js/src/core/ops/domain/project.ts @@ -1,8 +1,14 @@ import * as Urls from '../../_external/util/urls'; -import {hash, list, typedDict} from '../../model'; +import {hash, list, maybe, typedDict, union} from '../../model'; import {docType} from '../../util/docs'; import * as OpKinds from '../opKinds'; -import {connectionToNodes} from './util'; +import { + connectionToNodes, + traceFilterType, + traceLimitType, + traceOffsetType, + traceSortByType, +} from './util'; const makeProjectOp = OpKinds.makeTaggingStandardOp; @@ -298,13 +304,38 @@ export const opProjectRunQueues = makeProjectOp({ resolver: ({project}) => project.runQueues, }); +const projectTracesArgTypes = { + ...projectArgTypes, + payload: union([ + 'none', + typedDict( + { + filter: traceFilterType, + limit: traceLimitType, + offset: traceOffsetType, + sort_by: traceSortByType, + }, + ['filter', 'limit', 'offset', 'sort_by'] + ), + ]), +}; + +const projectTracesArgTypesDescription = { + project: projectArgDescription, + payload: 'The payload object to the trace api', + 'payload.filter': `The filter object used when querying traces`, + 'payload.limit': `A number representing the limit for number of trace calls`, + 'payload.offset': `A number representing the offset for the number of trace calls`, + 'payload.sort_by': `An array with a dictionary with keys \`field\`() and \`direction\` ("asc"|"desc")`, +}; + export const opProjectTracesType = makeProjectOp({ name: 'project-tracesType', - argTypes: projectArgTypes, + argTypes: projectTracesArgTypes, description: `Returns the ${docType('list', { plural: true, })} for a ${docType('project')}`, - argDescriptions: {project: projectArgDescription}, + argDescriptions: projectTracesArgTypesDescription, returnValueDescription: `The ${docType('list', { plural: true, })} for a ${docType('project')}`, @@ -315,27 +346,26 @@ export const opProjectTracesType = makeProjectOp({ export const opProjectTraces = makeProjectOp({ name: 'project-traces', - argTypes: projectArgTypes, + argTypes: projectTracesArgTypes, description: `Returns the ${docType('list', { plural: true, - })} for a ${docType('project')}`, - argDescriptions: {project: projectArgDescription}, + })} of traces for a ${docType('project')}`, + argDescriptions: projectTracesArgTypesDescription, returnValueDescription: `The ${docType('list', { plural: true, })} for a ${docType('project')}`, returnType: inputTypes => list(typedDict({})), resolver: ({project}) => project.traces, - // resolveOutputType: async ( - // inputTypes, - // node, - // executableNode, - // client, - // stack - // ) => { - // console.log(executableNode); - // const res = await client.query( - // opProjectTracesType(executableNode.fromOp.inputs as any) - // ); - // return res; - // }, + resolveOutputType: async ( + inputTypes, + node, + executableNode, + client, + stack + ) => { + const res = await client.query( + opProjectTracesType(executableNode.fromOp.inputs as any) + ); + return res; + }, }); diff --git a/weave-js/src/core/ops/domain/util.ts b/weave-js/src/core/ops/domain/util.ts index df05ccfe3af..a77cbb6dd17 100644 --- a/weave-js/src/core/ops/domain/util.ts +++ b/weave-js/src/core/ops/domain/util.ts @@ -1,3 +1,5 @@ +import {list, typedDict, TypedDictType, union} from '../../model'; + /** * `connectionToNodes` is a helper function that converts a `connection` type * returned from gql to its list of nodes. In many, many location in our @@ -41,3 +43,26 @@ export const connectionToNodes = (connection: MaybeConnection): T[] => (connection?.edges ?? []) .map(edge => edge?.node) .filter(node => node != null) as T[]; + +const traceFilterPropertyTypes = { + trace_roots_only: union(['none', 'boolean']), + op_names: union(['none', list('string')]), + input_refs: union(['none', list('string')]), + output_refs: union(['none', list('string')]), + parent_ids: union(['none', list('string')]), + trace_ids: union(['none', list('string')]), + call_ids: union(['none', list('string')]), + wb_user_ids: union(['none', list('string')]), + wb_run_ids: union(['none', list('string')]), +}; + +export const traceFilterType = union([ + 'none', + typedDict(traceFilterPropertyTypes, Object.keys(traceFilterPropertyTypes)), +]); +export const traceLimitType = union(['none', 'number']); +export const traceOffsetType = union(['none', 'number']); +export const traceSortByType = union([ + 'none', + list(typedDict({field: 'string', direction: 'string'})), +]); diff --git a/weave_query/weave_query/ops_domain/project_ops.py b/weave_query/weave_query/ops_domain/project_ops.py index 20d5184b53b..c7f532a3057 100644 --- a/weave_query/weave_query/ops_domain/project_ops.py +++ b/weave_query/weave_query/ops_domain/project_ops.py @@ -1,6 +1,6 @@ import json import typing -import asyncio + from weave_query import errors from weave_query import weave_types as types from weave_query import ops_arrow @@ -18,7 +18,7 @@ gql_root_op, make_root_op_gql_op_output_type, ) -from weave_query.wandb_trace_server_api import get_wandb_api +from weave_query.wandb_trace_server_api import get_wandb_trace_api_sync # Section 1/6: Tag Getters get_project_tag = make_tag_getter_op("project", wdt.ProjectType, op_name="tag-project") @@ -262,14 +262,48 @@ def artifacts( for edge in typeEdge["node"]["artifactCollections_100"]["edges"] ] -def _get_project_traces(project): - api = get_wandb_api() - res = asyncio.run(api.query_calls_stream(project.id)) - print('#############', res) +def _get_project_traces(project, payload): + api = get_wandb_trace_api_sync() + project_id = f'{project["entity"]["name"]}/{project["name"]}' + filter = None + limit = None + offset = None + sort_by = None + if payload is not None: + filter = payload.get("filter") + limit = payload.get("limit") + offset = payload.get("offset") + sort_by = payload.get("sort_by") + res = [] + for call in api.query_calls_stream(project_id, filter=filter, limit=limit, offset=offset, sort_by=sort_by): + res.append(call) return res +traces_filter_property_types = { + "op_names": types.optional(types.List(types.String())), + "input_refs": types.optional(types.List(types.String())), + "output_refs": types.optional(types.List(types.String())), + "parent_ids": types.optional(types.List(types.String())), + "trace_ids": types.optional(types.List(types.String())), + "call_ids": types.optional(types.List(types.String())), + "trace_roots_only": types.optional(types.Boolean()), + "wb_user_ids": types.optional(types.List(types.String())), + "wb_run_ids": types.optional(types.List(types.String())), +} + +traces_input_types = { + "project": wdt.ProjectType, + "payload": types.optional(types.TypedDict(property_types={ + "filter": types.optional(types.TypedDict(property_types=traces_filter_property_types, not_required_keys=set(traces_filter_property_types.keys()))), + "limit": types.optional(types.Number()), + "offset": types.optional(types.Number()), + "sort_by": types.optional(types.List(types.TypedDict(property_types={"field": types.String(), "direction": types.String()}))) + }, not_required_keys=set(['filter', 'limit', 'offset', 'sort_by']))) +} + @op( name="project-tracesType", + input_type=traces_input_types, output_type=types.TypeType(), plugins=wb_gql_op_plugin( lambda inputs, inner: """ @@ -281,13 +315,14 @@ def _get_project_traces(project): ), hidden=True ) -def traces_type(project: wdt.Project): - ttype = types.TypeRegistry.type_of([{"test": 1}, {"test": 2}]) - print('ttype:', ttype) +def traces_type(project, payload): + res = _get_project_traces(project, payload) + ttype = types.TypeRegistry.type_of(res) return ttype @op( name="project-traces", + input_type=traces_input_types, output_type=ops_arrow.ArrowWeaveListType(types.TypedDict({})), plugins=wb_gql_op_plugin( lambda inputs, inner: """ @@ -299,12 +334,9 @@ def traces_type(project: wdt.Project): ), refine_output_type=traces_type ) -def traces(project: wdt.Project): - # api = get_wandb_api() - res = _get_project_traces(project) - print('#############', res) - return ops_arrow.to_arrow([{"test": 1}, {"test": 2}]) - if "calls" in res[0]: - return ops_arrow.to_arrow(res[0]["calls"]) +def traces(project, payload): + res = _get_project_traces(project, payload) + if res: + return ops_arrow.to_arrow(res) else: return ops_arrow.to_arrow([]) \ No newline at end of file diff --git a/weave_query/weave_query/wandb_trace_server_api.py b/weave_query/weave_query/wandb_trace_server_api.py index 3b32e9eff26..2549915bdaa 100644 --- a/weave_query/weave_query/wandb_trace_server_api.py +++ b/weave_query/weave_query/wandb_trace_server_api.py @@ -3,8 +3,10 @@ import contextlib import contextvars import typing +import json import aiohttp +import requests from weave_query import errors from weave_query import environment as weave_env @@ -17,6 +19,7 @@ def get_wandb_api_context() -> typing.Optional[WandbApiContext]: return _wandb_api_context.get() +# todo(dom): Figure out how to use this async client API within ops. class WandbTraceApiAsync: def __init__(self) -> None: self.connector = aiohttp.TCPConnector(limit=50) @@ -27,15 +30,13 @@ async def query_calls_stream( filter: typing.Optional[dict] = None, limit: typing.Optional[int] = None, offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, include_costs: bool = False, include_feedback: bool = False, **kwargs: typing.Any - ) -> typing.Any: + ) -> typing.AsyncIterator[dict]: wandb_context = get_wandb_api_context() - headers = { - "Accept": "application/jsonl", - "Content-Type": "application/json" - } + headers = {'content-type: application/json'} auth = None if wandb_context is not None: @@ -48,7 +49,9 @@ async def query_calls_stream( if api_key_override: auth = aiohttp.BasicAuth("api", api_key_override) - url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" + # todo(dom): Add env var support instead of hardcoding + # url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" + url = "http://127.0.0.1:6345/calls/stream_query" payload = { "project_id": project_id, @@ -62,17 +65,22 @@ async def query_calls_stream( payload["limit"] = limit if offset: payload["offset"] = offset + if sort_by: + payload["sort_by"] = sort_by payload.update(kwargs) async with aiohttp.ClientSession( connector=self.connector, headers=headers, - auth=auth ) as session: async with session.post(url, json=payload) as response: response.raise_for_status() - return await response.json() + async for line in response.content: + if line: + decoded_line = line.decode('utf-8').strip() + if decoded_line: + yield json.loads(decoded_line) class WandbTraceApiSync: def query_calls_stream( @@ -81,28 +89,28 @@ def query_calls_stream( filter: typing.Optional[dict] = None, limit: typing.Optional[int] = None, offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, include_costs: bool = False, include_feedback: bool = False, **kwargs: typing.Any ) -> typing.Any: wandb_context = get_wandb_api_context() - headers = { - "Accept": "application/jsonl", - "Content-Type": "application/json" - } - auth = None + headers = {} if wandb_context is not None: if wandb_context.headers: headers.update(wandb_context.headers) + headers["authorization"] = "Basic Og==" if wandb_context.api_key is not None: - auth = ( "api", wandb_context.api_key) + auth = ("api", wandb_context.api_key) api_key_override = kwargs.pop("api_key", None) if api_key_override: auth = ("api", api_key_override) - url = f"{weave_env.wandb_base_url()}/calls/stream_query" + # todo(dom): Add env var support instead of hardcoding + # url = f"https://trace_server.wandb.test/calls/stream_query" + url = "http://127.0.0.1:6345/calls/stream_query" payload = { "project_id": project_id, @@ -116,21 +124,27 @@ def query_calls_stream( payload["limit"] = limit if offset: payload["offset"] = offset + if sort_by: + payload["sort_by"] = sort_by payload.update(kwargs) - with aiohttp.ClientSession( - connector=self.connector, + # todo(dom): Figure out a way to specify the auth kwarg with it + # causing a 403 error when it is None (when using the authorization header) + response = requests.post( + url, + json=payload, headers=headers, - auth=auth - ) as session: - with session.post(url, json=payload) as response: - response.raise_for_status() - return response.json() - -async def get_wandb_api() -> WandbTraceApiAsync: - return WandbTraceApiAsync() + stream=True + ) + response.raise_for_status() + + for line in response.iter_lines(): + if line: + yield json.loads(line.decode('utf-8')) +async def get_wandb_trace_api() -> WandbTraceApiAsync: + return WandbTraceApiAsync() -def get_wandb_api_sync() -> WandbTraceApiSync: +def get_wandb_trace_api_sync() -> WandbTraceApiSync: return WandbTraceApiSync() From 6d5391fcd1f919ae184d1fd551037088a265e21c Mon Sep 17 00:00:00 2001 From: dom phan Date: Thu, 12 Dec 2024 10:35:02 -0800 Subject: [PATCH 3/7] chore: integrate async api --- weave-js/src/core/ops/domain/project.ts | 7 +- weave-js/src/core/ops/domain/util.ts | 3 +- weave_query/weave_query/io_service.py | 69 ++++++++++++++++- .../weave_query/ops_domain/project_ops.py | 36 ++++++--- .../weave_query/wandb_trace_server_api.py | 77 +++++++------------ weave_query/weave_query/weave_http.py | 46 +++++++++++ 6 files changed, 172 insertions(+), 66 deletions(-) diff --git a/weave-js/src/core/ops/domain/project.ts b/weave-js/src/core/ops/domain/project.ts index 45b4351dd31..fc2532fb26d 100644 --- a/weave-js/src/core/ops/domain/project.ts +++ b/weave-js/src/core/ops/domain/project.ts @@ -1,5 +1,5 @@ import * as Urls from '../../_external/util/urls'; -import {hash, list, maybe, typedDict, union} from '../../model'; +import {hash, list, typedDict, union} from '../../model'; import {docType} from '../../util/docs'; import * as OpKinds from '../opKinds'; import { @@ -7,6 +7,7 @@ import { traceFilterType, traceLimitType, traceOffsetType, + traceQueryType, traceSortByType, } from './util'; @@ -314,8 +315,9 @@ const projectTracesArgTypes = { limit: traceLimitType, offset: traceOffsetType, sort_by: traceSortByType, + query: traceQueryType, }, - ['filter', 'limit', 'offset', 'sort_by'] + ['filter', 'limit', 'offset', 'sort_by', 'query'] ), ]), }; @@ -327,6 +329,7 @@ const projectTracesArgTypesDescription = { 'payload.limit': `A number representing the limit for number of trace calls`, 'payload.offset': `A number representing the offset for the number of trace calls`, 'payload.sort_by': `An array with a dictionary with keys \`field\`() and \`direction\` ("asc"|"desc")`, + 'payload.query': `A dictionary to query data inspired by mongodb aggregation operators`, }; export const opProjectTracesType = makeProjectOp({ diff --git a/weave-js/src/core/ops/domain/util.ts b/weave-js/src/core/ops/domain/util.ts index a77cbb6dd17..e135635a8e3 100644 --- a/weave-js/src/core/ops/domain/util.ts +++ b/weave-js/src/core/ops/domain/util.ts @@ -1,4 +1,4 @@ -import {list, typedDict, TypedDictType, union} from '../../model'; +import {dict, list, typedDict, union} from '../../model'; /** * `connectionToNodes` is a helper function that converts a `connection` type @@ -66,3 +66,4 @@ export const traceSortByType = union([ 'none', list(typedDict({field: 'string', direction: 'string'})), ]); +export const traceQueryType = union(['none', dict('any')]) diff --git a/weave_query/weave_query/io_service.py b/weave_query/weave_query/io_service.py index 79b917d3634..1713ffd19d5 100644 --- a/weave_query/weave_query/io_service.py +++ b/weave_query/weave_query/io_service.py @@ -35,6 +35,7 @@ uris, wandb_api, wandb_file_manager, + wandb_trace_server_api ) tracer = engine_trace.tracer() # type: ignore @@ -217,6 +218,7 @@ def __init__( self.register_handler_fn("ensure_file", self.handle_ensure_file) self.register_handler_fn("direct_url", self.handle_direct_url) self.register_handler_fn("sleep", self.handle_sleep) + self.register_handler_fn("query_traces", self.handle_query_traces) if process: self.request_handler = aioprocessing.AioProcess( @@ -343,6 +345,7 @@ async def _request_handler_fn_main(self) -> None: self.wandb_file_manager = wandb_file_manager.WandbFileManagerAsync( fs, net, await wandb_api.get_wandb_api() ) + self.wandb_trace_server_api = wandb_trace_server_api.WandbTraceApiAsync(net) self._request_handler_ready_event.set() while True: @@ -417,6 +420,28 @@ async def handle_ensure_file(self, artifact_uri: str) -> typing.Optional[str]: ): raise errors.WeaveInternalError("invalid scheme ", uri) return await self.wandb_file_manager.ensure_file(uri) + + async def handle_query_traces( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + include_costs: bool = False, + include_feedback: bool = False, + ) -> list[dict]: + return await self.wandb_trace_server_api.query_calls_stream( + project_id, + filter, + limit, + offset, + sort_by, + query, + include_costs, + include_feedback + ) async def handle_ensure_file_downloaded( self, download_url: str @@ -481,12 +506,24 @@ def response_task_ended(self, task: asyncio.Task) -> None: async def close(self) -> None: self.connected = False + sentinel = ServerResponse(http_error_code=200, error=False, client_id=self.client_id, id=-1, value=None) + await self.response_queue.async_put(sentinel) + await self.response_task async def handle_responses(self) -> None: - while self.connected: - resp = await self.response_queue.async_get() - self.response_queue.task_done() - self.requests[resp.id].set_result(resp) + try: + while self.connected: + resp = await self.response_queue.async_get() + + if resp.id == -1: + break + + self.response_queue.task_done() + self.requests[resp.id].set_result(resp) + finally: + for future in self.requests.values(): + if not future.done(): + future.cancel() async def request(self, name: str, *args: typing.Any) -> typing.Any: # Caller must check ServerResponse.error! @@ -569,6 +606,30 @@ async def direct_url( async def sleep(self, seconds: float) -> float: return await self.request("sleep", seconds) + + async def query_traces( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + include_costs: bool = False, + include_feedback: bool = False + ): + res = await self.request( + "query_traces", + project_id, + filter, + limit, + offset, + sort_by, + query, + include_costs, + include_feedback + ) + return res class AsyncClient: diff --git a/weave_query/weave_query/ops_domain/project_ops.py b/weave_query/weave_query/ops_domain/project_ops.py index c7f532a3057..e5a8dd038fc 100644 --- a/weave_query/weave_query/ops_domain/project_ops.py +++ b/weave_query/weave_query/ops_domain/project_ops.py @@ -1,9 +1,11 @@ +import asyncio import json import typing from weave_query import errors from weave_query import weave_types as types from weave_query import ops_arrow +from weave_query import io_service from weave_query.api import op from weave_query import input_provider from weave_query.gql_op_plugin import wb_gql_op_plugin @@ -18,7 +20,6 @@ gql_root_op, make_root_op_gql_op_output_type, ) -from weave_query.wandb_trace_server_api import get_wandb_trace_api_sync # Section 1/6: Tag Getters get_project_tag = make_tag_getter_op("project", wdt.ProjectType, op_name="tag-project") @@ -262,22 +263,34 @@ def artifacts( for edge in typeEdge["node"]["artifactCollections_100"]["edges"] ] -def _get_project_traces(project, payload): - api = get_wandb_trace_api_sync() +async def _get_project_traces(project, payload): + client = io_service.get_async_client() project_id = f'{project["entity"]["name"]}/{project["name"]}' filter = None limit = None offset = None sort_by = None + query = None if payload is not None: filter = payload.get("filter") limit = payload.get("limit") offset = payload.get("offset") sort_by = payload.get("sort_by") - res = [] - for call in api.query_calls_stream(project_id, filter=filter, limit=limit, offset=offset, sort_by=sort_by): - res.append(call) - return res + query = payload.get("query") + + loop = asyncio.get_running_loop() + tasks = set() + async with client.connect() as conn: + task = loop.create_task(conn.query_traces( + project_id, + filter=filter, + limit=limit, + offset=offset, + sort_by=sort_by, + query=query)) + tasks.add(task) + await asyncio.wait(tasks) + return task.result() traces_filter_property_types = { "op_names": types.optional(types.List(types.String())), @@ -297,8 +310,9 @@ def _get_project_traces(project, payload): "filter": types.optional(types.TypedDict(property_types=traces_filter_property_types, not_required_keys=set(traces_filter_property_types.keys()))), "limit": types.optional(types.Number()), "offset": types.optional(types.Number()), - "sort_by": types.optional(types.List(types.TypedDict(property_types={"field": types.String(), "direction": types.String()}))) - }, not_required_keys=set(['filter', 'limit', 'offset', 'sort_by']))) + "sort_by": types.optional(types.List(types.TypedDict(property_types={"field": types.String(), "direction": types.String()}))), + "query": types.optional(types.Dict()) + }, not_required_keys=set(['filter', 'limit', 'offset', 'sort_by', 'query']))) } @op( @@ -316,7 +330,7 @@ def _get_project_traces(project, payload): hidden=True ) def traces_type(project, payload): - res = _get_project_traces(project, payload) + res = asyncio.run(_get_project_traces(project, payload)) ttype = types.TypeRegistry.type_of(res) return ttype @@ -335,7 +349,7 @@ def traces_type(project, payload): refine_output_type=traces_type ) def traces(project, payload): - res = _get_project_traces(project, payload) + res = asyncio.run(_get_project_traces(project, payload)) if res: return ops_arrow.to_arrow(res) else: diff --git a/weave_query/weave_query/wandb_trace_server_api.py b/weave_query/weave_query/wandb_trace_server_api.py index 2549915bdaa..56a716e896a 100644 --- a/weave_query/weave_query/wandb_trace_server_api.py +++ b/weave_query/weave_query/wandb_trace_server_api.py @@ -10,7 +10,7 @@ from weave_query import errors from weave_query import environment as weave_env -from weave_query import wandb_client_api, engine_trace +from weave_query import wandb_client_api, engine_trace, weave_http from weave_query.context_state import WandbApiContext, _wandb_api_context @@ -19,11 +19,10 @@ def get_wandb_api_context() -> typing.Optional[WandbApiContext]: return _wandb_api_context.get() -# todo(dom): Figure out how to use this async client API within ops. class WandbTraceApiAsync: - def __init__(self) -> None: - self.connector = aiohttp.TCPConnector(limit=50) - + def __init__(self, http: weave_http.HttpAsync) -> None: + self.http = http + async def query_calls_stream( self, project_id: str, @@ -31,28 +30,24 @@ async def query_calls_stream( limit: typing.Optional[int] = None, offset: typing.Optional[int] = None, sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, include_costs: bool = False, include_feedback: bool = False, - **kwargs: typing.Any - ) -> typing.AsyncIterator[dict]: - wandb_context = get_wandb_api_context() - headers = {'content-type: application/json'} + ) -> typing.List[dict]: + wandb_api_context = get_wandb_api_context() + headers = {'content-type': 'application/json'} auth = None - - if wandb_context is not None: - if wandb_context.headers: - headers.update(wandb_context.headers) - if wandb_context.api_key is not None: - auth = aiohttp.BasicAuth("api", wandb_context.api_key) + cookies = None + if wandb_api_context is not None: + headers = wandb_api_context.headers + cookies = wandb_api_context.cookies + if wandb_api_context.api_key is not None: + auth = aiohttp.BasicAuth("api", wandb_api_context.api_key) + if cookies: + headers["authorization"] = "Basic Og==" - api_key_override = kwargs.pop("api_key", None) - if api_key_override: - auth = aiohttp.BasicAuth("api", api_key_override) + url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" - # todo(dom): Add env var support instead of hardcoding - # url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" - url = "http://127.0.0.1:6345/calls/stream_query" - payload = { "project_id": project_id, "include_costs": include_costs, @@ -67,20 +62,10 @@ async def query_calls_stream( payload["offset"] = offset if sort_by: payload["sort_by"] = sort_by + if query: + payload["query"] = query - payload.update(kwargs) - - async with aiohttp.ClientSession( - connector=self.connector, - headers=headers, - ) as session: - async with session.post(url, json=payload) as response: - response.raise_for_status() - async for line in response.content: - if line: - decoded_line = line.decode('utf-8').strip() - if decoded_line: - yield json.loads(decoded_line) + return await self.http.query_traces(url, payload, headers=headers, cookies=cookies, auth=auth) class WandbTraceApiSync: def query_calls_stream( @@ -90,27 +75,28 @@ def query_calls_stream( limit: typing.Optional[int] = None, offset: typing.Optional[int] = None, sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, include_costs: bool = False, include_feedback: bool = False, **kwargs: typing.Any ) -> typing.Any: wandb_context = get_wandb_api_context() - headers = {} + headers = {'content-type': 'application/json'} + auth = None if wandb_context is not None: if wandb_context.headers: headers.update(wandb_context.headers) - headers["authorization"] = "Basic Og==" if wandb_context.api_key is not None: auth = ("api", wandb_context.api_key) + else: + headers["authorization"] = "Basic Og==" api_key_override = kwargs.pop("api_key", None) if api_key_override: auth = ("api", api_key_override) - # todo(dom): Add env var support instead of hardcoding - # url = f"https://trace_server.wandb.test/calls/stream_query" - url = "http://127.0.0.1:6345/calls/stream_query" + url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" payload = { "project_id": project_id, @@ -126,15 +112,16 @@ def query_calls_stream( payload["offset"] = offset if sort_by: payload["sort_by"] = sort_by + if query: + payload["query"] = query payload.update(kwargs) - # todo(dom): Figure out a way to specify the auth kwarg with it - # causing a 403 error when it is None (when using the authorization header) response = requests.post( url, json=payload, headers=headers, + auth=auth if auth else None, stream=True ) response.raise_for_status() @@ -142,9 +129,3 @@ def query_calls_stream( for line in response.iter_lines(): if line: yield json.loads(line.decode('utf-8')) - -async def get_wandb_trace_api() -> WandbTraceApiAsync: - return WandbTraceApiAsync() - -def get_wandb_trace_api_sync() -> WandbTraceApiSync: - return WandbTraceApiSync() diff --git a/weave_query/weave_query/weave_http.py b/weave_query/weave_query/weave_http.py index 39bd91fbb9c..5d8f6cd67f9 100644 --- a/weave_query/weave_query/weave_http.py +++ b/weave_query/weave_query/weave_http.py @@ -12,6 +12,7 @@ import requests import requests.auth import yarl +import json from weave_query import engine_trace, filesystem, server_error_handling @@ -128,6 +129,29 @@ async def download_file( r.status, "Download failed" ) + async def query_traces( + self, + url: str, + payload: typing.Optional[dict], + headers: typing.Optional[dict[str, str]] = None, + cookies: typing.Optional[dict[str, str]] = None, + auth: typing.Optional[aiohttp.BasicAuth] = None, + ) -> list[dict]: + with tracer.trace("query_traces"): + results = [] + async with self.session.post(url, json=payload, headers=headers, cookies=cookies, auth=auth) as response: + if response.status == 200: + async for line in response.content: + if line: + decoded_line = line.decode('utf-8').strip() + if decoded_line: + results.append(json.loads(decoded_line)) + return results + else: + raise server_error_handling.WeaveInternalHttpException.from_code( + response.status_code, + "Traces query failed", + ) class Http: def __init__(self, fs: filesystem.Filesystem) -> None: @@ -169,3 +193,25 @@ def download_file( r.status_code, "Download failed", # type: ignore ) + + def query_traces( + self, + url: str, + payload: typing.Optional[dict], + headers: typing.Optional[dict[str, str]] = None, + cookies: typing.Optional[dict[str, str]] = None, + auth: typing.Optional[aiohttp.BasicAuth] = None, + ) -> list[dict]: + with tracer.trace("query_traces"): + results = [] + with self.session.post(url, json=payload, headers=headers, cookies=cookies, auth=auth) as response: + if response.status_code == 200: + for line in response.iter_lines(): + if line: + results.append(json.loads(line)) + else: + raise server_error_handling.WeaveInternalHttpException.from_code( + response.status_code, + "Traces query failed", + ) + return results From 02e29baaedb5b0abf8ec542762f3298059e84486 Mon Sep 17 00:00:00 2001 From: dom phan Date: Mon, 16 Dec 2024 16:13:18 -0800 Subject: [PATCH 4/7] chore: lint --- weave-js/src/core/ops/domain/util.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weave-js/src/core/ops/domain/util.ts b/weave-js/src/core/ops/domain/util.ts index e135635a8e3..094cbde2e6c 100644 --- a/weave-js/src/core/ops/domain/util.ts +++ b/weave-js/src/core/ops/domain/util.ts @@ -66,4 +66,4 @@ export const traceSortByType = union([ 'none', list(typedDict({field: 'string', direction: 'string'})), ]); -export const traceQueryType = union(['none', dict('any')]) +export const traceQueryType = union(['none', dict('any')]); From 9429656077d1e825ea2058ab2d06447f8b91f0a8 Mon Sep 17 00:00:00 2001 From: dom phan Date: Mon, 16 Dec 2024 16:25:15 -0800 Subject: [PATCH 5/7] chore: add value to test --- weave-js/src/core/ll.test.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/weave-js/src/core/ll.test.ts b/weave-js/src/core/ll.test.ts index 6e7bb2666a7..95fa418dcc9 100644 --- a/weave-js/src/core/ll.test.ts +++ b/weave-js/src/core/ll.test.ts @@ -340,6 +340,7 @@ describe('ll', () => { 'project-artifactTypes', 'project-artifact', 'project-artifactVersion', + 'project-traces', ]); }); From a0e8dab411048e82fc7e79ccbc93b32b66dbce2b Mon Sep 17 00:00:00 2001 From: dom phan Date: Mon, 16 Dec 2024 17:16:58 -0800 Subject: [PATCH 6/7] chore: remove costs and feedback for now --- weave_query/weave_query/io_service.py | 8 ----- .../weave_query/wandb_trace_server_api.py | 30 ++++++++----------- 2 files changed, 12 insertions(+), 26 deletions(-) diff --git a/weave_query/weave_query/io_service.py b/weave_query/weave_query/io_service.py index 1713ffd19d5..d8327a17b8f 100644 --- a/weave_query/weave_query/io_service.py +++ b/weave_query/weave_query/io_service.py @@ -429,8 +429,6 @@ async def handle_query_traces( offset: typing.Optional[int] = None, sort_by: typing.Optional[list] = None, query: typing.Optional[dict] = None, - include_costs: bool = False, - include_feedback: bool = False, ) -> list[dict]: return await self.wandb_trace_server_api.query_calls_stream( project_id, @@ -439,8 +437,6 @@ async def handle_query_traces( offset, sort_by, query, - include_costs, - include_feedback ) async def handle_ensure_file_downloaded( @@ -615,8 +611,6 @@ async def query_traces( offset: typing.Optional[int] = None, sort_by: typing.Optional[list] = None, query: typing.Optional[dict] = None, - include_costs: bool = False, - include_feedback: bool = False ): res = await self.request( "query_traces", @@ -626,8 +620,6 @@ async def query_traces( offset, sort_by, query, - include_costs, - include_feedback ) return res diff --git a/weave_query/weave_query/wandb_trace_server_api.py b/weave_query/weave_query/wandb_trace_server_api.py index 56a716e896a..8ea05172bce 100644 --- a/weave_query/weave_query/wandb_trace_server_api.py +++ b/weave_query/weave_query/wandb_trace_server_api.py @@ -31,8 +31,6 @@ async def query_calls_stream( offset: typing.Optional[int] = None, sort_by: typing.Optional[list] = None, query: typing.Optional[dict] = None, - include_costs: bool = False, - include_feedback: bool = False, ) -> typing.List[dict]: wandb_api_context = get_wandb_api_context() headers = {'content-type': 'application/json'} @@ -44,14 +42,12 @@ async def query_calls_stream( if wandb_api_context.api_key is not None: auth = aiohttp.BasicAuth("api", wandb_api_context.api_key) if cookies: - headers["authorization"] = "Basic Og==" + headers["authorization"] = "Basic Og==" # base64 encoding of ":" url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" payload = { "project_id": project_id, - "include_costs": include_costs, - "include_feedback": include_feedback, } if filter: @@ -76,21 +72,21 @@ def query_calls_stream( offset: typing.Optional[int] = None, sort_by: typing.Optional[list] = None, query: typing.Optional[dict] = None, - include_costs: bool = False, - include_feedback: bool = False, **kwargs: typing.Any ) -> typing.Any: - wandb_context = get_wandb_api_context() + wandb_api_context = get_wandb_api_context() headers = {'content-type': 'application/json'} auth = None - - if wandb_context is not None: - if wandb_context.headers: - headers.update(wandb_context.headers) - if wandb_context.api_key is not None: - auth = ("api", wandb_context.api_key) - else: - headers["authorization"] = "Basic Og==" + cookies = None + if wandb_api_context is not None: + headers = wandb_api_context.headers + cookies = wandb_api_context.cookies + if wandb_api_context.api_key is not None: + auth = aiohttp.BasicAuth("api", wandb_api_context.api_key) + if cookies: + headers["authorization"] = "Basic Og==" # base64 encoding of ":" + + url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" api_key_override = kwargs.pop("api_key", None) if api_key_override: @@ -100,8 +96,6 @@ def query_calls_stream( payload = { "project_id": project_id, - "include_costs": include_costs, - "include_feedback": include_feedback, } if filter: From 6622028b64a609d8e13eae835ba02de94cfbf8e4 Mon Sep 17 00:00:00 2001 From: dom phan Date: Tue, 17 Dec 2024 08:17:42 -0800 Subject: [PATCH 7/7] chore: adjust trace api handlers --- weave_query/weave_query/io_service.py | 50 ++++++++++++++++--- .../weave_query/ops_domain/project_ops.py | 6 ++- .../weave_query/wandb_trace_server_api.py | 30 ++--------- weave_query/weave_query/weave_http.py | 2 +- 4 files changed, 53 insertions(+), 35 deletions(-) diff --git a/weave_query/weave_query/io_service.py b/weave_query/weave_query/io_service.py index d8327a17b8f..acf67cbdf19 100644 --- a/weave_query/weave_query/io_service.py +++ b/weave_query/weave_query/io_service.py @@ -432,11 +432,11 @@ async def handle_query_traces( ) -> list[dict]: return await self.wandb_trace_server_api.query_calls_stream( project_id, - filter, - limit, - offset, - sort_by, - query, + filter=filter, + limit=limit, + offset=offset, + sort_by=sort_by, + query=query, ) async def handle_ensure_file_downloaded( @@ -611,7 +611,7 @@ async def query_traces( offset: typing.Optional[int] = None, sort_by: typing.Optional[list] = None, query: typing.Optional[dict] = None, - ): + ) -> typing.Optional[list[dict]]: res = await self.request( "query_traces", project_id, @@ -718,6 +718,24 @@ def direct_url( def sleep(self, seconds: float) -> None: return self.request("sleep", seconds) + def query_traces( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + ) -> typing.Optional[list[dict]]: + return self.request( + "query_traces", + project_id, + filter, + limit, + offset, + sort_by, + query, + ) class ServerlessClient: def __init__(self, fs: filesystem.Filesystem) -> None: @@ -727,6 +745,7 @@ def __init__(self, fs: filesystem.Filesystem) -> None: self.wandb_file_manager = wandb_file_manager.WandbFileManager( self.fs, self.http, self.wandb_api ) + self.wandb_trace_server_api = wandb_trace_server_api.WandbTraceApiSync(self.http) def manifest( self, @@ -758,6 +777,25 @@ def direct_url( def sleep(self, seconds: float) -> None: time.sleep(seconds) + def query_traces( + self, + project_id: str, + filter: typing.Optional[dict] = None, + limit: typing.Optional[int] = None, + offset: typing.Optional[int] = None, + sort_by: typing.Optional[list] = None, + query: typing.Optional[dict] = None, + ) -> typing.Optional[list[dict]]: + res = self.wandb_trace_server_api.query_calls_stream( + project_id, + filter=filter, + limit=limit, + offset=offset, + sort_by=sort_by, + query=query, + ) + return res + def get_sync_client() -> typing.Union[SyncClient, ServerlessClient]: if context_state.serverless_io_service(): diff --git a/weave_query/weave_query/ops_domain/project_ops.py b/weave_query/weave_query/ops_domain/project_ops.py index e5a8dd038fc..63c5a3ef29f 100644 --- a/weave_query/weave_query/ops_domain/project_ops.py +++ b/weave_query/weave_query/ops_domain/project_ops.py @@ -331,8 +331,10 @@ async def _get_project_traces(project, payload): ) def traces_type(project, payload): res = asyncio.run(_get_project_traces(project, payload)) - ttype = types.TypeRegistry.type_of(res) - return ttype + if res: + return types.TypeRegistry.type_of(res) + else: + return types.TypeRegistry.type_of([]) @op( name="project-traces", diff --git a/weave_query/weave_query/wandb_trace_server_api.py b/weave_query/weave_query/wandb_trace_server_api.py index 8ea05172bce..0f6187353ab 100644 --- a/weave_query/weave_query/wandb_trace_server_api.py +++ b/weave_query/weave_query/wandb_trace_server_api.py @@ -1,16 +1,11 @@ # This is an experimental client api used to make requests to the # Weave Trace Server -import contextlib -import contextvars import typing -import json import aiohttp -import requests -from weave_query import errors from weave_query import environment as weave_env -from weave_query import wandb_client_api, engine_trace, weave_http +from weave_query import engine_trace, weave_http from weave_query.context_state import WandbApiContext, _wandb_api_context @@ -64,6 +59,8 @@ async def query_calls_stream( return await self.http.query_traces(url, payload, headers=headers, cookies=cookies, auth=auth) class WandbTraceApiSync: + def __init__(self, http: weave_http.Http) -> None: + self.http = http def query_calls_stream( self, project_id: str, @@ -87,12 +84,6 @@ def query_calls_stream( headers["authorization"] = "Basic Og==" # base64 encoding of ":" url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" - - api_key_override = kwargs.pop("api_key", None) - if api_key_override: - auth = ("api", api_key_override) - - url = f"{weave_env.weave_trace_server_url()}/calls/stream_query" payload = { "project_id": project_id, @@ -109,17 +100,4 @@ def query_calls_stream( if query: payload["query"] = query - payload.update(kwargs) - - response = requests.post( - url, - json=payload, - headers=headers, - auth=auth if auth else None, - stream=True - ) - response.raise_for_status() - - for line in response.iter_lines(): - if line: - yield json.loads(line.decode('utf-8')) + return self.http.query_traces(url, payload, headers=headers, cookies=cookies, auth=auth) diff --git a/weave_query/weave_query/weave_http.py b/weave_query/weave_query/weave_http.py index 5d8f6cd67f9..16d9c8f0a40 100644 --- a/weave_query/weave_query/weave_http.py +++ b/weave_query/weave_query/weave_http.py @@ -149,7 +149,7 @@ async def query_traces( return results else: raise server_error_handling.WeaveInternalHttpException.from_code( - response.status_code, + response.status, "Traces query failed", )