Skip to content

Commit

Permalink
fixed it up
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Oct 11, 2024
1 parent 167af04 commit bce97e9
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 49 deletions.
169 changes: 121 additions & 48 deletions weave/trace/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _execute_call(
*args: Any,
__should_raise: bool = True,
**kwargs: Any,
) -> Any:
) -> Union[tuple[Any, "Call"], Coroutine[Any, Any, tuple[Any, "Call"]]]:
func = __op.resolve_fn
client = weave_client_context.require_weave_client()
has_finished = False
Expand All @@ -266,7 +266,7 @@ def on_output(output: Any) -> Any:
finish(output)
return output

def process(res: Any) -> Any:
def process(res: Any) -> tuple[Any, "Call"]:
res = box.box(res)
try:
# Here we do a try/catch because we don't want to
Expand All @@ -285,15 +285,15 @@ def process(res: Any) -> Any:

return res, call

def handle_exception(e: Exception) -> Any:
def handle_exception(e: Exception) -> tuple[Any, "Call"]:
finish(exception=e)
if __should_raise:
raise
return None, call

if inspect.iscoroutinefunction(func):

async def _call_async() -> Coroutine[Any, Any, Any]:
async def _call_async() -> tuple[Any, "Call"]:
try:
res = await func(*args, **kwargs)
except Exception as e:
Expand All @@ -314,8 +314,12 @@ async def _call_async() -> Coroutine[Any, Any, Any]:


def call(
op: Op, *args: Any, __weave: Optional[WeaveKwargs] = None, **kwargs: Any
) -> tuple[Any, "Call"]:
op: Op,
*args: Any,
__weave: Optional[WeaveKwargs] = None,
__should_raise: bool = False,
**kwargs: Any,
) -> Union[tuple[Any, "Call"], Coroutine[Any, Any, tuple[Any, "Call"]]]:
"""
Executes the op and returns both the result and a Call representing the execution.
Expand All @@ -332,8 +336,111 @@ def add(a: int, b: int) -> int:
result, call = add.call(1, 2)
```
"""
c = _create_call(op, *args, __weave=__weave, **kwargs)
return _execute_call(op, c, *args, __should_raise=False, **kwargs)
if inspect.iscoroutinefunction(op.resolve_fn):
return do_call_async(
op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs
)
else:
return do_call(
op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs
)


def _placeholder_call() -> "Call":
# Import here to avoid circular dependency
from weave.trace.weave_client import Call

return Call(
_op_name="",
trace_id="",
project_id="",
parent_id=None,
inputs={},
)


def do_call(
op: Op,
*args: Any,
__weave: Optional[WeaveKwargs] = None,
__should_raise: bool = False,
**kwargs: Any,
) -> tuple[Any, "Call"]:
func = op.resolve_fn
call = _placeholder_call()
if settings.should_disable_weave():
res = func(*args, **kwargs)
elif weave_client_context.get_weave_client() is None:
res = func(*args, **kwargs)
elif not op._tracing_enabled:
res = func(*args, **kwargs)
else:
try:
# This try/except allows us to fail gracefully and
# still let the user code continue to execute
call = _create_call(op, *args, __weave=__weave, **kwargs)
except Exception as e:
if get_raise_on_captured_errors():
raise
log_once(
logger.error,
CALL_CREATE_MSG.format(traceback.format_exc()),
)
res = func(*args, **kwargs)
else:
execute_result = _execute_call(
op, call, *args, __should_raise=__should_raise, **kwargs
)
if inspect.iscoroutine(execute_result):
raise Exception(
"Internal error: Expected `_execute_call` to return a sync result"
)
execute_result = cast(tuple[Any, "Call"], execute_result)
res, call = execute_result
return res, call


async def do_call_async(
op: Op,
*args: Any,
__weave: Optional[WeaveKwargs] = None,
__should_raise: bool = False,
**kwargs: Any,
) -> tuple[Any, "Call"]:
func = op.resolve_fn
call = _placeholder_call()
if settings.should_disable_weave():
res = await func(*args, **kwargs)
elif weave_client_context.get_weave_client() is None:
res = await func(*args, **kwargs)
elif not op._tracing_enabled:
res = await func(*args, **kwargs)
else:
try:
# This try/except allows us to fail gracefully and
# still let the user code continue to execute
call = _create_call(op, *args, __weave=__weave, **kwargs)
except Exception as e:
if get_raise_on_captured_errors():
raise
log_once(
logger.error,
ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()),
)
res = await func(*args, **kwargs)
else:
execute_result = _execute_call(
op, call, *args, __should_raise=__should_raise, **kwargs
)
if not inspect.iscoroutine(execute_result):
raise Exception(
"Internal error: Expected `_execute_call` to return a coroutine"
)
execute_result = cast(
Coroutine[Any, Any, tuple[Any, "Call"]], execute_result
)
res, call = await execute_result
return res, call


def calls(op: Op) -> "CallsIter":
Expand Down Expand Up @@ -453,51 +560,17 @@ def create_wrapper(func: Callable) -> Op:

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
__weave: Optional[WeaveKwargs] = kwargs.pop(WEAVE_KWARGS_KEY, None)
if settings.should_disable_weave():
return await func(*args, **kwargs)
if weave_client_context.get_weave_client() is None:
return await func(*args, **kwargs)
if not wrapper._tracing_enabled: # type: ignore
return await func(*args, **kwargs)
try:
# This try/except allows us to fail gracefully and
# still let the user code continue to execute
call = _create_call(wrapper, *args, __weave=__weave, **kwargs) # type: ignore
except Exception as e:
if get_raise_on_captured_errors():
raise
log_once(
logger.error,
ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()),
)
return await func(*args, **kwargs)
res, _ = await _execute_call(wrapper, call, *args, **kwargs) # type: ignore
res, _ = await do_call_async(
cast(Op, wrapper), *args, __should_raise=True, **kwargs
)
return res
else:

@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
__weave: Optional[WeaveKwargs] = kwargs.pop(WEAVE_KWARGS_KEY, None)
if settings.should_disable_weave():
return func(*args, **kwargs)
if weave_client_context.get_weave_client() is None:
return func(*args, **kwargs)
if not wrapper._tracing_enabled: # type: ignore
return func(*args, **kwargs)
try:
# This try/except allows us to fail gracefully and
# still let the user code continue to execute

call = _create_call(wrapper, *args, __weave=__weave, **kwargs) # type: ignore
except Exception as e:
if get_raise_on_captured_errors():
raise
log_once(
logger.error, CALL_CREATE_MSG.format(traceback.format_exc())
)
return func(*args, **kwargs)
res, _ = _execute_call(wrapper, call, *args, **kwargs) # type: ignore
res, _ = do_call(
cast(Op, wrapper), *args, __should_raise=True, **kwargs
)
return res

# Tack these helpers on to our wrapper
Expand Down
3 changes: 2 additions & 1 deletion weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,8 @@ def finish_call(
*,
op: Optional[Op] = None,
) -> None:
ended_at = datetime.datetime.now(tz=datetime.timezone.utc)
call.ended_at = ended_at
original_output = output

if op is not None and op.postprocess_output:
Expand Down Expand Up @@ -771,7 +773,6 @@ def finish_call(
call.exception = exception_str

project_id = self._project_id()
ended_at = datetime.datetime.now(tz=datetime.timezone.utc)

# The finish handler serves as a last chance for integrations
# to customize what gets logged for a call.
Expand Down

0 comments on commit bce97e9

Please sign in to comment.