Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
adrnswanberg committed Dec 11, 2024
1 parent 7ce644f commit 0cd0fc2
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 110 deletions.
68 changes: 34 additions & 34 deletions tests/trace/test_client_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_client_project_id(client: weave_client.WeaveClient) -> str:


def test_simple_op(client):
@weave.op()
@weave.op
def my_op(a: int) -> int:
return a + 1

Expand Down Expand Up @@ -229,7 +229,7 @@ def test_call_read_not_found(client):


def test_graph_call_ordering(client):
@weave.op()
@weave.op
def my_op(a: int) -> int:
return a + 1

Expand Down Expand Up @@ -263,27 +263,27 @@ def simple_line_call_bootstrap(init_wandb: bool = False) -> OpCallSpec:
class Number(weave.Object):
value: int

@weave.op()
@weave.op
def adder(a: Number) -> Number:
return Number(value=a.value + a.value)

adder_v0 = adder

@weave.op() # type: ignore
@weave.op # type: ignore
def adder(a: Number, b) -> Number:
return Number(value=a.value + b)

@weave.op()
@weave.op
def subtractor(a: Number, b) -> Number:
return Number(value=a.value - b)

@weave.op()
@weave.op
def multiplier(
a: Number, b
) -> int: # intentionally deviant in returning plain int - so that we have a different type
return a.value * b

@weave.op()
@weave.op
def liner(m: Number, b, x) -> Number:
return adder(Number(value=multiplier(m, x)), b)

Expand Down Expand Up @@ -691,7 +691,7 @@ def test_trace_call_query_offset(client):


def test_trace_call_sort(client):
@weave.op()
@weave.op
def basic_op(in_val: dict, delay) -> dict:
import time

Expand Down Expand Up @@ -727,7 +727,7 @@ def test_trace_call_sort_with_mixed_types(client):
# SQLite does not support sorting over mixed types in a column, so we skip this test
return

@weave.op()
@weave.op
def basic_op(in_val: dict) -> dict:
import time

Expand Down Expand Up @@ -769,7 +769,7 @@ def basic_op(in_val: dict) -> dict:
def test_trace_call_filter(client):
is_sqlite = client_is_sqlite(client)

@weave.op()
@weave.op
def basic_op(in_val: dict, delay) -> dict:
return in_val

Expand Down Expand Up @@ -1160,7 +1160,7 @@ def basic_op(in_val: dict, delay) -> dict:


def test_ops_with_default_params(client):
@weave.op()
@weave.op
def op_with_default(a: int, b: int = 10) -> int:
return a + b

Expand Down Expand Up @@ -1234,7 +1234,7 @@ class BaseTypeC(BaseTypeB):


def test_attributes_on_ops(client):
@weave.op()
@weave.op
def op_with_attrs(a: int, b: int) -> int:
return a + b

Expand Down Expand Up @@ -1277,7 +1277,7 @@ def test_dataclass_support(client):
class MyDataclass:
val: int

@weave.op()
@weave.op
def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass:
return MyDataclass(a.val + b.val)

Expand Down Expand Up @@ -1322,7 +1322,7 @@ def dataclass_maker(a: MyDataclass, b: MyDataclass) -> MyDataclass:


def test_op_retrieval(client):
@weave.op()
@weave.op
def my_op(a: int) -> int:
return a + 1

Expand All @@ -1336,7 +1336,7 @@ def test_bound_op_retrieval(client):
class CustomType(weave.Object):
a: int

@weave.op()
@weave.op
def op_with_custom_type(self, v):
return self.a + v

Expand All @@ -1359,7 +1359,7 @@ def test_bound_op_retrieval_no_self(client):
class CustomTypeWithoutSelf(weave.Object):
a: int

@weave.op()
@weave.op
def op_with_custom_type(me, v):
return me.a + v

Expand Down Expand Up @@ -1387,7 +1387,7 @@ def test_dataset_row_ref(client):


