Skip to content

Commit

Permalink
Mem (2/3): redirect dispatcher to in-memory runner
Browse files Browse the repository at this point in the history
Make API endpoints restful

Cancel all dispatches upon shutdown
  • Loading branch information
cjao committed Jul 11, 2023
1 parent c7b87ed commit f5fc7fb
Show file tree
Hide file tree
Showing 14 changed files with 411 additions and 134 deletions.
19 changes: 19 additions & 0 deletions covalent/_api/apiclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,25 @@ def post(self, endpoint: str, **kwargs):

return r

def delete(self, endpoint: str, **kwargs):
headers = CovalentAPIClient.get_extra_headers()
url = self.dispatcher_addr + endpoint
try:
with requests.Session() as session:
if self.adapter:
session.mount("http://", self.adapter)

r = session.delete(url, headers=headers, **kwargs)

if self.auto_raise:
r.raise_for_status()
except requests.exceptions.ConnectionError:
message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage."
print(message)
raise

return r

@classmethod
def get_extra_headers(headers: Dict) -> Dict:
# This is expected to be a JSONified dictionary
Expand Down
8 changes: 4 additions & 4 deletions covalent/_dispatcher_plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def start(
if dispatcher_addr is None:
dispatcher_addr = format_server_url()

endpoint = f"/api/v1/dispatch/start/{dispatch_id}"
r = APIClient(dispatcher_addr).put(endpoint)
endpoint = f"/api/v1/dispatch/{dispatch_id}"
r = APIClient(dispatcher_addr).post(endpoint)
r.raise_for_status()
return r.content.decode("utf-8").strip().replace('"', "")

Expand Down Expand Up @@ -546,7 +546,7 @@ def register_manifest(
else:
stripped = manifest

endpoint = "/api/v1/dispatch/register"
endpoint = "/api/v1/dispatch"

if parent_dispatch_id:
endpoint = f"{endpoint}?parent_dispatch_id={parent_dispatch_id}"
Expand Down Expand Up @@ -578,7 +578,7 @@ def register_derived_manifest(
# We don't yet support pulling assets for redispatch
stripped = strip_local_uris(manifest)

endpoint = f"/api/v1/dispatch/register/{dispatch_id}"
endpoint = f"/api/v1/dispatch/{dispatch_id}/redispatch"

params = {"reuse_previous_results": reuse_previous_results}
r = APIClient(dispatcher_addr).post(endpoint, data=stripped.json(), params=params)
Expand Down
6 changes: 3 additions & 3 deletions covalent/_results_manager/results_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str =
task_ids = []

api_client = CovalentAPIClient(dispatcher_addr)
endpoint = "/api/v1/dispatch/cancel"
endpoint = f"/api/v1/dispatch/{dispatch_id}"

if isinstance(task_ids, int):
task_ids = [task_ids]

r = api_client.post(endpoint, json={"dispatch_id": dispatch_id, "task_ids": task_ids})
r = api_client.delete(endpoint, params={"task_ids": task_ids})
return r.content.decode("utf-8").strip().replace('"', "")


Expand Down Expand Up @@ -176,7 +176,7 @@ def _get_result_export_from_dispatcher(
adapter = HTTPAdapter(max_retries=Retry(total=retries, backoff_factor=1))
api_client = CovalentAPIClient(dispatcher_addr, adapter=adapter, auto_raise=False)

endpoint = "/api/v1/dispatch/export/" + dispatch_id
endpoint = f"/api/v1/dispatch/{dispatch_id}"
response = api_client.get(
endpoint,
params={"wait": wait, "status_only": status_only},
Expand Down
30 changes: 24 additions & 6 deletions covalent_dispatcher/_core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from covalent._shared_files.util_classes import RESULT_STATUS

from . import data_manager as datasvc
from . import runner_ng
from . import runner

# from . import runner_ng
from .data_modules import job_manager as jbmgr
from .dispatcher_modules.caches import _pending_parents, _sorted_task_groups, _unresolved_tasks

Expand Down Expand Up @@ -224,13 +226,28 @@ async def _submit_task_group(dispatch_id: str, sorted_nodes: List[int], task_gro
app_log.debug(f"Using new runner for task group {task_group_id}")

known_nodes = list(set(known_nodes))
coro = runner_ng.run_abstract_task_group(

task_spec = task_specs[0]
abstract_inputs = {"args": task_spec["args_ids"], "kwargs": task_spec["kwargs_ids"]}

# Temporarily redirect to in-memory runner (this is incompatible with task packing)
if len(task_specs) > 1:
raise RuntimeError("Task packing is not supported yet.")

coro = runner.run_abstract_task(
dispatch_id=dispatch_id,
task_group_id=task_group_id,
task_seq=task_specs,
known_nodes=known_nodes,
node_id=task_group_id,
node_name=node_name,
abstract_inputs=abstract_inputs,
selected_executor=[selected_executor, selected_executor_data],
)
# coro = runner_ng.run_abstract_task_group(
# dispatch_id=dispatch_id,
# task_group_id=task_group_id,
# task_seq=task_specs,
# known_nodes=known_nodes,
# selected_executor=[selected_executor, selected_executor_data],
# )

asyncio.create_task(coro)
else:
Expand Down Expand Up @@ -337,7 +354,8 @@ async def cancel_dispatch(dispatch_id: str, task_ids: List[int] = []) -> None:
app_log.debug(f"Cancelling dispatch {dispatch_id}")

await jbmgr.set_cancel_requested(dispatch_id, task_ids)
await runner_ng.cancel_tasks(dispatch_id, task_ids)
await runner.cancel_tasks(dispatch_id, task_ids)
# await runner_ng.cancel_tasks(dispatch_id, task_ids)

# Recursively cancel running sublattice dispatches
attrs = await datasvc.electron.get_bulk(dispatch_id, task_ids, ["sub_dispatch_id"])
Expand Down
2 changes: 1 addition & 1 deletion covalent_dispatcher/_core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from covalent._shared_files.config import get_config
from covalent._shared_files.util_classes import RESULT_STATUS
from covalent._workflow import DepsBash, DepsCall, DepsPip
from covalent.executor.utils.wrappers import wrapper_fn
from covalent.executor.base import wrapper_fn

from . import data_manager as datasvc
from .runner_modules import executor_proxy
Expand Down
10 changes: 6 additions & 4 deletions covalent_dispatcher/_dal/utils/uri_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ class URIFilterPolicy(enum.Enum):
def _srv_asset_uri(
uri: str, attrs: dict, scope: AssetScope, dispatch_id: str, node_id: Optional[int], key: str
) -> str:
base_uri = SERVER_URL + f"/api/v1/assets/{dispatch_id}/{scope.value}"
base_uri = SERVER_URL + f"/api/v1/dispatch/{dispatch_id}"

if scope == AssetScope.DISPATCH or scope == AssetScope.LATTICE:
uri = base_uri + f"/{key}"
if scope == AssetScope.DISPATCH:
uri = f"{base_uri}/assets/{key}"
elif scope == AssetScope.LATTICE:
uri = f"{base_uri}/lattice/assets/{key}"
else:
uri = base_uri + f"/{node_id}/{key}"
uri = f"{base_uri}/electron/{node_id}/assets/{key}"
return uri


Expand Down
122 changes: 90 additions & 32 deletions covalent_dispatcher/_service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,27 @@
import asyncio
import json
from contextlib import asynccontextmanager
from typing import Optional, Union
from typing import List, Optional, Union
from uuid import UUID

from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi import APIRouter, FastAPI, HTTPException, Query, Request
from fastapi.responses import JSONResponse

import covalent_dispatcher.entry_point as dispatcher
from covalent._shared_files import logger
from covalent._shared_files.schemas.result import ResultSchema
from covalent._shared_files.util_classes import RESULT_STATUS
from covalent_dispatcher._core import dispatcher as core_dispatcher
from covalent_dispatcher._core import runner_ng as core_runner

from .._dal.exporters.result import export_result_manifest
from .._dal.result import Result, get_result_object
from .._db.datastore import workflow_db
from .._db.dispatchdb import DispatchDB
from .heartbeat import Heartbeat
from .models import ExportResponseSchema

app_log = logger.app_log
log_stack_info = logger.log_stack_info
# from covalent_dispatcher._core import runner_ng as core_runner

router: APIRouter = APIRouter()

app_log = logger.app_log
log_stack_info = logger.log_stack_info
Expand All @@ -65,9 +63,9 @@ async def lifespan(app: FastAPI):
_background_tasks.add(fut)
fut.add_done_callback(_background_tasks.discard)

# Runner event queue and listener
core_runner._job_events = asyncio.Queue()
core_runner._job_event_listener = asyncio.create_task(core_runner._listen_for_job_events())
# # Runner event queue and listener
# core_runner._job_events = asyncio.Queue()
# core_runner._job_event_listener = asyncio.create_task(core_runner._listen_for_job_events())

# Dispatcher event queue and listener
core_dispatcher._global_status_queue = asyncio.Queue()
Expand All @@ -77,8 +75,32 @@ async def lifespan(app: FastAPI):

yield

# Cancel all scheduled and running dispatches
for status in [
RESULT_STATUS.NEW_OBJECT,
RESULT_STATUS.RUNNING,
]:
await cancel_all_with_status(status)

core_dispatcher._global_event_listener.cancel()
core_runner._job_event_listener.cancel()
# core_runner._job_event_listener.cancel()

Heartbeat.stop()


async def cancel_all_with_status(status: RESULT_STATUS):
"""Cancel all dispatches with the specified status."""

with workflow_db.session() as session:
records = Result.get_db_records(
session,
keys=["dispatch_id"],
equality_filters={"status": str(status)},
membership_filters={},
)
for record in records:
dispatch_id = record.attrs["dispatch_id"]
await dispatcher.cancel_running_dispatch(dispatch_id)


@router.post("/dispatch/submit")
Expand Down Expand Up @@ -106,26 +128,24 @@ async def submit(request: Request) -> UUID:
) from e


@router.post("/dispatch/cancel")
async def cancel(request: Request) -> str:
@router.delete("/dispatch/{dispatch_id}")
async def cancel(dispatch_id: str, task_ids: List[int] = Query([])) -> str:
"""
Function to accept the cancel request of
a dispatch.
Args:
None
dispatch_id: ID of the dispatch
task_ids: (Query) Optional list of specific task ids to cancel.
An empty list will cause all tasks to be cancelled.
Returns:
Fast API Response object confirming that the dispatch
has been cancelled.
"""

data = await request.json()

dispatch_id = data["dispatch_id"]
task_ids = data["task_ids"]

await dispatcher.cancel_running_dispatch(dispatch_id, task_ids)
print("DEBUG: task_ids", task_ids)
if task_ids:
return f"Cancelled tasks {task_ids} in dispatch {dispatch_id}."
else:
Expand All @@ -138,10 +158,19 @@ def db_path() -> str:
return json.dumps(db_path)


@router.post("/dispatch/register")
@router.post("/dispatch", status_code=201)
async def register(
manifest: ResultSchema, parent_dispatch_id: Union[str, None] = None
) -> ResultSchema:
"""Register a dispatch in the database.
Args:
manifest: Declares all metadata and assets in the workflow
parent_dispatch_id: The parent dispatch id if registering a sublattice dispatch
Returns:
The manifest with `dispatch_id` and remote URIs for each asset populated.
"""
try:
return await dispatcher.register_dispatch(manifest, parent_dispatch_id)
except Exception as e:
Expand All @@ -152,12 +181,23 @@ async def register(
) from e


@router.post("/dispatch/register/{dispatch_id}")
@router.post("/dispatch/{dispatch_id}/redispatch", status_code=201)
async def register_redispatch(
manifest: ResultSchema,
dispatch_id: str,
reuse_previous_results: bool = False,
):
"""Register a redispatch in the database.
Args:
manifest: Declares all metadata and assets in the workflow
dispatch_id: The original dispatch's id.
reuse_previous_results: Whether to try reusing the results of
previously completed electrons.
Returns:
The manifest with `dispatch_id` and remote URIs for each asset populated.
"""
try:
return await dispatcher.register_redispatch(
manifest,
Expand All @@ -172,25 +212,43 @@ async def register_redispatch(
) from e


@router.put("/dispatch/start/{dispatch_id}")
@router.post("/dispatch/{dispatch_id}", status_code=202)
async def start(dispatch_id: str):
try:
fut = asyncio.create_task(dispatcher.start_dispatch(dispatch_id))
_background_tasks.add(fut)
fut.add_done_callback(_background_tasks.discard)
"""Start a previously registered (re-)dispatch.
return dispatch_id
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to start workflow: {e}",
) from e
Args:
`dispatch_id`: The dispatch's unique id.
Returns:
`dispatch_id`
"""
fut = asyncio.create_task(dispatcher.start_dispatch(dispatch_id))
_background_tasks.add(fut)
fut.add_done_callback(_background_tasks.discard)

@router.get("/dispatch/export/{dispatch_id}")
return dispatch_id


@router.get("/dispatch/{dispatch_id}")
async def export_result(
dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False
) -> ExportResponseSchema:
"""Export all metadata about a registered dispatch
Args:
`dispatch_id`: The dispatch's unique id.
Returns:
{
id: `dispatch_id`,
status: status,
result_export: manifest for the result
}
The manifest `result_export` has the same schema as that which is
submitted to `/register`.
"""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
None,
Expand Down
Loading

0 comments on commit f5fc7fb

Please sign in to comment.