Skip to content

Commit

Permalink
1 more escape hatch
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong committed Nov 20, 2024
1 parent 346ad69 commit 8b1b01c
Showing 1 changed file with 68 additions and 25 deletions.
93 changes: 68 additions & 25 deletions weave/trace/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import traceback
from collections.abc import Coroutine, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from types import MethodType
Expand Down Expand Up @@ -519,6 +520,8 @@ def _is_context_manager(obj: Any) -> bool:
__enter__ = getattr(obj, "__enter__", None)
__exit__ = getattr(obj, "__exit__", None)

print(__enter__, __exit__)

return callable(__enter__) and callable(__exit__)


Expand All @@ -528,6 +531,7 @@ def _execute_op(
__weave: WeaveKwargs | None = None,
__should_raise: bool = True,
__should_accumulate: bool = False,
__should_use_contextmanager: bool = False,
**kwargs: Any,
) -> tuple[Any, Call]:
func = __op.resolve_fn
Expand Down Expand Up @@ -559,38 +563,79 @@ def _execute_op(
display_name=call_time_display_name or __op.call_display_name,
attributes=attributes,
)
__op.lifecycle_handler.run_before_call({}, None, None, "")
__op.lifecycle_handler.run_before_call(inputs, parent_call, attributes)

should_accumulate = __should_accumulate and __should_accumulate(call)
should_use_contextmanager = (
__should_use_contextmanager and __should_use_contextmanager(func)
)

if is_generator or should_accumulate:
# TODO: Not sure if this is the right pattern. This hack is only
# used to support the Anthropic streaming context streaming case atm...
if should_use_contextmanager:

@wraps(func)
def _wrapped_sync_generator():
try:
for val in func(*args, **kwargs):
__op.lifecycle_handler.run_before_yield(call, val)
yield val
__op.lifecycle_handler.run_after_yield(call, val)
except Exception as e:
exception = e
if __should_raise:
raise
else:
@contextmanager
@wraps(func)
def _wrapped_context_manager_yields_sync_generator():
exception = None
__op.lifecycle_handler.run_after_yield_all(call)
finally:
with func(*args, **kwargs) as original_gen:

def _wrapped_sync_generator():
nonlocal exception
try:
for val in original_gen:
__op.lifecycle_handler.run_before_yield(call, val)
yield val
__op.lifecycle_handler.run_after_yield(call, val)
except Exception as e:
exception = e
if __should_raise:
raise
else:
__op.lifecycle_handler.run_after_yield_all(call)

yield _wrapped_sync_generator()

if __op.lifecycle_handler.has_finished:
raise OpCallError("Should not call finish more than once")
boxed_output = box.box(call.output)
client.finish_call(call, boxed_output, exception=exception, op=__op)
# TODO: We choose to print the call link at the end, but for streaming cases it might be helpful to:
# 1. Print the link at the beginning of the stream
# 2. (maybe?): Update the call object periodically with the latest output
if not call_context.get_current_call():
print_call_link(call)
call_context.pop_call(call.id)

res = _wrapped_sync_generator()
res = _wrapped_context_manager_yields_sync_generator()

else:

@wraps(func)
def _wrapped_sync_generator():
try:
for val in func(*args, **kwargs):
__op.lifecycle_handler.run_before_yield(call, val)
yield val
__op.lifecycle_handler.run_after_yield(call, val)
except Exception as e:
exception = e
if __should_raise:
raise
else:
exception = None
__op.lifecycle_handler.run_after_yield_all(call)
finally:
if __op.lifecycle_handler.has_finished:
raise OpCallError("Should not call finish more than once")
boxed_output = box.box(call.output)
client.finish_call(call, boxed_output, exception=exception, op=__op)
# TODO: We choose to print the call link at the end, but for streaming cases it might be helpful to:
# 1. Print the link at the beginning of the stream
# 2. (maybe?): Update the call object periodically with the latest output
if not call_context.get_current_call():
print_call_link(call)
call_context.pop_call(call.id)

res = _wrapped_sync_generator()
else:
# Regular sync function
try:
Expand Down Expand Up @@ -650,8 +695,6 @@ def op(
postprocess_output: PostprocessOutputFunc | None = None,
callbacks: list[Callback] | None = None,
reducers: list[Reducer] | None = None,
__should_accumulate: Callable[[Call], bool]
| None = None, # escape hatch for integrations
) -> Op: ...


Expand All @@ -664,8 +707,6 @@ def op(
postprocess_output: PostprocessOutputFunc | None = None,
callbacks: list[Callback] | None = None,
reducers: list[Reducer] | None = None,
__should_accumulate: Callable[[Call], bool]
| None = None, # escape hatch for integrations
) -> Callable[[Callable], Op]: ...


Expand All @@ -678,8 +719,9 @@ def op(
postprocess_output: PostprocessOutputFunc | None = None,
callbacks: list[Callback] | None = None,
reducers: list[Reducer] | None = None,
__should_accumulate: Callable[[Call], bool]
| None = None, # escape hatch for integrations
# escape hatch for integrations
__should_accumulate: Callable[[Call], bool] | None = None,
__should_use_contextmanager: Callable[[Callable], bool] | None = None,
) -> Callable[[Callable], Op] | Op:
"""
A decorator to weave op-ify a function or method. Works for both sync and async.
Expand Down Expand Up @@ -768,6 +810,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
*args,
__should_raise=True,
__should_accumulate=__should_accumulate,
__should_use_contextmanager=__should_use_contextmanager,
**kwargs,
)
return res
Expand Down

0 comments on commit 8b1b01c

Please sign in to comment.