diff --git a/covalent/_file_transfer/strategies/shutil_strategy.py b/covalent/_file_transfer/strategies/shutil_strategy.py index cf8713d49..319d47d04 100644 --- a/covalent/_file_transfer/strategies/shutil_strategy.py +++ b/covalent/_file_transfer/strategies/shutil_strategy.py @@ -34,6 +34,17 @@ def __init__( # return callable to copy files in the local file system def cp(self, from_file: File, to_file: File = File()) -> None: + """ + Get a callable that copies a file from one location to another locally + + Args: + from_file: File to copy from + to_file: File to copy to. Defaults to File(). + + Returns: + A callable that copies a file from one location to another locally + """ + def callable(): shutil.copyfile(from_file.filepath, to_file.filepath) diff --git a/covalent/_results_manager/result.py b/covalent/_results_manager/result.py index 5012f8d26..3c11db025 100644 --- a/covalent/_results_manager/result.py +++ b/covalent/_results_manager/result.py @@ -63,7 +63,9 @@ class Result: """ NEW_OBJ = RESULT_STATUS.NEW_OBJECT - PENDING_REUSE = RESULT_STATUS.PENDING_REUSE + PENDING_REUSE = ( + RESULT_STATUS.PENDING_REUSE + ) # Facilitates reuse of previous electrons in the new dispatcher design COMPLETED = RESULT_STATUS.COMPLETED POSTPROCESSING = RESULT_STATUS.POSTPROCESSING PENDING_POSTPROCESSING = RESULT_STATUS.PENDING_POSTPROCESSING @@ -106,8 +108,8 @@ def __str__(self): pattern = re.compile(regex) m = pattern.match(input_string) if m: - arg_str_repr = m.group(1).rstrip(",") - kwarg_str_repr = m.group(2) + arg_str_repr = m[1].rstrip(",") + kwarg_str_repr = m[2] else: arg_str_repr = str(None) kwarg_str_repr = str(None) diff --git a/covalent/_serialize/common.py b/covalent/_serialize/common.py index 27fd29fc8..142640752 100644 --- a/covalent/_serialize/common.py +++ b/covalent/_serialize/common.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" Serialization/Deserialization methods for Assets """ import hashlib import json @@ -26,17 +27,41 @@ from .._shared_files.schemas.asset import AssetSchema from .._workflow.transportable_object import TransportableObject +__all__ = [ + "AssetType", + "save_asset", + "load_asset", +] + + CHECKSUM_ALGORITHM = "sha" class AssetType(Enum): - OBJECT = 0 - TRANSPORTABLE = 1 + """ + Enum for the type of Asset data + + """ + + OBJECT = 0 # Fallback to cloudpickling + TRANSPORTABLE = 1 # Custom TO serialization JSONABLE = 2 - TEXT = 3 + TEXT = 3 # Mainly for stdout, stderr, docstrings, etc. def serialize_asset(data: Any, data_type: AssetType) -> bytes: + """ + Serialize the asset data + + Args: + data: Data to serialize + data_type: Type of the Asset data to serialize + + Returns: + Serialized data as bytes + + """ + if data_type == AssetType.OBJECT: return cloudpickle.dumps(data) elif data_type == AssetType.TRANSPORTABLE: @@ -50,6 +75,18 @@ def serialize_asset(data: Any, data_type: AssetType) -> bytes: def deserialize_asset(data: bytes, data_type: AssetType) -> Any: + """ + Deserialize the asset data + + Args: + data: Data to deserialize + data_type: Type of the Asset data to deserialize + + Returns: + Deserialized data + + """ + if data_type == AssetType.OBJECT: return cloudpickle.loads(data) elif data_type == AssetType.TRANSPORTABLE: @@ -63,10 +100,35 @@ def deserialize_asset(data: bytes, data_type: AssetType) -> Any: def _sha1_asset(data: bytes) -> str: + """ + Compute the sha1 checksum of the asset data + + Args: + data: Data to compute checksum for + + Returns: + sha1 checksum of the data + + """ + return hashlib.sha1(data).hexdigest() def save_asset(data: Any, data_type: AssetType, storage_path: str, filename: str) -> AssetSchema: + """ + Save the asset data to the storage path + + Args: + data: Data to save + data_type: Type of the Asset data to save + storage_path: Path to save the data to + filename: Name of the file to save the data to + + Returns: + AssetSchema object containing metadata about the saved data + + """ + scheme = "file" serialized = serialize_asset(data, data_type) @@ -80,16 +142,26 @@ def save_asset(data: Any, data_type: AssetType, storage_path: str, filename: str def load_asset(asset_meta: AssetSchema, data_type: AssetType) -> Any: + """ + Load the asset data from the storage path + + Args: + asset_meta: Metadata about the asset to load + data_type: Type of the Asset data to load + + Returns: + Asset data + + """ + scheme_prefix = "file://" uri = asset_meta.uri if not uri: return None - if uri.startswith(scheme_prefix): - path = uri[len(scheme_prefix) :] - else: - path = uri + path = uri[len(scheme_prefix) :] if uri.startswith(scheme_prefix) else uri + with open(path, "rb") as f: data = f.read() return deserialize_asset(data, data_type) diff --git a/covalent/_serialize/electron.py b/covalent/_serialize/electron.py index 7289b0c08..00c3d58d4 100644 --- a/covalent/_serialize/electron.py +++ b/covalent/_serialize/electron.py @@ -29,6 +29,12 @@ from .._workflow.transportable_object import TransportableObject from .common import AssetType, load_asset, save_asset +__all__ = [ + "serialize_node", + "deserialize_node", +] + + ASSET_TYPES = { "function": AssetType.TRANSPORTABLE, "function_string": AssetType.TEXT, @@ -52,11 +58,11 @@ def _serialize_node_metadata(node_attrs: dict, node_storage_path: str) -> Electr # Optional status = node_attrs.get("status", RESULT_STATUS.NEW_OBJECT) - start_time = node_attrs.get("start_time", None) + start_time = node_attrs.get("start_time") if start_time: start_time = start_time.isoformat() - end_time = node_attrs.get("end_time", None) + end_time = node_attrs.get("end_time") if end_time: end_time = end_time.isoformat() diff --git a/covalent/_serialize/lattice.py b/covalent/_serialize/lattice.py index 20a61f392..5aefbb61c 100644 --- a/covalent/_serialize/lattice.py +++ b/covalent/_serialize/lattice.py @@ -29,6 +29,12 @@ from .common import AssetType, load_asset, save_asset from .transport_graph import deserialize_transport_graph, serialize_transport_graph +__all__ = [ + "serialize_lattice", + "deserialize_lattice", +] + + ASSET_TYPES = { "workflow_function": AssetType.TRANSPORTABLE, "workflow_function_string": AssetType.TEXT, diff --git a/covalent/_serialize/result.py b/covalent/_serialize/result.py index 3086144bf..612b10a60 100644 --- a/covalent/_serialize/result.py +++ b/covalent/_serialize/result.py @@ -30,6 +30,15 @@ from .common import AssetType, load_asset, save_asset from .lattice import deserialize_lattice, serialize_lattice +__all__ = [ + "serialize_result", + "deserialize_result", + "strip_local_uris", + "merge_response_manifest", + "extract_assets", +] + + ASSET_TYPES = { "error": AssetType.TEXT, "result": AssetType.TRANSPORTABLE, @@ -136,6 +145,7 @@ def merge_response_manifest(manifest: ResultSchema, response: ResultSchema) -> R response: The manifest returned from `/register`. Returns: A combined manifest with asset `remote_uri`s populated. + """ manifest.metadata.dispatch_id = response.metadata.dispatch_id @@ -170,32 +180,28 @@ def merge_response_manifest(manifest: ResultSchema, response: ResultSchema) -> R def extract_assets(manifest: ResultSchema) -> List[AssetSchema]: - """Extract all of the asset metadata from a manifest dictionary. + """ + Extract all of the asset metadata from a manifest dictionary. Args: manifest: A result manifest Returns: A list of assets - """ - assets = [] + """ # workflow-level assets dispatch_assets = manifest.assets - for key, asset in dispatch_assets: - assets.append(asset) - + assets = [asset for key, asset in dispatch_assets] lattice = manifest.lattice lattice_assets = lattice.assets - for key, asset in lattice_assets: - assets.append(asset) + assets.extend(asset for key, asset in lattice_assets) # Node assets tg = lattice.transport_graph nodes = tg.nodes for node in nodes: node_assets = node.assets - for key, asset in node_assets: - assets.append(asset) + assets.extend(asset for key, asset in node_assets) return assets diff --git a/covalent/_serialize/transport_graph.py b/covalent/_serialize/transport_graph.py index dc83bad9e..a7ce04eec 100644 --- a/covalent/_serialize/transport_graph.py +++ b/covalent/_serialize/transport_graph.py @@ -28,8 +28,26 @@ from .._workflow.transport import _TransportGraph from .electron import deserialize_node, serialize_node +__all__ = [ + "serialize_transport_graph", + "deserialize_transport_graph", +] + + +def _serialize_edge(source: int, target: int, attrs: dict) -> EdgeSchema: + """ + Serialize an edge in a graph + + Args: + source: Source node + target: Target node + attrs: Edge attributes + + Returns: + Serialized EdgeSchema object + + """ -def serialize_edge(source: int, target: int, attrs: dict) -> EdgeSchema: meta = EdgeMetadata( edge_name=attrs["edge_name"], param_type=attrs.get("param_type"), @@ -38,7 +56,18 @@ def serialize_edge(source: int, target: int, attrs: dict) -> EdgeSchema: return EdgeSchema(source=source, target=target, metadata=meta) -def deserialize_edge(e: EdgeSchema) -> dict: +def _deserialize_edge(e: EdgeSchema) -> dict: + """ + Deserialize an EdgeSchema into a dictionary + + Args: + e: EdgeSchema + + Returns: + Deserialized dictionary + + """ + return { "source": e.source, "target": e.target, @@ -47,6 +76,18 @@ def deserialize_edge(e: EdgeSchema) -> dict: def _serialize_nodes(g: nx.MultiDiGraph, storage_path: str) -> List[ElectronSchema]: + """ + Serialize nodes in a graph + + Args: + g: NetworkX graph + storage_path: Path to store serialized object + + Returns: + Serialized nodes + + """ + results = [] base_path = Path(storage_path) for i in g.nodes: @@ -57,14 +98,37 @@ def _serialize_nodes(g: nx.MultiDiGraph, storage_path: str) -> List[ElectronSche def _serialize_edges(g: nx.MultiDiGraph) -> List[EdgeSchema]: + """ + Serialize edges in a graph + + Args: + g: NetworkX graph + + Returns: + Serialized edges + + """ + results = [] for edge in g.edges: source, target, key = edge - results.append(serialize_edge(source, target, g.edges[edge])) + results.append(_serialize_edge(source, target, g.edges[edge])) return results def serialize_transport_graph(tg, storage_path: str) -> TransportGraphSchema: + """ + Serialize a TransportGraph object into a TransportGraphSchema + + Args: + tg: TransportGraph object + storage_path: Path to store serialized object + + Returns: + Serialized TransportGraphSchema object + + """ + g = tg.get_internal_graph_copy() return TransportGraphSchema( nodes=_serialize_nodes(g, storage_path), @@ -73,10 +137,21 @@ def serialize_transport_graph(tg, storage_path: str) -> TransportGraphSchema: def deserialize_transport_graph(t: TransportGraphSchema) -> _TransportGraph: + """ + Deserialize a TransportGraphSchema into a TransportGraph object + + Args: + t: TransportGraphSchema + + Returns: + Deserialized TransportGraph object + + """ + tg = _TransportGraph() g = tg._graph nodes = [deserialize_node(n) for n in t.nodes] - edges = [deserialize_edge(e) for e in t.links] + edges = [_deserialize_edge(e) for e in t.links] for node in nodes: node_id = node["id"] attrs = node["attrs"] diff --git a/covalent/_shared_files/defaults.py b/covalent/_shared_files/defaults.py index 3a077a08d..3fdcaf010 100644 --- a/covalent/_shared_files/defaults.py +++ b/covalent/_shared_files/defaults.py @@ -71,7 +71,9 @@ def get_default_sdk_config(): "multistage_dispatch": "false" if os.environ.get("COVALENT_DISABLE_MULTISTAGE_DISPATCH") == "1" else "true", - "results_dir": os.environ.get("COVALENT_RESULTS_DIR") + "results_dir": os.environ.get( + "COVALENT_RESULTS_DIR" + ) # COVALENT_RESULTS_DIR is where the client downloads workflow artifacts during get_result() which is different from COVALENT_DATA_DIR or ( (os.environ.get("XDG_CACHE_HOME") or (os.environ["HOME"] + "/.cache")) + "/covalent/results" diff --git a/covalent/_shared_files/schemas/common.py b/covalent/_shared_files/schemas/common.py index de1b20ada..0c4a226d6 100644 --- a/covalent/_shared_files/schemas/common.py +++ b/covalent/_shared_files/schemas/common.py @@ -24,8 +24,10 @@ class StatusEnum(str, Enum): NEW_OBJECT = str(RESULT_STATUS.NEW_OBJECT) STARTING = str(RESULT_STATUS.STARTING) - PENDING_REUSE = str(RESULT_STATUS.PENDING_REUSE) - PENDING_REPLACEMENT = str(RESULT_STATUS.PENDING_REPLACEMENT) + PENDING_REUSE = str(RESULT_STATUS.PENDING_REUSE) # For redispatch in the new dispatcher design + PENDING_REPLACEMENT = str( + RESULT_STATUS.PENDING_REPLACEMENT + ) # For redispatch in the new dispatcher design COMPLETED = str(RESULT_STATUS.COMPLETED) POSTPROCESSING = str(RESULT_STATUS.POSTPROCESSING) FAILED = str(RESULT_STATUS.FAILED) diff --git a/covalent/_shared_files/schemas/electron.py b/covalent/_shared_files/schemas/electron.py index 3b79793cc..8e769edca 100644 --- a/covalent/_shared_files/schemas/electron.py +++ b/covalent/_shared_files/schemas/electron.py @@ -86,7 +86,7 @@ class ElectronAssets(BaseModel): stdout: Optional[AssetSchema] = None stderr: Optional[AssetSchema] = None - # electron_metadata + # electron_metadata attached by the user deps: AssetSchema call_before: AssetSchema call_after: AssetSchema diff --git a/covalent/_shared_files/util_classes.py b/covalent/_shared_files/util_classes.py index 13119d4a4..58313c326 100644 --- a/covalent/_shared_files/util_classes.py +++ b/covalent/_shared_files/util_classes.py @@ -47,8 +47,10 @@ def __ne__(self, __value: object) -> bool: class RESULT_STATUS: NEW_OBJECT = Status("NEW_OBJECT") STARTING = Status("STARTING") # Dispatch level - PENDING_REUSE = Status("PENDING_REUSE") # For redispatch - PENDING_REPLACEMENT = Status("PENDING_REPLACEMENT") # For redispatch + PENDING_REUSE = Status("PENDING_REUSE") # For redispatch in the new dispatcher design + PENDING_REPLACEMENT = Status( + "PENDING_REPLACEMENT" + ) # For redispatch in the new dispatcher design COMPLETED = Status("COMPLETED") POSTPROCESSING = Status("POSTPROCESSING") PENDING_POSTPROCESSING = Status("PENDING_POSTPROCESSING") diff --git a/covalent/_workflow/transportable_object.py b/covalent/_workflow/transportable_object.py index 632e31ca8..4285daed0 100644 --- a/covalent/_workflow/transportable_object.py +++ b/covalent/_workflow/transportable_object.py @@ -32,6 +32,10 @@ class _TOArchive: + """ + Archived TransportableObject + """ + def __init__(self, header: bytes, object_string: bytes, data: bytes): self.header = header self.object_string = object_string @@ -48,18 +52,18 @@ def cat(self) -> bytes: return string_offset + data_offset + self.header + self.object_string + self.data - def load(serialized: bytes, header_only: bool, string_only: bool) -> "_TOArchive": - string_offset = TOArchiveUtils.string_offset(serialized) - header = TOArchiveUtils.parse_header(serialized, string_offset) + def load(self, header_only: bool, string_only: bool) -> "_TOArchive": + string_offset = TOArchiveUtils.string_offset(self) + header = TOArchiveUtils.parse_header(self, string_offset) object_string = b"" data = b"" if not header_only: - data_offset = TOArchiveUtils.data_offset(serialized) - object_string = TOArchiveUtils.parse_string(serialized, string_offset, data_offset) + data_offset = TOArchiveUtils.data_offset(self) + object_string = TOArchiveUtils.parse_string(self, string_offset, data_offset) if not string_only: - data = TOArchiveUtils.parse_data(serialized, data_offset) + data = TOArchiveUtils.parse_data(self, data_offset) return _TOArchive(header, object_string, data) @@ -319,7 +323,7 @@ def _from_archive(ar: _TOArchive) -> TransportableObject: decoded_header = json.loads(ar.header.decode("utf-8")) to = TransportableObject(None) to._header = decoded_header - to._object_string = decoded_object_str if decoded_object_str else "" - to._object = decoded_data if decoded_data else "" + to._object_string = decoded_object_str or "" + to._object = decoded_data or "" return to diff --git a/covalent/executor/executor_plugins/dask.py b/covalent/executor/executor_plugins/dask.py index cffa9d323..5e628bdfd 100644 --- a/covalent/executor/executor_plugins/dask.py +++ b/covalent/executor/executor_plugins/dask.py @@ -31,6 +31,7 @@ # Relative imports are not allowed in executor plugins from covalent._shared_files.config import get_config from covalent._shared_files.exceptions import TaskCancelledError +from covalent._shared_files.utils import _address_client_mapper from covalent.executor.base import AsyncBaseExecutor from covalent.executor.utils.wrappers import io_wrapper as dask_wrapper @@ -54,9 +55,6 @@ "create_unique_workdir": False, } -# Temporary -_address_client_mapper = {} - class DaskExecutor(AsyncBaseExecutor): """ @@ -141,8 +139,8 @@ async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: try: result, worker_stdout, worker_stderr, tb = await future - except CancelledError: - raise TaskCancelledError() + except CancelledError as e: + raise TaskCancelledError() from e print(worker_stdout, end="", file=self.task_stdout) print(worker_stderr, end="", file=self.task_stderr)