Skip to content

Commit

Permalink
Add backward compatibility layer for deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
cjao committed Apr 8, 2024
1 parent e890bc5 commit 2236798
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 4 deletions.
58 changes: 54 additions & 4 deletions covalent/_workflow/transportable_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import base64
import json
import platform
from typing import Any, Callable, Tuple
from typing import Any, Callable, Dict, Tuple

import cloudpickle

Expand All @@ -29,6 +29,7 @@
DATA_OFFSET_BYTES = 8
HEADER_OFFSET = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES
BYTE_ORDER = "big"
TOBJ_FMT_STR = "0.1"


class TOArchiveUtils:
Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(self, obj: Any) -> None:
self._buffer.extend(b"\0" * HEADER_OFFSET)

_header = {
"format": "0.1",
"format": TOBJ_FMT_STR,
"ver": {
"python": platform.python_version(),
"cloudpickle": cloudpickle.__version__,
Expand All @@ -125,7 +126,7 @@ def __init__(self, obj: Any) -> None:
self._buffer.extend(object_string_u8)
del object_string_u8

# Append picklebytes
# Append picklebytes (not base64-encoded)
cloudpickle.dump(obj, _ByteArrayFile(self._buffer))

# Write byte offsets
Expand Down Expand Up @@ -287,9 +288,58 @@ def deserialize(serialized: bytes) -> "TransportableObject":
object: The deserialized transportable object.
"""
to = TransportableObject(None)
to._buffer = serialized
header = TOArchiveUtils.header(serialized)

# For backward compatibility
if header.get("format") is None:
# Re-encode TObj serialized using older versions of the SDK,
# characterized by the lack of a "format" field in the
# header. TObj was previously serialized as
# [offsets][header][string][b64-encoded picklebytes],
# whereas starting from format 0.1 we store them as
# [offsets][header][string][picklebytes].
to._buffer = TransportableObject._upgrade_tobj_format(serialized, header)
else:
to._buffer = serialized
return to

@staticmethod
def _upgrade_tobj_format(serialized: bytes, header: Dict) -> bytes:
"""Re-encode a serialized TObj in the newer format.
This involves adding a format version in the header and
base64-decoding the data segment. Because the header at the
beginning of the byte array, the string and data offsets need
to be recomputed.
"""
buf = bytearray()

# Upgrade header and recompute byte offsets
header["format"] = TOBJ_FMT_STR
serialized_header = json.dumps(header).encode("utf-8")
string_offset = HEADER_OFFSET + len(serialized_header)

# This is just a view into the bytearray and consumes
# negligible space on its own.
string_segment = TOArchiveUtils.string_segment(serialized)

data_offset = string_offset + len(string_segment)
string_offset_bytes = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER)
data_offset_bytes = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER)

# Write the new byte offsets
buf.extend(b"\0" * HEADER_OFFSET)
buf[:STRING_OFFSET_BYTES] = string_offset_bytes
buf[STRING_OFFSET_BYTES:HEADER_OFFSET] = data_offset_bytes

buf.extend(serialized_header)
buf.extend(string_segment)

# base64-decode the data segment into raw picklebytes
buf.extend(base64.b64decode(TOArchiveUtils.data_segment(serialized)))

return buf

@staticmethod
def deserialize_list(collection: list) -> list:
"""
Expand Down
119 changes: 119 additions & 0 deletions tests/covalent_tests/workflow/transport_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Unit tests for transport graph."""

import base64
import json
import platform
from unittest.mock import call

Expand All @@ -33,6 +34,11 @@
encode_metadata,
pickle_modules_by_value,
)
from covalent._workflow.transportable_object import (
BYTE_ORDER,
DATA_OFFSET_BYTES,
STRING_OFFSET_BYTES,
)
from covalent.executor import LocalExecutor
from covalent.triggers import BaseTrigger

Expand Down Expand Up @@ -81,6 +87,109 @@ def workflow_transport_graph():
return tg


# For testing TObj back-compat -- copied from earlier SDK
class _TOArchive:

"""Archived transportable object."""

def __init__(self, header: bytes, object_string: bytes, data: bytes):
"""
Initialize TOArchive.
Args:
header: Archived transportable object header.
object_string: Archived transportable object string.
data: Archived transportable object data.
Returns:
None
"""

self.header = header
self.object_string = object_string
self.data = data

def cat(self) -> bytes:
"""
Concatenate TOArchive.
Returns:
Concatenated TOArchive.
"""

header_size = len(self.header)
string_size = len(self.object_string)
data_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size + string_size
string_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size

data_offset = data_offset.to_bytes(DATA_OFFSET_BYTES, BYTE_ORDER, signed=False)
string_offset = string_offset.to_bytes(STRING_OFFSET_BYTES, BYTE_ORDER, signed=False)

return string_offset + data_offset + self.header + self.object_string + self.data


# Copied from previous SDK version
class LegacyTransportableObject:
"""
A function is converted to a transportable object by serializing it using cloudpickle
and then whenever executing it, the transportable object is deserialized. The object
will also contain additional info like the python version used to serialize it.
Attributes:
_object: The serialized object.
python_version: The python version used on the client's machine.
"""

def __init__(self, obj) -> None:
b64object = base64.b64encode(cloudpickle.dumps(obj))
object_string_u8 = str(obj).encode("utf-8")

self._object = b64object.decode("utf-8")
self._object_string = object_string_u8.decode("utf-8")

self._header = {
"py_version": platform.python_version(),
"cloudpickle_version": cloudpickle.__version__,
"attrs": {
"doc": getattr(obj, "__doc__", ""),
"name": getattr(obj, "__name__", ""),
},
}

# For testing TObj back-compat
@staticmethod
def _to_archive(to) -> _TOArchive:
"""
Convert a TransportableObject to a _TOArchive.
Args:
to: Transportable object to be converted.
Returns:
Archived transportable object.
"""

header = json.dumps(to._header).encode("utf-8")
object_string = to._object_string.encode("utf-8")
data = to._object.encode("utf-8")
return _TOArchive(header=header, object_string=object_string, data=data)

def serialize(self) -> bytes:
"""
Serialize the transportable object.
Args:
None
Returns:
pickled_object: The serialized object alongwith the python version.
"""

return LegacyTransportableObject._to_archive(self).cat()


def test_transportable_object_python_version(transportable_object):
"""Test that the transportable object retrieves the correct python version."""

Expand Down Expand Up @@ -190,6 +299,16 @@ def test_transportable_object_deserialize_dict(transportable_object):
assert TransportableObject.deserialize_dict(serialized_dict) == deserialized


def test_tobj_deserialize_back_compat():
lto = LegacyTransportableObject({"a": 5})
serialized = lto.serialize()
to = TransportableObject.deserialize(serialized)
obj = to.get_deserialized()
assert obj == {"a": 5}
obj2 = TransportableObject.deserialize(to.serialize()).get_deserialized()
assert obj2 == {"a": 5}


def test_transport_graph_initialization():
"""Test the initialization of an empty transport graph."""

Expand Down

0 comments on commit 2236798

Please sign in to comment.