From b844cffd02d921c9115807f1dcd2c5e1c72db0fd Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 10 Aug 2023 17:48:53 -0400 Subject: [PATCH] expose get/set_array_storage/compression to SerializationContext this allows converters to control array storage and compression --- asdf/_tests/test_serialization_context.py | 34 ++++++++++ asdf/extension/_serialization_context.py | 80 +++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/asdf/_tests/test_serialization_context.py b/asdf/_tests/test_serialization_context.py index 85d455afb..5204e6df6 100644 --- a/asdf/_tests/test_serialization_context.py +++ b/asdf/_tests/test_serialization_context.py @@ -123,3 +123,37 @@ class Foo: # the key does not yet have an assigned object assert not key._is_valid() op_ctx.assign_blocks() + + +@pytest.mark.parametrize("block_access", [None, *list(BlockAccess)]) +def test_get_set_array_storage(block_access): + af = asdf.AsdfFile() + if block_access is None: + context = af._create_serialization_context() + else: + context = af._create_serialization_context(block_access) + arr = np.zeros(3) + storage = "external" + assert af.get_array_storage(arr) != storage + context.set_array_storage(arr, storage) + assert af.get_array_storage(arr) == storage + assert context.get_array_storage(arr) == storage + + +@pytest.mark.parametrize("block_access", [None, *list(BlockAccess)]) +def test_get_set_array_compression(block_access): + af = asdf.AsdfFile() + if block_access is None: + context = af._create_serialization_context() + else: + context = af._create_serialization_context(block_access) + arr = np.zeros(3) + compression = "bzp2" + kwargs = {"a": 1} + assert af.get_array_compression(arr) != compression + assert af.get_array_compression_kwargs(arr) != kwargs + context.set_array_compression(arr, compression, **kwargs) + assert af.get_array_compression(arr) == compression + assert af.get_array_compression_kwargs(arr) == kwargs + assert context.get_array_compression(arr) == compression + assert context.get_array_compression_kwargs(arr) == kwargs diff --git a/asdf/extension/_serialization_context.py b/asdf/extension/_serialization_context.py index 80ed8cdc8..95e2a0b67 100644 --- a/asdf/extension/_serialization_context.py +++ b/asdf/extension/_serialization_context.py @@ -147,6 +147,86 @@ def assign_object(self, obj): def assign_blocks(self): pass + def set_array_storage(self, arr, array_storage): + """ + Set the block type to use for the given array data. + + Parameters + ---------- + arr : numpy.ndarray + The array to set. If multiple views of the array are in + the tree, only the most recent block type setting will be + used, since all views share a single block. + + array_storage : str + Must be one of: + + - ``internal``: The default. The array data will be + stored in a binary block in the same ASDF file. + + - ``external``: Store the data in a binary block in a + separate ASDF file. + + - ``inline``: Store the data as YAML inline in the tree. + """ + self._blocks._set_array_storage(arr, array_storage) + + def get_array_storage(self, arr): + """ + Get the block type for the given array data. + + Parameters + ---------- + arr : numpy.ndarray + """ + return self._blocks._get_array_storage(arr) + + def set_array_compression(self, arr, compression, **compression_kwargs): + """ + Set the compression to use for the given array data. + + Parameters + ---------- + arr : numpy.ndarray + The array to set. If multiple views of the array are in + the tree, only the most recent compression setting will be + used, since all views share a single block. + + compression : str or None + Must be one of: + + - ``''`` or `None`: no compression + + - ``zlib``: Use zlib compression + + - ``bzp2``: Use bzip2 compression + + - ``lz4``: Use lz4 compression + + - ``input``: Use the same compression as in the file read. + If there is no prior file, acts as None. + + """ + self._blocks._set_array_compression(arr, compression, **compression_kwargs) + + def get_array_compression(self, arr): + """ + Get the compression type for the given array data. + + Parameters + ---------- + arr : numpy.ndarray + + Returns + ------- + compression : str or None + """ + return self._blocks._get_array_compression(arr) + + def get_array_compression_kwargs(self, arr): + """ """ + return self._blocks._get_array_compression_kwargs(arr) + class ReadBlocksContext(SerializationContext): """