diff --git a/weave/trace/op.py b/weave/trace/op.py index 7a6b4677e1de..4678a863ea92 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -7,6 +7,7 @@ import sys import traceback from collections.abc import Coroutine, Mapping +from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps from types import MethodType @@ -519,6 +520,8 @@ def _is_context_manager(obj: Any) -> bool: __enter__ = getattr(obj, "__enter__", None) __exit__ = getattr(obj, "__exit__", None) + print(__enter__, __exit__) + return callable(__enter__) and callable(__exit__) @@ -528,6 +531,7 @@ def _execute_op( __weave: WeaveKwargs | None = None, __should_raise: bool = True, __should_accumulate: bool = False, + __should_use_contextmanager: bool = False, **kwargs: Any, ) -> tuple[Any, Call]: func = __op.resolve_fn @@ -559,38 +563,79 @@ def _execute_op( display_name=call_time_display_name or __op.call_display_name, attributes=attributes, ) - __op.lifecycle_handler.run_before_call({}, None, None, "") + __op.lifecycle_handler.run_before_call(inputs, parent_call, attributes) should_accumulate = __should_accumulate and __should_accumulate(call) + should_use_contextmanager = ( + __should_use_contextmanager and __should_use_contextmanager(func) + ) + if is_generator or should_accumulate: + # TODO: Not sure if this is the right pattern. This hack is only + # used to support the Anthropic streaming context streaming case atm... + if should_use_contextmanager: - @wraps(func) - def _wrapped_sync_generator(): - try: - for val in func(*args, **kwargs): - __op.lifecycle_handler.run_before_yield(call, val) - yield val - __op.lifecycle_handler.run_after_yield(call, val) - except Exception as e: - exception = e - if __should_raise: - raise - else: + @contextmanager + @wraps(func) + def _wrapped_context_manager_yields_sync_generator(): exception = None - __op.lifecycle_handler.run_after_yield_all(call) - finally: + with func(*args, **kwargs) as original_gen: + + def _wrapped_sync_generator(): + nonlocal exception + try: + for val in original_gen: + __op.lifecycle_handler.run_before_yield(call, val) + yield val + __op.lifecycle_handler.run_after_yield(call, val) + except Exception as e: + exception = e + if __should_raise: + raise + else: + __op.lifecycle_handler.run_after_yield_all(call) + + yield _wrapped_sync_generator() + if __op.lifecycle_handler.has_finished: raise OpCallError("Should not call finish more than once") boxed_output = box.box(call.output) client.finish_call(call, boxed_output, exception=exception, op=__op) - # TODO: We choose to print the call link at the end, but for streaming cases it might be helpful to: - # 1. Print the link at the beginning of the stream - # 2. (maybe?): Update the call object periodically with the latest output if not call_context.get_current_call(): print_call_link(call) call_context.pop_call(call.id) - res = _wrapped_sync_generator() + res = _wrapped_context_manager_yields_sync_generator() + + else: + + @wraps(func) + def _wrapped_sync_generator(): + try: + for val in func(*args, **kwargs): + __op.lifecycle_handler.run_before_yield(call, val) + yield val + __op.lifecycle_handler.run_after_yield(call, val) + except Exception as e: + exception = e + if __should_raise: + raise + else: + exception = None + __op.lifecycle_handler.run_after_yield_all(call) + finally: + if __op.lifecycle_handler.has_finished: + raise OpCallError("Should not call finish more than once") + boxed_output = box.box(call.output) + client.finish_call(call, boxed_output, exception=exception, op=__op) + # TODO: We choose to print the call link at the end, but for streaming cases it might be helpful to: + # 1. Print the link at the beginning of the stream + # 2. (maybe?): Update the call object periodically with the latest output + if not call_context.get_current_call(): + print_call_link(call) + call_context.pop_call(call.id) + + res = _wrapped_sync_generator() else: # Regular sync function try: @@ -650,8 +695,6 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, - __should_accumulate: Callable[[Call], bool] - | None = None, # escape hatch for integrations ) -> Op: ... @@ -664,8 +707,6 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, - __should_accumulate: Callable[[Call], bool] - | None = None, # escape hatch for integrations ) -> Callable[[Callable], Op]: ... @@ -678,8 +719,9 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, - __should_accumulate: Callable[[Call], bool] - | None = None, # escape hatch for integrations + # escape hatch for integrations + __should_accumulate: Callable[[Call], bool] | None = None, + __should_use_contextmanager: Callable[[Callable], bool] | None = None, ) -> Callable[[Callable], Op] | Op: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -768,6 +810,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: *args, __should_raise=True, __should_accumulate=__should_accumulate, + __should_use_contextmanager=__should_use_contextmanager, **kwargs, ) return res