diff --git a/CHANGELOG.md b/CHANGELOG.md index e3e6a1376..bcc0ebfc3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Fixed + +- `get_result(wait=True)` will wait as long as needed + ## [0.234.1-rc.0] - 2024-05-10 ### Authors diff --git a/covalent/_results_manager/results_manager.py b/covalent/_results_manager/results_manager.py index 4c751206a..75ba88633 100644 --- a/covalent/_results_manager/results_manager.py +++ b/covalent/_results_manager/results_manager.py @@ -19,12 +19,11 @@ import contextlib import os +import time from pathlib import Path -from typing import Dict, List, Optional +from typing import List, Optional from furl import furl -from requests.adapters import HTTPAdapter -from urllib3.util import Retry from .._api.apiclient import CovalentAPIClient from .._serialize.common import load_asset @@ -40,9 +39,9 @@ from .._shared_files.exceptions import MissingLatticeRecordError from .._shared_files.schemas.asset import AssetSchema from .._shared_files.schemas.result import ResultSchema +from .._shared_files.util_classes import RESULT_STATUS, Status from .._shared_files.utils import copy_file_locally, format_server_url from .result import Result -from .wait import EXTREME app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -139,12 +138,20 @@ def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str = # Multi-part +def _query_dispatch_status(dispatch_id: str, api_client: CovalentAPIClient): + endpoint = "/api/v2/dispatches" + resp = api_client.get(endpoint, params={"dispatch_id": dispatch_id, "status_only": True}) + resp.raise_for_status() + dispatches = resp.json()["dispatches"] + if len(dispatches) == 0: + raise MissingLatticeRecordError + + return dispatches[0]["status"] + + def _get_result_export_from_dispatcher( - dispatch_id: str, - wait: bool = False, - status_only: bool = False, - dispatcher_addr: str = None, -) -> Dict: + dispatch_id: str, api_client: CovalentAPIClient +) -> ResultSchema: """ Internal function to get the results of a dispatch from the server without checking if it is ready to read. @@ -161,24 +168,21 @@ def _get_result_export_from_dispatcher( MissingLatticeRecordError: If the result is not found. """ - if dispatcher_addr is None: - dispatcher_addr = format_server_url() + # if dispatcher_addr is None: + # dispatcher_addr = format_server_url() - retries = int(EXTREME) if wait else 5 + # retries = int(EXTREME) if wait else 5 - adapter = HTTPAdapter(max_retries=Retry(total=retries, backoff_factor=1)) - api_client = CovalentAPIClient(dispatcher_addr, adapter=adapter, auto_raise=False) + # adapter = HTTPAdapter(max_retries=Retry(total=retries, backoff_factor=1)) + # api_client = CovalentAPIClient(dispatcher_addr, adapter=adapter, auto_raise=False) endpoint = f"/api/v2/dispatches/{dispatch_id}" - response = api_client.get( - endpoint, - params={"wait": wait, "status_only": status_only}, - ) + response = api_client.get(endpoint) if response.status_code == 404: raise MissingLatticeRecordError response.raise_for_status() export = response.json() - return export + return ResultSchema.model_validate(export) # Function to download default assets @@ -346,11 +350,17 @@ def from_dispatch_id( wait: bool = False, dispatcher_addr: str = None, ) -> "ResultManager": - export = _get_result_export_from_dispatcher( - dispatch_id, wait, status_only=False, dispatcher_addr=dispatcher_addr - ) + if dispatcher_addr is None: + dispatcher_addr = format_server_url() - manifest = ResultSchema.model_validate(export["result_export"]) + api_client = CovalentAPIClient(dispatcher_addr) + if wait: + status = Status(_query_dispatch_status(dispatch_id, api_client)) + while not RESULT_STATUS.is_terminal(status): + time.sleep(1) + status = Status(_query_dispatch_status(dispatch_id, api_client)) + + manifest = _get_result_export_from_dispatcher(dispatch_id, api_client) # sort the nodes manifest.lattice.transport_graph.nodes.sort(key=lambda x: x.id) @@ -408,14 +418,15 @@ def _get_result_multistage( """ + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + api_client = CovalentAPIClient(dispatcher_addr) try: if status_only: - return _get_result_export_from_dispatcher( - dispatch_id=dispatch_id, - wait=wait, - status_only=status_only, - dispatcher_addr=dispatcher_addr, - ) + status = _query_dispatch_status(dispatch_id, api_client) + return {"id": dispatch_id, "status": status} + rm = get_result_manager(dispatch_id, results_dir, wait, dispatcher_addr) _get_default_assets(rm) @@ -496,23 +507,14 @@ def get_result( The Result object from the Covalent server """ - max_attempts = int(os.getenv("COVALENT_GET_RESULT_RETRIES", 10)) - num_attempts = 0 - while num_attempts < max_attempts: - try: - return _get_result_multistage( - dispatch_id=dispatch_id, - wait=wait, - dispatcher_addr=dispatcher_addr, - status_only=status_only, - results_dir=results_dir, - workflow_output=workflow_output, - intermediate_outputs=intermediate_outputs, - sublattice_results=sublattice_results, - qelectron_db=qelectron_db, - ) - - except RecursionError as re: - app_log.error(re) - num_attempts += 1 - raise RuntimeError("Timed out waiting for result. Please retry or check dispatch.") + return _get_result_multistage( + dispatch_id=dispatch_id, + wait=wait, + dispatcher_addr=dispatcher_addr, + status_only=status_only, + results_dir=results_dir, + workflow_output=workflow_output, + intermediate_outputs=intermediate_outputs, + sublattice_results=sublattice_results, + qelectron_db=qelectron_db, + ) diff --git a/covalent/triggers/base.py b/covalent/triggers/base.py index 2eb49a434..341cd7c74 100644 --- a/covalent/triggers/base.py +++ b/covalent/triggers/base.py @@ -15,8 +15,6 @@ # limitations under the License. -import asyncio -import json from abc import abstractmethod import requests @@ -108,17 +106,12 @@ def _get_status(self) -> Status: """ if self.use_internal_funcs: - from covalent_dispatcher._service.app import export_result + from covalent_dispatcher._service.app import get_dispatches_bulk - response = asyncio.run_coroutine_threadsafe( - export_result(self.lattice_dispatch_id, status_only=True), - self.event_loop, - ).result() - - if isinstance(response, dict): - return response["status"] - - return json.loads(response.body.decode()).get("status") + response = get_dispatches_bulk( + dispatch_id=[self.lattice_dispatch_id], status_only=True + ) + return response.dispatches[0].status from .. import get_result diff --git a/covalent_dispatcher/_dal/controller.py b/covalent_dispatcher/_dal/controller.py index 3e682b979..34a1af792 100644 --- a/covalent_dispatcher/_dal/controller.py +++ b/covalent_dispatcher/_dal/controller.py @@ -17,10 +17,12 @@ from __future__ import annotations -from typing import Generic, Type, TypeVar +from typing import Callable, Generic, List, Optional, Sequence, Type, TypeVar, Union from sqlalchemy import select, update -from sqlalchemy.orm import Session, load_only +from sqlalchemy.engine import Row +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import Select, desc from .._db import models @@ -54,7 +56,12 @@ def get( equality_filters: dict, membership_filters: dict, for_update: bool = False, - ): + sort_fields: List[str] = [], + reverse: bool = True, + offset: int = 0, + max_items: Optional[int] = None, + custom_filter: Optional[Callable[[Select], Select]] = None, + ) -> Union[Sequence[Row], Sequence[T]]: """Bulk ORM-enabled SELECT. Args: @@ -64,19 +71,42 @@ def get( membership_filters: Dict{field_name: value_list} for_update: Whether to lock the selected rows + Returns: + A list of SQLAlchemy Rows or whole ORM entities depending + on whether only a subset of fields is specified. + """ - stmt = select(cls.model) + if len(fields) > 0: + entities = [getattr(cls.model, attr) for attr in fields] + stmt = select(*entities) + else: + stmt = select(cls.model) + for attr, val in equality_filters.items(): stmt = stmt.where(getattr(cls.model, attr) == val) for attr, vals in membership_filters.items(): stmt = stmt.where(getattr(cls.model, attr).in_(vals)) - if len(fields) > 0: - attrs = [getattr(cls.model, f) for f in fields] - stmt = stmt.options(load_only(*attrs)) if for_update: stmt = stmt.with_for_update() - - return session.scalars(stmt).all() + for attr in sort_fields: + if reverse: + stmt = stmt.order_by(desc(getattr(cls.model, attr))) + else: + stmt = stmt.order_by(getattr(cls.model, attr)) + + stmt = stmt.offset(offset) + if max_items: + stmt = stmt.limit(max_items) + + if custom_filter is not None: + stmt = custom_filter(stmt) + + if len(fields) == 0: + # Return whole ORM entities + return session.scalars(stmt).all() + else: + # Return a named tuple containing the selected cols + return session.execute(stmt).all() @classmethod def get_by_primary_key( diff --git a/covalent_dispatcher/_dal/result.py b/covalent_dispatcher/_dal/result.py index a9378558c..72c1c2def 100644 --- a/covalent_dispatcher/_dal/result.py +++ b/covalent_dispatcher/_dal/result.py @@ -45,6 +45,38 @@ class ResultMeta(Record[models.Lattice]): model = models.Lattice + @staticmethod + def filter_root_lattices(stmt): + stmt = stmt.where(models.Lattice.root_dispatch_id == models.Lattice.dispatch_id) + return stmt + + @classmethod + def get_toplevel_dispatches( + cls, + session: Session, + *, + fields: list, + equality_filters: dict, + membership_filters: dict, + for_update: bool = False, + sort_fields: List[str] = [], + reverse: bool = True, + offset: int = 0, + max_items: int = 10, + ): + return cls.get( + session=session, + fields=fields, + equality_filters=equality_filters, + membership_filters=membership_filters, + for_update=for_update, + sort_fields=sort_fields, + reverse=reverse, + offset=offset, + max_items=max_items, + custom_filter=ResultMeta.filter_root_lattices, + ) + class ResultAsset(Record[models.LatticeAsset]): model = models.LatticeAsset @@ -175,7 +207,7 @@ def _update_dispatch( with self.session() as session: electron_rec = Electron.get_db_records( session, - keys={"id", "parent_lattice_id"}, + keys=ELECTRON_KEYS, equality_filters={"id": self._electron_id}, membership_filters={}, )[0] @@ -343,7 +375,7 @@ def _get_incomplete_nodes(self): A dictionary {"failed": [node_ids], "cancelled": [node_ids]} """ with self.session() as session: - query_keys = {"parent_lattice_id", "node_id", "name", "status"} + query_keys = {"id", "parent_lattice_id", "node_id", "name", "status"} records = Electron.get_db_records( session, keys=query_keys, diff --git a/covalent_dispatcher/_service/app.py b/covalent_dispatcher/_service/app.py index 9a9c7d460..321f9b1ea 100644 --- a/covalent_dispatcher/_service/app.py +++ b/covalent_dispatcher/_service/app.py @@ -20,11 +20,12 @@ import asyncio import json from contextlib import asynccontextmanager -from typing import List, Optional, Union +from typing import List, Union from uuid import UUID -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, FastAPI, HTTPException, Query, Request from fastapi.responses import JSONResponse +from typing_extensions import Annotated import covalent_dispatcher.entry_point as dispatcher from covalent._shared_files import logger @@ -38,7 +39,16 @@ from .._db.datastore import workflow_db from .._db.dispatchdb import DispatchDB from .heartbeat import Heartbeat -from .models import DispatchStatusSetSchema, ExportResponseSchema, TargetDispatchStatus +from .models import ( + BulkDispatchGetSchema, + BulkGetMetadata, + Cursor, + DispatchLinks, + DispatchStatusSetSchema, + DispatchSummary, + Link, + TargetDispatchStatus, +) app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -266,74 +276,77 @@ async def set_dispatch_status(dispatch_id: str, desired_status: DispatchStatusSe return await cancel(dispatch_id, desired_status.task_ids) -@router.get("/dispatches/{dispatch_id}") -async def export_result( - dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False -) -> ExportResponseSchema: - """Export all metadata about a registered dispatch - - Args: - `dispatch_id`: The dispatch's unique id. +@router.get("/dispatches", response_model_exclude_unset=True) +def get_dispatches_bulk( + dispatch_id: Annotated[Union[List[str], None], Query()] = None, + status: Annotated[Union[List[str], None], Query()] = None, + latest: bool = True, + offset: int = 0, + count: int = 10, + status_only: bool = False, +) -> BulkDispatchGetSchema: + dispatch_controller = Result.meta_type - Returns: - { - id: `dispatch_id`, - status: status, - result_export: manifest for the result - } + if status_only: + fields = ["dispatch_id", "status"] + else: + fields = [ + "dispatch_id", + "root_dispatch_id", + "status", + "name", + "electron_num", + "completed_electron_num", + "created_at", + "updated_at", + "completed_at", + ] + + summaries = [] + with workflow_db.session() as session: + in_filters = {} + if dispatch_id is not None: + in_filters["dispatch_id"] = dispatch_id + if status is not None: + in_filters["status"] = status - The manifest `result_export` has the same schema as that which is - submitted to `/register`. + results = dispatch_controller.get( + session, + fields=fields, + equality_filters={"is_active": True}, + membership_filters=in_filters, + sort_fields=["created_at"], + reverse=latest, + offset=offset, + max_items=count, + ) + for res in results: + dispatch_id = res.dispatch_id + summary = DispatchSummary.model_validate(res) + if not status_only: + links = DispatchLinks(manifest=Link(href=f"./{dispatch_id}")) + summary.links = links + summaries.append(summary) - """ - loop = asyncio.get_running_loop() - return await loop.run_in_executor( - None, - _export_result_sync, - dispatch_id, - wait, - status_only, - ) + bulk_meta = BulkGetMetadata(count=len(results), links=Cursor()) + return BulkDispatchGetSchema(dispatches=summaries, metadata=bulk_meta) -def _export_result_sync( - dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False -) -> ExportResponseSchema: +@router.get("/dispatches/{dispatch_id}") +def export_manifest(dispatch_id: str) -> ResultSchema: result_object = _try_get_result_object(dispatch_id) if not result_object: return JSONResponse( status_code=404, content={"message": f"The requested dispatch ID {dispatch_id} was not found."}, ) - status = str(result_object.get_value("status", refresh=False)) - if not wait or status in [ - str(RESULT_STATUS.COMPLETED), - str(RESULT_STATUS.FAILED), - str(RESULT_STATUS.CANCELLED), - ]: - output = { - "id": dispatch_id, - "status": status, - } - if not status_only: - output["result_export"] = export_result_manifest(dispatch_id) - - return output - - response = JSONResponse( - status_code=503, - content={"message": "Result not ready to read yet. Please wait for a couple of seconds."}, - headers={"Retry-After": "2"}, - ) - return response + return export_result_manifest(dispatch_id) def _try_get_result_object(dispatch_id: str) -> Union[Result, None]: try: - res = get_result_object( - dispatch_id, bare=True, keys=["id", "dispatch_id", "status"], lattice_keys=["id"] - ) + res = get_result_object(dispatch_id, bare=True) except KeyError: res = None return res diff --git a/covalent_dispatcher/_service/models.py b/covalent_dispatcher/_service/models.py index 2d2f7db10..06e922e8b 100644 --- a/covalent_dispatcher/_service/models.py +++ b/covalent_dispatcher/_service/models.py @@ -17,12 +17,13 @@ """FastAPI models for /api/v1/resultv2 endpoints""" import re +from datetime import datetime from enum import Enum from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict -from covalent._shared_files.schemas.result import ResultSchema +from covalent._shared_files.schemas.result import StatusEnum range_regex = "bytes=([0-9]+)-([0-9]*)" range_pattern = re.compile(range_regex) @@ -60,12 +61,6 @@ class ElectronAssetKey(str, Enum): qelectron_db = "qelectron_db" -class ExportResponseSchema(BaseModel): - id: str - status: str - result_export: Optional[ResultSchema] = None - - class AssetRepresentation(str, Enum): string = "string" b64pickle = "object" @@ -82,3 +77,42 @@ class DispatchStatusSetSchema(BaseModel): # For cancellation, an optional list of task ids to cancel task_ids: Optional[List] = [] + + +class Link(BaseModel): + href: str + + +class Cursor(BaseModel): + back: Optional[Link] = None + forward: Optional[Link] = None + + +class BulkGetMetadata(BaseModel): + count: int + links: Cursor + + +class DispatchLinks(BaseModel): + manifest: Link + + +class DispatchSummary(BaseModel): + model_config = ConfigDict(from_attributes=True) + + dispatch_id: str + root_dispatch_id: Optional[str] = None + status: StatusEnum + name: Optional[str] = None + electron_num: Optional[int] = None + completed_electron_num: Optional[int] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + links: Optional[DispatchLinks] = None + + +class BulkDispatchGetSchema(BaseModel): + dispatches: List[DispatchSummary] + metadata: BulkGetMetadata diff --git a/tests/covalent_dispatcher_tests/_dal/result_test.py b/tests/covalent_dispatcher_tests/_dal/result_test.py index 5b2ec19fa..cb858ebb5 100644 --- a/tests/covalent_dispatcher_tests/_dal/result_test.py +++ b/tests/covalent_dispatcher_tests/_dal/result_test.py @@ -551,3 +551,84 @@ def test_result_filters_parent_electron_updates(test_db, mocker): assert third_update assert subl_node.get_value("output").get_deserialized() == 42 + + +def test_result_controller_bulk_get(test_db, mocker): + record_1 = models.Lattice( + dispatch_id="dispatch_1", + root_dispatch_id="dispatch_1", + name="dispatch_1", + status="NEW_OBJECT", + electron_num=5, + completed_electron_num=0, + ) + + record_2 = models.Lattice( + dispatch_id="dispatch_2", + root_dispatch_id="dispatch_2", + name="dispatch_2", + status="NEW_OBJECT", + electron_num=25, + completed_electron_num=0, + ) + + record_3 = models.Lattice( + dispatch_id="dispatch_3", + root_dispatch_id="dispatch_2", + name="dispatch_3", + status="COMPLETED", + electron_num=25, + completed_electron_num=25, + ) + + with test_db.session() as session: + session.add(record_1) + session.add(record_2) + session.add(record_3) + session.commit() + + dispatch_controller = Result.meta_type + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + ) + assert len(results) == 3 + + with test_db.session() as session: + results = dispatch_controller.get_toplevel_dispatches( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + ) + assert len(results) == 2 + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + sort_fields=["name"], + reverse=False, + max_items=1, + ) + assert len(results) == 1 + assert results[0].dispatch_id == "dispatch_1" + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + sort_fields=["name"], + max_items=2, + offset=1, + ) + assert len(results) == 2 + assert results[0].dispatch_id == "dispatch_2" diff --git a/tests/covalent_dispatcher_tests/_service/app_test.py b/tests/covalent_dispatcher_tests/_service/app_test.py index 7877fe673..c7522e8e1 100644 --- a/tests/covalent_dispatcher_tests/_service/app_test.py +++ b/tests/covalent_dispatcher_tests/_service/app_test.py @@ -262,8 +262,8 @@ def test_start(mocker, app, client): assert resp.json() == dispatch_id -def test_export_result_nowait(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" +def test_export_manifest(mocker, app, client, mock_manifest): + dispatch_id = "test_export_manifest" mock_result_object = MagicMock() mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) mocker.patch( @@ -274,29 +274,11 @@ def test_export_result_nowait(mocker, app, client, mock_manifest): ) mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") resp = client.get(f"/api/v2/dispatches/{dispatch_id}") - assert resp.status_code == 200 - assert resp.json()["id"] == dispatch_id - assert resp.json()["status"] == str(RESULT_STATUS.NEW_OBJECT) - assert resp.json()["result_export"] == json.loads(mock_manifest.json()) - - -def test_export_result_wait_not_ready(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" - mock_result_object = MagicMock() - mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.RUNNING)) - mocker.patch( - "covalent_dispatcher._service.app._try_get_result_object", return_value=mock_result_object - ) - mock_export = mocker.patch( - "covalent_dispatcher._service.app.export_result_manifest", return_value=mock_manifest - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - resp = client.get(f"/api/v2/dispatches/{dispatch_id}", params={"wait": True}) - assert resp.status_code == 503 + assert resp.json() == json.loads(mock_manifest.json()) -def test_export_result_bad_dispatch_id(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" +def test_export_manifest_bad_dispatch_id(mocker, app, client, mock_manifest): + dispatch_id = "test_export_manifest" mock_result_object = MagicMock() mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) mocker.patch("covalent_dispatcher._service.app._try_get_result_object", return_value=None) diff --git a/tests/covalent_tests/results_manager_tests/results_manager_test.py b/tests/covalent_tests/results_manager_tests/results_manager_test.py index 06ba8ece0..72c0260d9 100644 --- a/tests/covalent_tests/results_manager_tests/results_manager_test.py +++ b/tests/covalent_tests/results_manager_tests/results_manager_test.py @@ -22,9 +22,9 @@ from unittest.mock import MagicMock import pytest -from requests import Response import covalent as ct +from covalent._api.apiclient import CovalentAPIClient from covalent._results_manager.results_manager import ( MissingLatticeRecordError, Result, @@ -105,24 +105,23 @@ def test_cancel_with_multiple_task_ids(mocker): def test_result_export(mocker): + import json + with tempfile.TemporaryDirectory() as staging_dir: test_manifest = get_test_manifest(staging_dir) dispatch_id = "test_result_export" - mock_body = {"id": "test_result_export", "status": "COMPLETED"} - + mock_response_body = json.loads(test_manifest.model_dump_json()) mock_client = MagicMock() - mock_response = Response() + mock_response = MagicMock() mock_response.status_code = 200 - mock_response.json = MagicMock(return_value=mock_body) + mock_response.json = MagicMock(return_value=mock_response_body) mocker.patch("covalent._api.apiclient.requests.Session.get", return_value=mock_response) - + apiclient = CovalentAPIClient("http://localhost:48008") endpoint = f"/api/v2/dispatches/{dispatch_id}" - assert mock_body == _get_result_export_from_dispatcher( - dispatch_id, wait=False, status_only=True - ) + assert test_manifest == _get_result_export_from_dispatcher(dispatch_id, apiclient) def test_result_manager_assets_local_copies(): @@ -176,11 +175,7 @@ def test_get_result(mocker): # local file copies from server_dir to results_dir. manifest = get_test_manifest(server_dir) - mock_result_export = { - "id": dispatch_id, - "status": "COMPLETED", - "result_export": manifest.dict(), - } + mock_result_export = manifest mocker.patch( "covalent._results_manager.results_manager._get_result_export_from_dispatcher", return_value=mock_result_export, @@ -208,17 +203,9 @@ def test_get_result_sublattice(mocker): # Sublattice manifest sub_manifest = get_test_manifest(server_dir_sub) - mock_result_export = { - "id": dispatch_id, - "status": "COMPLETED", - "result_export": manifest.dict(), - } + mock_result_export = manifest - mock_subresult_export = { - "id": sub_dispatch_id, - "status": "COMPLETED", - "result_export": sub_manifest.dict(), - } + mock_subresult_export = sub_manifest exports = {dispatch_id: mock_result_export, sub_dispatch_id: mock_subresult_export} @@ -277,10 +264,10 @@ def test_get_result_RecursionError(mocker): def test_get_status_only(mocker): """Check get_result when status_only=True""" - dispatch_id = "test_get_result_st" + dispatch_id = "test_get_status_only" mock_get_result_export = mocker.patch( - "covalent._results_manager.results_manager._get_result_export_from_dispatcher", - return_value={"id": dispatch_id, "status": "RUNNING"}, + "covalent._results_manager.results_manager._query_dispatch_status", + return_value="RUNNING", ) status_report = get_result(dispatch_id, status_only=True) diff --git a/tests/covalent_tests/triggers/base_test.py b/tests/covalent_tests/triggers/base_test.py index 0ca0c8d7c..b70aee295 100644 --- a/tests/covalent_tests/triggers/base_test.py +++ b/tests/covalent_tests/triggers/base_test.py @@ -17,7 +17,6 @@ from unittest import mock import pytest -from fastapi.responses import JSONResponse from covalent.triggers import BaseTrigger @@ -46,7 +45,6 @@ def test_register(mocker): @pytest.mark.parametrize( "use_internal_func, mock_status", [ - (True, JSONResponse("mock")), (True, {"status": "mocked-status"}), (False, {"status": "mocked-status"}), ], @@ -61,27 +59,15 @@ def test_get_status(mocker, use_internal_func, mock_status): base_trigger.use_internal_funcs = use_internal_func if use_internal_func: - mocker.patch("covalent_dispatcher._service.app.export_result") - - mock_fut_res = mock.Mock() - mock_fut_res.result.return_value = mock_status - mock_run_coro = mocker.patch( - "covalent.triggers.base.asyncio.run_coroutine_threadsafe", return_value=mock_fut_res + mock_bulk_get_res = mock.Mock() + mock_bulk_get_res.dispatches = [mock.Mock()] + mock_bulk_get_res.dispatches[0].status = mock_status["status"] + mocker.patch( + "covalent_dispatcher._service.app.get_dispatches_bulk", return_value=mock_bulk_get_res ) - if not isinstance(mock_status, dict): - mock_json_loads = mocker.patch( - "covalent.triggers.base.json.loads", return_value={"status": "mocked-status"} - ) - status = base_trigger._get_status() - mock_run_coro.assert_called_once() - mock_fut_res.result.assert_called_once() - - if not isinstance(mock_status, dict): - mock_json_loads.assert_called_once() - else: mock_get_status = mocker.patch("covalent.get_result", return_value=mock_status) status = base_trigger._get_status()