Skip to content

Commit

Permalink
chore(weave): Move all sentry tracking to just the weave client
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney authored Apr 3, 2024
1 parent d37a8fb commit b305008
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 9 deletions.
3 changes: 0 additions & 3 deletions weave/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from weave.trace import context as trace_context
from .trace.constants import TRACE_OBJECT_EMOJI
from weave.trace.refs import ObjectRef
from . import trace_sentry

# exposed as part of api
from . import weave_types as types
Expand Down Expand Up @@ -186,7 +185,6 @@ def local_client() -> typing.Iterator[_weave_client.WeaveClient]:
inited_client.reset()


@trace_sentry.global_trace_sentry.watch()
def publish(obj: typing.Any, name: Optional[str] = None) -> _weave_client.ObjectRef:
"""Save and version a python object.
Expand Down Expand Up @@ -226,7 +224,6 @@ def publish(obj: typing.Any, name: Optional[str] = None) -> _weave_client.Object
return ref


@trace_sentry.global_trace_sentry.watch()
def ref(location: str) -> _weave_client.ObjectRef:
"""Construct a Ref to a Weave object.
Expand Down
3 changes: 1 addition & 2 deletions weave/trace/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from weave.trace.errors import OpCallError
from weave.trace.refs import ObjectRef
from weave.trace.context import call_attributes
from weave import graph_client_context, trace_sentry
from weave import graph_client_context
from weave import run_context
from weave import box

Expand Down Expand Up @@ -43,7 +43,6 @@ def __get__(
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._watched_call(*args, **kwargs)

@trace_sentry.global_trace_sentry.watch()
def _watched_call(self, *args: Any, **kwargs: Any) -> Any:
maybe_client = graph_client_context.get_graph_client()
if maybe_client is None:
Expand Down
2 changes: 0 additions & 2 deletions weave/trace/refs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Union, Any
import dataclasses
from ..trace_server import refs_internal
from .. import trace_sentry

DICT_KEY_EDGE_NAME = refs_internal.DICT_KEY_EDGE_NAME
LIST_INDEX_EDGE_NAME = refs_internal.LIST_INDEX_EDGE_NAME
Expand Down Expand Up @@ -59,7 +58,6 @@ def uri(self) -> str:
u += "/" + "/".join(self.extra)
return u

@trace_sentry.global_trace_sentry.watch()
def get(self) -> Any:
# Move import here so that it only happens when the function is called.
# This import is invalid in the trace server and represents a dependency
Expand Down
13 changes: 12 additions & 1 deletion weave/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from requests import HTTPError

from weave.table import Table
from weave import urls
from weave import trace_sentry, urls
from weave import run_context
from weave.trace.op import Op
from weave.trace.object_record import (
Expand Down Expand Up @@ -259,6 +259,7 @@ def _project_id(self) -> str:

# This is used by tests and op_execute still, but the save() interface
# is nicer for clients I think?
@trace_sentry.global_trace_sentry.watch()
def save_object(self, val: Any, name: str, branch: str = "latest") -> ObjectRef:
self.save_nested_objects(val, name=name)
return self._save_object(val, name, branch)
Expand Down Expand Up @@ -288,12 +289,14 @@ def _save_object(self, val: Any, name: str, branch: str = "latest") -> ObjectRef
# save instead?
return ref

@trace_sentry.global_trace_sentry.watch()
def save(self, val: Any, name: str, branch: str = "latest") -> Any:
ref = self.save_object(val, name, branch)
if not isinstance(ref, ObjectRef):
raise ValueError(f"Expected ObjectRef, got {ref}")
return self.get(ref)

@trace_sentry.global_trace_sentry.watch()
def get(self, ref: ObjectRef) -> Any:
try:
read_res = self.server.obj_read(
Expand Down Expand Up @@ -340,6 +343,7 @@ def get(self, ref: ObjectRef) -> Any:

return make_trace_obj(val, ref, self.server, None)

@trace_sentry.global_trace_sentry.watch()
def save_table(self, table: Table) -> TableRef:
response = self.server.table_create(
TableCreateReq(
Expand All @@ -352,12 +356,14 @@ def save_table(self, table: Table) -> TableRef:
entity=self.entity, project=self.project, digest=response.digest
)

@trace_sentry.global_trace_sentry.watch()
def calls(self, filter: Optional[_CallsFilter] = None) -> CallsIter:
if filter is None:
filter = _CallsFilter()

return CallsIter(self.server, self._project_id(), filter)

@trace_sentry.global_trace_sentry.watch()
def call(self, call_id: str) -> TraceObject:
response = self.server.calls_query(
CallsQueryReq(
Expand All @@ -370,12 +376,14 @@ def call(self, call_id: str) -> TraceObject:
response_call = response.calls[0]
return make_client_call(self.entity, self.project, response_call, self.server)

@trace_sentry.global_trace_sentry.watch()
def op_calls(self, op: Op) -> CallsIter:
op_ref = get_ref(op)
if op_ref is None:
raise ValueError(f"Can't get runs for unpublished op: {op}")
return self.calls(_CallsFilter(op_names=[op_ref.uri()]))

@trace_sentry.global_trace_sentry.watch()
def objects(self, filter: Optional[_ObjectVersionFilter] = None) -> list[ObjSchema]:
if not filter:
filter = _ObjectVersionFilter()
Expand All @@ -399,6 +407,7 @@ def _save_op(self, op: Op) -> Ref:
op.ref = op_def_ref # type: ignore
return op_def_ref

@trace_sentry.global_trace_sentry.watch()
def create_call(
self,
op: Union[str, Op],
Expand Down Expand Up @@ -450,6 +459,7 @@ def create_call(
self.server.call_start(CallStartReq(start=start))
return call

@trace_sentry.global_trace_sentry.watch()
def finish_call(self, call: Call, output: Any) -> None:
self.save_nested_objects(output)
output = map_to_refs(output)
Expand Down Expand Up @@ -483,6 +493,7 @@ def finish_call(self, call: Call, output: Any) -> None:
# )["successes"] += 1
call.summary = summary

@trace_sentry.global_trace_sentry.watch()
def fail_call(self, call: Call, exception: BaseException) -> None:
# Full traceback disabled til we fix UI.
# stack_trace = "".join(
Expand Down
1 change: 0 additions & 1 deletion weave/weave_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def get_entity_project_from_project_name(project_name: str) -> tuple[str, str]:
return entity_name, project_name


@trace_sentry.global_trace_sentry.watch()
def init_weave(project_name: str) -> InitializedClient:
from . import wandb_api

Expand Down

0 comments on commit b305008

Please sign in to comment.