Skip to content

Commit

Permalink
chore: Test all coroutine patterns and refactor op call (#2684)
Browse files Browse the repository at this point in the history
* tests!

* fixed it up

* small fix

* comments

* comments

* comments
  • Loading branch information
tssweeney authored Oct 11, 2024
1 parent 998cb18 commit e8b1f40
Show file tree
Hide file tree
Showing 4 changed files with 348 additions and 50 deletions.
220 changes: 220 additions & 0 deletions tests/trace/test_op_coroutines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import asyncio
from typing import Coroutine

import pytest

import weave
from weave.trace.weave_client import Call


def test_sync_val(client):
@weave.op()
def sync_val():
return 1

res = sync_val()
assert res == 1
res, call = sync_val.call()
assert isinstance(call, Call)
assert res == 1


def test_sync_val_method(client):
class TestClass:
@weave.op()
def sync_val(self):
return 1

test_inst = TestClass()
res = test_inst.sync_val()
assert res == 1
res, call = test_inst.sync_val.call(test_inst)
assert isinstance(call, Call)
assert res == 1


@pytest.mark.asyncio
async def test_sync_coro(client):
@weave.op()
def sync_coro():
return asyncio.to_thread(lambda: 1)

res = sync_coro()
assert isinstance(res, Coroutine)
assert await res == 1
res, call = sync_coro.call()
assert isinstance(call, Call)
assert isinstance(res, Coroutine)
assert await res == 1


@pytest.mark.asyncio
async def test_sync_coro_method(client):
class TestClass:
@weave.op()
def sync_coro(self):
return asyncio.to_thread(lambda: 1)

test_inst = TestClass()
res = test_inst.sync_coro()
assert isinstance(res, Coroutine)
assert await res == 1
res, call = test_inst.sync_coro.call(test_inst)
assert isinstance(call, Call)
assert isinstance(res, Coroutine)
assert await res == 1


@pytest.mark.asyncio
async def test_async_coro(client):
@weave.op()
async def async_coro():
return asyncio.to_thread(lambda: 1)

res = async_coro()
assert isinstance(res, Coroutine)
res2 = await res
assert isinstance(res2, Coroutine)
assert await res2 == 1
res, call = await async_coro.call()
assert isinstance(call, Call)
assert isinstance(res, Coroutine)
assert await res == 1


@pytest.mark.asyncio
async def test_async_coro_method(client):
class TestClass:
@weave.op()
async def async_coro(self):
return asyncio.to_thread(lambda: 1)

test_inst = TestClass()

res = test_inst.async_coro()
assert isinstance(res, Coroutine)
res2 = await res
assert isinstance(res2, Coroutine)
assert await res2 == 1
res, call = await test_inst.async_coro.call(test_inst)
assert isinstance(call, Call)
assert isinstance(res, Coroutine)
assert await res == 1


@pytest.mark.asyncio
async def test_async_awaited_coro(client):
@weave.op()
async def async_awaited_coro():
return await asyncio.to_thread(lambda: 1)

res = async_awaited_coro()
assert isinstance(res, Coroutine)
assert await res == 1
res, call = await async_awaited_coro.call()
assert isinstance(call, Call)
assert res == 1


@pytest.mark.asyncio
async def test_async_awaited_coro_method(client):
class TestClass:
@weave.op()
async def async_awaited_coro(self):
return await asyncio.to_thread(lambda: 1)

test_inst = TestClass()
res = test_inst.async_awaited_coro()
assert isinstance(res, Coroutine)
assert await res == 1
res, call = await test_inst.async_awaited_coro.call(test_inst)
assert isinstance(call, Call)
assert res == 1


@pytest.mark.asyncio
async def test_async_val(client):
@weave.op()
async def async_val():
return 1

res = async_val()
assert isinstance(res, Coroutine)
assert await res == 1
res, call = await async_val.call()
assert isinstance(call, Call)
assert res == 1


@pytest.mark.asyncio
async def test_async_val_method(client):
class TestClass:
@weave.op()
async def async_val(self):
return 1

test_inst = TestClass()
res = test_inst.async_val()
assert isinstance(res, Coroutine)
assert await res == 1
res, call = await test_inst.async_val.call(test_inst)
assert isinstance(call, Call)
assert res == 1


def test_sync_with_exception(client):
@weave.op()
def sync_with_exception():
raise ValueError("test")

with pytest.raises(ValueError, match="test"):
sync_with_exception()
res, call = sync_with_exception.call()
assert isinstance(call, Call)
assert call.exception is not None
assert res is None


def test_sync_with_exception_method(client):
class TestClass:
@weave.op()
def sync_with_exception(self):
raise ValueError("test")

test_inst = TestClass()
with pytest.raises(ValueError, match="test"):
test_inst.sync_with_exception()
res, call = test_inst.sync_with_exception.call(test_inst)
assert isinstance(call, Call)
assert call.exception is not None
assert res is None


@pytest.mark.asyncio
async def test_async_with_exception(client):
@weave.op()
async def async_with_exception():
raise ValueError("test")

with pytest.raises(ValueError, match="test"):
await async_with_exception()
res, call = await async_with_exception.call()
assert isinstance(call, Call)
assert call.exception is not None
assert res is None


@pytest.mark.asyncio
async def test_async_with_exception_method(client):
class TestClass:
@weave.op()
async def async_with_exception(self):
raise ValueError("test")

test_inst = TestClass()
with pytest.raises(ValueError, match="test"):
await test_inst.async_with_exception()
res, call = await test_inst.async_with_exception.call(test_inst)
assert isinstance(call, Call)
assert call.exception is not None
assert res is None
2 changes: 1 addition & 1 deletion tests/trace/test_op_decorator_behaviour.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_sync_method_call(client, weave_obj, py_obj):
weave_obj_method2 = weave_obj_method_ref.get()

with pytest.raises(errors.OpCallError):
res2, call2 = py_obj.amethod.call(1)
res2, call2 = py_obj.method.call(1)


@pytest.mark.asyncio
Expand Down
Loading

0 comments on commit e8b1f40

Please sign in to comment.