diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 0000000..4a79063 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,40 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Python package + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest + python -m pip install . + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pytest diff --git a/gadgetron/__init__.py b/gadgetron/__init__.py index 658b224..1f56999 100644 --- a/gadgetron/__init__.py +++ b/gadgetron/__init__.py @@ -21,7 +21,6 @@ __all__ = [ util, - legacy, external, examples, Gadget, diff --git a/gadgetron/external/__init__.py b/gadgetron/external/__init__.py index a18f3dd..be25728 100644 --- a/gadgetron/external/__init__.py +++ b/gadgetron/external/__init__.py @@ -1,5 +1,7 @@ from .connection import Connection from .listen import listen +from .readers import read +from .writers import write __all__ = [Connection, listen] diff --git a/gadgetron/external/connection.py b/gadgetron/external/connection.py index 9a50c29..8af80cf 100644 --- a/gadgetron/external/connection.py +++ b/gadgetron/external/connection.py @@ -1,6 +1,5 @@ - -import socket import logging +from socket import MSG_WAITALL import xml.etree.ElementTree as xml @@ -8,12 +7,15 @@ from . import constants -from .readers import read, read_byte_string, read_acquisition, read_waveform, read_image -from .writers import write_acquisition, write_waveform, write_image +from .readers import read, read_byte_string +from .writers import write, write_byte_string + +from ..types.image_array import ImageArray +from ..types.recon_data import ReconData +from ..types.acquisition_bucket import AcquisitionBucket -from ..types.image_array import ImageArray, read_image_array, write_image_array -from ..types.recon_data import ReconData, read_recon_data, write_recon_data -from ..types.acquisition_bucket import read_acquisition_bucket +from ..types.serialization import message_reader, message_writer +from ..types import serialization class Connection: @@ -27,34 +29,41 @@ def __init__(self, socket): self.socket.settimeout(None) def read(self, nbytes): - bytes = self.socket.recv(nbytes, socket.MSG_WAITALL) - while len(bytes) < nbytes: - bytes += self.socket.recv(nbytes - len(bytes),socket.MSG_WAITALL) - return bytes + bytedata = self.socket.recv(nbytes, MSG_WAITALL) + while len(bytedata) < nbytes: + bytedata += self.socket.recv(nbytes - len(bytedata), MSG_WAITALL) + return bytedata def write(self, byte_array): self.socket.sendall(byte_array) def close(self): - end = constants.GadgetMessageIdentifier.pack(constants.GADGET_MESSAGE_CLOSE) - self.socket.send(end) self.socket.close() class Struct: def __init__(self, **fields): self.__dict__.update(fields) - def __init__(self, socket): + @staticmethod + def initiate_connection(socket, config, header): + + connection = Connection(socket,config,header) + connection._write_config(config) + connection._write_header(header) + + return connection + + def __init__(self, socket, config=None, header=None): self.socket = Connection.SocketWrapper(socket) self.readers = Connection._default_readers() self.writers = Connection._default_writers() - self.raw_bytes = Connection.Struct(config=None, header=None) - self.config, self.raw_bytes.config = self._read_config() - self.header, self.raw_bytes.header = self._read_header() + self.config = config if config is not None else self._read_config() + self.header = header if header is not None else self._read_header() self.filters = [] + self.__closed = False def __next__(self): return self.next() @@ -63,6 +72,7 @@ def __enter__(self): return self def __exit__(self, *exception_info): + self.close() self.socket.close() def __iter__(self): @@ -96,7 +106,7 @@ def add_reader(self, mid, reader, *args, **kwargs): def add_writer(self, accepts, writer, *args, **kwargs): """ Add a writer to the connection's writers. - :param accepts: Predicate used to determine if a writer accepts an item. + aparam accepts: Predicate used to determine if a writer accepts an item. :param writer: Writer function to be called when `accepts` predicate returned truthy value. :param args: Additional arguments. These are forwarded to the writer when it's called. :param kwargs: Additional keyword-arguments. These are forwarded to the writer when it's called. @@ -161,6 +171,12 @@ def next(self): return mid, item + def close(self): + if not self.__closed: + end = constants.GadgetMessageIdentifier.pack(constants.GADGET_MESSAGE_CLOSE) + self.socket.write(end) + self.__closed = True + def _read_item(self): message_identifier = self._read_message_identifier() @@ -176,47 +192,57 @@ def _read_message_identifier(self): def _read_config(self): message_identifier = self._read_message_identifier() - assert(message_identifier == constants.GADGET_MESSAGE_CONFIG) + assert (message_identifier == constants.GADGET_MESSAGE_CONFIG) config_bytes = read_byte_string(self.socket) try: - parsed_config = xml.fromstring(config_bytes) + parsed_config = xml.fromstring(config_bytes) except xml.ParseError as e: - logging.log(logging.WARN,"Config parsing failed with error message {}".format(e)) - parsed_config = None + logging.warning(f"Config parsing failed with error message {e}") + parsed_config = None + + return parsed_config - return parsed_config, config_bytes + def _write_config(self, config): + serialization.write(self.socket, constants.GADGET_MESSAGE_CONFIG, constants.GadgetMessageIdentifier) + write_byte_string(self.socket, xml.tostring(config, encoding='utf-8', method='xml')) def _read_header(self): message_identifier = self._read_message_identifier() - assert(message_identifier == constants.GADGET_MESSAGE_HEADER) + assert (message_identifier == constants.GADGET_MESSAGE_HEADER) header_bytes = read_byte_string(self.socket) - return ismrmrd.xsd.CreateFromDocument(header_bytes), header_bytes + return ismrmrd.xsd.CreateFromDocument(header_bytes) - @ staticmethod + def _write_header(self, header: ismrmrd.xsd.ismrmrdHeader): + serialization.write(self.socket, constants.GADGET_MESSAGE_HEADER, constants.GadgetMessageIdentifier) + write_byte_string(self.socket, ismrmrd.xsd.ToXML(header).encode('utf-8')) + + @staticmethod def _default_readers(): return { constants.GADGET_MESSAGE_CLOSE: Connection.stop_iteration, - constants.GADGET_MESSAGE_ISMRMRD_ACQUISITION: read_acquisition, - constants.GADGET_MESSAGE_ISMRMRD_WAVEFORM: read_waveform, - constants.GADGET_MESSAGE_ISMRMRD_IMAGE: read_image, - constants.GADGET_MESSAGE_IMAGE_ARRAY: read_image_array, - constants.GADGET_MESSAGE_RECON_DATA: read_recon_data, - constants.GADGET_MESSAGE_BUCKET: read_acquisition_bucket + constants.GADGET_MESSAGE_ISMRMRD_ACQUISITION: message_reader(ismrmrd.Acquisition), + constants.GADGET_MESSAGE_ISMRMRD_WAVEFORM: message_reader(ismrmrd.Waveform), + constants.GADGET_MESSAGE_ISMRMRD_IMAGE: message_reader(ismrmrd.Image), + constants.GADGET_MESSAGE_IMAGE_ARRAY: message_reader(ImageArray), + constants.GADGET_MESSAGE_RECON_DATA: message_reader(ReconData), + constants.GADGET_MESSAGE_BUCKET: message_reader(AcquisitionBucket) } - @ staticmethod + @staticmethod def _default_writers(): + def create_writer(message_id, obj_type): + return lambda item: isinstance(item, obj_type), message_writer(message_id, obj_type) + return [ - (lambda item: isinstance(item, ismrmrd.Acquisition), write_acquisition), - (lambda item: isinstance(item, ismrmrd.Waveform), write_waveform), - (lambda item: isinstance(item, ismrmrd.Image), write_image), - (lambda item: isinstance(item, ImageArray), write_image_array), - (lambda item: isinstance(item, ReconData), write_recon_data) + create_writer(constants.GADGET_MESSAGE_ISMRMRD_ACQUISITION, ismrmrd.Acquisition), + create_writer(constants.GADGET_MESSAGE_ISMRMRD_WAVEFORM, ismrmrd.Waveform), + create_writer(constants.GADGET_MESSAGE_ISMRMRD_IMAGE, ismrmrd.Image), + create_writer(constants.GADGET_MESSAGE_IMAGE_ARRAY, ImageArray), + create_writer(constants.GADGET_MESSAGE_RECON_DATA, ReconData) ] - @ staticmethod + @staticmethod def stop_iteration(_): logging.debug("Connection closed normally.") raise StopIteration() - diff --git a/gadgetron/external/listen.py b/gadgetron/external/listen.py index 7ad90f0..b58cbd7 100644 --- a/gadgetron/external/listen.py +++ b/gadgetron/external/listen.py @@ -1,11 +1,13 @@ -import logging - +import os import socket +import logging from . import connection + + def wait_for_client_connection(port): sock = socket.socket(family=socket.AF_INET6) @@ -13,6 +15,7 @@ def wait_for_client_connection(port): sock.bind(('', port)) sock.listen(0) conn, address = sock.accept() + sock.close() logging.info(f"Accepted connection from client: {address}") @@ -30,5 +33,7 @@ def listen(port, handler, *args, **kwargs): logging.debug(f"Starting external Python module '{handler.__name__}' in state: [PASSIVE]") logging.debug(f"Waiting for connection from client on port: {port}") - with connection.Connection(wait_for_client_connection(port)) as conn: + storage_address = kwargs.get('storage_address', os.environ.get("GADGETRON_STORAGE_ADDRESS", None)) + + with connection.Connection(wait_for_client_connection(port), storage_address) as conn: handler(conn, *args, **kwargs) diff --git a/gadgetron/external/readers.py b/gadgetron/external/readers.py index a4bd1b2..46e3468 100644 --- a/gadgetron/external/readers.py +++ b/gadgetron/external/readers.py @@ -1,4 +1,6 @@ import ctypes +import numpy as np + import ismrmrd import functools @@ -7,46 +9,95 @@ from . import constants +from ..types.serialization import reader, read, isstruct, isgeneric, NDArray, inheritsfrom, isoptional, Vector +from typing import get_args, List, Optional +import dataclasses + + +@reader(predicate=inheritsfrom(np.number)) +def read_numpy_number(source, class_type): + dtype = np.dtype(class_type) + return np.frombuffer(source.read(dtype.itemsize), dtype=dtype).item() + + +@reader(predicate=dataclasses.is_dataclass) +def read_dataclass(source, class_type): + return class_type(*(read(source, dim.type) for dim in dataclasses.fields(class_type))) + + +@reader(predicate=isstruct) +def read_struct(source, struct_type): + return struct_type.unpack(source.read(struct_type.size))[0] -def read(source, type): - return type.unpack(source.read(type.size))[0] +@reader(predicate=inheritsfrom(ctypes.Structure)) +def read_cstruct(source, obj_type): + return obj_type.from_buffer_copy(source.read(ctypes.sizeof(obj_type))) -def read_optional(source, continuation, *args, **kwargs): - is_present = read(source, constants.bool) - return continuation(source, *args, **kwargs) if is_present else None +@reader(predicate=isgeneric(set)) +def read_set(source, obj_type): + return set(read_list(source, obj_type)) -def read_vector(source, numpy_type=numpy.uint64): - size = read(source, constants.uint64) - dtype = numpy.dtype(numpy_type) - return numpy.frombuffer(source.read(size * dtype.itemsize), dtype=dtype) +@reader(predicate=isgeneric(Vector)) +def read_vector(source, obj_type): + subtype = get_args(obj_type)[0] + size = read_struct(source, constants.uint64) + dtype = np.dtype(subtype) + if dtype == object or not dtype.isbuiltin or subtype == str: + return np.array([read(source, subtype) for s in range(size)],dtype=object) + else: + return np.frombuffer(source.read(size * dtype.itemsize), dtype=dtype) -def read_array(source, numpy_type=numpy.uint64): - dtype = numpy.dtype(numpy_type) - dimensions = read_vector(source) - elements = int(functools.reduce(lambda a, b: a * b, dimensions)) - return numpy.reshape(numpy.frombuffer(source.read(elements * dtype.itemsize), dtype=dtype), dimensions, order='F') +@reader(predicate=isgeneric(list)) +def read_list(source, obj_type): + subtype = get_args(obj_type)[0] + size = read_struct(source, constants.uint64) + dtype = np.dtype(subtype) -def read_object_array(source, read_object): - dimensions = read_vector(source) - elements = int(functools.reduce(lambda a, b: a * b, dimensions)) - return numpy.reshape(numpy.asarray([read_object(source) for _ in range(elements)], dtype=object), dimensions, - order='F') + if dtype == object or not dtype.isbuiltin or subtype == str: + return [read(source, subtype) for s in range(size)] + else: + return list(np.frombuffer(source.read(size * dtype.itemsize), dtype=dtype)) +@reader(predicate=isgeneric(NDArray)) +def read_array(source, obj_type): + subtype = get_args(obj_type)[0] + dtype = np.dtype(subtype) + dimensions = read(source, List[np.uint64]) + elements = np.prod(dimensions) + if dtype == object or not dtype.isbuiltin: + + return np.reshape(np.asarray([read(source, subtype) for _ in range(elements)], dtype=object), dimensions, + order='F') + else: + return np.reshape(np.frombuffer(source.read(int(elements) * dtype.itemsize), dtype=dtype), dimensions, + order='F') + + +@reader(predicate=isoptional) +def read_optional(source, obj_type): + subtype = get_args(obj_type)[0] + is_present = read_struct(source, constants.bool) + return read(source, subtype) if is_present else None + + +@reader(ismrmrd.ImageHeader) def read_image_header(source): header_bytes = source.read(ctypes.sizeof(ismrmrd.ImageHeader)) return ismrmrd.ImageHeader.from_buffer_copy(header_bytes) +@reader(ismrmrd.AcquisitionHeader) def read_acquisition_header(source): header_bytes = source.read(ctypes.sizeof(ismrmrd.AcquisitionHeader)) return ismrmrd.AcquisitionHeader.from_buffer_copy(header_bytes) +@reader(ismrmrd.WaveformHeader) def read_waveform_header(source): header_bytes = source.read(ctypes.sizeof(ismrmrd.WaveformHeader)) return ismrmrd.Waveform.from_buffer_copy(header_bytes) @@ -62,13 +113,21 @@ def read_byte_string(source, type=constants.uint32): return byte_string +@reader(ismrmrd.Acquisition) def read_acquisition(source): return ismrmrd.Acquisition.deserialize_from(source.read) +@reader(ismrmrd.Waveform) def read_waveform(source): return ismrmrd.Waveform.deserialize_from(source.read) +@reader(data_type=ismrmrd.Image) def read_image(source): return ismrmrd.Image.deserialize_from(source.read) + + +@reader(str) +def read_str(source): + return read_byte_string(source, constants.uint64).rstrip(b'\0').decode('utf-8') diff --git a/gadgetron/external/writers.py b/gadgetron/external/writers.py index e4f58b1..43dffbf 100644 --- a/gadgetron/external/writers.py +++ b/gadgetron/external/writers.py @@ -1,40 +1,91 @@ +import ctypes +import ismrmrd import numpy as np -from ..external import constants +from . import constants +from ..types.serialization import writer, NDArray, isstruct, isgeneric, write, inheritsfrom, isoptional, Vector +from typing import Optional, List, get_args +import dataclasses -def write_optional(destination, optional, continuation, *args, **kwargs): +@writer(predicate=inheritsfrom(np.number)) +def write_numpy_number(source, number, num_type): + source.write(num_type(number).tobytes()) + + +@writer(predicate=dataclasses.is_dataclass) +def write_dataclass(source, dataclass_obj, class_type): + for dim in dataclasses.fields(class_type): + write(source, getattr(dataclass_obj, dim.name), dim.type) + + +@writer(predicate=isstruct) +def write_struct(destination, value, struct_type): + destination.write(struct_type.pack(value)) + + +@writer(predicate=inheritsfrom(ctypes.Structure)) +def write_cstruct(destination, value, obj_type): + destination.write(value) + + +@writer(predicate=isoptional) +def write_optional(destination, optional, obj_type): + subtype = get_args(obj_type)[0] if optional is None: destination.write(constants.bool.pack(False)) else: destination.write(constants.bool.pack(True)) - continuation(destination, optional, *args, **kwargs) + write(destination, optional, subtype) -def write_vector(destination, values, type=constants.uint64): +@writer(predicate=isgeneric(set)) +def write_set(destination, values, obj_type): + subtype = get_args(obj_type)[0] destination.write(constants.uint64.pack(len(values))) for val in values: - destination.write(type.pack(val)) + write(destination, val, subtype) -def write_array(destination, array, dtype): - write_vector(destination, array.shape) - array_view = np.array(array,dtype=dtype,copy=False) - destination.write(array_view.tobytes(order='F')) +@writer(predicate=isgeneric(Vector)) +def write_vector(destination, values, obj_type): + subtype = get_args(obj_type)[0] + destination.write(constants.uint64.pack(len(values))) + __writer_array_content(destination,values,subtype) + +@writer(predicate=isgeneric(list)) +def write_list(destination, values, obj_type): + subtype = get_args(obj_type)[0] + destination.write(constants.uint64.pack(len(values))) + for val in values: + write(destination, val, subtype) -def write_object_array(destination, array, writer, *args, **kwargs): - write_vector(destination, array.shape) - for item in np.nditer(array, ('refs_ok', 'zerosize_ok'), order='F'): - item = item.item() # Get rid of the numpy 0-dimensional array. - writer(destination, item, *args, **kwargs) +def __writer_array_content(destination, array, data_type): + dtype = np.dtype(data_type) + if dtype == object or not dtype.isbuiltin: + for item in np.nditer(array, ('refs_ok', 'zerosize_ok'), order='F'): + item = item.item() # Get rid of the numpy 0-dimensional array. + write(destination, item, data_type) + else: + array_view = np.array(array, dtype=dtype, copy=False) + destination.write(array_view.tobytes(order='F')) + +@writer(predicate=isgeneric(NDArray)) +def write_array(destination, array, array_type): + write_list(destination, array.shape, List[np.uint64]) + subtype = get_args(array_type)[0] + __writer_array_content(destination,array,subtype) + +@writer(ismrmrd.AcquisitionHeader) def write_acquisition_header(destination, header): destination.write(header) +@writer(ismrmrd.ImageHeader) def write_image_header(destination, header): destination.write(header) @@ -44,19 +95,20 @@ def write_byte_string(destination, byte_string, type=constants.uint32): destination.write(byte_string) +@writer(ismrmrd.Acquisition) def write_acquisition(destination, acquisition): - message_id_bytes = constants.GadgetMessageIdentifier.pack(constants.GADGET_MESSAGE_ISMRMRD_ACQUISITION) - destination.write(message_id_bytes) acquisition.serialize_into(destination.write) +@writer(ismrmrd.Waveform) def write_waveform(destination, waveform): - message_id_bytes = constants.GadgetMessageIdentifier.pack(constants.GADGET_MESSAGE_ISMRMRD_WAVEFORM) - destination.write(message_id_bytes) waveform.serialize_into(destination.write) +@writer(ismrmrd.Image) def write_image(destination, image): - message_id_bytes = constants.GadgetMessageIdentifier.pack(constants.GADGET_MESSAGE_ISMRMRD_IMAGE) - destination.write(message_id_bytes) image.serialize_into(destination.write) + +@writer(str) +def write_str(destination, string : str ): + write_byte_string(destination, string.encode("utf-8"), type=constants.uint64) \ No newline at end of file diff --git a/gadgetron/legacy/gadget.py b/gadgetron/legacy/gadget.py index 6c207f2..7fc6987 100644 --- a/gadgetron/legacy/gadget.py +++ b/gadgetron/legacy/gadget.py @@ -67,7 +67,7 @@ def __init__(self, *args, **kwargs): def handle(self, connection): self.connection = connection self.params = _parse_params(connection.config) - self.process_config(connection.raw_bytes.header) + self.process_config(connection.header.toXML()) def invoke_process(process, args): if not args: diff --git a/gadgetron/types/__init__.py b/gadgetron/types/__init__.py index 8c69a96..b81f703 100644 --- a/gadgetron/types/__init__.py +++ b/gadgetron/types/__init__.py @@ -2,5 +2,4 @@ from .image_array import ImageArray from .acquisition_bucket import AcquisitionBucket from .recon_data import ReconData - __all__ = [ImageArray] diff --git a/gadgetron/types/acquisition_bucket.py b/gadgetron/types/acquisition_bucket.py index 38a2226..845d756 100644 --- a/gadgetron/types/acquisition_bucket.py +++ b/gadgetron/types/acquisition_bucket.py @@ -1,29 +1,27 @@ - import ctypes -import logging +import ismrmrd import numpy as np from ismrmrd import Acquisition, Waveform -from ..external.constants import uint64 -from gadgetron.external.readers import read, read_acquisition_header, read_vector, read_waveform_header -from gadgetron.external.writers import write_optional, write_array, write_object_array, write_acquisition_header +from gadgetron.external.constants import uint64 +from gadgetron.types.serialization import NDArray, read, reader +from typing import Optional, List, Set +from dataclasses import dataclass, field +@dataclass class AcquisitionBucketStats: - - def __init__(self, kspace_encode_step_1={}, kspace_encode_step_2={}, slice={}, phase={}, contrast={}, repetition={}, - set={}, segment={}, average={}): - self.kspace_encode_step_1 = kspace_encode_step_1 - self.kspace_encode_step_2 = kspace_encode_step_2 - self.contrast = contrast - self.slice = slice - self.phase = phase - self.repetition = repetition - self.segment = segment - self.average = average - self.set = set + kspace_encode_step_1: Set[np.uint16] = field(default_factory=set) + kspace_encode_step_2: Set[np.uint16] = field(default_factory=set) + contrast: Set[np.uint16] = field(default_factory=set) + slice: Set[np.uint16] = field(default_factory=set) + phase: Set[np.uint16] = field(default_factory=set) + repetition: Set[np.uint16] = field(default_factory=set) + segment: Set[np.uint16] = field(default_factory=set) + average: Set[np.uint16] = field(default_factory=set) + set: Set[np.uint16] = field(default_factory=set) class AcquisitionBucket: @@ -69,48 +67,40 @@ class bucket_meta(ctypes.Structure): ] -def read_bucketstats(source): - count = read(source, uint64) - return [AcquisitionBucketStats(*[{s for s in read_vector(source, np.uint16)} - for _ in range(9)]) - for _ in range(count)] - - -def read_waveforms(source, sizes): - headers = [read_waveform_header(source) for _ in range(sizes.count)] - data_arrays = [read_data_as_array(source, np.uint32, (header.channels, header.number_of_samples)) +def __read_waveforms(source, sizes): + headers = [read(source, ismrmrd.WaveformHeader) for _ in range(sizes.count)] + data_arrays = [__read_data_as_array(source, np.uint32, (header.channels, header.number_of_samples)) for header in headers] return [Waveform(head, data) for head, data in zip(headers, data_arrays)] -def read_data_as_array(source, data_type, shape): +def __read_data_as_array(source, data_type, shape): dtype = np.dtype(data_type) bytesize = np.prod(shape) * dtype.itemsize return np.reshape(np.frombuffer(source.read(bytesize), dtype), shape) -def read_acquisitions(source, sizes): - headers = [read_acquisition_header(source) for _ in range(sizes.count)] +def __read_acquisitions(source, sizes): + headers = [read(source, ismrmrd.AcquisitionHeader) for _ in range(sizes.count)] - trajectories = [read_data_as_array(source, np.float32, (head.number_of_samples, head.trajectory_dimensions)) + trajectories = [__read_data_as_array(source, np.float32, (head.number_of_samples, head.trajectory_dimensions)) if head.trajectory_dimensions > 0 else None for head in headers] - acqs = [read_data_as_array(source, np.complex64, (head.active_channels, head.number_of_samples)) + acqs = [__read_data_as_array(source, np.complex64, (head.active_channels, head.number_of_samples)) for head in headers] return [Acquisition(header, data, trajectory) for header, data, trajectory in zip(headers, acqs, trajectories)] +@reader(AcquisitionBucket) def read_acquisition_bucket(source): meta = bucket_meta.from_buffer_copy(source.read(ctypes.sizeof(bucket_meta))) return AcquisitionBucket( - read_acquisitions(source, meta.data), - read_bucketstats(source), - read_acquisitions(source, meta.reference), - read_bucketstats(source), - read_waveforms(source, meta.waveforms) + __read_acquisitions(source, meta.data), + read(source, List[AcquisitionBucketStats]), + __read_acquisitions(source, meta.reference), + read(source, List[AcquisitionBucketStats]), + __read_waveforms(source, meta.waveforms) ) - - diff --git a/gadgetron/types/image_array.py b/gadgetron/types/image_array.py index b043068..381be39 100644 --- a/gadgetron/types/image_array.py +++ b/gadgetron/types/image_array.py @@ -1,69 +1,18 @@ +import ismrmrd import numpy as np -from ..external import readers +from ..external import readers, constants from ..external import writers -from ..external import constants +import dataclasses +from ..types.serialization import NDArray, Vector +from typing import List, Optional +@dataclasses.dataclass class ImageArray: - def __init__(self, data=None, headers=None, meta=None, waveform=None, acq_headers=None): - self.data = data - self.headers = headers - self.meta = meta - self.waveform = waveform - self.acq_headers = acq_headers + data: NDArray[np.complex64] = np.zeros(0,dtype=np.complex64) + headers: NDArray[ismrmrd.ImageHeader] = np.array([],dtype=object) + meta: List[str] = dataclasses.field(default_factory=list) + waveform: Optional[Vector[ismrmrd.Waveform]] = None + acq_headers: Optional[Vector[ismrmrd.AcquisitionHeader]] = None - - -def read_meta_container(source): - return readers.read_byte_string(source, constants.uint64).decode('ascii') - - -def read_meta_container_vector(source): - size = readers.read(source, constants.uint64) - return [read_meta_container(source) for _ in range(size)] - - -def read_waveforms(source): - size = readers.read(source, constants.uint64) - return [readers.read_waveform(source) for _ in range(size)] - - -def read_image_array(source): - return ImageArray( - data=readers.read_array(source, np.complex64), - headers=readers.read_object_array(source, readers.read_image_header), - meta=read_meta_container_vector(source), - waveform=readers.read_optional(source, read_waveforms), - acq_headers=readers.read_optional( - source, readers.read_object_array, readers.read_acquisition_header) - ) - - -def write_meta_container(destination, container): - writers.write_byte_string( - destination, container.encode('ascii'), constants.uint64) - - -def write_meta_container_vector(destination, containers): - destination.write(constants.uint64.pack(len(containers))) - for container in containers: - write_meta_container(destination, container) - - -def write_waveforms(destination, waveforms): - destination.write(constants.uint64.pack(len(waveforms))) - for waveform in waveforms: - writers.write_waveform(destination, waveform) - - -def write_image_array(destination, image_array): - destination.write(constants.GadgetMessageIdentifier.pack( - constants.GADGET_MESSAGE_IMAGE_ARRAY)) - writers.write_array(destination, image_array.data, np.complex64) - writers.write_object_array( - destination, image_array.headers, writers.write_image_header) - write_meta_container_vector(destination, image_array.meta) - writers.write_optional(destination, image_array.waveform, write_waveforms) - writers.write_optional(destination, image_array.acq_headers, - writers.write_object_array, writers.write_acquisition_header) diff --git a/gadgetron/types/recon_data.py b/gadgetron/types/recon_data.py index 8858a2c..477e0c3 100644 --- a/gadgetron/types/recon_data.py +++ b/gadgetron/types/recon_data.py @@ -1,11 +1,13 @@ - -import numpy +import ismrmrd +import numpy as np import struct import ctypes -from gadgetron.external.readers import read, read_optional, read_array, read_object_array, read_acquisition_header -from gadgetron.external.writers import write_optional, write_array, write_object_array, write_acquisition_header from gadgetron.external.constants import uint64, GadgetMessageIdentifier, GADGET_MESSAGE_RECON_DATA +from gadgetron.types.serialization import NDArray +import dataclasses +from typing import Optional, List + uint16 = struct.Struct('=1.15.1', 'ismrmrd>=1.6.2', 'pyFFTW>=0.11', - 'multimethod >= 1.0' + 'multimethod>=1.0', + 'requests>=2.24' ] ) diff --git a/test/random_data.py b/test/random_data.py new file mode 100644 index 0000000..e441b08 --- /dev/null +++ b/test/random_data.py @@ -0,0 +1,130 @@ + +import ismrmrd + +import numpy as np +import numpy.random + + + +def random_32bit_float(): + return numpy.random.rand(1).astype(np.float32) + +def random_int(dtype,size=None): + dinfo = np.iinfo(dtype) + return numpy.random.randint(dinfo.min, dinfo.max,dtype=dtype,size=size) + +def random_tuple(size, random_fn): + return tuple([random_fn() for _ in range(0, size)]) + + +def create_random_acquisition_properties(): + return { + 'flags': np.random.randint(0, 1 << 64,dtype=np.uint64), + 'measurement_uid': np.random.randint(0, 1 << 32,dtype=np.uint32), + 'scan_counter': np.random.randint(0, 1 << 32,dtype=np.uint32 ), + 'acquisition_time_stamp': random_int(np.uint32), + 'physiology_time_stamp': random_tuple(3, lambda: random_int(np.uint32)), + 'available_channels': random_int(np.uint16), + 'channel_mask': random_tuple(16, lambda: np.random.randint(0, 1 << 64,dtype=np.uint64)), + 'discard_pre': random_int(np.uint16), + 'discard_post': random_int(np.uint16), + 'center_sample': random_int(np.uint16), + 'encoding_space_ref': random_int(np.uint16), + 'sample_time_us': random_32bit_float(), + 'position': random_tuple(3, random_32bit_float), + 'read_dir': random_tuple(3, random_32bit_float), + 'phase_dir': random_tuple(3, random_32bit_float), + 'slice_dir': random_tuple(3, random_32bit_float), + 'patient_table_position': random_tuple(3, random_32bit_float), + 'idx': ismrmrd.EncodingCounters(), + 'user_int': random_tuple(8, lambda: random_int(np.int32)), + 'user_float': random_tuple(8, random_32bit_float) + } + + +def create_random_image_properties(): + return { + 'flags': random_int(np.uint64), + 'measurement_uid': random_int(np.uint32), + 'field_of_view': random_tuple(3, random_32bit_float), + 'position': random_tuple(3, random_32bit_float), + 'read_dir': random_tuple(3, random_32bit_float), + 'phase_dir': random_tuple(3, random_32bit_float), + 'slice_dir': random_tuple(3, random_32bit_float), + 'patient_table_position': random_tuple(3, random_32bit_float), + 'average': random_int(np.uint16), + 'slice': random_int(np.uint16), + 'contrast': random_int(np.uint16), + 'phase': random_int(np.uint16), + 'repetition': random_int(np.uint16), + 'set': random_int(np.uint16), + 'acquisition_time_stamp': random_int(np.uint32), + 'physiology_time_stamp': random_tuple(3, lambda: random_int(np.uint32)), + 'image_index': random_int(np.uint16), + 'image_series_index': random_int(np.uint16), + 'user_int': random_tuple(8, lambda: random_int(np.int32)), + 'user_float': random_tuple(8, random_32bit_float), + } + + +def create_random_waveform_properties(): + return { + 'flags': random_int(np.uint64), + 'measurement_uid': random_int(np.uint32), + 'waveform_id': random_int(np.uint16), + 'scan_counter': random_int(np.uint32), + 'time_stamp': random_int(np.uint32), + 'sample_time_us': random_32bit_float() + } + + +def create_random_array(shape, dtype): + array = numpy.random.random_sample(shape) + return array.astype(dtype) + + +def create_random_data(shape=(32, 256)): + array = numpy.random.random_sample(shape) + 1j * numpy.random.random_sample(shape) + return array.astype(np.complex64) + + +def create_random_trajectory(shape=(256, 2)): + return create_random_array(shape, dtype=np.float32) + + +def create_random_waveform_data(shape=(32, 256)): + data = numpy.np.random.randint(0, 1 << 32, size=shape) + return data.astype(np.uint32) + + +def create_random_acquisition(): + data = create_random_data((32, 256)) + traj = create_random_trajectory((256, 2)) + header = create_random_acquisition_properties() + + return ismrmrd.Acquisition.from_array(data, traj, **header) + + +def create_random_image(): + + data = create_random_array((256, 256), dtype=np.float32) + header = create_random_image_properties() + + image = ismrmrd.Image.from_array(data, **header) + image.meta = {f"Random_{i}" : random_int(np.int64) for i in range(10)} + + return image + +def create_random_image_header(): + return ismrmrd.ImageHeader(**create_random_image_properties()) + +def create_random_acquisition_header(): + return ismrmrd.AcquisitionHeader(**create_random_acquisition_properties()) + + +def create_random_waveform(): + + data = random_int(np.uint32, size=(4, 256)) + header = create_random_waveform_properties() + + return ismrmrd.Waveform.from_array(data, **header) diff --git a/test/test_connection.py b/test/test_connection.py new file mode 100644 index 0000000..ffc0db8 --- /dev/null +++ b/test/test_connection.py @@ -0,0 +1,130 @@ + +import socket +from ismrmrd.xsd import ismrmrdHeader +import ismrmrd +import gadgetron +import numpy as np +from gadgetron.external import Connection +import concurrent.futures as cf +import xml.etree.ElementTree as xml +import random_data as rd + +def client(socket, testdata): + config = xml.fromstring("<_/>") + header = ismrmrd.xsd.CreateFromDocument(sample_header) + with Connection.initiate_connection(socket,config,header) as connection: + connection.send(testdata) + connection.close() + id,item = next(connection) + + assert item == testdata + +def parrot_server(socket): + with Connection(socket) as connection: + for item in connection: + connection.send(item) + +def run_connection_test(testdata): + sock1, sock2 = socket.socketpair() + sock1.setblocking(True) + sock2.setblocking(True) + with cf.ProcessPoolExecutor(max_workers=2) as executor: + future1 = executor.submit(client, sock1,testdata) + future2 = executor.submit(parrot_server, sock2) + future1.result() + future2.result() + +def test_acquisitions(): + acq = rd.create_random_acquisition() + run_connection_test(acq) + +def test_images(): + img = rd.create_random_image() + run_connection_test(img) + +def test_waveforms(): + wav = rd.create_random_waveform() + run_connection_test(wav) + + + +sample_header = ''' + + + phantom + 70.3068 + + + SIEMENS + Avanto + 1.494 + 32 + 0.79 + + + 63642459 + + + cartesian + + + 256 + 140 + 80 + + + 600 + 328.153125 + 160 + + + + + 128 + 116 + 64 + + + 300 + 271.875 + 128 + + + + + 0 + 83 +
28
+
+ + 0 + 45 +
20
+
+ + 0 + 0 +
0
+
+ + 0 + 0 +
0
+
+
+ + + 1 + 1 + + other + +
+ + + 4.6 + 2.35 + 300 + +
''' + diff --git a/test/types/test_types.py b/test/types/test_types.py new file mode 100644 index 0000000..c776f5c --- /dev/null +++ b/test/types/test_types.py @@ -0,0 +1,72 @@ +from gadgetron.types.serialization import NDArray +from gadgetron.types import serialization +from gadgetron.types.recon_data import SamplingDescription +from gadgetron.types.image_array import ImageArray + +import ismrmrd +import numpy as np +import dataclasses + +from typing import List, Optional + +from io import BytesIO + + +@dataclasses.dataclass +class SimpleTestClass: + data: NDArray[np.float32] + mdata: NDArray[np.float64] + headers: List[ismrmrd.AcquisitionHeader] + + +def roundtrip_serialization(obj, obj_type): + buffer = BytesIO() + serialization.write(buffer, obj, obj_type) + buffer.seek(0) + return serialization.read(buffer, obj_type) + + +def test_dataclass(): + a = SimpleTestClass(np.zeros((2, 2), dtype=np.float32), np.ones((1, 1, 3), dtype=np.float64), []) + b = roundtrip_serialization(a, SimpleTestClass) + + assert np.equal(a.data, b.data, casting='no').all() + assert np.equal(a.mdata, b.mdata, casting='no').all() + assert a.headers == b.headers + + +def test_acquisition(): + data = np.array(np.random.normal(0, 10, size=(12, 128)), dtype=np.complex64) + traj = np.array(np.random.normal(0, 10, size=(128, 2)), dtype=np.float32) + + a = ismrmrd.Acquisition.from_array(data, traj) + + b = roundtrip_serialization(a, ismrmrd.Acquisition) + assert a == b + + +def test_optional(): + a = np.array(np.random.random((1, 2, 3)), dtype=np.complex128) + + b = roundtrip_serialization(a, Optional[NDArray[np.complex128]]) + + assert np.equal(a, b, casting='no').all() + + +def test_sampling_description(): + a = SamplingDescription() + b = roundtrip_serialization(a, SamplingDescription) + + for field in a._fields_: + assert (np.array(getattr(a, field[0])) == np.array(getattr(b, field[0]))).all() + + +def test_image_array(): + a = ImageArray() + a.acq_headers = np.array([ismrmrd.AcquisitionHeader() for k in range(20)], dtype=object) + + b = roundtrip_serialization(a, ImageArray) + + assert np.equal(a.data, b.data, casting='no').all() + assert np.equal(a.headers, b.headers, casting='no').all() + assert a.meta == b.meta