From 391a74c407b7b9435fbce8dcf175244010fc0df1 Mon Sep 17 00:00:00 2001 From: Theodore Chang Date: Mon, 14 Oct 2024 21:41:50 +0200 Subject: [PATCH] Use abstract unpacker class --- pyproject.toml | 5 -- src/msglc/__init__.py | 10 +++- src/msglc/config.py | 10 +++- src/msglc/generate.py | 18 ++++-- src/msglc/reader.py | 129 +++++++++++++++++++++++++++++----------- src/msglc/toc.py | 8 ++- src/msglc/unpacker.py | 41 +++++++++++++ src/msglc/utility.py | 8 ++- src/msglc/writer.py | 36 ++++++++--- tests/test_benchmark.py | 4 +- 10 files changed, 209 insertions(+), 60 deletions(-) create mode 100644 src/msglc/unpacker.py diff --git a/pyproject.toml b/pyproject.toml index 56754d2..8657877 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ dev = [ "pytest-cov", "pytest-benchmark", "pytest-asyncio", - "black", "ruff", ] numpy = [ @@ -46,7 +45,3 @@ numpy = [ "Homepage" = "https://github.com/TLCFEM/msglc" "Bug Reports" = "https://github.com/TLCFEM/msglc/issuess" "Source" = "https://github.com/TLCFEM/msglc" - -[tool.black] -line-length = 120 -fast = true diff --git a/src/msglc/__init__.py b/src/msglc/__init__.py index 3b1074d..2ceb39d 100644 --- a/src/msglc/__init__.py +++ b/src/msglc/__init__.py @@ -44,7 +44,11 @@ class FileInfo: def combine( - archive: str | BytesIO, files: FileInfo | list[FileInfo], *, mode: Literal["a", "w"] = "w", validate: bool = True + archive: str | BytesIO, + files: FileInfo | list[FileInfo], + *, + mode: Literal["a", "w"] = "w", + validate: bool = True, ): """ This function is used to combine the multiple serialized files into a single archive. @@ -98,7 +102,9 @@ def _iter(path: str | BinaryIO): combiner.write(_iter(file.path), file.name) -def append(archive: str | BytesIO, files: FileInfo | list[FileInfo], *, validate: bool = True): +def append( + archive: str | BytesIO, files: FileInfo | list[FileInfo], *, validate: bool = True +): """ This function is used to append the multiple serialized files to an existing single archive. diff --git a/src/msglc/config.py b/src/msglc/config.py index 2182be7..1110122 100644 --- a/src/msglc/config.py +++ b/src/msglc/config.py @@ -63,7 +63,10 @@ def configure( This function is used to configure the settings. It accepts any number of keyword arguments. The function updates the values of the configuration parameters if they are provided in the arguments. """ - if isinstance(small_obj_optimization_threshold, int) and small_obj_optimization_threshold > 0: + if ( + isinstance(small_obj_optimization_threshold, int) + and small_obj_optimization_threshold > 0 + ): config.small_obj_optimization_threshold = small_obj_optimization_threshold if config.trivial_size > config.small_obj_optimization_threshold: config.trivial_size = config.small_obj_optimization_threshold @@ -77,7 +80,10 @@ def configure( if isinstance(fast_loading, bool): config.fast_loading = fast_loading - if isinstance(fast_loading_threshold, (int, float)) and 0 <= fast_loading_threshold <= 1: + if ( + isinstance(fast_loading_threshold, (int, float)) + and 0 <= fast_loading_threshold <= 1 + ): config.fast_loading_threshold = fast_loading_threshold if isinstance(trivial_size, int) and trivial_size > 0: diff --git a/src/msglc/generate.py b/src/msglc/generate.py index 961c38b..dfea153 100644 --- a/src/msglc/generate.py +++ b/src/msglc/generate.py @@ -25,13 +25,18 @@ from msglc import dump from msglc.config import configure from msglc.reader import LazyDict, LazyList, LazyStats, LazyReader +from msglc.unpacker import Unpacker def generate_random_json(depth=10, width=4, simple=False): seed = random.random() def generate_token(): - return "".join(random.choices(string.ascii_letters + string.digits, k=random.randint(5, 10))) + return "".join( + random.choices( + string.ascii_letters + string.digits, k=random.randint(5, 10) + ) + ) if depth == 0 or (simple and seed < 0.1): return random.choice( @@ -44,7 +49,10 @@ def generate_token(): ) if seed < 0.7: - return {generate_token(): generate_random_json(depth - 1, width, True) for _ in range(width)} + return { + generate_token(): generate_random_json(depth - 1, width, True) + for _ in range(width) + } if seed < 0.95 or not simple: return [generate_random_json(depth - 1, width, True) for _ in range(width)] @@ -106,13 +114,15 @@ def generate(*, depth=6, width=11, threshold=23): p.map(_dump, step) -def compare(mode, size: int = 13, total: int = 5, unpacker=None): +def compare(mode, size: int = 13, total: int = 5, unpacker: Unpacker = None): accumulator: int = 0 with open("path.txt", "r") as f: if mode > 0: counter = LazyStats() - with LazyReader(f"archive_{size}.msg", counter=counter, unpacker=unpacker) as reader: + with LazyReader( + f"archive_{size}.msg", counter=counter, unpacker=unpacker + ) as reader: while p := f.readline(): accumulator += 1 if accumulator == 10**total: diff --git a/src/msglc/reader.py b/src/msglc/reader.py index 61a016a..8e42f89 100644 --- a/src/msglc/reader.py +++ b/src/msglc/reader.py @@ -20,13 +20,13 @@ from io import BytesIO, BufferedReader from bitarray import bitarray -from msgpack import Unpacker # type: ignore -from msgspec.msgpack import Decoder +import msgpack from .config import config, increment_gc_counter, decrement_gc_counter, BufferReader from .index import normalise_index, to_index from .utility import MockIO from .writer import LazyWriter +from .unpacker import Unpacker, MsgpackUnpacker def to_obj(v): @@ -84,17 +84,17 @@ def __init__( *, counter: LazyStats | None = None, cached: bool = True, - unpacker: Unpacker | Decoder | None = None, + unpacker: Unpacker | None = None, ): self._buffer: BufferReader = buffer self._offset: int = offset # start of original data self._counter: LazyStats | None = counter self._cached: bool = cached - self._unpacker: Unpacker | Decoder - if isinstance(unpacker, (Unpacker, Decoder)): + self._unpacker: Unpacker + if isinstance(unpacker, Unpacker): self._unpacker = unpacker elif unpacker is None: - self._unpacker = Unpacker() + self._unpacker = MsgpackUnpacker() else: raise TypeError("Need a valid unpacker.") @@ -121,11 +121,7 @@ def _readb(self, start: int, end: int): return self._buffer.read(size) def _unpack(self, data: bytes): - if isinstance(self._unpacker, Decoder): - return self._unpacker.decode(data) - - self._unpacker.feed(data) - return self._unpacker.unpack() + return self._unpacker.decode(data) def _read(self, start: int, end: int): return self._unpack(self._readb(start, end)) @@ -133,7 +129,11 @@ def _read(self, start: int, end: int): def _child(self, toc: dict | int): self._accessed_items += 1 - params: dict = {"counter": self._counter, "cached": self._cached, "unpacker": self._unpacker} + params: dict = { + "counter": self._counter, + "cached": self._cached, + "unpacker": self._unpacker, + } # {"t": {"name1": start_pos, "name2": start_pos}} # this is used in combined archives @@ -144,8 +144,13 @@ def _child(self, toc: dict | int): if (child_toc := toc.get("t", None)) is None: # {"p": [start_pos, end_pos]} # this is used in small objects - if 2 == len(child_pos := toc["p"]) and all(isinstance(x, int) for x in child_pos): - if isinstance(data := self._read(*child_pos), bytes) and b"multiarray" in data[:40]: + if 2 == len(child_pos := toc["p"]) and all( + isinstance(x, int) for x in child_pos + ): + if ( + isinstance(data := self._read(*child_pos), bytes) + and b"multiarray" in data[:40] + ): return pickle.loads(data) return data @@ -167,7 +172,10 @@ def _child(self, toc: dict | int): @property def _fast_loading(self): - return config.fast_loading and self._accessed_items < config.fast_loading_threshold * len(self) + return ( + config.fast_loading + and self._accessed_items < config.fast_loading_threshold * len(self) + ) def to_obj(self): raise NotImplementedError @@ -188,10 +196,14 @@ def __init__( *, counter: LazyStats | None = None, cached: bool = True, - unpacker: Unpacker | Decoder | None = None, + unpacker: Unpacker | None = None, ): - super().__init__(buffer, offset, counter=counter, cached=cached, unpacker=unpacker) - self._toc: list | None = toc.get("t", None) # if None, it's a list of small objects + super().__init__( + buffer, offset, counter=counter, cached=cached, unpacker=unpacker + ) + self._toc: list | None = toc.get( + "t", None + ) # if None, it's a list of small objects self._pos: list = toc.get("p", None) # noqa # if None, it comes from a combined archive self._index: int = 0 self._cache: list = [None] * len(self) @@ -206,7 +218,11 @@ def __init__( self._size_list.append(total_size) def __repr__(self): - return f"LazyList[{len(self)}]" if config.simple_repr or not self._cached else self.to_obj().__repr__() + return ( + f"LazyList[{len(self)}]" + if config.simple_repr or not self._cached + else self.to_obj().__repr__() + ) def _lookup_index(self, index: int) -> int: low: int = 0 @@ -224,7 +240,7 @@ def _lookup_index(self, index: int) -> int: low = mid def _all(self, start: int, end: int) -> list: - return list(Unpacker(BytesIO(self._readb(start, end)))) + return list(msgpack.Unpacker(BytesIO(self._readb(start, end)))) def __getitem__(self, index): index_range: list | range @@ -250,9 +266,14 @@ def __getitem__(self, index): self._cache[item] = self._child(self._toc[item]) else: lookup_index: int = self._lookup_index(item) - num_start, num_end = self._size_list[lookup_index], self._size_list[lookup_index + 1] + num_start, num_end = ( + self._size_list[lookup_index], + self._size_list[lookup_index + 1], + ) self._mask[num_start:num_end] = 1 - self._cache[num_start:num_end] = self._all(*self._pos[lookup_index][1:]) + self._cache[num_start:num_end] = self._all( + *self._pos[lookup_index][1:] + ) return self._cache[index] @@ -263,7 +284,10 @@ def __getitem__(self, index): self._cache[item] = self._child(self._toc[item]) else: lookup_index = self._lookup_index(item) - num_start, num_end = self._size_list[lookup_index], self._size_list[lookup_index + 1] + num_start, num_end = ( + self._size_list[lookup_index], + self._size_list[lookup_index + 1], + ) self._cache[num_start:num_end] = self._all(*self._pos[lookup_index][1:]) result = self._cache[index] @@ -283,7 +307,11 @@ def __next__(self): return item def __len__(self): - return self._toc.__len__() if self._toc is not None else sum(x[0] for x in self._pos) + return ( + self._toc.__len__() + if self._toc is not None + else sum(x[0] for x in self._pos) + ) def to_obj(self): if not self._cached: @@ -325,16 +353,22 @@ def __init__( *, counter: LazyStats | None = None, cached: bool = True, - unpacker: Unpacker | Decoder | None = None, + unpacker: Unpacker | None = None, ): - super().__init__(buffer, offset, counter=counter, cached=cached, unpacker=unpacker) + super().__init__( + buffer, offset, counter=counter, cached=cached, unpacker=unpacker + ) self._toc: dict = toc["t"] self._pos: list = toc.get("p", None) # noqa # if empty, it comes from a combined archive self._cache: dict = {} self._full_loaded: bool = False def __repr__(self): - return f"LazyDict[{len(self)}]" if config.simple_repr or not self._cached else self.to_obj().__repr__() + return ( + f"LazyDict[{len(self)}]" + if config.simple_repr or not self._cached + else self.to_obj().__repr__() + ) def __getitem__(self, key): if not self._cached: @@ -390,7 +424,7 @@ def __init__( *, counter: LazyStats | None = None, cached: bool = True, - unpacker: Unpacker | Decoder | None = None, + unpacker: Unpacker | None = None, ): """ :param buffer_or_path: the buffer or path to the file @@ -408,7 +442,11 @@ def __init__( else: raise ValueError("Expecting a buffer or path.") - sep_a, sep_b, sep_c = LazyWriter.magic_len(), LazyWriter.magic_len() + 10, LazyWriter.magic_len() + 20 + sep_a, sep_b, sep_c = ( + LazyWriter.magic_len(), + LazyWriter.magic_len() + 10, + LazyWriter.magic_len() + 20, + ) # keep the buffer unchanged in case of failure original_pos: int = buffer.tell() @@ -418,7 +456,13 @@ def __init__( if header[:sep_a] != LazyWriter.magic: raise ValueError("Invalid file format.") - super().__init__(buffer, original_pos + sep_c, counter=counter, cached=cached, unpacker=unpacker) + super().__init__( + buffer, + original_pos + sep_c, + counter=counter, + cached=cached, + unpacker=unpacker, + ) toc_start: int = self._unpack(header[sep_a:sep_b].lstrip(b"\0")) toc_size: int = self._unpack(header[sep_b:sep_c].lstrip(b"\0")) @@ -430,7 +474,11 @@ def __repr__(self): if isinstance(self._buffer_or_path, str): file_path = " (" + self._buffer_or_path + ")" - return f"LazyReader{file_path}" if config.simple_repr or not self._cached else self.to_obj().__repr__() + return ( + f"LazyReader{file_path}" + if config.simple_repr or not self._cached + else self.to_obj().__repr__() + ) def __enter__(self): increment_gc_counter() @@ -492,7 +540,9 @@ def read(self, path: str | list | slice | None = None): target = self._obj for key in (v for v in path_stack if v != ""): target = target[ - to_index(key, len(target)) if isinstance(key, str) and isinstance(target, (list, LazyList)) else key + to_index(key, len(target)) + if isinstance(key, str) and isinstance(target, (list, LazyList)) + else key ] return target @@ -510,7 +560,11 @@ def visit(self, path: str = ""): """ target = self._obj for key in (v for v in path.split("/") if v != ""): - target = target[to_index(key, len(target)) if isinstance(target, (list, LazyList)) else key] + target = target[ + to_index(key, len(target)) + if isinstance(target, (list, LazyList)) + else key + ] return target async def async_read(self, path: str | list | slice | None = None): @@ -542,7 +596,9 @@ async def async_read(self, path: str | list | slice | None = None): for key in (v for v in path_stack if v != ""): target = await async_get( target, - to_index(key, len(target)) if isinstance(key, str) and isinstance(target, (list, LazyList)) else key, + to_index(key, len(target)) + if isinstance(key, str) and isinstance(target, (list, LazyList)) + else key, ) return target @@ -561,7 +617,10 @@ async def async_visit(self, path: str = ""): target = self._obj for key in (v for v in path.split("/") if v != ""): target = await async_get( - target, to_index(key, len(target)) if isinstance(target, (list, LazyList)) else key + target, + to_index(key, len(target)) + if isinstance(target, (list, LazyList)) + else key, ) return target diff --git a/src/msglc/toc.py b/src/msglc/toc.py index 1493687..bdfa030 100644 --- a/src/msglc/toc.py +++ b/src/msglc/toc.py @@ -39,7 +39,9 @@ class Node: class TOC: - def __init__(self, *, packer: Packer, buffer: BytesIO | BinaryIO, transform: callable = None): # type: ignore + def __init__( + self, *, packer: Packer, buffer: BytesIO | BinaryIO, transform: callable = None + ): # type: ignore self._buffer: BytesIO | BinaryIO = buffer self._packer: Packer = packer self._initial_pos = self._buffer.tell() @@ -113,7 +115,9 @@ def _generate(_start: int) -> Node: accu_list.append(v) accu_size += v.p[1] - v.p[0] if accu_size > config.small_obj_optimization_threshold: - groups.append((len(accu_list), accu_list[0].p[0], accu_list[-1].p[1])) + groups.append( + (len(accu_list), accu_list[0].p[0], accu_list[-1].p[1]) + ) accu_list = [] accu_size = 0 diff --git a/src/msglc/unpacker.py b/src/msglc/unpacker.py new file mode 100644 index 0000000..c27e88c --- /dev/null +++ b/src/msglc/unpacker.py @@ -0,0 +1,41 @@ +# Copyright (C) 2024 Theodore Chang +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from abc import abstractmethod +import msgpack +import msgspec + + +class Unpacker: + @abstractmethod + def decode(self, data): + raise NotImplementedError + + +class MsgpackUnpacker(Unpacker): + def __init__(self): + self._unpacker = msgpack.Unpacker() + + def decode(self, data): + self._unpacker.feed(data) + return self._unpacker.unpack() + + +class MsgspecUnpacker(Unpacker): + def __init__(self): + self._unpacker = msgspec.msgpack.Decoder() + + def decode(self, data): + return self._unpacker.decode(data) diff --git a/src/msglc/utility.py b/src/msglc/utility.py index 1cd162d..7929845 100644 --- a/src/msglc/utility.py +++ b/src/msglc/utility.py @@ -20,7 +20,13 @@ class MockIO: - def __init__(self, path: str | BytesIO, mode: str, seek_delay: float = 0, read_speed: int | list = 1 * 2**20): + def __init__( + self, + path: str | BytesIO, + mode: str, + seek_delay: float = 0, + read_speed: int | list = 1 * 2**20, + ): self._path = path if isinstance(path, str) else None self._io = open(path, mode) if isinstance(path, str) else path self._seek_delay: float = seek_delay diff --git a/src/msglc/writer.py b/src/msglc/writer.py index 4d79feb..6603886 100644 --- a/src/msglc/writer.py +++ b/src/msglc/writer.py @@ -21,7 +21,13 @@ from msgpack import Packer, packb, unpackb # type: ignore -from .config import config, increment_gc_counter, decrement_gc_counter, BufferWriter, max_magic_len +from .config import ( + config, + increment_gc_counter, + decrement_gc_counter, + BufferWriter, + max_magic_len, +) from .toc import TOC @@ -54,7 +60,9 @@ def __enter__(self): increment_gc_counter() if isinstance(self._buffer_or_path, str): - self._buffer = open(self._buffer_or_path, "wb", buffering=config.write_buffer_size) + self._buffer = open( + self._buffer_or_path, "wb", buffering=config.write_buffer_size + ) elif isinstance(self._buffer_or_path, (BytesIO, BufferedReader)): self._buffer = self._buffer_or_path else: @@ -100,7 +108,9 @@ def write(self, obj) -> None: class LazyCombiner: - def __init__(self, buffer_or_path: str | BufferWriter, *, mode: Literal["a", "w"] = "w"): + def __init__( + self, buffer_or_path: str | BufferWriter, *, mode: Literal["a", "w"] = "w" + ): """ :param buffer_or_path: target buffer or file path :param mode: mode of operation, 'w' for write and 'a' for append @@ -116,8 +126,14 @@ def __init__(self, buffer_or_path: str | BufferWriter, *, mode: Literal["a", "w" def __enter__(self): if isinstance(self._buffer_or_path, str): - mode: str = "wb" if not os.path.exists(self._buffer_or_path) or self._mode == "w" else "r+b" - self._buffer = open(self._buffer_or_path, mode, buffering=config.write_buffer_size) + mode: str = ( + "wb" + if not os.path.exists(self._buffer_or_path) or self._mode == "w" + else "r+b" + ) + self._buffer = open( + self._buffer_or_path, mode, buffering=config.write_buffer_size + ) elif isinstance(self._buffer_or_path, (BytesIO, BufferedReader)): self._buffer = self._buffer_or_path if self._mode == "a": @@ -131,7 +147,11 @@ def __enter__(self): self._buffer.write(b"\0" * 20) self._file_start = self._buffer.tell() else: - sep_a, sep_b, sep_c = LazyWriter.magic_len(), LazyWriter.magic_len() + 10, LazyWriter.magic_len() + 20 + sep_a, sep_b, sep_c = ( + LazyWriter.magic_len(), + LazyWriter.magic_len() + 10, + LazyWriter.magic_len() + 20, + ) ini_position: int = self._buffer.tell() header: bytes = self._buffer.read(sep_c) @@ -141,7 +161,9 @@ def _raise_invalid(msg: str): raise ValueError(msg) if header[:sep_a] != LazyWriter.magic: - _raise_invalid("Invalid file format, cannot append to the current file.") + _raise_invalid( + "Invalid file format, cannot append to the current file." + ) toc_start: int = unpackb(header[sep_a:sep_b].lstrip(b"\0")) toc_size: int = unpackb(header[sep_b:sep_c].lstrip(b"\0")) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 1026b59..1079a37 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -16,11 +16,11 @@ import random import pytest -from msgspec.msgpack import Decoder from msglc import dump, config from msglc.generate import generate_random_json, find_all_paths, goto_path, generate, compare from msglc.reader import LazyStats, LazyReader +from msglc.unpacker import MsgpackUnpacker, MsgspecUnpacker def test_random_benchmark(monkeypatch, tmpdir): @@ -112,7 +112,7 @@ def prepare(tmpdir_factory): @pytest.mark.parametrize("size", [x for x in range(13, 25)]) @pytest.mark.parametrize("total", [0, 1, 2, 3, 4]) -@pytest.mark.parametrize("unpacker", [None, Decoder()]) +@pytest.mark.parametrize("unpacker", [MsgpackUnpacker(), MsgspecUnpacker()]) def test_matrix(prepare, benchmark, size, total, unpacker): with prepare.as_cwd(): benchmark(compare, 1, size, total, unpacker)