Skip to content

Commit

Permalink
added local executor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kessler-frost committed Nov 9, 2023
1 parent e66cd82 commit 1ea6d3e
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 25 deletions.
12 changes: 1 addition & 11 deletions covalent/executor/executor_plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from enum import Enum
from typing import Any, Callable, Dict, List, Optional

import requests
from pydantic import BaseModel

from covalent._shared_files import TaskCancelledError, TaskRuntimeError, logger
Expand Down Expand Up @@ -169,7 +170,6 @@ def _send(
stdout_uri = os.path.join(self.cache_dir, f"stdout_{dispatch_id}-{node_id}.txt")
stderr_uri = os.path.join(self.cache_dir, f"stderr_{dispatch_id}-{node_id}.txt")
output_uris.append((result_uri, stdout_uri, stderr_uri))
# future = dask_client.submit(lambda x: x**3, 3)

server_url = format_server_url()

Expand All @@ -185,8 +185,6 @@ def _send(
)

def handle_cancelled(fut):
import requests

app_log.debug(f"In done callback for {dispatch_id}:{gid}, future {fut}")
if fut.cancelled():
for task_id in task_ids:
Expand All @@ -195,8 +193,6 @@ def handle_cancelled(fut):

future.add_done_callback(handle_cancelled)

return 42

def _receive(self, task_group_metadata: Dict, data: Any) -> List[TaskUpdate]:
# Returns (output_uri, stdout_uri, stderr_uri,
# exception_raised)
Expand All @@ -207,9 +203,6 @@ def _receive(self, task_group_metadata: Dict, data: Any) -> List[TaskUpdate]:

task_results = []

# if len(task_ids) > 1:
# raise RuntimeError("Task packing is not yet supported")

for task_id in task_ids:
# Handle the case where the job was cancelled before the task started running
app_log.debug(f"Receive called for task {dispatch_id}:{task_id} with data {data}")
Expand Down Expand Up @@ -276,6 +269,3 @@ async def receive(self, task_group_metadata: Dict, data: Any) -> List[TaskUpdate
task_group_metadata,
data,
)

def get_upload_uri(self, task_group_metadata: Dict, object_key: str):
return ""
28 changes: 17 additions & 11 deletions tests/covalent_dispatcher_tests/_core/runner_ng_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@


import asyncio
import datetime
import sys
from unittest.mock import AsyncMock

import pytest
Expand Down Expand Up @@ -79,8 +81,6 @@ def test_db():
def get_mock_result() -> Result:
"""Construct a mock result object corresponding to a lattice."""

import sys

@ct.electron(executor="local")
def task(x):
print(f"stdout: {x}")
Expand Down Expand Up @@ -109,11 +109,17 @@ def get_mock_srvresult(sdkres, test_db) -> SRVResult:


@pytest.mark.asyncio
async def test_submit_abstract_task_group(mocker):
import datetime

@pytest.mark.parametrize(
"task_cancelled",
[False, True],
)
async def test_submit_abstract_task_group(mocker, task_cancelled):
me = MockManagedExecutor()
me.send = AsyncMock(return_value="42")

if task_cancelled:
me.send = AsyncMock(side_effect=TaskCancelledError())
else:
me.send = AsyncMock(return_value="42")

mocker.patch(
"covalent_dispatcher._core.runner_ng.datamgr.electron.get",
Expand Down Expand Up @@ -242,7 +248,11 @@ async def test_submit_abstract_task_group(mocker):
ResourceMap(**resources),
task_group_metadata,
)
assert send_retval == "42"

if task_cancelled:
assert send_retval is None
else:
assert send_retval == "42"


@pytest.mark.asyncio
Expand Down Expand Up @@ -300,8 +310,6 @@ async def test_submit_requires_opt_in(mocker):

@pytest.mark.asyncio
async def test_get_task_result(mocker):
import datetime

me = MockManagedExecutor()
asset_uri = "file:///tmp/asset.pkl"
mock_task_result = {
Expand Down Expand Up @@ -436,8 +444,6 @@ async def test_poll_status(mocker):

@pytest.mark.asyncio
async def test_event_listener(mocker):
import datetime

ts = datetime.datetime.now()
node_result = {
"node_id": 0,
Expand Down
177 changes: 174 additions & 3 deletions tests/covalent_tests/executor/executor_plugins/local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
import tempfile
from functools import partial
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock, patch

import pytest

Expand All @@ -31,10 +31,13 @@
from covalent._workflow.transport import TransportableObject
from covalent.executor.executor_plugins.local import (
_EXECUTOR_PLUGIN_DEFAULTS,
RESULT_STATUS,
LocalExecutor,
StatusEnum,
TaskSpec,
run_task_from_uris,
)
from covalent.executor.schemas import ResourceMap
from covalent.executor.utils.serialize import serialize_node_asset
from covalent.executor.utils.wrappers import wrapper_fn

Expand Down Expand Up @@ -242,7 +245,7 @@ def test_run_task_from_uris(mocker):
def task(x, y):
return x + y

dispatch_id = "test_dask_send_receive"
dispatch_id = "test_local_send_receive"
node_id = 0
task_group_id = 0
server_url = "http://localhost:48008"
Expand Down Expand Up @@ -371,7 +374,7 @@ def test_run_task_from_uris_exception(mocker):
def task(x, y):
assert False

dispatch_id = "test_dask_send_receive"
dispatch_id = "test_local_send_receive"
node_id = 0
task_group_id = 0
server_url = "http://localhost:48008"
Expand Down Expand Up @@ -486,3 +489,171 @@ def mock_req_post(url, files):
with open(summary_file_path, "r") as f:
summary = json.load(f)
assert summary["exception_occurred"] is True


# Mocks for external dependencies
@pytest.fixture
def mock_os_path_join():
with patch("os.path.join", return_value="mock_path") as mock:
yield mock


@pytest.fixture
def mock_format_server_url():
with patch(
"covalent.executor.executor_plugins.local.format_server_url",
return_value="mock_server_url",
) as mock:
yield mock


@pytest.fixture
def mock_future():
mock = Mock()
mock.cancelled.return_value = False
return mock


@pytest.fixture
def mock_proc_pool_submit(mock_future):
with patch(
"covalent.executor.executor_plugins.local.proc_pool.submit", return_value=mock_future
) as mock:
yield mock


# Test cases
test_cases = [
# Happy path
{
"id": "happy_path",
"task_specs": [
TaskSpec(
function_id=0,
args_ids=[1],
kwargs_ids={"y": 2},
deps_id="deps",
call_before_id="call_before",
call_after_id="call_after",
)
],
"resources": ResourceMap(
functions={0: "mock_function_uri"},
inputs={1: "mock_input_uri"},
deps={"deps": "mock_deps_uri"},
),
"task_group_metadata": {"dispatch_id": "1", "node_ids": ["1"], "task_group_id": "1"},
"expected_output_uris": [("mock_path", "mock_path", "mock_path")],
"expected_server_url": "mock_server_url",
"expected_future_cancelled": False,
},
{
"id": "future_cancelled",
"task_specs": [
TaskSpec(
function_id=0,
args_ids=[1],
kwargs_ids={"y": 2},
deps_id="deps",
call_before_id="call_before",
call_after_id="call_after",
)
],
"resources": ResourceMap(
functions={0: "mock_function_uri"},
inputs={1: "mock_input_uri"},
deps={"deps": "mock_deps_uri"},
),
"task_group_metadata": {"dispatch_id": "1", "node_ids": ["1"], "task_group_id": "1"},
"expected_output_uris": [("mock_path", "mock_path", "mock_path")],
"expected_server_url": "mock_server_url",
"expected_future_cancelled": True,
},
]


@pytest.mark.parametrize("test_case", test_cases, ids=[tc["id"] for tc in test_cases])
def test_send(
test_case,
mock_os_path_join,
mock_format_server_url,
mock_future,
mock_proc_pool_submit,
):
"""Test the send function of LocalExecutor"""

local_exec = LocalExecutor()

# Arrange
local_exec.cache_dir = "mock_cache_dir"
mock_future.cancelled.return_value = test_case["expected_future_cancelled"]

# Act
local_exec._send(
test_case["task_specs"],
test_case["resources"],
test_case["task_group_metadata"],
)

# Assert
mock_os_path_join.assert_called()
mock_format_server_url.assert_called_once_with()
mock_proc_pool_submit.assert_called_once_with(
run_task_from_uris,
list(map(lambda t: t.dict(), test_case["task_specs"])),
test_case["resources"].dict(),
test_case["expected_output_uris"],
"mock_cache_dir",
test_case["task_group_metadata"],
test_case["expected_server_url"],
)


# Test data
test_data = [
# Happy path tests
{
"id": "HP1",
"task_group_metadata": {"dispatch_id": "1", "node_ids": ["1", "2"]},
"data": {"status": StatusEnum.COMPLETED},
"expected_status": StatusEnum.COMPLETED,
},
{
"id": "HP2",
"task_group_metadata": {"dispatch_id": "2", "node_ids": ["3", "4"]},
"data": {"status": StatusEnum.FAILED},
"expected_status": StatusEnum.FAILED,
},
# Edge case tests
{
"id": "EC1",
"task_group_metadata": {"dispatch_id": "3", "node_ids": []},
"data": {"status": StatusEnum.COMPLETED},
"expected_status": StatusEnum.COMPLETED,
},
{
"id": "EC2",
"task_group_metadata": {"dispatch_id": "4", "node_ids": ["5"]},
"data": None,
"expected_status": RESULT_STATUS.CANCELLED,
},
]


@pytest.mark.parametrize("test_case", test_data, ids=[tc["id"] for tc in test_data])
def test_receive(test_case):
"""Test the receive function of LocalExecutor"""

local_exec = LocalExecutor()

# Arrange
task_group_metadata = test_case["task_group_metadata"]
data = test_case["data"]
expected_status = test_case["expected_status"]

# Act
task_results = local_exec._receive(task_group_metadata, data)

# Assert
for task_result in task_results:
assert task_result.status == expected_status

0 comments on commit 1ea6d3e

Please sign in to comment.