def test_tuple_support(client):
@weave.op()
@weave.op
def tuple_maker(a, b):
return (a, b)

Expand All @@ -1411,7 +1411,7 @@ def tuple_maker(a, b):


def test_namedtuple_support(client):
@weave.op()
@weave.op
def tuple_maker(a, b):
return (a, b)

Expand Down Expand Up @@ -1442,7 +1442,7 @@ def test_named_reuse(client):
d_ref = weave.publish(d, "test_dataset")
dataset = weave.ref(d_ref.uri()).get()

@weave.op()
@weave.op
async def dummy_score(output):
return 1

Expand Down Expand Up @@ -1489,7 +1489,7 @@ class MyUnknownClassB:
def __init__(self, b_val) -> None:
self.b_val = b_val

@weave.op()
@weave.op
def op_with_unknown_types(a: MyUnknownClassA, b: float) -> MyUnknownClassB:
return MyUnknownClassB(a.a_val + b)

Expand Down Expand Up @@ -1564,19 +1564,19 @@ def init_weave_get_server_patched(api_key):


def test_single_primitive_output(client):
@weave.op()
@weave.op
def single_int_output(a: int) -> int:
return a

@weave.op()
@weave.op
def single_bool_output(a: int) -> bool:
return a == 1

@weave.op()
@weave.op
def single_none_output(a: int) -> None:
return None

@weave.op()
@weave.op
def dict_output(a: int, b: bool, c: None) -> dict:
return {"a": a, "b": b, "c": c}

Expand Down Expand Up @@ -1669,30 +1669,30 @@ def test_mapped_execution(client, mapper):

events = []

@weave.op()
@weave.op
def op_a(a: int) -> int:
events.append("A(S):" + str(a))
time.sleep(0.3)
events.append("A(E):" + str(a))
return a

@weave.op()
@weave.op
def op_b(b: int) -> int:
events.append("B(S):" + str(b))
time.sleep(0.2)
res = op_a(b)
events.append("B(E):" + str(b))
return res

@weave.op()
@weave.op
def op_c(c: int) -> int:
events.append("C(S):" + str(c))
time.sleep(0.1)
res = op_b(c)
events.append("C(E):" + str(c))
return res

@weave.op()
@weave.op
def op_mapper(vals):
return mapper(op_c, vals)

Expand Down Expand Up @@ -2238,7 +2238,7 @@ def calculate(a: int, b: int) -> dict[str, Any]:

@pytest.mark.skip("Not implemented: filter / sort through refs")
def test_sort_and_filter_through_refs(client):
@weave.op()
@weave.op
def test_op(label, val):
return val

Expand Down Expand Up @@ -2356,7 +2356,7 @@ def test_obj(val):


def test_in_operation(client):
@weave.op()
@weave.op
def test_op(label, val):
return val

Expand Down Expand Up @@ -2501,7 +2501,7 @@ def func(x):


class BasicModel(weave.Model):
@weave.op()
@weave.op
def predict(self, x):
return {"answer": "42"}

Expand Down Expand Up @@ -2547,7 +2547,7 @@ class SimpleObject(weave.Object):
class NestedObject(weave.Object):
b: SimpleObject

@weave.op()
@weave.op
def return_nested_object(nested_obj: NestedObject):
return nested_obj

Expand Down Expand Up @@ -3099,7 +3099,7 @@ def test_op_sampling_inheritance(client):
parent_calls = 0
child_calls = 0

@weave.op()
@weave.op
def child_op(x: int) -> int:
nonlocal child_calls
child_calls += 1
Expand Down Expand Up @@ -3135,7 +3135,7 @@ def test_op_sampling_inheritance_async(client):
parent_calls = 0
child_calls = 0

@weave.op()
@weave.op
async def child_op(x: int) -> int:
nonlocal child_calls
child_calls += 1
Expand Down
Loading

0 comments on commit 0cd0fc2

Please sign in to comment.