Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cjao committed Feb 28, 2024
1 parent bdecc52 commit d112463
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def mock_get_incoming_edges(dispatch_id, node_id):

# dict-type inputs

# Nodes 0=task, 1=:electron_dict:, 2=1, 3=2
# Nodes 0=task, 1=:electron_dict:, 2=["a" (3), "b" (4)], 5=[1 (6), 2 (7)]
dict_workflow.build_graph({"a": 1, "b": 2})
abstract_args = {"a": 2, "b": 3}
tg = dict_workflow.transport_graph
Expand All @@ -173,9 +173,23 @@ async def mock_get_incoming_edges(dispatch_id, node_id):
)

task_inputs = await _get_abstract_task_inputs(
result_object.dispatch_id, 1, tg.get_node_value(1, "name")
result_object.dispatch_id, 1, tg.get_node_value(2, "name")
)
expected_inputs = {"args": [2, 5], "kwargs": {}}

assert task_inputs == expected_inputs

task_inputs = await _get_abstract_task_inputs(
result_object.dispatch_id, 2, tg.get_node_value(2, "name")
)
expected_inputs = {"args": [3, 4], "kwargs": {}}

assert task_inputs == expected_inputs

task_inputs = await _get_abstract_task_inputs(
result_object.dispatch_id, 5, tg.get_node_value(5, "name")
)
expected_inputs = {"args": [], "kwargs": abstract_args}
expected_inputs = {"args": [6, 7], "kwargs": {}}

assert task_inputs == expected_inputs

Expand Down
21 changes: 16 additions & 5 deletions tests/covalent_dispatcher_tests/_core/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,20 +167,31 @@ def multivar_workflow(x, y):
# dict-type inputs

dict_workflow.build_graph({"a": 1, "b": 2})
serialized_args = {"a": ct.TransportableObject(1), "b": ct.TransportableObject(2)}

# Nodes 0=task, 1=:electron_dict:, 2=1, 3=2
# Nodes 0=task, 1=:electron_dict:, 2=["a" (3), "b" (4)], 5=[1 (6), 2 (7)]

sdkres = Result(lattice=dict_workflow, dispatch_id="asdf_dict_workflow")
result_object = get_mock_srvresult(sdkres, test_db)
tg = result_object.lattice.transport_graph
tg.set_node_value(2, "output", ct.TransportableObject(1))
tg.set_node_value(3, "output", ct.TransportableObject(2))
tg.set_node_value(3, "output", ct.TransportableObject("a"))
tg.set_node_value(4, "output", ct.TransportableObject("b"))
tg.set_node_value(6, "output", ct.TransportableObject(1))
tg.set_node_value(7, "output", ct.TransportableObject(2))

mock_get_result = mocker.patch(
"covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object
)
task_inputs = await _get_task_inputs(1, tg.get_node_value(1, "name"), result_object)
expected_inputs = {"args": [], "kwargs": serialized_args}

serialized_args = [ct.TransportableObject("a"), ct.TransportableObject("b")]
task_inputs = await _get_task_inputs(2, tg.get_node_value(1, "name"), result_object)
expected_inputs = {"args": serialized_args, "kwargs": {}}

assert task_inputs == expected_inputs

serialized_args = [ct.TransportableObject(1), ct.TransportableObject(2)]
task_inputs = await _get_task_inputs(5, tg.get_node_value(1, "name"), result_object)
expected_inputs = {"args": serialized_args, "kwargs": {}}

assert task_inputs == expected_inputs

Expand Down
28 changes: 20 additions & 8 deletions tests/covalent_tests/workflow/electron_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,18 +377,30 @@ def workflow(x):
g = workflow.transport_graph._graph

# Account for postprocessing node
assert list(g.nodes) == [0, 1, 2, 3, 4]
assert list(g.nodes) == [0, 1, 2, 3, 4, 5, 6, 7, 8]
fn = g.nodes[1]["function"].get_deserialized()
assert fn(x=2, y=5, z=7) == {"x": 2, "y": 5, "z": 7}
assert g.nodes[2]["value"].get_deserialized() == 5
assert g.nodes[3]["value"].get_deserialized() == 7
assert fn(["x", "y", "z"], [2, 5, 7]) == {"x": 2, "y": 5, "z": 7}
fn = g.nodes[2]["function"].get_deserialized()
assert fn("x", "y") == ["x", "y"]
keys = [g.nodes[3]["value"].get_deserialized(), g.nodes[4]["value"].get_deserialized()]
fn = g.nodes[5]["function"].get_deserialized()
assert fn(2, 3) == [2, 3]
vals = [g.nodes[6]["value"].get_deserialized(), g.nodes[7]["value"].get_deserialized()]
assert keys == ["x", "y"]
assert vals == [5, 7]
assert set(g.edges) == {
(1, 0, 0),
(2, 1, 0),
(3, 1, 0),
(0, 4, 0),
(0, 4, 1),
(1, 4, 0),
(3, 2, 0),
(4, 2, 0),
(5, 1, 0),
(6, 5, 0),
(7, 5, 0),
(0, 8, 0),
(0, 8, 1),
(1, 8, 0),
(2, 8, 0),
(5, 8, 0),
}


Expand Down
3 changes: 2 additions & 1 deletion tests/functional_tests/workflow_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,8 @@ def workflow(x):
res_1 = sum_values(x)
return square(res_1)

dispatch_id = ct.dispatch(workflow)({"x": 1, "y": 2, "z": 3})
# Check that non-string keys are allowed
dispatch_id = ct.dispatch(workflow)({"x": 1, "y": 2, 3: 3})

res_obj = rm.get_result(dispatch_id, wait=True)

Expand Down

0 comments on commit d112463

Please sign in to comment.