diff --git a/covalent/_workflow/transportable_object.py b/covalent/_workflow/transportable_object.py index 4ba789662..a28490da4 100644 --- a/covalent/_workflow/transportable_object.py +++ b/covalent/_workflow/transportable_object.py @@ -17,6 +17,7 @@ """Transportable object module defining relevant classes and functions""" import base64 +import io import json import platform from typing import Any, Callable, Tuple @@ -31,76 +32,9 @@ BYTE_ORDER = "big" -class _TOArchive: - - """Archived transportable object.""" - - def __init__(self, header: bytes, object_string: bytes, data: bytes): - """ - Initialize TOArchive. - - Args: - header: Archived transportable object header. - object_string: Archived transportable object string. - data: Archived transportable object data. - - Returns: - None - """ - - self.header = header - self.object_string = object_string - self.data = data - - def cat(self) -> bytes: - """ - Concatenate TOArchive. - - Returns: - Concatenated TOArchive. - - """ - - header_size = len(self.header) - string_size = len(self.object_string) - data_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size + string_size - string_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size - - data_offset = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER, signed=False) - string_offset = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER, signed=False) - - return string_offset + data_offset + self.header + self.object_string + self.data - - @staticmethod - def load(serialized: bytes, header_only: bool, string_only: bool) -> "_TOArchive": - """ - Load TOArchive object from serialized bytes. - - Args: - serialized: Serialized transportable object. - header_only: Load header only. - string_only: Load string only. - - Returns: - Archived transportable object. - - """ - - string_offset = TOArchiveUtils.string_offset(serialized) - header = TOArchiveUtils.parse_header(serialized, string_offset) - object_string = b"" - data = b"" - - if not header_only: - data_offset = TOArchiveUtils.data_offset(serialized) - object_string = TOArchiveUtils.parse_string(serialized, string_offset, data_offset) - - if not string_only: - data = TOArchiveUtils.parse_data(serialized, data_offset) - return _TOArchive(header, object_string, data) - - class TOArchiveUtils: + """Utilities for reading serialized TransportableObjects""" + @staticmethod def data_offset(serialized: bytes) -> int: size64 = serialized[STRING_OFFSET_BYTES : STRING_OFFSET_BYTES + DATA_OFFSET_BYTES] @@ -120,24 +54,38 @@ def string_byte_range(serialized: bytes) -> Tuple[int, int]: @staticmethod def data_byte_range(serialized: bytes) -> Tuple[int, int]: - """Return byte range for the b64 picklebytes""" + """Return byte range for the picklebytes""" start_byte = TOArchiveUtils.data_offset(serialized) return start_byte, -1 @staticmethod - def parse_header(serialized: bytes, string_offset: int) -> bytes: + def header(serialized: bytes) -> dict: + string_offset = TOArchiveUtils.string_offset(serialized) header = serialized[HEADER_OFFSET:string_offset] - return header + return json.loads(header.decode("utf-8")) @staticmethod - def parse_string(serialized: bytes, string_offset: int, data_offset: int) -> bytes: + def string_segment(serialized: bytes) -> bytes: + string_offset = TOArchiveUtils.string_offset(serialized) + data_offset = TOArchiveUtils.data_offset(serialized) return serialized[string_offset:data_offset] @staticmethod - def parse_data(serialized: bytes, data_offset: int) -> bytes: + def data_segment(serialized: bytes) -> bytes: + data_offset = TOArchiveUtils.data_offset(serialized) return serialized[data_offset:] +class _ByteArrayFile: + """File-like interface for appending to a bytearray.""" + + def __init__(self, buf: bytearray): + self._buf = buf + + def write(self, data: bytes): + self._buf.extend(data) + + class TransportableObject: """ A function is converted to a transportable object by serializing it using cloudpickle @@ -150,13 +98,12 @@ class TransportableObject: """ def __init__(self, obj: Any) -> None: - b64object = base64.b64encode(cloudpickle.dumps(obj)) - object_string_u8 = str(obj).encode("utf-8") - - self._object = b64object.decode("utf-8") - self._object_string = object_string_u8.decode("utf-8") + _buffer_file = io.BytesIO() + self._buffer = bytearray() + self._buffer.extend(b"\0" * HEADER_OFFSET) - self._header = { + _header = { + "format": "0.1", "py_version": platform.python_version(), "cloudpickle_version": cloudpickle.__version__, "attrs": { @@ -164,24 +111,45 @@ def __init__(self, obj: Any) -> None: "name": getattr(obj, "__name__", ""), }, } + header_u8 = json.dumps(_header).encode("utf-8") + header_len = len(header_u8) + + object_string_u8 = str(obj).encode("utf-8") + object_string_len = len(object_string_u8) + + self._buffer.extend(header_u8) + self._buffer.extend(object_string_u8) + del object_string_u8 + cloudpickle.dump(obj, _ByteArrayFile(self._buffer)) + + # Write byte offsets + string_offset = HEADER_OFFSET + header_len + data_offset = string_offset + object_string_len + + string_offset_bytes = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER) + data_offset_bytes = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER) + self._buffer[:STRING_OFFSET_BYTES] = string_offset_bytes + self._buffer[STRING_OFFSET_BYTES:HEADER_OFFSET] = data_offset_bytes @property def python_version(self): - return self._header["py_version"] + return self.header["py_version"] @property def header(self): - return self._header + return TOArchiveUtils.header(self._buffer) @property def attrs(self): - return self._header["attrs"] + return self.header["attrs"] @property def object_string(self): # For compatibility with older Covalent try: - return self._object_string + return ( + TOArchiveUtils.string_segment(memoryview(self._buffer)).tobytes().decode("utf-8") + ) except AttributeError: return self.__dict__["object_string"] @@ -202,11 +170,15 @@ def get_deserialized(self) -> Callable: """ - return cloudpickle.loads(base64.b64decode(self._object.encode("utf-8"))) + return cloudpickle.loads(TOArchiveUtils.data_segment(memoryview(self._buffer))) def to_dict(self) -> dict: """Return a JSON-serializable dictionary representation of self""" - return {"type": "TransportableObject", "attributes": self.__dict__.copy()} + attr_dict = { + "buffer_b64": base64.b64encode(memoryview(self._buffer)).decode("utf-8"), + } + + return {"type": "TransportableObject", "attributes": attr_dict} @staticmethod def from_dict(object_dict) -> "TransportableObject": @@ -220,7 +192,7 @@ def from_dict(object_dict) -> "TransportableObject": """ sc = TransportableObject(None) - sc.__dict__ = object_dict["attributes"] + sc._buffer = base64.b64decode(object_dict["attributes"]["buffer_b64"].encode("utf-8")) return sc def get_serialized(self) -> str: @@ -234,7 +206,9 @@ def get_serialized(self) -> str: object: The serialized transportable object. """ - return self._object + # For backward compatibility + data_segment = TOArchiveUtils.data_segment(memoryview(self._buffer)) + return base64.b64encode(data_segment).decode("utf-8") def serialize(self) -> bytes: """ @@ -247,7 +221,7 @@ def serialize(self) -> bytes: pickled_object: The serialized object alongwith the python version. """ - return _to_archive(self).cat() + return self._buffer def serialize_to_json(self) -> str: """ @@ -296,9 +270,7 @@ def make_transportable(obj) -> "TransportableObject": return TransportableObject(obj) @staticmethod - def deserialize( - serialized: bytes, *, header_only: bool = False, string_only: bool = False - ) -> "TransportableObject": + def deserialize(serialized: bytes) -> "TransportableObject": """ Deserialize the transportable object. @@ -308,9 +280,9 @@ def deserialize( Returns: object: The deserialized transportable object. """ - - ar = _TOArchive.load(serialized, header_only, string_only) - return _from_archive(ar) + to = TransportableObject(None) + to._buffer = serialized + return to @staticmethod def deserialize_list(collection: list) -> list: @@ -357,44 +329,3 @@ def deserialize_dict(collection: dict) -> dict: else: raise TypeError("Couldn't deserialize collection") return new_dict - - -def _to_archive(to: TransportableObject) -> _TOArchive: - """ - Convert a TransportableObject to a _TOArchive. - - Args: - to: Transportable object to be converted. - - Returns: - Archived transportable object. - - """ - - header = json.dumps(to._header).encode("utf-8") - object_string = to._object_string.encode("utf-8") - data = to._object.encode("utf-8") - return _TOArchive(header=header, object_string=object_string, data=data) - - -def _from_archive(ar: _TOArchive) -> TransportableObject: - """ - Convert a _TOArchive to a TransportableObject. - - Args: - ar: Archived transportable object. - - Returns: - Transportable object. - - """ - - decoded_object_str = ar.object_string.decode("utf-8") - decoded_data = ar.data.decode("utf-8") - decoded_header = json.loads(ar.header.decode("utf-8")) - to = TransportableObject(None) - to._header = decoded_header - to._object_string = decoded_object_str or "" - to._object = decoded_data or "" - - return to diff --git a/tests/covalent_dispatcher_tests/_service/assets_test.py b/tests/covalent_dispatcher_tests/_service/assets_test.py index 5f704ca43..8b939c6c4 100644 --- a/tests/covalent_dispatcher_tests/_service/assets_test.py +++ b/tests/covalent_dispatcher_tests/_service/assets_test.py @@ -16,6 +16,7 @@ """Unit tests for the FastAPI asset endpoints""" +import base64 import tempfile from contextlib import contextmanager from typing import Generator @@ -704,7 +705,7 @@ def test_get_pickle_offsets(): start, end = _get_tobj_pickle_offsets(f"file://{write_file.name}") - assert data[start:].decode("utf-8") == tobj.get_serialized() + assert data[start:] == base64.b64decode(tobj.get_serialized().encode("utf-8")) def test_generate_partial_file_slice(): diff --git a/tests/covalent_tests/workflow/transport_test.py b/tests/covalent_tests/workflow/transport_test.py index 40de076c1..6430afa86 100644 --- a/tests/covalent_tests/workflow/transport_test.py +++ b/tests/covalent_tests/workflow/transport_test.py @@ -16,6 +16,7 @@ """Unit tests for transport graph.""" +import base64 import platform from unittest.mock import call @@ -87,27 +88,22 @@ def test_transportable_object_python_version(transportable_object): assert to.python_version == platform.python_version() -def test_transportable_object_eq(transportable_object): +def test_transportable_object_eq(): """Test the __eq__ magic method of TransportableObject""" - import copy - - to = transportable_object - to_new = TransportableObject(None) - to_new.__dict__ = copy.deepcopy(to.__dict__) - assert to.__eq__(to_new) - - to_new._header["py_version"] = "3.5.1" - assert not to.__eq__(to_new) - - assert not to.__eq__({}) + to = TransportableObject(1) + to_new = TransportableObject(1) + to_new_2 = TransportableObject(2) + assert to == to_new + assert to != to_new_2 + assert to != 1 def test_transportable_object_get_serialized(transportable_object): """Test serialized transportable object retrieval.""" to = transportable_object - assert to.get_serialized() == to._object + assert to.get_serialized() == base64.b64encode(cloudpickle.dumps(subtask)).decode("utf-8") def test_transportable_object_get_deserialized(transportable_object): @@ -124,15 +120,8 @@ def test_transportable_object_from_dict(transportable_object): to_new = TransportableObject.from_dict(object_dict) assert to == to_new - - -def test_transportable_object_to_dict_attributes(transportable_object): - """Test attributes from `to_dict` contain correct name and docstrings""" - - tr_dict = transportable_object.to_dict() - - assert tr_dict["attributes"]["_header"]["attrs"]["doc"] == subtask.__doc__ - assert tr_dict["attributes"]["_header"]["attrs"]["name"] == subtask.__name__ + assert to_new.header == to.header + assert to_new.object_string == to.object_string def test_transportable_object_serialize_to_json(transportable_object): @@ -148,7 +137,9 @@ def test_transportable_object_deserialize_from_json(transportable_object): to = transportable_object json_string = to.serialize_to_json() deserialized_to = TransportableObject.deserialize_from_json(json_string) - assert to.__dict__ == deserialized_to.__dict__ + assert to == deserialized_to + assert deserialized_to.header == to.header + assert deserialized_to.object_string == to.object_string def test_transportable_object_make_transportable_idempotent(transportable_object): @@ -169,29 +160,6 @@ def test_transportable_object_serialize_deserialize(transportable_object): assert new_to.python_version == to.python_version -def test_transportable_object_sedeser_string_only(): - """Test extracting string only from serialized to""" - x = 123 - to = TransportableObject(x) - - ser = to.serialize() - new_to = TransportableObject.deserialize(ser, string_only=True) - assert new_to.object_string == to.object_string - assert new_to._object == "" - - -def test_transportable_object_sedeser_header_only(): - """Test extracting header only from serialized to""" - x = 123 - to = TransportableObject(x) - - ser = to.serialize() - new_to = TransportableObject.deserialize(ser, header_only=True) - - assert new_to.object_string == "" - assert new_to._header - - def test_transportable_object_deserialize_list(transportable_object): deserialized = [1, 2, {"a": 3, "b": [4, 5]}] serialized_list = [