diff --git a/src/radical/utils/__init__.py b/src/radical/utils/__init__.py index 3b0ef13f..8e009ce9 100644 --- a/src/radical/utils/__init__.py +++ b/src/radical/utils/__init__.py @@ -77,11 +77,14 @@ from .profile import PROF_KEY_MAX from .json_io import read_json, read_json_str, write_json -from .json_io import parse_json, parse_json_str +from .json_io import parse_json, parse_json_str, dumps_json from .which import which from .tracer import trace, untrace from .get_version import get_version +from .serialize import to_json, from_json, to_msgpack, from_msgpack +from .serialize import register_serializable + # import various utility methods from .ids import * diff --git a/src/radical/utils/config.py b/src/radical/utils/config.py index f795f540..37c11f12 100644 --- a/src/radical/utils/config.py +++ b/src/radical/utils/config.py @@ -176,22 +176,23 @@ class Config(TypedDict): # -------------------------------------------------------------------------- # - def __init__(self, module=None, category=None, name=None, cfg=None, - from_dict=None, path=None, expand=True, env=None, - _internal=False): + def __init__(self, from_dict=None, + module=None, category=None, name=None, + cfg=None, path=None, expand=True, + env=None, _internal=False): """ Load a config (json) file from the module's config tree, and overload any user specific config settings if found. Parameters ---------- + from_dict: alias for cfg, to satisfy base class constructor module: used to determine the module's config file location - default: `radical.utils` category: name of config to be loaded from module's config path name: specify a specific configuration to be used path: path to app config json to be used for initialization cfg: application config dict to be used for initialization - from_dict: alias for cfg, to satisfy base class constructor expand: enable / disable environment var expansion - default: True env: environment dictionary to be used for expansion @@ -215,6 +216,12 @@ def __init__(self, module=None, category=None, name=None, cfg=None, configuration hierarchy. """ + # if the `from_dict` is given but is a string, we interpret it as + # `module` parameter. + if from_dict and isinstance(from_dict, str): + module = from_dict + from_dict = None + if from_dict: # if we could only overload constructors by signature... :-/ # As it is, we have to emulate that... diff --git a/src/radical/utils/json_io.py b/src/radical/utils/json_io.py index 855add1d..1b954757 100644 --- a/src/radical/utils/json_io.py +++ b/src/radical/utils/json_io.py @@ -6,9 +6,9 @@ import re -import json -from .misc import as_string, ru_open +from .serialize import to_json, from_json +from .misc import as_string, ru_open # ------------------------------------------------------------------------------ @@ -61,11 +61,22 @@ def write_json(data, fname): fname = tmp with ru_open(fname, 'w') as f: - json.dump(data, f, sort_keys=True, indent=4, ensure_ascii=False) + f.write(to_json(data)) f.write('\n') f.flush() +# ------------------------------------------------------------------------------ +# +def dumps_json(data): + ''' + thin wrapper around python's json write, for consistency of interface + + ''' + + return to_json(data) + + # ------------------------------------------------------------------------------ # def parse_json(json_str, filter_comments=True): @@ -77,16 +88,11 @@ def parse_json(json_str, filter_comments=True): are stripped from json before parsing ''' - if not filter_comments: - return json.loads(json_str) - - else: - content = '' - for line in json_str.split('\n'): - content += re.sub(r'^\s*#.*$', '', line) - content += '\n' + if filter_comments: + json_str = '\n'.join([re.sub(r'^\s*#.*$', '', line) + for line in json_str.split('\n')]) - return json.loads(content) + return from_json(json_str) # ------------------------------------------------------------------------------ diff --git a/src/radical/utils/serialize.py b/src/radical/utils/serialize.py new file mode 100644 index 00000000..dbace795 --- /dev/null +++ b/src/radical/utils/serialize.py @@ -0,0 +1,173 @@ + +import json +import msgpack + +from .typeddict import as_dict, TypedDict + +# ------------------------------------------------------------------------------ +# +class _CType: + + def __init__(self, ctype, encode, decode): + + self.ctype : type = ctype + self.encode: callable = encode + self.decode: callable = decode + + +_ctypes = dict() + + +# ------------------------------------------------------------------------------ +# +def register_serializable(cls, encode=None, decode=None): + ''' + register a class for json and msgpack serialization / deserialization. + + Args: + cls (type): class type to register + encode (callable): converts class instance into encodable data structure + decode (callable): recreates the class instance from that data structure + ''' + + if encode is None: encode = cls + if decode is None: decode = cls + + _ctypes[cls.__name__] = _CType(cls, encode, decode) + +register_serializable(TypedDict) + + +# ------------------------------------------------------------------------------ +# +def _prep_typed_dict(d): + return as_dict(d, _annotate=True) + + +# ------------------------------------------------------------------------------ +# +class _json_encoder(json.JSONEncoder): + ''' + internal methods to encode registered classes to json + ''' + + def encode(self, o, *args, **kw): + tmp = as_dict(o, _annotate=True) + return super().encode(tmp, *args, **kw) + + def default(self, o): + # print('encode: %s' % o) + for cname,methods in _ctypes.items(): + if isinstance(o, methods.ctype): + return {'_type': cname, + 'as_str': methods.encode(o)} + return super().default(o) + + +# ------------------------------------------------------------------------------ +# +def _json_decoder(obj): + ''' + internal methods to decode registered classes from json + ''' + # print('decode: %s' % obj) + for cname, methods in _ctypes.items(): + # print('check %s' % cname) + if '_type' in obj and obj['_type'] == cname: + del obj['_type'] + # print('found %s' % cname) + if 'as_str' in obj: + return methods.decode(obj['as_str']) + return methods.decode(obj) + return obj + + +# ------------------------------------------------------------------------------ +# +def _msgpack_encoder(obj): + ''' + internal methods to encode registered classes to msgpack + ''' + for cname,methods in _ctypes.items(): + if isinstance(obj, methods.ctype): + return {'__%s__' % cname: True, 'as_str': methods.encode(obj)} + return obj + + +# ------------------------------------------------------------------------------ +# +def _msgpack_decoder(obj): + ''' + internal methods to decode registered classes from msgpack + ''' + for cname,methods in _ctypes.items(): + if '__%s__' % cname in obj: + return methods.decode(obj['as_str']) + return obj + + +# ------------------------------------------------------------------------------ +# +def to_json(data): + ''' + convert data to json, using registered classes for serialization + + Args: + data (object): data to be serialized + + Returns: + str: json serialized data + ''' + return json.dumps(data, sort_keys=True, indent=4, ensure_ascii=False, + cls=_json_encoder) + + +# ------------------------------------------------------------------------------ +# +def from_json(data): + ''' + convert json data to python data structures, using registered classes for + deserialization + + Args: + data (str): json data to be deserialized + + Returns: + object: deserialized data + ''' + return json.loads(data, object_hook=_json_decoder) + + +# ------------------------------------------------------------------------------ +# +def to_msgpack(data): + ''' + convert data to msgpack, using registered classes for serialization + + Args: + data (object): data to be serialized + + Returns: + bytes: msgpack serialized data + ''' + return msgpack.packb(data, default=_msgpack_encoder, use_bin_type=True) + + +# ------------------------------------------------------------------------------ +# +def from_msgpack(data): + ''' + convert msgpack data to python data structures, using registered classes for + deserialization + + Args: + data (bytes): msgpack data to be deserialized + + Returns: + object: deserialized data + ''' + return msgpack.unpackb(data, object_hook=_msgpack_decoder, raw=False) + + +# ------------------------------------------------------------------------------ + diff --git a/src/radical/utils/typeddict.py b/src/radical/utils/typeddict.py index 0cf9c284..aeaaa6e9 100644 --- a/src/radical/utils/typeddict.py +++ b/src/radical/utils/typeddict.py @@ -20,7 +20,7 @@ import copy import sys -from .misc import as_list, as_tuple, is_string +from .misc import as_list, as_tuple, is_string # ------------------------------------------------------------------------------ @@ -98,7 +98,16 @@ def __new__(mcs, name, bases, namespace): elif k not in namespace: namespace[k] = v - return super().__new__(mcs, name, bases, namespace) + _new_cls = super().__new__(mcs, name, bases, namespace) + + if _new_cls.__base__ is not dict: + + # register sub-classes + from .serialize import register_serializable + register_serializable(_new_cls) + + return _new_cls + # ------------------------------------------------------------------------------ @@ -138,6 +147,10 @@ def __init__(self, from_dict=None, **kwargs): `kwargs`). ''' + from .serialize import register_serializable + + register_serializable(self.__class__) + self.update(copy.deepcopy(self._defaults)) self.update(from_dict) @@ -288,15 +301,15 @@ def __getattr__(self, k): def __setattr__(self, k, v): - # if k.startswith('_'): - # return object.__setattr__(self, k, v) + if k.startswith('__'): + return object.__setattr__(self, k, v) self._data[k] = self._verify_setter(k, v) def __delattr__(self, k): - # if k.startswith('_'): - # return object.__delattr__(self, k) + if k.startswith('__'): + return object.__delattr__(self, k) del self._data[k] @@ -312,8 +325,8 @@ def __repr__(self): # -------------------------------------------------------------------------- # - def as_dict(self): - return as_dict(self._data) + def as_dict(self, _annotate=False): + return as_dict(self._data, _annotate) # -------------------------------------------------------------------------- @@ -483,21 +496,21 @@ def _query(self, key, default=None, last_key=True): # ------------------------------------------------------------------------------ # -def _as_dict_value(v): - return v.as_dict() if isinstance(v, TypedDict) else as_dict(v) - - -def as_dict(src): +def as_dict(src, _annotate=False): ''' Iterate given object, apply `as_dict()` to all typed values, and return the result (effectively a shallow copy). ''' - if isinstance(src, dict): - tgt = {k: _as_dict_value(v) for k, v in src.items()} + if isinstance(src, TypedDict): + tgt = {k: as_dict(v, _annotate) for k, v in src.items()} + if _annotate: + tgt['_type'] = type(src).__name__ + elif isinstance(src, dict): + tgt = {k: as_dict(v, _annotate) for k, v in src.items()} elif isinstance(src, list): - tgt = [_as_dict_value(x) for x in src] + tgt = [as_dict(x, _annotate) for x in src] elif isinstance(src, tuple): - tgt = tuple([_as_dict_value(x) for x in src]) + tgt = tuple([as_dict(x, _annotate) for x in src]) else: tgt = src return tgt diff --git a/src/radical/utils/zmq/client.py b/src/radical/utils/zmq/client.py index e23d17d2..4ec40aab 100644 --- a/src/radical/utils/zmq/client.py +++ b/src/radical/utils/zmq/client.py @@ -1,14 +1,14 @@ import zmq -import msgpack from typing import Any import threading as mt -from ..json_io import read_json -from ..misc import as_string -from .utils import no_intr, sock_connect +from ..json_io import read_json +from ..misc import as_string +from ..serialize import to_msgpack, from_msgpack +from .utils import no_intr, sock_connect # ------------------------------------------------------------------------------ @@ -61,14 +61,14 @@ def url(self) -> str: # def request(self, cmd: str, *args: Any, **kwargs: Any) -> Any: - req = msgpack.packb({'cmd' : cmd, - 'args' : args, - 'kwargs': kwargs}) + req = to_msgpack({'cmd' : cmd, + 'args' : args, + 'kwargs': kwargs}) no_intr(self._sock.send, req) msg = no_intr(self._sock.recv) - res = as_string(msgpack.unpackb(msg)) + res = as_string(from_msgpack(msg)) # FIXME: assert proper res structure diff --git a/src/radical/utils/zmq/message.py b/src/radical/utils/zmq/message.py index b903f45b..1fa30bdf 100644 --- a/src/radical/utils/zmq/message.py +++ b/src/radical/utils/zmq/message.py @@ -1,15 +1,16 @@ from typing import Dict, Any -import msgpack - from ..typeddict import TypedDict +from ..serialize import to_msgpack, from_msgpack # ------------------------------------------------------------------------------ # class Message(TypedDict): + # FIXME: register serialization methods for all message types + _schema = { '_msg_type': str, } @@ -48,11 +49,11 @@ def deserialize(data: Dict[str, Any]): def packb(self): - return msgpack.packb(self) + return to_msgpack(self) @staticmethod def unpackb(bdata): - return Message.deserialize(msgpack.unpackb(bdata)) + return Message.deserialize(from_msgpack(bdata)) # ------------------------------------------------------------------------------ diff --git a/src/radical/utils/zmq/pipe.py b/src/radical/utils/zmq/pipe.py index 09f71fa9..ec747fe3 100644 --- a/src/radical/utils/zmq/pipe.py +++ b/src/radical/utils/zmq/pipe.py @@ -1,6 +1,7 @@ import zmq -import msgpack + +from ..serialize import to_msgpack, from_msgpack MODE_PUSH = 'push' MODE_PULL = 'pull' @@ -121,7 +122,7 @@ def put(self, msg): ''' assert self._mode == MODE_PUSH - self._sock.send(msgpack.packb(msg)) + self._sock.send(to_msgpack(msg)) # -------------------------------------------------------------------------- @@ -132,7 +133,7 @@ def get(self): ''' assert self._mode == MODE_PULL - return msgpack.unpackb(self._sock.recv()) + return from_msgpack(self._sock.recv()) # -------------------------------------------------------------------------- @@ -150,7 +151,7 @@ def get_nowait(self, timeout: float = 0): socks = dict(self._poller.poll(timeout=int(timeout * 1000))) if self._sock in socks: - return msgpack.unpackb(self._sock.recv()) + return from_msgpack(self._sock.recv()) # ------------------------------------------------------------------------------ diff --git a/src/radical/utils/zmq/pubsub.py b/src/radical/utils/zmq/pubsub.py index 66f27475..a4e350bb 100644 --- a/src/radical/utils/zmq/pubsub.py +++ b/src/radical/utils/zmq/pubsub.py @@ -2,23 +2,23 @@ import zmq import time -import msgpack import threading as mt -from typing import Optional +from typing import Optional -from ..atfork import atfork -from ..config import Config -from ..ids import generate_id, ID_CUSTOM -from ..url import Url -from ..misc import as_string, as_bytes, as_list, noop -from ..host import get_hostip -from ..logger import Logger -from ..profile import Profiler +from ..atfork import atfork +from ..config import Config +from ..ids import generate_id, ID_CUSTOM +from ..url import Url +from ..misc import as_string, as_bytes, as_list, noop +from ..host import get_hostip +from ..logger import Logger +from ..profile import Profiler +from ..serialize import to_msgpack, from_msgpack -from .bridge import Bridge -from .utils import no_intr +from .bridge import Bridge +from .utils import no_intr # ------------------------------------------------------------------------------ @@ -246,7 +246,7 @@ def put(self, topic, msg): # log_bulk(self._log, '-> %s' % topic, [msg]) btopic = as_bytes(topic.replace(' ', '_')) - bmsg = msgpack.packb(msg) + bmsg = to_msgpack(msg) data = btopic + b' ' + bmsg self._socket.send(data) @@ -273,7 +273,7 @@ def _get_nowait(socket, lock, timeout, log, prof): data = no_intr(socket.recv, flags=zmq.NOBLOCK) topic, bmsg = data.split(b' ', 1) - msg = msgpack.unpackb(bmsg) + msg = from_msgpack(bmsg) # log.debug(' <- %s: %s', topic, msg) @@ -497,7 +497,7 @@ def get(self): data = no_intr(self._sock.recv) topic, bmsg = data.split(b' ', 1) - msg = msgpack.unpackb(bmsg) + msg = from_msgpack(bmsg) # log_bulk(self._log, '<- %s' % topic, [msg]) @@ -519,7 +519,7 @@ def get_nowait(self, timeout=None): data = no_intr(self._sock.recv, flags=zmq.NOBLOCK) topic, bmsg = data.split(b' ', 1) - msg = msgpack.unpackb(bmsg) + msg = from_msgpack(bmsg) # log_bulk(self._log, '<- %s' % topic, [msg]) diff --git a/src/radical/utils/zmq/queue.py b/src/radical/utils/zmq/queue.py index 9914fb19..bb90ee05 100644 --- a/src/radical/utils/zmq/queue.py +++ b/src/radical/utils/zmq/queue.py @@ -3,24 +3,24 @@ import sys import zmq import time -import msgpack import threading as mt -from typing import Optional +from typing import Optional -from ..atfork import atfork -from ..config import Config -from ..ids import generate_id, ID_CUSTOM -from ..url import Url -from ..misc import as_string, as_bytes, as_list, noop -from ..host import get_hostip -from ..logger import Logger -from ..profile import Profiler -from ..debug import print_exception_trace +from ..atfork import atfork +from ..config import Config +from ..ids import generate_id, ID_CUSTOM +from ..url import Url +from ..misc import as_string, as_bytes, as_list, noop +from ..host import get_hostip +from ..logger import Logger +from ..profile import Profiler +from ..debug import print_exception_trace +from ..serialize import to_msgpack, from_msgpack -from .bridge import Bridge -from .utils import no_intr +from .bridge import Bridge +from .utils import no_intr # NOTE: the log bulk method is frequently called and slow # from .utils import log_bulk @@ -238,8 +238,8 @@ def _bridge_work(self): if len(data) != 2: raise RuntimeError('%d frames unsupported' % len(data)) - qname = as_string(msgpack.unpackb(data[0])) - msgs = msgpack.unpackb(data[1]) + qname = as_string(from_msgpack(data[0])) + msgs = from_msgpack(data[1]) # prof_bulk(self._prof, 'poll_put_recv', msgs) # log_bulk(self._log, '<> %s' % qname, msgs) # self._log.debug('put %s: %s ! ', qname, len(msgs)) @@ -278,7 +278,7 @@ def _bridge_work(self): # log_bulk(self._log, '>< %s' % qname, msgs) - data = [msgpack.packb(qname), msgpack.packb(msgs)] + data = [to_msgpack(qname), to_msgpack(msgs)] active = True # self._log.debug('==== get %s: %s', qname, list(buf.keys())) @@ -380,7 +380,7 @@ def put(self, msgs, qname=None): qname = 'default' # log_bulk(self._log, '-> %s[%s]' % (self._channel, qname), msgs) - data = [msgpack.packb(qname), msgpack.packb(msgs)] + data = [to_msgpack(qname), to_msgpack(msgs)] with self._lock: no_intr(self._q.send_multipart, data) @@ -427,8 +427,8 @@ def _get_nowait(url, qname=None, timeout=None, uid=None): # timeout in ms data = list(no_intr(info['socket'].recv_multipart)) info['requested'] = False - qname = as_string(msgpack.unpackb(data[0])) - msgs = as_string(msgpack.unpackb(data[1])) + qname = as_string(from_msgpack(data[0])) + msgs = as_string(from_msgpack(data[1])) # log_bulk(logger, '<-1 %s [%s]' % (uid, qname), msgs) return msgs @@ -694,8 +694,8 @@ def get(self, qname=None): data = list(no_intr(self._q.recv_multipart)) self._requested = False - qname = msgpack.unpackb(data[0]) - msgs = msgpack.unpackb(data[1]) + qname = from_msgpack(data[0]) + msgs = from_msgpack(data[1]) # log_bulk(self._log, '<-2 %s [%s]' % (self._channel, qname), msgs) @@ -729,8 +729,8 @@ def get_nowait(self, qname=None, timeout=None): # timeout in ms data = list(no_intr(self._q.recv_multipart)) self._requested = False - qname = msgpack.unpackb(data[0]) - msgs = msgpack.unpackb(data[1]) + qname = from_msgpack(data[0]) + msgs = from_msgpack(data[1]) # log_bulk(self._log, '<-3 %s [%s]' % (self._channel, qname), msgs) return as_string(msgs) diff --git a/src/radical/utils/zmq/server.py b/src/radical/utils/zmq/server.py index 69eac886..74492104 100644 --- a/src/radical/utils/zmq/server.py +++ b/src/radical/utils/zmq/server.py @@ -1,20 +1,20 @@ import zmq -import msgpack import threading as mt -from typing import Optional, Union, Iterator, Any, Dict +from typing import Optional, Union, Iterator, Any, Dict -from ..ids import generate_id -from ..url import Url -from ..misc import as_string -from ..host import get_hostip -from ..logger import Logger -from ..profile import Profiler -from ..debug import get_exception_trace +from ..ids import generate_id +from ..url import Url +from ..misc import as_string +from ..host import get_hostip +from ..logger import Logger +from ..profile import Profiler +from ..debug import get_exception_trace +from ..serialize import to_msgpack, from_msgpack -from .utils import no_intr +from .utils import no_intr # -------------------------------------------------------------------------- @@ -284,7 +284,7 @@ def _work(self) -> None: try: data = no_intr(self._sock.recv) - req = as_string(msgpack.unpackb(data)) + req = as_string(from_msgpack(data)) self._log.debug('req: %s', str(req)[:128]) if not isinstance(req, dict): @@ -312,7 +312,7 @@ def _work(self) -> None: finally: if not rep: rep = self._error('server error') - no_intr(self._sock.send, msgpack.packb(rep)) + no_intr(self._sock.send, to_msgpack(rep)) self._log.debug('rep: %s', str(rep)[:128]) self._sock.close() diff --git a/tests/unittests/test_json.py b/tests/unittests/test_json.py index 01792533..138cd3a4 100644 --- a/tests/unittests/test_json.py +++ b/tests/unittests/test_json.py @@ -3,6 +3,8 @@ # noqa: E201 +# import radical.utils as ru + # ------------------------------------------------------------------------------ # run tests if called directly diff --git a/tests/unittests/test_serialization.py b/tests/unittests/test_serialization.py new file mode 100644 index 00000000..b685171f --- /dev/null +++ b/tests/unittests/test_serialization.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +__author__ = "Radical.Utils Development Team (Andre Merzky)" +__copyright__ = "Copyright 2024, RADICAL@Rutgers" +__license__ = "MIT" + +import radical.utils as ru + + +# ------------------------------------------------------------------------------ +# +def test_serialization(): + + class Complex(object): + + def __init__(self, real, imag): + self.real = real + self.imag = imag + + def __eq__(self, other): + return self.real == other.real and self.imag == other.imag + + def serialize(self): + return {'real': self.real, 'imag': self.imag} + + @classmethod + def deserialize(cls, data): + return cls(data['real'], data['imag']) + + + ru.register_serializable(Complex, encode=Complex.serialize, + decode=Complex.deserialize) + + old = {'foo': {'complex_number': Complex(1, 2)}} + new = ru.from_json(ru.to_json(old)) + + assert old == new + + new = ru.from_msgpack(ru.to_msgpack(old)) + + assert old == new + + +# ------------------------------------------------------------------------------ +# +def test_serialization_typed_dict(): + + class A(ru.TypedDict): + _schema = {'s': str, 'i': int} + + class B(ru.TypedDict): + _schema = {'a': A} + + old = B(a=A(s='buz', i=42)) + new = ru.from_json(ru.to_json(old)) + + assert old == new + assert isinstance(old, B) and isinstance(new, B) + assert isinstance(old['a'], A) and isinstance(new['a'], A) + + +# ------------------------------------------------------------------------------ +# +if __name__ == '__main__': + + test_serialization() + test_serialization_typed_dict() + + +# ------------------------------------------------------------------------------ +