From e8b1f406baec7802d3e5c5010f5079bba91d5d40 Mon Sep 17 00:00:00 2001 From: Tim Sweeney Date: Fri, 11 Oct 2024 16:16:38 -0700 Subject: [PATCH] chore: Test all coroutine patterns and refactor op call (#2684) * tests! * fixed it up * small fix * comments * comments * comments --- tests/trace/test_op_coroutines.py | 220 +++++++++++++++++++++ tests/trace/test_op_decorator_behaviour.py | 2 +- weave/trace/op.py | 173 +++++++++++----- weave/trace/weave_client.py | 3 +- 4 files changed, 348 insertions(+), 50 deletions(-) create mode 100644 tests/trace/test_op_coroutines.py diff --git a/tests/trace/test_op_coroutines.py b/tests/trace/test_op_coroutines.py new file mode 100644 index 00000000000..6117a381882 --- /dev/null +++ b/tests/trace/test_op_coroutines.py @@ -0,0 +1,220 @@ +import asyncio +from typing import Coroutine + +import pytest + +import weave +from weave.trace.weave_client import Call + + +def test_sync_val(client): + @weave.op() + def sync_val(): + return 1 + + res = sync_val() + assert res == 1 + res, call = sync_val.call() + assert isinstance(call, Call) + assert res == 1 + + +def test_sync_val_method(client): + class TestClass: + @weave.op() + def sync_val(self): + return 1 + + test_inst = TestClass() + res = test_inst.sync_val() + assert res == 1 + res, call = test_inst.sync_val.call(test_inst) + assert isinstance(call, Call) + assert res == 1 + + +@pytest.mark.asyncio +async def test_sync_coro(client): + @weave.op() + def sync_coro(): + return asyncio.to_thread(lambda: 1) + + res = sync_coro() + assert isinstance(res, Coroutine) + assert await res == 1 + res, call = sync_coro.call() + assert isinstance(call, Call) + assert isinstance(res, Coroutine) + assert await res == 1 + + +@pytest.mark.asyncio +async def test_sync_coro_method(client): + class TestClass: + @weave.op() + def sync_coro(self): + return asyncio.to_thread(lambda: 1) + + test_inst = TestClass() + res = test_inst.sync_coro() + assert isinstance(res, Coroutine) + assert await res == 1 + res, call = test_inst.sync_coro.call(test_inst) + assert isinstance(call, Call) + assert isinstance(res, Coroutine) + assert await res == 1 + + +@pytest.mark.asyncio +async def test_async_coro(client): + @weave.op() + async def async_coro(): + return asyncio.to_thread(lambda: 1) + + res = async_coro() + assert isinstance(res, Coroutine) + res2 = await res + assert isinstance(res2, Coroutine) + assert await res2 == 1 + res, call = await async_coro.call() + assert isinstance(call, Call) + assert isinstance(res, Coroutine) + assert await res == 1 + + +@pytest.mark.asyncio +async def test_async_coro_method(client): + class TestClass: + @weave.op() + async def async_coro(self): + return asyncio.to_thread(lambda: 1) + + test_inst = TestClass() + + res = test_inst.async_coro() + assert isinstance(res, Coroutine) + res2 = await res + assert isinstance(res2, Coroutine) + assert await res2 == 1 + res, call = await test_inst.async_coro.call(test_inst) + assert isinstance(call, Call) + assert isinstance(res, Coroutine) + assert await res == 1 + + +@pytest.mark.asyncio +async def test_async_awaited_coro(client): + @weave.op() + async def async_awaited_coro(): + return await asyncio.to_thread(lambda: 1) + + res = async_awaited_coro() + assert isinstance(res, Coroutine) + assert await res == 1 + res, call = await async_awaited_coro.call() + assert isinstance(call, Call) + assert res == 1 + + +@pytest.mark.asyncio +async def test_async_awaited_coro_method(client): + class TestClass: + @weave.op() + async def async_awaited_coro(self): + return await asyncio.to_thread(lambda: 1) + + test_inst = TestClass() + res = test_inst.async_awaited_coro() + assert isinstance(res, Coroutine) + assert await res == 1 + res, call = await test_inst.async_awaited_coro.call(test_inst) + assert isinstance(call, Call) + assert res == 1 + + +@pytest.mark.asyncio +async def test_async_val(client): + @weave.op() + async def async_val(): + return 1 + + res = async_val() + assert isinstance(res, Coroutine) + assert await res == 1 + res, call = await async_val.call() + assert isinstance(call, Call) + assert res == 1 + + +@pytest.mark.asyncio +async def test_async_val_method(client): + class TestClass: + @weave.op() + async def async_val(self): + return 1 + + test_inst = TestClass() + res = test_inst.async_val() + assert isinstance(res, Coroutine) + assert await res == 1 + res, call = await test_inst.async_val.call(test_inst) + assert isinstance(call, Call) + assert res == 1 + + +def test_sync_with_exception(client): + @weave.op() + def sync_with_exception(): + raise ValueError("test") + + with pytest.raises(ValueError, match="test"): + sync_with_exception() + res, call = sync_with_exception.call() + assert isinstance(call, Call) + assert call.exception is not None + assert res is None + + +def test_sync_with_exception_method(client): + class TestClass: + @weave.op() + def sync_with_exception(self): + raise ValueError("test") + + test_inst = TestClass() + with pytest.raises(ValueError, match="test"): + test_inst.sync_with_exception() + res, call = test_inst.sync_with_exception.call(test_inst) + assert isinstance(call, Call) + assert call.exception is not None + assert res is None + + +@pytest.mark.asyncio +async def test_async_with_exception(client): + @weave.op() + async def async_with_exception(): + raise ValueError("test") + + with pytest.raises(ValueError, match="test"): + await async_with_exception() + res, call = await async_with_exception.call() + assert isinstance(call, Call) + assert call.exception is not None + assert res is None + + +@pytest.mark.asyncio +async def test_async_with_exception_method(client): + class TestClass: + @weave.op() + async def async_with_exception(self): + raise ValueError("test") + + test_inst = TestClass() + with pytest.raises(ValueError, match="test"): + await test_inst.async_with_exception() + res, call = await test_inst.async_with_exception.call(test_inst) + assert isinstance(call, Call) + assert call.exception is not None + assert res is None diff --git a/tests/trace/test_op_decorator_behaviour.py b/tests/trace/test_op_decorator_behaviour.py index 6e788fac150..03d9b6991a2 100644 --- a/tests/trace/test_op_decorator_behaviour.py +++ b/tests/trace/test_op_decorator_behaviour.py @@ -141,7 +141,7 @@ def test_sync_method_call(client, weave_obj, py_obj): weave_obj_method2 = weave_obj_method_ref.get() with pytest.raises(errors.OpCallError): - res2, call2 = py_obj.amethod.call(1) + res2, call2 = py_obj.method.call(1) @pytest.mark.asyncio diff --git a/weave/trace/op.py b/weave/trace/op.py index 1130ffd9fa4..7614b1d8630 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -241,7 +241,7 @@ def _execute_call( *args: Any, __should_raise: bool = True, **kwargs: Any, -) -> Any: +) -> Union[tuple[Any, "Call"], Coroutine[Any, Any, tuple[Any, "Call"]]]: func = __op.resolve_fn client = weave_client_context.require_weave_client() has_finished = False @@ -266,7 +266,7 @@ def on_output(output: Any) -> Any: finish(output) return output - def process(res: Any) -> Any: + def process(res: Any) -> tuple[Any, "Call"]: res = box.box(res) try: # Here we do a try/catch because we don't want to @@ -285,7 +285,7 @@ def process(res: Any) -> Any: return res, call - def handle_exception(e: Exception) -> Any: + def handle_exception(e: Exception) -> tuple[Any, "Call"]: finish(exception=e) if __should_raise: raise @@ -293,7 +293,7 @@ def handle_exception(e: Exception) -> Any: if inspect.iscoroutinefunction(func): - async def _call_async() -> Coroutine[Any, Any, Any]: + async def _call_async() -> tuple[Any, "Call"]: try: res = await func(*args, **kwargs) except Exception as e: @@ -314,8 +314,12 @@ async def _call_async() -> Coroutine[Any, Any, Any]: def call( - op: Op, *args: Any, __weave: Optional[WeaveKwargs] = None, **kwargs: Any -) -> tuple[Any, "Call"]: + op: Op, + *args: Any, + __weave: Optional[WeaveKwargs] = None, + __should_raise: bool = False, + **kwargs: Any, +) -> Union[tuple[Any, "Call"], Coroutine[Any, Any, tuple[Any, "Call"]]]: """ Executes the op and returns both the result and a Call representing the execution. @@ -332,8 +336,115 @@ def add(a: int, b: int) -> int: result, call = add.call(1, 2) ``` """ - c = _create_call(op, *args, __weave=__weave, **kwargs) - return _execute_call(op, c, *args, __should_raise=False, **kwargs) + if inspect.iscoroutinefunction(op.resolve_fn): + return _do_call_async( + op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs + ) + else: + return _do_call( + op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs + ) + + +def _placeholder_call() -> "Call": + # Import here to avoid circular dependency + from weave.trace.weave_client import Call + + return Call( + _op_name="", + trace_id="", + project_id="", + parent_id=None, + inputs={}, + ) + + +def _do_call( + op: Op, + *args: Any, + __weave: Optional[WeaveKwargs] = None, + __should_raise: bool = False, + **kwargs: Any, +) -> tuple[Any, "Call"]: + func = op.resolve_fn + call = _placeholder_call() + if settings.should_disable_weave(): + res = func(*args, **kwargs) + elif weave_client_context.get_weave_client() is None: + res = func(*args, **kwargs) + elif not op._tracing_enabled: + res = func(*args, **kwargs) + else: + try: + # This try/except allows us to fail gracefully and + # still let the user code continue to execute + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + CALL_CREATE_MSG.format(traceback.format_exc()), + ) + res = func(*args, **kwargs) + else: + execute_result = _execute_call( + op, call, *args, __should_raise=__should_raise, **kwargs + ) + if inspect.iscoroutine(execute_result): + raise Exception( + "Internal error: Expected `_execute_call` to return a sync result" + ) + execute_result = cast(tuple[Any, "Call"], execute_result) + res, call = execute_result + return res, call + + +async def _do_call_async( + op: Op, + *args: Any, + __weave: Optional[WeaveKwargs] = None, + __should_raise: bool = False, + **kwargs: Any, +) -> tuple[Any, "Call"]: + func = op.resolve_fn + call = _placeholder_call() + if settings.should_disable_weave(): + res = await func(*args, **kwargs) + elif weave_client_context.get_weave_client() is None: + res = await func(*args, **kwargs) + elif not op._tracing_enabled: + res = await func(*args, **kwargs) + else: + try: + # This try/except allows us to fail gracefully and + # still let the user code continue to execute + call = _create_call(op, *args, __weave=__weave, **kwargs) + except OpCallError as e: + raise e + except Exception as e: + if get_raise_on_captured_errors(): + raise + log_once( + logger.error, + ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), + ) + res = await func(*args, **kwargs) + else: + execute_result = _execute_call( + op, call, *args, __should_raise=__should_raise, **kwargs + ) + if not inspect.iscoroutine(execute_result): + raise Exception( + "Internal error: Expected `_execute_call` to return a coroutine" + ) + execute_result = cast( + Coroutine[Any, Any, tuple[Any, "Call"]], execute_result + ) + res, call = await execute_result + return res, call def calls(op: Op) -> "CallsIter": @@ -453,51 +564,17 @@ def create_wrapper(func: Callable) -> Op: @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: - __weave: Optional[WeaveKwargs] = kwargs.pop(WEAVE_KWARGS_KEY, None) - if settings.should_disable_weave(): - return await func(*args, **kwargs) - if weave_client_context.get_weave_client() is None: - return await func(*args, **kwargs) - if not wrapper._tracing_enabled: # type: ignore - return await func(*args, **kwargs) - try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute - call = _create_call(wrapper, *args, __weave=__weave, **kwargs) # type: ignore - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), - ) - return await func(*args, **kwargs) - res, _ = await _execute_call(wrapper, call, *args, **kwargs) # type: ignore + res, _ = await _do_call_async( + cast(Op, wrapper), *args, __should_raise=True, **kwargs + ) return res else: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - __weave: Optional[WeaveKwargs] = kwargs.pop(WEAVE_KWARGS_KEY, None) - if settings.should_disable_weave(): - return func(*args, **kwargs) - if weave_client_context.get_weave_client() is None: - return func(*args, **kwargs) - if not wrapper._tracing_enabled: # type: ignore - return func(*args, **kwargs) - try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute - - call = _create_call(wrapper, *args, __weave=__weave, **kwargs) # type: ignore - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, CALL_CREATE_MSG.format(traceback.format_exc()) - ) - return func(*args, **kwargs) - res, _ = _execute_call(wrapper, call, *args, **kwargs) # type: ignore + res, _ = _do_call( + cast(Op, wrapper), *args, __should_raise=True, **kwargs + ) return res # Tack these helpers on to our wrapper diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 55d6053b560..ac6c17f9994 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -715,6 +715,8 @@ def finish_call( *, op: Optional[Op] = None, ) -> None: + ended_at = datetime.datetime.now(tz=datetime.timezone.utc) + call.ended_at = ended_at original_output = output if op is not None and op.postprocess_output: @@ -771,7 +773,6 @@ def finish_call( call.exception = exception_str project_id = self._project_id() - ended_at = datetime.datetime.now(tz=datetime.timezone.utc) # The finish handler serves as a last chance for integrations # to customize what gets logged for a call.