From ad75ab9151aba90b3caf779657e60afa00dd556c Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 19 Nov 2024 12:25:54 -0500 Subject: [PATCH 01/17] add basic handler --- weave/trace/op.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 75bea0b9dd74..24d83941202b 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -138,6 +138,50 @@ class WeaveKwargs(TypedDict): display_name: str | None +class Callback(Protocol): + def before_call(self, call: Call) -> None: ... + def before_yield(self, call: Call, value: Any) -> Any: ... + def after_yield(self, call: Call, value: Any) -> Any: ... + def after_call(self, call: Call) -> None: ... + def after_error(self, call: Call, exc: Exception) -> None: ... + + +class LifecycleHandler: + def __init__(self, callbacks: list[Callback] | None = None): + self.callbacks = callbacks or [] + self.intermediate_results: list[Any] = [] + + def add_callback(self, callback: Callback): + self.callbacks.append(callback) + + def run_before_call(self, call: Call) -> None: + for callback in self.callbacks: + if hasattr(callback, "before_call"): + callback.before_call(call) + + def run_before_yield(self, call: Call, value: Any) -> Any: + for callback in self.callbacks: + if hasattr(callback, "before_yield"): + value = callback.before_yield(call, value) + return value + + def run_after_yield(self, call: Call, value: Any) -> Any: + for callback in self.callbacks: + if hasattr(callback, "after_yield"): + value = callback.after_yield(call, value) + return value + + def run_after_call(self, call: Call) -> None: + for callback in self.callbacks: + if hasattr(callback, "after_call"): + callback.after_call(call) + + def run_after_error(self, call: Call, exc: Exception) -> None: + for callback in self.callbacks: + if hasattr(callback, "after_error"): + callback.after_error(call, exc) + + @runtime_checkable class Op(Protocol): """ @@ -169,6 +213,8 @@ class Op(Protocol): call: Callable[..., Any] calls: Callable[..., CallsIter] + lifecycle_handler: LifecycleHandler + _set_on_input_handler: Callable[[OnInputHandlerType], None] _on_input_handler: OnInputHandlerType | None @@ -312,7 +358,7 @@ def process(res: Any) -> tuple[Any, Call]: # break the user process if we trip up on processing # the output res = on_output(res) - except Exception as e: + except Exception: if get_raise_on_captured_errors(): raise log_once(logger.error, ON_OUTPUT_MSG.format(traceback.format_exc())) @@ -527,6 +573,7 @@ def op( call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, + callbacks: list[Callback] | None = None, ) -> Op: ... @@ -537,6 +584,7 @@ def op( call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, + callbacks: list[Callback] | None = None, ) -> Callable[[Callable], Op]: ... @@ -547,6 +595,7 @@ def op( call_display_name: str | CallDisplayNameFunc | None = None, postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, + callbacks: list[Callback] | None = None, ) -> Callable[[Callable], Op] | Op: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -637,6 +686,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: wrapper.signature = sig # type: ignore wrapper.ref = None # type: ignore + wrapper.lifecycle_handler = LifecycleHandler(callbacks) + wrapper.postprocess_inputs = postprocess_inputs # type: ignore wrapper.postprocess_output = postprocess_output # type: ignore From 30b37ef1791ee225915548e44008c27e4178fa05 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 19 Nov 2024 20:01:45 -0500 Subject: [PATCH 02/17] basic sync generator support --- weave/trace/op.py | 131 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 116 insertions(+), 15 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 24d83941202b..6da7d4a16d74 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -139,9 +139,16 @@ class WeaveKwargs(TypedDict): class Callback(Protocol): - def before_call(self, call: Call) -> None: ... - def before_yield(self, call: Call, value: Any) -> Any: ... - def after_yield(self, call: Call, value: Any) -> Any: ... + def before_call( + self, + inputs: dict, + parent: Call | None, + attributes: dict | None, + display_name: str | Callable[[Call], str], + ) -> None: ... + def before_yield(self, call: Call, value: Any) -> None: ... + def after_yield(self, call: Call, value: Any) -> None: ... + def after_yield_all(self, call: Call) -> None: ... def after_call(self, call: Call) -> None: ... def after_error(self, call: Call, exc: Exception) -> None: ... @@ -151,35 +158,69 @@ def __init__(self, callbacks: list[Callback] | None = None): self.callbacks = callbacks or [] self.intermediate_results: list[Any] = [] - def add_callback(self, callback: Callback): + def add_callback(self, callback: Callback) -> None: self.callbacks.append(callback) - def run_before_call(self, call: Call) -> None: + def run_before_call( + self, + inputs: dict, + parent: Call | None, + attributes: dict | None, + display_name: str | Callable[[Call], str], + ) -> None: for callback in self.callbacks: if hasattr(callback, "before_call"): - callback.before_call(call) + callback.before_call(inputs, parent, attributes, display_name) def run_before_yield(self, call: Call, value: Any) -> Any: for callback in self.callbacks: if hasattr(callback, "before_yield"): - value = callback.before_yield(call, value) - return value + try: + callback.before_yield(call, value) + except Exception: + logger.exception( + f"Error in before_yield callback:\n{traceback.format_exc()}" + ) def run_after_yield(self, call: Call, value: Any) -> Any: for callback in self.callbacks: if hasattr(callback, "after_yield"): - value = callback.after_yield(call, value) - return value + try: + callback.after_yield(call, value) + except Exception: + logger.exception( + f"Error in after_yield callback:\n{traceback.format_exc()}" + ) + + def run_after_yield_all(self, call: Call) -> None: + for callback in self.callbacks: + if hasattr(callback, "after_yield_all"): + try: + callback.after_yield_all(call) + except Exception: + logger.exception( + f"Error in after_yield_all callback:\n{traceback.format_exc()}" + ) def run_after_call(self, call: Call) -> None: for callback in self.callbacks: if hasattr(callback, "after_call"): - callback.after_call(call) + try: + callback.after_call(call) + except Exception: + logger.exception( + f"Error in after_call callback:\n{traceback.format_exc()}" + ) def run_after_error(self, call: Call, exc: Exception) -> None: for callback in self.callbacks: if hasattr(callback, "after_error"): - callback.after_error(call, exc) + try: + callback.after_error(call, exc) + except Exception: + logger.exception( + f"Error in after_error callback:\n{traceback.format_exc()}" + ) @runtime_checkable @@ -653,6 +694,8 @@ def op_deco(func: Callable) -> Op: sig = inspect.signature(func) is_method = _is_unbound_method(func) is_async = inspect.iscoroutinefunction(func) + is_generator = inspect.isgeneratorfunction(func) + is_async_generator = inspect.isasyncgenfunction(func) def create_wrapper(func: Callable) -> Op: if is_async: @@ -667,9 +710,67 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRe @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - res, _ = _do_call( - cast(Op, wrapper), *args, __should_raise=True, **kwargs - ) + def _wrapper(__op: Op) -> Any: + # This exists only so we can cast the wrapper to an Op + client = weave_client_context.require_weave_client() + parent_call = call_context.get_current_call() + attributes = call_attributes.get() + + # Instead of creating a call inside here, it should be a dummy + # first before optionally getting passed in + call = client.create_call( + __op, + {"testing": "dict"}, + parent_call, + display_name="testing", + attributes=attributes, + ) + __op.lifecycle_handler.run_before_call({}, None, None, "") + if is_generator: + + def _wrapper_generator(): + for val in func(*args, **kwargs): + __op.lifecycle_handler.run_before_yield(call, val) + yield val + __op.lifecycle_handler.run_after_yield(call, val) + __op.lifecycle_handler.run_after_yield_all(call) + # box the result (but how do you do it here?) + # call on_output + # call the on_output_handler (but this is just after_call?) + # call finish + client.finish_call( + call, + box.box(call.output), + None, + op=__op, + ) + if not call_context.get_current_call(): + print_call_link(call) + + call_context.pop_call(call.id) + + res = _wrapper_generator() + __op.lifecycle_handler.run_after_call(call) + else: + res = func(*args, **kwargs) + __op.lifecycle_handler.run_after_call(call) + + client.finish_call( + call, + box.box(res), + None, + op=__op, + ) + if not call_context.get_current_call(): + print_call_link(call) + # box the result + # call on_output + # call the on_output_handler + # call finish + call_context.pop_call(call.id) + return res, call + + res, _ = _wrapper(cast(Op, wrapper)) return res # Tack these helpers on to our wrapper From 38a7796b4b9e95d786500c6f82c81e1451dd0b20 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 19 Nov 2024 20:57:05 -0500 Subject: [PATCH 03/17] add basic reducer feature --- weave/trace/op.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/weave/trace/op.py b/weave/trace/op.py index 6da7d4a16d74..4ace80e65dbc 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -14,9 +14,11 @@ TYPE_CHECKING, Any, Callable, + Generic, Optional, Protocol, TypedDict, + TypeVar, cast, overload, runtime_checkable, @@ -153,6 +155,34 @@ def after_call(self, call: Call) -> None: ... def after_error(self, call: Call, exc: Exception) -> None: ... +T = TypeVar("T") +Acc = TypeVar("Acc") + + +class ReducerFunc(Protocol, Generic[T, Acc]): + """Any function that implements this can be automatically converted into a reducer callback.""" + + def __call__(self, val: T, acc: Acc) -> Acc: ... + + +@dataclass +class Reducer(Generic[T, Acc]): + func: ReducerFunc[T, Acc] + initial_acc: Acc + + +class ReducerCallback(Generic[T, Acc]): + def __init__(self, reducer: Reducer[T, Acc]): + self.func = reducer.func + self.acc = reducer.initial_acc + + def after_yield(self, call, val): + self.acc = self.func(val, self.acc) + + def after_yield_all(self, call): + call.output = self.acc + + class LifecycleHandler: def __init__(self, callbacks: list[Callback] | None = None): self.callbacks = callbacks or [] @@ -615,6 +645,7 @@ def op( postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, + reducers: list[Reducer] | None = None, ) -> Op: ... @@ -626,6 +657,7 @@ def op( postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, + reducers: list[Reducer] | None = None, ) -> Callable[[Callable], Op]: ... @@ -637,6 +669,7 @@ def op( postprocess_inputs: PostprocessInputsFunc | None = None, postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, + reducers: list[Reducer] | None = None, ) -> Callable[[Callable], Op] | Op: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -691,6 +724,7 @@ async def extract(): def op_deco(func: Callable) -> Op: # Check function type + sig = inspect.signature(func) is_method = _is_unbound_method(func) is_async = inspect.iscoroutinefunction(func) @@ -787,6 +821,11 @@ def _wrapper_generator(): wrapper.signature = sig # type: ignore wrapper.ref = None # type: ignore + nonlocal reducers + nonlocal callbacks + reducers = reducers or [] + callbacks = callbacks or [] + callbacks += [ReducerCallback(reducer) for reducer in reducers] wrapper.lifecycle_handler = LifecycleHandler(callbacks) wrapper.postprocess_inputs = postprocess_inputs # type: ignore From e556722212d615e9b8aeba8afc45cc1883451958 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 19 Nov 2024 21:45:08 -0500 Subject: [PATCH 04/17] simplify --- weave/trace/op.py | 96 ++++++++++++++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 34 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 4ace80e65dbc..a7199c707cfe 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -187,6 +187,7 @@ class LifecycleHandler: def __init__(self, callbacks: list[Callback] | None = None): self.callbacks = callbacks or [] self.intermediate_results: list[Any] = [] + self.has_finished = False def add_callback(self, callback: Callback) -> None: self.callbacks.append(callback) @@ -763,48 +764,75 @@ def _wrapper(__op: Op) -> Any: if is_generator: def _wrapper_generator(): - for val in func(*args, **kwargs): - __op.lifecycle_handler.run_before_yield(call, val) - yield val - __op.lifecycle_handler.run_after_yield(call, val) - __op.lifecycle_handler.run_after_yield_all(call) - # box the result (but how do you do it here?) - # call on_output - # call the on_output_handler (but this is just after_call?) - # call finish + try: + for val in func(*args, **kwargs): + __op.lifecycle_handler.run_before_yield( + call, val + ) + yield val + __op.lifecycle_handler.run_after_yield( + call, val + ) + except Exception as e: + exception = e + if __op.lifecycle_handler.has_finished: + raise ValueError( + "Should not call finish more than once" + ) + __should_raise = True + if __should_raise: + raise + else: + exception = None + __op.lifecycle_handler.run_after_yield_all(call) + finally: + # box the result (but how do you do it here?) + # call on_output + # call the on_output_handler (but this is just after_call?) + # call finish + client.finish_call( + call, + box.box(call.output), + exception=exception, + op=__op, + ) + if not call_context.get_current_call(): + print_call_link(call) + call_context.pop_call(call.id) + + # TODO: may need to wrap this too? + res = _wrapper_generator() + __op.lifecycle_handler.run_after_call(call) + return res, call + else: + try: + res = func(*args, **kwargs) + except Exception as e: + exception = e + res = None + if __op.lifecycle_handler.has_finished: + raise ValueError( + "Should not call finish more than once" + ) + __should_raise = True + if __should_raise: + raise + else: + exception = None + __op.lifecycle_handler.run_after_call(call) + finally: client.finish_call( call, - box.box(call.output), - None, + output=box.box(res), + exception=exception, op=__op, ) if not call_context.get_current_call(): print_call_link(call) - call_context.pop_call(call.id) + return res, call - res = _wrapper_generator() - __op.lifecycle_handler.run_after_call(call) - else: - res = func(*args, **kwargs) - __op.lifecycle_handler.run_after_call(call) - - client.finish_call( - call, - box.box(res), - None, - op=__op, - ) - if not call_context.get_current_call(): - print_call_link(call) - # box the result - # call on_output - # call the on_output_handler - # call finish - call_context.pop_call(call.id) - return res, call - - res, _ = _wrapper(cast(Op, wrapper)) + res, _ = _wrapper(as_op(wrapper)) return res # Tack these helpers on to our wrapper From cc1dbc143ec70e53a6c81d11cf22f41da507a5b9 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 19 Nov 2024 21:58:52 -0500 Subject: [PATCH 05/17] simplify and cleanup --- weave/trace/op.py | 42 ++++++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index a7199c707cfe..86e35b970dc0 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -159,22 +159,26 @@ def after_error(self, call: Call, exc: Exception) -> None: ... Acc = TypeVar("Acc") -class ReducerFunc(Protocol, Generic[T, Acc]): - """Any function that implements this can be automatically converted into a reducer callback.""" - - def __call__(self, val: T, acc: Acc) -> Acc: ... +class Reducer(Protocol, Generic[T, Acc]): + """Any function that implements this can be automatically converted into a reducer callback. + Note that `acc` must have a default value to serve as the initializer!""" -@dataclass -class Reducer(Generic[T, Acc]): - func: ReducerFunc[T, Acc] - initial_acc: Acc + def __call__(self, val: T, acc: Acc) -> Acc: ... class ReducerCallback(Generic[T, Acc]): def __init__(self, reducer: Reducer[T, Acc]): - self.func = reducer.func - self.acc = reducer.initial_acc + self.func = reducer + sig = inspect.signature(reducer) + if not (acc := sig.parameters.get("acc")): + raise ValueError("Reducer must have an 'acc' parameter") + if acc.default is inspect.Parameter.empty: + raise ValueError( + "Reducer 'acc' parameter must have a default value (the initial value)" + ) + + self.acc = acc.default def after_yield(self, call, val): self.acc = self.func(val, self.acc) @@ -751,19 +755,26 @@ def _wrapper(__op: Op) -> Any: parent_call = call_context.get_current_call() attributes = call_attributes.get() + __weave = None + call_time_display_name = ( + __weave.get("display_name") if __weave else None + ) + inputs = inspect.signature(func).bind(*args, **kwargs).arguments + # Instead of creating a call inside here, it should be a dummy # first before optionally getting passed in call = client.create_call( __op, - {"testing": "dict"}, + inputs, parent_call, - display_name="testing", + display_name=call_time_display_name + or __op.call_display_name, attributes=attributes, ) __op.lifecycle_handler.run_before_call({}, None, None, "") if is_generator: - def _wrapper_generator(): + def _generator(): try: for val in func(*args, **kwargs): __op.lifecycle_handler.run_before_yield( @@ -787,8 +798,7 @@ def _wrapper_generator(): __op.lifecycle_handler.run_after_yield_all(call) finally: # box the result (but how do you do it here?) - # call on_output - # call the on_output_handler (but this is just after_call?) + # call the on_output_handler (but this is just after_call, or in the generator case after_yield_all?) # call finish client.finish_call( call, @@ -801,7 +811,7 @@ def _wrapper_generator(): call_context.pop_call(call.id) # TODO: may need to wrap this too? - res = _wrapper_generator() + res = _generator() __op.lifecycle_handler.run_after_call(call) return res, call else: From 2ff4beaf3fe15a2c90c8e67c65989a081a2f327b Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 00:16:47 -0500 Subject: [PATCH 06/17] streamline sync op execution --- weave/trace/op.py | 190 +++++++++++++++++++++++----------------------- 1 file changed, 94 insertions(+), 96 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 86e35b970dc0..709bd5b3148c 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -80,6 +80,9 @@ class DisplayNameFuncError(ValueError): ... +class OpExecutionError(Exception): ... + + def print_call_link(call: Call) -> None: if settings.should_print_call_link(): print(f"{TRACE_CALL_EMOJI} {call.ui_url}") @@ -286,8 +289,8 @@ class Op(Protocol): postprocess_inputs: Callable[[dict[str, Any]], dict[str, Any]] | None postprocess_output: Callable[..., Any] | None - call: Callable[..., Any] - calls: Callable[..., CallsIter] + call: Callable[..., Any] | Coroutine[Any, Any, Any] + calls: Callable[..., CallsIter] | Coroutine[Any, Any, CallsIter] lifecycle_handler: LifecycleHandler @@ -520,6 +523,92 @@ def _placeholder_call() -> Call: ) +def _execute_op( + __op: Op, + *args: Any, + __weave: WeaveKwargs | None = None, + __should_raise: bool = True, + **kwargs: Any, +) -> tuple[Any, Call]: + func = __op.resolve_fn + is_generator = inspect.isgeneratorfunction(func) + call = _placeholder_call() + + # TODO: Add back on_input_handler (maybe part of callback?) + + # Early returns for disabled cases -- no call is created + if settings.should_disable_weave(): + return func(*args, **kwargs), call + elif weave_client_context.get_weave_client() is None: + return func(*args, **kwargs), call + elif not __op._tracing_enabled: + return func(*args, **kwargs), call + + # Setup call context + client = weave_client_context.require_weave_client() + parent_call = call_context.get_current_call() + attributes = call_attributes.get() + + # Create the call + call_time_display_name = __weave.get("display_name") if __weave else None + inputs = inspect.signature(func).bind(*args, **kwargs).arguments + call = client.create_call( + __op, + inputs, + parent_call, + display_name=call_time_display_name or __op.call_display_name, + attributes=attributes, + ) + __op.lifecycle_handler.run_before_call({}, None, None, "") + + if is_generator: + + def _wrapped_sync_generator(): + try: + for val in func(*args, **kwargs): + __op.lifecycle_handler.run_before_yield(call, val) + yield val + __op.lifecycle_handler.run_after_yield(call, val) + except Exception as e: + exception = e + if __should_raise: + raise + else: + exception = None + __op.lifecycle_handler.run_after_yield_all(call) + finally: + if __op.lifecycle_handler.has_finished: + raise OpExecutionError("Should not call finish more than once") + boxed_output = box.box(call.output) + client.finish_call(call, boxed_output, exception=exception, op=__op) + if not call_context.get_current_call(): + print_call_link(call) + call_context.pop_call(call.id) + + res = _wrapped_sync_generator() + else: + # Regular sync function + try: + res = func(*args, **kwargs) + exception = None + except Exception as e: + res = None + exception = e + if __should_raise: + raise + finally: + if __op.lifecycle_handler.has_finished: + raise OpExecutionError("Should not call finish more than once") + boxed_output = box.box(res) + client.finish_call(call, boxed_output, exception=exception, op=__op) + if not call_context.get_current_call(): + print_call_link(call) + call_context.pop_call(call.id) + + __op.lifecycle_handler.run_after_call(call) + return res, call + + def _do_call( op: Op, *args: Any, @@ -749,100 +838,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRe @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - def _wrapper(__op: Op) -> Any: - # This exists only so we can cast the wrapper to an Op - client = weave_client_context.require_weave_client() - parent_call = call_context.get_current_call() - attributes = call_attributes.get() - - __weave = None - call_time_display_name = ( - __weave.get("display_name") if __weave else None - ) - inputs = inspect.signature(func).bind(*args, **kwargs).arguments - - # Instead of creating a call inside here, it should be a dummy - # first before optionally getting passed in - call = client.create_call( - __op, - inputs, - parent_call, - display_name=call_time_display_name - or __op.call_display_name, - attributes=attributes, - ) - __op.lifecycle_handler.run_before_call({}, None, None, "") - if is_generator: - - def _generator(): - try: - for val in func(*args, **kwargs): - __op.lifecycle_handler.run_before_yield( - call, val - ) - yield val - __op.lifecycle_handler.run_after_yield( - call, val - ) - except Exception as e: - exception = e - if __op.lifecycle_handler.has_finished: - raise ValueError( - "Should not call finish more than once" - ) - __should_raise = True - if __should_raise: - raise - else: - exception = None - __op.lifecycle_handler.run_after_yield_all(call) - finally: - # box the result (but how do you do it here?) - # call the on_output_handler (but this is just after_call, or in the generator case after_yield_all?) - # call finish - client.finish_call( - call, - box.box(call.output), - exception=exception, - op=__op, - ) - if not call_context.get_current_call(): - print_call_link(call) - call_context.pop_call(call.id) - - # TODO: may need to wrap this too? - res = _generator() - __op.lifecycle_handler.run_after_call(call) - return res, call - else: - try: - res = func(*args, **kwargs) - except Exception as e: - exception = e - res = None - if __op.lifecycle_handler.has_finished: - raise ValueError( - "Should not call finish more than once" - ) - __should_raise = True - if __should_raise: - raise - else: - exception = None - __op.lifecycle_handler.run_after_call(call) - finally: - client.finish_call( - call, - output=box.box(res), - exception=exception, - op=__op, - ) - if not call_context.get_current_call(): - print_call_link(call) - call_context.pop_call(call.id) - return res, call - - res, _ = _wrapper(as_op(wrapper)) + res, _ = _execute_op( + cast(Op, wrapper), *args, __should_raise=True, **kwargs + ) return res # Tack these helpers on to our wrapper From 86435d3ac7230c768c368accff63b0dbe377f3b4 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 00:43:44 -0500 Subject: [PATCH 07/17] cleanup --- weave/trace/op.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 709bd5b3148c..fd2ef50cf1dd 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -80,9 +80,6 @@ class DisplayNameFuncError(ValueError): ... -class OpExecutionError(Exception): ... - - def print_call_link(call: Call) -> None: if settings.should_print_call_link(): print(f"{TRACE_CALL_EMOJI} {call.ui_url}") @@ -578,7 +575,7 @@ def _wrapped_sync_generator(): __op.lifecycle_handler.run_after_yield_all(call) finally: if __op.lifecycle_handler.has_finished: - raise OpExecutionError("Should not call finish more than once") + raise OpCallError("Should not call finish more than once") boxed_output = box.box(call.output) client.finish_call(call, boxed_output, exception=exception, op=__op) if not call_context.get_current_call(): @@ -598,7 +595,7 @@ def _wrapped_sync_generator(): raise finally: if __op.lifecycle_handler.has_finished: - raise OpExecutionError("Should not call finish more than once") + raise OpCallError("Should not call finish more than once") boxed_output = box.box(res) client.finish_call(call, boxed_output, exception=exception, op=__op) if not call_context.get_current_call(): From ec5b4c47928cef06c306ece608e3ffbd3d757451 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 01:24:33 -0500 Subject: [PATCH 08/17] add basic async support --- weave/trace/op.py | 123 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 121 insertions(+), 2 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index fd2ef50cf1dd..92f455810e23 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -155,6 +155,32 @@ def after_call(self, call: Call) -> None: ... def after_error(self, call: Call, exc: Exception) -> None: ... +class DebugCallback: + def before_call( + self, + inputs: dict, + parent: Call | None, + attributes: dict | None, + display_name: str | Callable[[Call], str], + ) -> None: + print(f"before_call: {inputs} {parent} {attributes} {display_name}") + + def before_yield(self, call: Call, value: Any) -> None: + print(f"before_yield: {call} {value}") + + def after_yield(self, call: Call, value: Any) -> None: + print(f"after_yield: {call} {value}") + + def after_yield_all(self, call: Call) -> None: + print(f"after_yield_all: {call}") + + def after_call(self, call: Call) -> None: + print(f"after_call: {call}") + + def after_error(self, call: Call, exc: Exception) -> None: + print(f"after_error: {call} {exc}") + + T = TypeVar("T") Acc = TypeVar("Acc") @@ -520,6 +546,90 @@ def _placeholder_call() -> Call: ) +async def _execute_op_async( + __op: Op, + *args: Any, + __weave: WeaveKwargs | None = None, + __should_raise: bool = True, + **kwargs: Any, +) -> tuple[Any, Call]: + func = __op.resolve_fn + is_async_generator = inspect.isasyncgenfunction(func) + call = _placeholder_call() + + # TODO: Add back on_input_handler (maybe part of callback?) + + # Early returns for disabled cases -- no call is created + if settings.should_disable_weave(): + return await func(*args, **kwargs), call + elif weave_client_context.get_weave_client() is None: + return await func(*args, **kwargs), call + elif not __op._tracing_enabled: + return await func(*args, **kwargs), call + + # Setup call context + client = weave_client_context.require_weave_client() + parent_call = call_context.get_current_call() + attributes = call_attributes.get() + + # Create the call + call_time_display_name = __weave.get("display_name") if __weave else None + inputs = inspect.signature(func).bind(*args, **kwargs).arguments + call = client.create_call( + __op, + inputs, + parent_call, + display_name=call_time_display_name or __op.call_display_name, + attributes=attributes, + ) + __op.lifecycle_handler.run_before_call({}, None, None, "") + if is_async_generator: + + async def _wrapped_async_generator(): + try: + async for val in func(*args, **kwargs): + __op.lifecycle_handler.run_before_yield(call, val) + yield val + __op.lifecycle_handler.run_after_yield(call, val) + except Exception as e: + exception = e + if __should_raise: + raise + else: + exception = None + __op.lifecycle_handler.run_after_yield_all(call) + finally: + if __op.lifecycle_handler.has_finished: + raise OpCallError("Should not call finish more than once") + boxed_output = box.box(call.output) + client.finish_call(call, boxed_output, exception=exception, op=__op) + if not call_context.get_current_call(): + print_call_link(call) + call_context.pop_call(call.id) + + res = _wrapped_async_generator() + else: + try: + res = await func(*args, **kwargs) + exception = None + except Exception as e: + res = None + exception = e + if __should_raise: + raise + finally: + if __op.lifecycle_handler.has_finished: + raise OpCallError("Should not call finish more than once") + boxed_output = box.box(res) + client.finish_call(call, boxed_output, exception=exception, op=__op) + if not call_context.get_current_call(): + print_call_link(call) + call_context.pop_call(call.id) + + __op.lifecycle_handler.run_after_call(call) + return res, call + + def _execute_op( __op: Op, *args: Any, @@ -819,18 +929,27 @@ def op_deco(func: Callable) -> Op: sig = inspect.signature(func) is_method = _is_unbound_method(func) is_async = inspect.iscoroutinefunction(func) - is_generator = inspect.isgeneratorfunction(func) is_async_generator = inspect.isasyncgenfunction(func) + # TODO: Maybe split out the 4 execute op methods (sync/async, gen/not-gen) def create_wrapper(func: Callable) -> Op: if is_async: @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRedeclaration] - res, _ = await _do_call_async( + res, _ = await _execute_op_async( cast(Op, wrapper), *args, __should_raise=True, **kwargs ) return res + elif is_async_generator: + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRedeclaration] + res, _ = await _execute_op_async( + cast(Op, wrapper), *args, __should_raise=True, **kwargs + ) + async for v in res: + yield v else: @wraps(func) From 89ae6044760d98047126969607996fc8ab8b1b92 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 01:30:32 -0500 Subject: [PATCH 09/17] add basic call support --- weave/trace/op.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 92f455810e23..b65bbe885a46 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -523,12 +523,17 @@ def add(a: int, b: int) -> int: result, call = add.call(1, 2) ``` """ - if inspect.iscoroutinefunction(op.resolve_fn): - return _do_call_async( + func = op.resolve_fn + is_async = inspect.iscoroutinefunction(func) + is_async_generator = inspect.isasyncgenfunction(func) + + if is_async or is_async_generator: + # TODO: This might not be right for async generators + return _execute_op_async( op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs ) else: - return _do_call( + return _execute_op( op, *args, __weave=__weave, __should_raise=__should_raise, **kwargs ) @@ -688,6 +693,9 @@ def _wrapped_sync_generator(): raise OpCallError("Should not call finish more than once") boxed_output = box.box(call.output) client.finish_call(call, boxed_output, exception=exception, op=__op) + # TODO: We choose to print the call link at the end, but for streaming cases it might be helpful to: + # 1. Print the link at the beginning of the stream + # 2. (maybe?): Update the call object periodically with the latest output if not call_context.get_current_call(): print_call_link(call) call_context.pop_call(call.id) From 7f7d149ba7a25e751fff47e76dc302fbb6d720df Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 01:31:31 -0500 Subject: [PATCH 10/17] cleanup old methods --- weave/trace/op.py | 208 ---------------------------------------------- 1 file changed, 208 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index b65bbe885a46..783fcad3825b 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -29,9 +29,7 @@ from weave.trace.context import call_context from weave.trace.context import weave_client_context as weave_client_context from weave.trace.context.call_context import call_attributes -from weave.trace.context.tests_context import get_raise_on_captured_errors from weave.trace.errors import OpCallError -from weave.trace.op_extensions.log_once import log_once from weave.trace.refs import ObjectRef logger = logging.getLogger(__name__) @@ -389,117 +387,6 @@ def default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedIn ) -def _create_call( - func: Op, *args: Any, __weave: WeaveKwargs | None = None, **kwargs: Any -) -> Call: - client = weave_client_context.require_weave_client() - - pargs = None - if func._on_input_handler is not None: - pargs = func._on_input_handler(func, args, kwargs) - if not pargs: - pargs = default_on_input_handler(func, args, kwargs) - inputs_with_defaults = pargs.inputs - - # This should probably be configurable, but for now we redact the api_key - if "api_key" in inputs_with_defaults: - inputs_with_defaults["api_key"] = "REDACTED" - - call_time_display_name = __weave.get("display_name") if __weave else None - - # If/When we do memoization, this would be a good spot - - parent_call = call_context.get_current_call() - attributes = call_attributes.get() - - return client.create_call( - func, - inputs_with_defaults, - parent_call, - # Very important for `call_time_display_name` to take precedence over `func.call_display_name` - display_name=call_time_display_name or func.call_display_name, - attributes=attributes, - ) - - -def _execute_call( - __op: Op, - call: Any, - *args: Any, - __should_raise: bool = True, - **kwargs: Any, -) -> tuple[Any, Call] | Coroutine[Any, Any, tuple[Any, Call]]: - func = __op.resolve_fn - client = weave_client_context.require_weave_client() - has_finished = False - - def finish(output: Any = None, exception: BaseException | None = None) -> None: - nonlocal has_finished - if has_finished: - raise ValueError("Should not call finish more than once") - - client.finish_call( - call, - output, - exception, - op=__op, - ) - if not call_context.get_current_call(): - print_call_link(call) - - def on_output(output: Any) -> Any: - if handler := getattr(__op, "_on_output_handler", None): - return handler(output, finish, call.inputs) - finish(output) - return output - - def process(res: Any) -> tuple[Any, Call]: - res = box.box(res) - try: - # Here we do a try/catch because we don't want to - # break the user process if we trip up on processing - # the output - res = on_output(res) - except Exception: - if get_raise_on_captured_errors(): - raise - log_once(logger.error, ON_OUTPUT_MSG.format(traceback.format_exc())) - finally: - # Is there a better place for this? We want to ensure that even - # if the final output fails to be captured, we still pop the call - # so we don't put future calls under the old call. - call_context.pop_call(call.id) - - return res, call - - def handle_exception(e: Exception) -> tuple[Any, Call]: - finish(exception=e) - if __should_raise: - raise - return None, call - - if inspect.iscoroutinefunction(func): - - async def _call_async() -> tuple[Any, Call]: - try: - res = await func(*args, **kwargs) - except Exception as e: - return handle_exception(e) - else: - return process(res) - - return _call_async() - - try: - res = func(*args, **kwargs) - except Exception as e: - handle_exception(e) - else: - return process(res) - - return None, call - - def call( op: Op, *args: Any, @@ -724,101 +611,6 @@ def _wrapped_sync_generator(): return res, call -def _do_call( - op: Op, - *args: Any, - __weave: WeaveKwargs | None = None, - __should_raise: bool = False, - **kwargs: Any, -) -> tuple[Any, Call]: - func = op.resolve_fn - call = _placeholder_call() - - pargs = None - if op._on_input_handler is not None: - pargs = op._on_input_handler(op, args, kwargs) - if not pargs: - pargs = default_on_input_handler(op, args, kwargs) - - if settings.should_disable_weave(): - res = func(*pargs.args, **pargs.kwargs) - elif weave_client_context.get_weave_client() is None: - res = func(*pargs.args, **pargs.kwargs) - elif not op._tracing_enabled: - res = func(*pargs.args, **pargs.kwargs) - else: - try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = func(*pargs.args, **pargs.kwargs) - else: - execute_result = _execute_call( - op, call, *pargs.args, __should_raise=__should_raise, **pargs.kwargs - ) - if inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a sync result" - ) - execute_result = cast(tuple[Any, "Call"], execute_result) - res, call = execute_result - return res, call - - -async def _do_call_async( - op: Op, - *args: Any, - __weave: WeaveKwargs | None = None, - __should_raise: bool = False, - **kwargs: Any, -) -> tuple[Any, Call]: - func = op.resolve_fn - call = _placeholder_call() - if settings.should_disable_weave(): - res = await func(*args, **kwargs) - elif weave_client_context.get_weave_client() is None: - res = await func(*args, **kwargs) - elif not op._tracing_enabled: - res = await func(*args, **kwargs) - else: - try: - # This try/except allows us to fail gracefully and - # still let the user code continue to execute - call = _create_call(op, *args, __weave=__weave, **kwargs) - except OpCallError as e: - raise e - except Exception as e: - if get_raise_on_captured_errors(): - raise - log_once( - logger.error, - ASYNC_CALL_CREATE_MSG.format(traceback.format_exc()), - ) - res = await func(*args, **kwargs) - else: - execute_result = _execute_call( - op, call, *args, __should_raise=__should_raise, **kwargs - ) - if not inspect.iscoroutine(execute_result): - raise TypeError( - "Internal error: Expected `_execute_call` to return a coroutine" - ) - execute_result = cast( - Coroutine[Any, Any, tuple[Any, "Call"]], execute_result - ) - res, call = await execute_result - return res, call - - def calls(op: Op) -> CallsIter: """ Get an iterator over all calls to this op. From d0a34d100f24bc8b9af7f7004b8c3cb2235ce994 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 01:42:19 -0500 Subject: [PATCH 11/17] tidy --- weave/trace/op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 783fcad3825b..dbf3fd4b9f4a 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -527,6 +527,7 @@ def _execute_op( *args: Any, __weave: WeaveKwargs | None = None, __should_raise: bool = True, + __should_accumulate: bool = True, **kwargs: Any, ) -> tuple[Any, Call]: func = __op.resolve_fn @@ -560,7 +561,7 @@ def _execute_op( ) __op.lifecycle_handler.run_before_call({}, None, None, "") - if is_generator: + if is_generator and __should_accumulate: def _wrapped_sync_generator(): try: From 092c3d26af26192e8e8ea375e2ccacc4da5da6c3 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 03:39:20 -0500 Subject: [PATCH 12/17] add __should_accumulate escape hatch for integrations --- weave/trace/op.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index dbf3fd4b9f4a..1219f57189e4 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -443,6 +443,7 @@ async def _execute_op_async( *args: Any, __weave: WeaveKwargs | None = None, __should_raise: bool = True, + __should_accumulate: bool = False, **kwargs: Any, ) -> tuple[Any, Call]: func = __op.resolve_fn @@ -475,8 +476,9 @@ async def _execute_op_async( attributes=attributes, ) __op.lifecycle_handler.run_before_call({}, None, None, "") - if is_async_generator: + if is_async_generator or __should_accumulate: + @wraps(func) async def _wrapped_async_generator(): try: async for val in func(*args, **kwargs): @@ -527,7 +529,7 @@ def _execute_op( *args: Any, __weave: WeaveKwargs | None = None, __should_raise: bool = True, - __should_accumulate: bool = True, + __should_accumulate: bool = False, **kwargs: Any, ) -> tuple[Any, Call]: func = __op.resolve_fn @@ -561,8 +563,9 @@ def _execute_op( ) __op.lifecycle_handler.run_before_call({}, None, None, "") - if is_generator and __should_accumulate: + if is_generator or __should_accumulate: + @wraps(func) def _wrapped_sync_generator(): try: for val in func(*args, **kwargs): @@ -648,6 +651,7 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, + __should_accumulate: bool = False, ) -> Op: ... @@ -660,6 +664,7 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, + __should_accumulate: bool = False, ) -> Callable[[Callable], Op]: ... @@ -672,6 +677,7 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, + __should_accumulate: bool = False, ) -> Callable[[Callable], Op] | Op: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -756,7 +762,11 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRe @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: res, _ = _execute_op( - cast(Op, wrapper), *args, __should_raise=True, **kwargs + cast(Op, wrapper), + *args, + __should_raise=True, + __should_accumulate=__should_accumulate, + **kwargs, ) return res From 65a9d97d5706bb7e7bbe413a598fdb6deefa9bc6 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 11:28:41 -0500 Subject: [PATCH 13/17] add should_accumulate escape hatch --- weave/trace/op.py | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 1219f57189e4..ca5290cd4e76 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -155,13 +155,9 @@ def after_error(self, call: Call, exc: Exception) -> None: ... class DebugCallback: def before_call( - self, - inputs: dict, - parent: Call | None, - attributes: dict | None, - display_name: str | Callable[[Call], str], + self, inputs: dict, parent: Call | None, attributes: dict | None ) -> None: - print(f"before_call: {inputs} {parent} {attributes} {display_name}") + print(f"before_call: {inputs} {parent} {attributes}") def before_yield(self, call: Call, value: Any) -> None: print(f"before_yield: {call} {value}") @@ -221,15 +217,11 @@ def add_callback(self, callback: Callback) -> None: self.callbacks.append(callback) def run_before_call( - self, - inputs: dict, - parent: Call | None, - attributes: dict | None, - display_name: str | Callable[[Call], str], + self, inputs: dict, parent: Call | None, attributes: dict | None ) -> None: for callback in self.callbacks: if hasattr(callback, "before_call"): - callback.before_call(inputs, parent, attributes, display_name) + callback.before_call(inputs, parent, attributes) def run_before_yield(self, call: Call, value: Any) -> Any: for callback in self.callbacks: @@ -468,6 +460,8 @@ async def _execute_op_async( # Create the call call_time_display_name = __weave.get("display_name") if __weave else None inputs = inspect.signature(func).bind(*args, **kwargs).arguments + __op.lifecycle_handler.run_before_call(inputs, parent_call, attributes) + call = client.create_call( __op, inputs, @@ -475,8 +469,9 @@ async def _execute_op_async( display_name=call_time_display_name or __op.call_display_name, attributes=attributes, ) - __op.lifecycle_handler.run_before_call({}, None, None, "") - if is_async_generator or __should_accumulate: + + should_accumulate = __should_accumulate and __should_accumulate(call) + if is_async_generator or should_accumulate: @wraps(func) async def _wrapped_async_generator(): @@ -524,6 +519,13 @@ async def _wrapped_async_generator(): return res, call +def _is_context_manager(obj: Any) -> bool: + __enter__ = getattr(obj, "__enter__", None) + __exit__ = getattr(obj, "__exit__", None) + + return callable(__enter__) and callable(__exit__) + + def _execute_op( __op: Op, *args: Any, @@ -563,7 +565,8 @@ def _execute_op( ) __op.lifecycle_handler.run_before_call({}, None, None, "") - if is_generator or __should_accumulate: + should_accumulate = __should_accumulate and __should_accumulate(call) + if is_generator or should_accumulate: @wraps(func) def _wrapped_sync_generator(): @@ -651,7 +654,8 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, - __should_accumulate: bool = False, + __should_accumulate: Callable[[Call], bool] + | None = None, # escape hatch for integrations ) -> Op: ... @@ -664,7 +668,8 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, - __should_accumulate: bool = False, + __should_accumulate: Callable[[Call], bool] + | None = None, # escape hatch for integrations ) -> Callable[[Callable], Op]: ... @@ -677,7 +682,8 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, - __should_accumulate: bool = False, + __should_accumulate: Callable[[Call], bool] + | None = None, # escape hatch for integrations ) -> Callable[[Callable], Op] | Op: """ A decorator to weave op-ify a function or method. Works for both sync and async. From 77f6c2937878bb8b1ffbdcb205a1d315b65ef393 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 11:28:57 -0500 Subject: [PATCH 14/17] add should_accumulate escape hatch --- weave/trace/op.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index ca5290cd4e76..7a6b4677e1de 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -140,11 +140,7 @@ class WeaveKwargs(TypedDict): class Callback(Protocol): def before_call( - self, - inputs: dict, - parent: Call | None, - attributes: dict | None, - display_name: str | Callable[[Call], str], + self, inputs: dict, parent: Call | None, attributes: dict | None ) -> None: ... def before_yield(self, call: Call, value: Any) -> None: ... def after_yield(self, call: Call, value: Any) -> None: ... From cf5f76b56cffd794701e1ad0be0255380596242f Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 11:57:50 -0500 Subject: [PATCH 15/17] 1 more escape hatch --- weave/trace/op.py | 93 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 68 insertions(+), 25 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 7a6b4677e1de..4678a863ea92 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -7,6 +7,7 @@ import sys import traceback from collections.abc import Coroutine, Mapping +from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps from types import MethodType @@ -519,6 +520,8 @@ def _is_context_manager(obj: Any) -> bool: __enter__ = getattr(obj, "__enter__", None) __exit__ = getattr(obj, "__exit__", None) + print(__enter__, __exit__) + return callable(__enter__) and callable(__exit__) @@ -528,6 +531,7 @@ def _execute_op( __weave: WeaveKwargs | None = None, __should_raise: bool = True, __should_accumulate: bool = False, + __should_use_contextmanager: bool = False, **kwargs: Any, ) -> tuple[Any, Call]: func = __op.resolve_fn @@ -559,38 +563,79 @@ def _execute_op( display_name=call_time_display_name or __op.call_display_name, attributes=attributes, ) - __op.lifecycle_handler.run_before_call({}, None, None, "") + __op.lifecycle_handler.run_before_call(inputs, parent_call, attributes) should_accumulate = __should_accumulate and __should_accumulate(call) + should_use_contextmanager = ( + __should_use_contextmanager and __should_use_contextmanager(func) + ) + if is_generator or should_accumulate: + # TODO: Not sure if this is the right pattern. This hack is only + # used to support the Anthropic streaming context streaming case atm... + if should_use_contextmanager: - @wraps(func) - def _wrapped_sync_generator(): - try: - for val in func(*args, **kwargs): - __op.lifecycle_handler.run_before_yield(call, val) - yield val - __op.lifecycle_handler.run_after_yield(call, val) - except Exception as e: - exception = e - if __should_raise: - raise - else: + @contextmanager + @wraps(func) + def _wrapped_context_manager_yields_sync_generator(): exception = None - __op.lifecycle_handler.run_after_yield_all(call) - finally: + with func(*args, **kwargs) as original_gen: + + def _wrapped_sync_generator(): + nonlocal exception + try: + for val in original_gen: + __op.lifecycle_handler.run_before_yield(call, val) + yield val + __op.lifecycle_handler.run_after_yield(call, val) + except Exception as e: + exception = e + if __should_raise: + raise + else: + __op.lifecycle_handler.run_after_yield_all(call) + + yield _wrapped_sync_generator() + if __op.lifecycle_handler.has_finished: raise OpCallError("Should not call finish more than once") boxed_output = box.box(call.output) client.finish_call(call, boxed_output, exception=exception, op=__op) - # TODO: We choose to print the call link at the end, but for streaming cases it might be helpful to: - # 1. Print the link at the beginning of the stream - # 2. (maybe?): Update the call object periodically with the latest output if not call_context.get_current_call(): print_call_link(call) call_context.pop_call(call.id) - res = _wrapped_sync_generator() + res = _wrapped_context_manager_yields_sync_generator() + + else: + + @wraps(func) + def _wrapped_sync_generator(): + try: + for val in func(*args, **kwargs): + __op.lifecycle_handler.run_before_yield(call, val) + yield val + __op.lifecycle_handler.run_after_yield(call, val) + except Exception as e: + exception = e + if __should_raise: + raise + else: + exception = None + __op.lifecycle_handler.run_after_yield_all(call) + finally: + if __op.lifecycle_handler.has_finished: + raise OpCallError("Should not call finish more than once") + boxed_output = box.box(call.output) + client.finish_call(call, boxed_output, exception=exception, op=__op) + # TODO: We choose to print the call link at the end, but for streaming cases it might be helpful to: + # 1. Print the link at the beginning of the stream + # 2. (maybe?): Update the call object periodically with the latest output + if not call_context.get_current_call(): + print_call_link(call) + call_context.pop_call(call.id) + + res = _wrapped_sync_generator() else: # Regular sync function try: @@ -650,8 +695,6 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, - __should_accumulate: Callable[[Call], bool] - | None = None, # escape hatch for integrations ) -> Op: ... @@ -664,8 +707,6 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, - __should_accumulate: Callable[[Call], bool] - | None = None, # escape hatch for integrations ) -> Callable[[Callable], Op]: ... @@ -678,8 +719,9 @@ def op( postprocess_output: PostprocessOutputFunc | None = None, callbacks: list[Callback] | None = None, reducers: list[Reducer] | None = None, - __should_accumulate: Callable[[Call], bool] - | None = None, # escape hatch for integrations + # escape hatch for integrations + __should_accumulate: Callable[[Call], bool] | None = None, + __should_use_contextmanager: Callable[[Callable], bool] | None = None, ) -> Callable[[Callable], Op] | Op: """ A decorator to weave op-ify a function or method. Works for both sync and async. @@ -768,6 +810,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: *args, __should_raise=True, __should_accumulate=__should_accumulate, + __should_use_contextmanager=__should_use_contextmanager, **kwargs, ) return res From 89ed434deb77cc64d6e1163ca082b8edd19df30e Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 20 Nov 2024 22:00:29 -0500 Subject: [PATCH 16/17] test --- weave/trace/op.py | 458 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 377 insertions(+), 81 deletions(-) diff --git a/weave/trace/op.py b/weave/trace/op.py index 4678a863ea92..04152ca6d62d 100644 --- a/weave/trace/op.py +++ b/weave/trace/op.py @@ -7,7 +7,6 @@ import sys import traceback from collections.abc import Coroutine, Mapping -from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps from types import MethodType @@ -139,17 +138,6 @@ class WeaveKwargs(TypedDict): display_name: str | None -class Callback(Protocol): - def before_call( - self, inputs: dict, parent: Call | None, attributes: dict | None - ) -> None: ... - def before_yield(self, call: Call, value: Any) -> None: ... - def after_yield(self, call: Call, value: Any) -> None: ... - def after_yield_all(self, call: Call) -> None: ... - def after_call(self, call: Call) -> None: ... - def after_error(self, call: Call, exc: Exception) -> None: ... - - class DebugCallback: def before_call( self, inputs: dict, parent: Call | None, attributes: dict | None @@ -176,6 +164,17 @@ def after_error(self, call: Call, exc: Exception) -> None: Acc = TypeVar("Acc") +class Callback(Protocol): + def before_call( + self, inputs: dict, parent: Call | None, attributes: dict | None + ) -> None: ... + def before_yield(self, call: Call, value: Any) -> None: ... + def after_yield(self, call: Call, value: Any) -> None: ... + def after_yield_all(self, call: Call) -> None: ... + def after_call(self, call: Call) -> None: ... + def after_error(self, call: Call, exc: Exception) -> None: ... + + class Reducer(Protocol, Generic[T, Acc]): """Any function that implements this can be automatically converted into a reducer callback. @@ -197,10 +196,10 @@ def __init__(self, reducer: Reducer[T, Acc]): self.acc = acc.default - def after_yield(self, call, val): + def after_yield(self, call: Call, val: T) -> None: self.acc = self.func(val, self.acc) - def after_yield_all(self, call): + def after_yield_all(self, call: Call) -> None: call.output = self.acc @@ -432,7 +431,8 @@ async def _execute_op_async( *args: Any, __weave: WeaveKwargs | None = None, __should_raise: bool = True, - __should_accumulate: bool = False, + __should_accumulate: Callable[[Call], bool] | None = None, + __should_use_contextmanager: Callable[[Callable], bool] | None = None, **kwargs: Any, ) -> tuple[Any, Call]: func = __op.resolve_fn @@ -468,32 +468,125 @@ async def _execute_op_async( ) should_accumulate = __should_accumulate and __should_accumulate(call) + should_use_contextmanager = ( + __should_use_contextmanager and __should_use_contextmanager(func) + ) if is_async_generator or should_accumulate: + if should_use_contextmanager: - @wraps(func) - async def _wrapped_async_generator(): - try: - async for val in func(*args, **kwargs): - __op.lifecycle_handler.run_before_yield(call, val) - yield val - __op.lifecycle_handler.run_after_yield(call, val) - except Exception as e: - exception = e - if __should_raise: - raise - else: - exception = None - __op.lifecycle_handler.run_after_yield_all(call) - finally: - if __op.lifecycle_handler.has_finished: - raise OpCallError("Should not call finish more than once") - boxed_output = box.box(call.output) - client.finish_call(call, boxed_output, exception=exception, op=__op) - if not call_context.get_current_call(): - print_call_link(call) - call_context.pop_call(call.id) - - res = _wrapped_async_generator() + class AsyncGeneratorContextManager: + def __init__(self, func, args, kwargs, op, call, should_raise): + self.func = func + self.args = args + self.kwargs = kwargs + self.op = op + self.call = call + self.should_raise = should_raise + self.exception = None + + async def __aenter__(self): + self.orig_gen = await func(*self.args, **self.kwargs).__aenter__() + return self._wrapped_async_generator() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.op.lifecycle_handler.has_finished: + raise OpCallError("Should not call finish more than once") + + boxed_output = box.box(self.call.output) + client = weave_client_context.require_weave_client() + client.finish_call( + self.call, boxed_output, exception=self.exception, op=self.op + ) + + if not call_context.get_current_call(): + print_call_link(self.call) + call_context.pop_call(self.call.id) + + return self.orig_gen.__aexit__(exc_type, exc_val, exc_tb) + + async def _wrapped_async_generator(self): + try: + async for val in self.orig_gen: + __op.lifecycle_handler.run_before_yield(call, val) + yield val + __op.lifecycle_handler.run_after_yield(call, val) + except Exception as e: + self.exception = e + if self.should_raise: + raise + else: + __op.lifecycle_handler.run_after_yield_all(call) + + def __getattr__(self, name): + # Forward unknown attrs to the original generator + if name not in { + "__aenter__", + "__aexit__", + "_wrapped_async_generator", + }: + return getattr(self.orig_gen, name) + + res = AsyncGeneratorContextManager( + func, args, kwargs, __op, call, __should_raise + ) + + # @asynccontextmanager + # @wraps(func) + # async def _wrapped_context_manager_yields_async_generator(): + # exception = None + # original_gen = await func(*args, **kwargs) + + # async def _wrapped_async_generator(): + # nonlocal exception + # try: + # async for val in original_gen: + # __op.lifecycle_handler.run_before_yield(call, val) + # yield val + # __op.lifecycle_handler.run_after_yield(call, val) + # except Exception as e: + # exception = e + # if __should_raise: + # raise + # else: + # __op.lifecycle_handler.run_after_yield_all(call) + + # yield _wrapped_async_generator() + + # if __op.lifecycle_handler.has_finished: + # raise OpCallError("Should not call finish more than once") + # boxed_output = box.box(call.output) + # client.finish_call(call, boxed_output, exception=exception, op=__op) + # if not call_context.get_current_call(): + # print_call_link(call) + # call_context.pop_call(call.id) + + # res = _wrapped_context_manager_yields_async_generator() + else: + + @wraps(func) + async def _wrapped_async_generator(): + try: + async for val in await func(*args, **kwargs): + __op.lifecycle_handler.run_before_yield(call, val) + yield val + __op.lifecycle_handler.run_after_yield(call, val) + except Exception as e: + exception = e + if __should_raise: + raise + else: + exception = None + __op.lifecycle_handler.run_after_yield_all(call) + finally: + if __op.lifecycle_handler.has_finished: + raise OpCallError("Should not call finish more than once") + boxed_output = box.box(call.output) + client.finish_call(call, boxed_output, exception=exception, op=__op) + if not call_context.get_current_call(): + print_call_link(call) + call_context.pop_call(call.id) + + res = _wrapped_async_generator() else: try: res = await func(*args, **kwargs) @@ -516,15 +609,6 @@ async def _wrapped_async_generator(): return res, call -def _is_context_manager(obj: Any) -> bool: - __enter__ = getattr(obj, "__enter__", None) - __exit__ = getattr(obj, "__exit__", None) - - print(__enter__, __exit__) - - return callable(__enter__) and callable(__exit__) - - def _execute_op( __op: Op, *args: Any, @@ -574,38 +658,140 @@ def _execute_op( # TODO: Not sure if this is the right pattern. This hack is only # used to support the Anthropic streaming context streaming case atm... if should_use_contextmanager: + # class SyncGeneratorContextManager: + # def __init__(self, func, args, kwargs, op, call, should_raise): + # self.func = func + # self.args = args + # self.kwargs = kwargs + # self.op = op + # self.call = call + # self.should_raise = should_raise + # self.exception = None + + # async def __aenter__(self): + # print("Entering aenter") + # self.orig_context = func(*self.args, **self.kwargs) + # self.orig_gen = await self.orig_context.__aenter__() + # self.gen = self._wrapped_async_generator() + # return self.gen + + # async def __aexit__(self, exc_type, exc_val, exc_tb): + # print("Entering aexit") + # if self.op.lifecycle_handler.has_finished: + # raise OpCallError("Should not call finish more than once") + + # boxed_output = box.box(self.call.output) + # client = weave_client_context.require_weave_client() + # client.finish_call( + # self.call, boxed_output, exception=self.exception, op=self.op + # ) + + # if not call_context.get_current_call(): + # print_call_link(self.call) + # call_context.pop_call(self.call.id) + + # return await self.orig_context.__aexit__(exc_type, exc_val, exc_tb) + + # async def _wrapped_async_generator(self): + # print("Entering wrapped async generator") + # try: + # async for val in self.orig_gen: + # self.op.lifecycle_handler.run_before_yield(call, val) + # yield val + # self.op.lifecycle_handler.run_after_yield(call, val) + # except Exception as e: + # self.exception = e + # if self.should_raise: + # raise + # else: + # self.op.lifecycle_handler.run_after_yield_all(call) + + # def __enter__(self): + # print("Entering enter") + # self.orig_gen = func(*self.args, **self.kwargs).__enter__() + # return self._wrapped_sync_generator() + + # def __exit__(self, exc_type, exc_val, exc_tb): + # print("Entering exit") + # if self.op.lifecycle_handler.has_finished: + # raise OpCallError("Should not call finish more than once") + + # boxed_output = box.box(self.call.output) + # client = weave_client_context.require_weave_client() + # client.finish_call( + # self.call, boxed_output, exception=self.exception, op=self.op + # ) + + # if not call_context.get_current_call(): + # print_call_link(self.call) + # call_context.pop_call(self.call.id) + + # return self.orig_gen.__exit__(exc_type, exc_val, exc_tb) + + # def _wrapped_sync_generator(self): + # print("Entering wrapped sync generator") + # try: + # for val in self.orig_gen: + # self.op.lifecycle_handler.run_before_yield(self.call, val) + # yield val + # self.op.lifecycle_handler.run_after_yield(self.call, val) + # except Exception as e: + # self.exception = e + # if self.should_raise: + # raise + # else: + # self.op.lifecycle_handler.run_after_yield_all(self.call) + + # def __getattr__(self, name): + # print(f"Getting {name}") + # # Forward unknown attrs to the original generator + # if name not in { + # "__enter__", + # "__aenter__", + # "__exit__", + # "__aexit__", + # "_wrapped_sync_generator", + # "_wrapped_async_generator", + # }: + # print(f"Forwarding {name} to original context") + # return getattr(self.orig_context, name) + # return getattr(self, name) + + res = SyncGeneratorContextManager( + func, args, kwargs, __op, call, __should_raise + ) - @contextmanager - @wraps(func) - def _wrapped_context_manager_yields_sync_generator(): - exception = None - with func(*args, **kwargs) as original_gen: - - def _wrapped_sync_generator(): - nonlocal exception - try: - for val in original_gen: - __op.lifecycle_handler.run_before_yield(call, val) - yield val - __op.lifecycle_handler.run_after_yield(call, val) - except Exception as e: - exception = e - if __should_raise: - raise - else: - __op.lifecycle_handler.run_after_yield_all(call) - - yield _wrapped_sync_generator() - - if __op.lifecycle_handler.has_finished: - raise OpCallError("Should not call finish more than once") - boxed_output = box.box(call.output) - client.finish_call(call, boxed_output, exception=exception, op=__op) - if not call_context.get_current_call(): - print_call_link(call) - call_context.pop_call(call.id) - - res = _wrapped_context_manager_yields_sync_generator() + # @contextmanager + # @wraps(func) + # def _wrapped_context_manager_yields_sync_generator(): + # exception = None + # with func(*args, **kwargs) as original_gen: + + # def _wrapped_sync_generator(): + # nonlocal exception + # try: + # for val in original_gen: + # __op.lifecycle_handler.run_before_yield(call, val) + # yield val + # __op.lifecycle_handler.run_after_yield(call, val) + # except Exception as e: + # exception = e + # if __should_raise: + # raise + # else: + # __op.lifecycle_handler.run_after_yield_all(call) + + # yield _wrapped_sync_generator() + + # if __op.lifecycle_handler.has_finished: + # raise OpCallError("Should not call finish more than once") + # boxed_output = box.box(call.output) + # client.finish_call(call, boxed_output, exception=exception, op=__op) + # if not call_context.get_current_call(): + # print_call_link(call) + # call_context.pop_call(call.id) + + # res = _wrapped_context_manager_yields_sync_generator() else: @@ -789,7 +975,12 @@ def create_wrapper(func: Callable) -> Op: @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRedeclaration] res, _ = await _execute_op_async( - cast(Op, wrapper), *args, __should_raise=True, **kwargs + cast(Op, wrapper), + *args, + __should_raise=True, + __should_accumulate=__should_accumulate, + __should_use_contextmanager=__should_use_contextmanager, + **kwargs, ) return res elif is_async_generator: @@ -797,7 +988,12 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRe @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: # pyright: ignore[reportRedeclaration] res, _ = await _execute_op_async( - cast(Op, wrapper), *args, __should_raise=True, **kwargs + cast(Op, wrapper), + *args, + __should_raise=True, + __should_accumulate=__should_accumulate, + __should_use_contextmanager=__should_use_contextmanager, + **kwargs, ) async for v in res: yield v @@ -928,4 +1124,104 @@ def as_op(fn: Callable) -> Op: return cast(Op, fn) +class SyncGeneratorContextManager: + def __init__(self, func, args, kwargs, op, call, should_raise): + self.func = func + self.args = args + self.kwargs = kwargs + self.op = op + self.call = call + self.should_raise = should_raise + self.exception = None + + async def __aenter__(self): + print("Entering aenter") + self.orig_context = self.func(*self.args, **self.kwargs) + self.orig_gen = await self.orig_context.__aenter__() + self.gen = self._wrapped_async_generator() + return self.gen + + async def __aexit__(self, exc_type, exc_val, exc_tb): + print("Entering aexit") + if self.op.lifecycle_handler.has_finished: + raise OpCallError("Should not call finish more than once") + + boxed_output = box.box(self.call.output) + client = weave_client_context.require_weave_client() + client.finish_call( + self.call, boxed_output, exception=self.exception, op=self.op + ) + + if not call_context.get_current_call(): + print_call_link(self.call) + call_context.pop_call(self.call.id) + + return await self.orig_context.__aexit__(exc_type, exc_val, exc_tb) + + async def _wrapped_async_generator(self): + print("Entering wrapped async generator") + try: + async for val in self.orig_gen: + self.op.lifecycle_handler.run_before_yield(call, val) + yield val + self.op.lifecycle_handler.run_after_yield(call, val) + except Exception as e: + self.exception = e + if self.should_raise: + raise + else: + self.op.lifecycle_handler.run_after_yield_all(call) + + def __enter__(self): + print("Entering enter") + self.orig_gen = self.func(*self.args, **self.kwargs).__enter__() + return self._wrapped_sync_generator() + + def __exit__(self, exc_type, exc_val, exc_tb): + print("Entering exit") + if self.op.lifecycle_handler.has_finished: + raise OpCallError("Should not call finish more than once") + + boxed_output = box.box(self.call.output) + client = weave_client_context.require_weave_client() + client.finish_call( + self.call, boxed_output, exception=self.exception, op=self.op + ) + + if not call_context.get_current_call(): + print_call_link(self.call) + call_context.pop_call(self.call.id) + + return self.orig_gen.__exit__(exc_type, exc_val, exc_tb) + + def _wrapped_sync_generator(self): + print("Entering wrapped sync generator") + try: + for val in self.orig_gen: + self.op.lifecycle_handler.run_before_yield(self.call, val) + yield val + self.op.lifecycle_handler.run_after_yield(self.call, val) + except Exception as e: + self.exception = e + if self.should_raise: + raise + else: + self.op.lifecycle_handler.run_after_yield_all(self.call) + + def __getattr__(self, name): + print(f"Getting {name}") + # Forward unknown attrs to the original generator + if name not in { + "__enter__", + "__aenter__", + "__exit__", + "__aexit__", + "_wrapped_sync_generator", + "_wrapped_async_generator", + }: + print(f"Forwarding {name} to original context") + return getattr(self.orig_context, name) + return getattr(self, name) + + __docspec__ = [call, calls] From ba5ff393352663ee0b2c1d116bf3833cf46910bd Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Thu, 21 Nov 2024 10:49:03 -0500 Subject: [PATCH 17/17] wip --- .../integrations/anthropic/anthropic_test.py | 4 + weave/integrations/anthropic/anthropic_sdk.py | 240 ++++++++---------- 2 files changed, 103 insertions(+), 141 deletions(-) diff --git a/tests/integrations/anthropic/anthropic_test.py b/tests/integrations/anthropic/anthropic_test.py index bfba91e4411e..935dbee84c82 100644 --- a/tests/integrations/anthropic/anthropic_test.py +++ b/tests/integrations/anthropic/anthropic_test.py @@ -174,6 +174,10 @@ async def test_async_anthropic_stream( assert call.exception is None and call.ended_at is not None output = call.output + + print(f"{output=}") + print(f"{message=}") + assert output.id == message.id assert output.model == message.model assert output.stop_reason == "end_turn" diff --git a/weave/integrations/anthropic/anthropic_sdk.py b/weave/integrations/anthropic/anthropic_sdk.py index 6e33e1a39062..a052a669ffd8 100644 --- a/weave/integrations/anthropic/anthropic_sdk.py +++ b/weave/integrations/anthropic/anthropic_sdk.py @@ -1,89 +1,97 @@ +from __future__ import annotations + import importlib -from collections.abc import AsyncIterator, Iterator from functools import wraps -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Optional, - Union, -) +from typing import Any, Callable + +from anthropic import MessageStopEvent import weave -from weave.trace.op_extensions.accumulator import _IteratorWrapper, add_accumulator +from weave.trace.op import Callback from weave.trace.patcher import MultiPatcher, SymbolPatcher +from weave.trace.weave_client import Call + + +def should_accumulate(call: Call) -> bool: + return bool(call.inputs.get("stream")) + + +class AnthropicCallback(Callback): + def __init__(self): + self.acc = None + + def after_yield(self, call: Call, value: Any) -> None: + from anthropic.types import ( + ContentBlockDeltaEvent, + Message, + MessageDeltaEvent, + TextBlock, + Usage, + ) + + print(f"{value=}, {self.acc=}") -if TYPE_CHECKING: - from anthropic.lib.streaming import MessageStream - from anthropic.types import Message, MessageStreamEvent - - -def anthropic_accumulator( - acc: Optional["Message"], - value: "MessageStreamEvent", -) -> "Message": - from anthropic.types import ( - ContentBlockDeltaEvent, - Message, - MessageDeltaEvent, - TextBlock, - Usage, - ) - - if acc is None: - if hasattr(value, "message"): - acc = Message( + if self.acc is None: + if not hasattr(value, "message"): + raise ValueError("Initial event must contain a message") + self.acc = Message( id=value.message.id, role=value.message.role, content=[], model=value.message.model, stop_reason=value.message.stop_reason, stop_sequence=value.message.stop_sequence, - type=value.message.type, # Include the type field + type=value.message.type, usage=Usage(input_tokens=0, output_tokens=0), ) - else: - raise ValueError("Initial event must contain a message") - # Merge in the usage info if available - if hasattr(value, "message") and value.message.usage is not None: - acc.usage.input_tokens += value.message.usage.input_tokens + # Merge in the usage info if available + if hasattr(value, "message") and value.message.usage is not None: + self.acc.usage.input_tokens += value.message.usage.input_tokens - # Accumulate the content if it's a ContentBlockDeltaEvent - if isinstance(value, ContentBlockDeltaEvent) and hasattr(value.delta, "text"): - if acc.content and isinstance(acc.content[-1], TextBlock): - acc.content[-1].text += value.delta.text - else: - acc.content.append(TextBlock(type="text", text=value.delta.text)) + # Accumulate the content if it's a ContentBlockDeltaEvent + if isinstance(value, ContentBlockDeltaEvent) and hasattr(value.delta, "text"): + if self.acc.content and isinstance(self.acc.content[-1], TextBlock): + self.acc.content[-1].text += value.delta.text + else: + self.acc.content.append(TextBlock(type="text", text=value.delta.text)) - # Handle MessageDeltaEvent for stop_reason and stop_sequence - if isinstance(value, MessageDeltaEvent): - if hasattr(value.delta, "stop_reason") and value.delta.stop_reason: - acc.stop_reason = value.delta.stop_reason - if hasattr(value.delta, "stop_sequence") and value.delta.stop_sequence: - acc.stop_sequence = value.delta.stop_sequence - if hasattr(value, "usage") and value.usage.output_tokens: - acc.usage.output_tokens = value.usage.output_tokens + # Handle MessageDeltaEvent for stop_reason and stop_sequence + if isinstance(value, MessageDeltaEvent): + if hasattr(value.delta, "stop_reason") and value.delta.stop_reason: + self.acc.stop_reason = value.delta.stop_reason + if hasattr(value.delta, "stop_sequence") and value.delta.stop_sequence: + self.acc.stop_sequence = value.delta.stop_sequence + if hasattr(value, "usage") and value.usage.output_tokens: + self.acc.usage.output_tokens = value.usage.output_tokens - return acc + def after_yield_all(self, call: Call) -> None: + call.output = self.acc -# Unlike other integrations, streaming is based on input flag -def should_use_accumulator(inputs: dict) -> bool: - return isinstance(inputs, dict) and bool(inputs.get("stream")) +class AnthropicStreamingCallback: + def __init__(self): + self.acc = None + def after_yield(self, call: Call, value: Any) -> None: + print(f"{value=}, {self.acc=}") -def create_wrapper_sync( - name: str, -) -> Callable[[Callable], Callable]: + if self.acc is None: + self.acc = "" + if isinstance(value, MessageStopEvent): + self.acc = value.message + + def after_yield_all(self, call: Call) -> None: + call.output = self.acc + + +def create_wrapper_sync(name: str) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - "We need to do this so we can check if `stream` is used" - op = weave.op()(fn) - op.name = name # type: ignore - return add_accumulator( - op, # type: ignore - make_accumulator=lambda inputs: anthropic_accumulator, - should_accumulate=should_use_accumulator, + return weave.op( + fn, + name=name, + callbacks=[AnthropicCallback()], + __should_accumulate=should_accumulate, ) return wrapper @@ -92,9 +100,7 @@ def wrapper(fn: Callable) -> Callable: # Surprisingly, the async `client.chat.completions.create` does not pass # `inspect.iscoroutinefunction`, so we can't dispatch on it and must write # it manually here... -def create_wrapper_async( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async(name: str) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: def _fn_wrapper(fn: Callable) -> Callable: @wraps(fn) @@ -104,84 +110,37 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any: return _async_wrapper "We need to do this so we can check if `stream` is used" - op = weave.op()(_fn_wrapper(fn)) - op.name = name # type: ignore - return add_accumulator( - op, # type: ignore - make_accumulator=lambda inputs: anthropic_accumulator, - should_accumulate=should_use_accumulator, + return weave.op( + _fn_wrapper(fn), + name=name, + callbacks=[AnthropicCallback()], + __should_accumulate=should_accumulate, + ) + + return wrapper + + +def create_wrapper_stream(name: str) -> Callable[[Callable], Callable]: + def wrapper(fn: Callable) -> Callable: + return weave.op( + fn, + name=name, + callbacks=[AnthropicStreamingCallback()], + __should_accumulate=lambda call: True, + __should_use_contextmanager=lambda f: True, ) return wrapper -## This part of the code is for dealing with the other way -## of streaming, by calling Messages.stream -## this has 2 options: event based or text based. -## This code handles both cases by patching the _IteratorWrapper -## and adding a text_stream property to it. - - -def anthropic_stream_accumulator( - acc: Optional["Message"], - value: "MessageStream", -) -> "Message": - from anthropic.lib.streaming._types import MessageStopEvent - - if acc is None: - acc = "" - if isinstance(value, MessageStopEvent): - acc = value.message - return acc - - -class AnthropicIteratorWrapper(_IteratorWrapper): - def __getattr__(self, name: str) -> Any: - """Delegate all other attributes to the wrapped iterator.""" - if name in [ - "_iterator_or_ctx_manager", - "_on_yield", - "_on_error", - "_on_close", - "_on_finished_called", - "_call_on_error_once", - "text_stream", - ]: - return object.__getattribute__(self, name) - return getattr(self._iterator_or_ctx_manager, name) - - def __stream_text__(self) -> Union[Iterator[str], AsyncIterator[str]]: - if isinstance(self._iterator_or_ctx_manager, AsyncIterator): - return self.__async_stream_text__() - else: - return self.__sync_stream_text__() - - def __sync_stream_text__(self) -> Iterator[str]: # type: ignore - for chunk in self: # type: ignore - if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": # type: ignore - yield chunk.delta.text # type: ignore - - async def __async_stream_text__(self) -> AsyncIterator[str]: # type: ignore - async for chunk in self: # type: ignore - if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": # type: ignore - yield chunk.delta.text # type: ignore - - @property - def text_stream(self) -> Union[Iterator[str], AsyncIterator[str]]: - return self.__stream_text__() - - -def create_stream_wrapper( - name: str, -) -> Callable[[Callable], Callable]: +def create_wrapper_async_stream(name: str) -> Callable[[Callable], Callable]: def wrapper(fn: Callable) -> Callable: - op = weave.op()(fn) - op.name = name # type: ignore - return add_accumulator( - op, # type: ignore - make_accumulator=lambda _: anthropic_stream_accumulator, - should_accumulate=lambda _: True, - iterator_wrapper=AnthropicIteratorWrapper, # type: ignore + return weave.op( + fn, + name=name, + callbacks=[AnthropicStreamingCallback()], + __should_accumulate=lambda call: True, + __should_use_contextmanager=lambda f: True, ) return wrapper @@ -189,7 +148,6 @@ def wrapper(fn: Callable) -> Callable: anthropic_patcher = MultiPatcher( [ - # Patch the sync messages.create method for all messages.create methods SymbolPatcher( lambda: importlib.import_module("anthropic.resources.messages"), "Messages.create", @@ -203,12 +161,12 @@ def wrapper(fn: Callable) -> Callable: SymbolPatcher( lambda: importlib.import_module("anthropic.resources.messages"), "Messages.stream", - create_stream_wrapper(name="anthropic.Messages.stream"), + create_wrapper_stream(name="anthropic.Messages.stream"), ), SymbolPatcher( lambda: importlib.import_module("anthropic.resources.messages"), "AsyncMessages.stream", - create_stream_wrapper(name="anthropic.AsyncMessages.stream"), + create_wrapper_async_stream(name="anthropic.AsyncMessages.stream"), ), ] )