Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds tests and a dataclass based serialization framework #7

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Python package

on:
push:
branches: [ master ]
pull_request:
branches: [ master ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
1 change: 0 additions & 1 deletion gadgetron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

__all__ = [
util,
legacy,
external,
examples,
Gadget,
Expand Down
2 changes: 2 additions & 0 deletions gadgetron/external/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@

from .connection import Connection
from .listen import listen
from .readers import read
from .writers import write

__all__ = [Connection, listen]
106 changes: 66 additions & 40 deletions gadgetron/external/connection.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@

import socket
import logging
from socket import MSG_WAITALL

import xml.etree.ElementTree as xml

import ismrmrd

from . import constants

from .readers import read, read_byte_string, read_acquisition, read_waveform, read_image
from .writers import write_acquisition, write_waveform, write_image
from .readers import read, read_byte_string
from .writers import write, write_byte_string

from ..types.image_array import ImageArray
from ..types.recon_data import ReconData
from ..types.acquisition_bucket import AcquisitionBucket

from ..types.image_array import ImageArray, read_image_array, write_image_array
from ..types.recon_data import ReconData, read_recon_data, write_recon_data
from ..types.acquisition_bucket import read_acquisition_bucket
from ..types.serialization import message_reader, message_writer
from ..types import serialization


class Connection:
Expand All @@ -27,34 +29,41 @@ def __init__(self, socket):
self.socket.settimeout(None)

def read(self, nbytes):
bytes = self.socket.recv(nbytes, socket.MSG_WAITALL)
while len(bytes) < nbytes:
bytes += self.socket.recv(nbytes - len(bytes),socket.MSG_WAITALL)
return bytes
bytedata = self.socket.recv(nbytes, MSG_WAITALL)
while len(bytedata) < nbytes:
bytedata += self.socket.recv(nbytes - len(bytedata), MSG_WAITALL)
return bytedata

def write(self, byte_array):
self.socket.sendall(byte_array)

def close(self):
end = constants.GadgetMessageIdentifier.pack(constants.GADGET_MESSAGE_CLOSE)
self.socket.send(end)
self.socket.close()

class Struct:
def __init__(self, **fields):
self.__dict__.update(fields)

def __init__(self, socket):
@staticmethod
def initiate_connection(socket, config, header):

connection = Connection(socket,config,header)
connection._write_config(config)
connection._write_header(header)

return connection

def __init__(self, socket, config=None, header=None):
self.socket = Connection.SocketWrapper(socket)

self.readers = Connection._default_readers()
self.writers = Connection._default_writers()

self.raw_bytes = Connection.Struct(config=None, header=None)
self.config, self.raw_bytes.config = self._read_config()
self.header, self.raw_bytes.header = self._read_header()
self.config = config if config is not None else self._read_config()
self.header = header if header is not None else self._read_header()

self.filters = []
self.__closed = False

def __next__(self):
return self.next()
Expand All @@ -63,6 +72,7 @@ def __enter__(self):
return self

def __exit__(self, *exception_info):
self.close()
self.socket.close()

def __iter__(self):
Expand Down Expand Up @@ -96,7 +106,7 @@ def add_reader(self, mid, reader, *args, **kwargs):
def add_writer(self, accepts, writer, *args, **kwargs):
""" Add a writer to the connection's writers.

:param accepts: Predicate used to determine if a writer accepts an item.
aparam accepts: Predicate used to determine if a writer accepts an item.
:param writer: Writer function to be called when `accepts` predicate returned truthy value.
:param args: Additional arguments. These are forwarded to the writer when it's called.
:param kwargs: Additional keyword-arguments. These are forwarded to the writer when it's called.
Expand Down Expand Up @@ -161,6 +171,12 @@ def next(self):

return mid, item

def close(self):
if not self.__closed:
end = constants.GadgetMessageIdentifier.pack(constants.GADGET_MESSAGE_CLOSE)
self.socket.write(end)
self.__closed = True

def _read_item(self):
message_identifier = self._read_message_identifier()

Expand All @@ -176,47 +192,57 @@ def _read_message_identifier(self):

def _read_config(self):
message_identifier = self._read_message_identifier()
assert(message_identifier == constants.GADGET_MESSAGE_CONFIG)
assert (message_identifier == constants.GADGET_MESSAGE_CONFIG)
config_bytes = read_byte_string(self.socket)

try:
parsed_config = xml.fromstring(config_bytes)
parsed_config = xml.fromstring(config_bytes)
except xml.ParseError as e:
logging.log(logging.WARN,"Config parsing failed with error message {}".format(e))
parsed_config = None
logging.warning(f"Config parsing failed with error message {e}")
parsed_config = None

return parsed_config

return parsed_config, config_bytes
def _write_config(self, config):
serialization.write(self.socket, constants.GADGET_MESSAGE_CONFIG, constants.GadgetMessageIdentifier)
write_byte_string(self.socket, xml.tostring(config, encoding='utf-8', method='xml'))

def _read_header(self):
message_identifier = self._read_message_identifier()
assert(message_identifier == constants.GADGET_MESSAGE_HEADER)
assert (message_identifier == constants.GADGET_MESSAGE_HEADER)
header_bytes = read_byte_string(self.socket)
return ismrmrd.xsd.CreateFromDocument(header_bytes), header_bytes
return ismrmrd.xsd.CreateFromDocument(header_bytes)

@ staticmethod
def _write_header(self, header: ismrmrd.xsd.ismrmrdHeader):
serialization.write(self.socket, constants.GADGET_MESSAGE_HEADER, constants.GadgetMessageIdentifier)
write_byte_string(self.socket, ismrmrd.xsd.ToXML(header).encode('utf-8'))

@staticmethod
def _default_readers():
return {
constants.GADGET_MESSAGE_CLOSE: Connection.stop_iteration,
constants.GADGET_MESSAGE_ISMRMRD_ACQUISITION: read_acquisition,
constants.GADGET_MESSAGE_ISMRMRD_WAVEFORM: read_waveform,
constants.GADGET_MESSAGE_ISMRMRD_IMAGE: read_image,
constants.GADGET_MESSAGE_IMAGE_ARRAY: read_image_array,
constants.GADGET_MESSAGE_RECON_DATA: read_recon_data,
constants.GADGET_MESSAGE_BUCKET: read_acquisition_bucket
constants.GADGET_MESSAGE_ISMRMRD_ACQUISITION: message_reader(ismrmrd.Acquisition),
constants.GADGET_MESSAGE_ISMRMRD_WAVEFORM: message_reader(ismrmrd.Waveform),
constants.GADGET_MESSAGE_ISMRMRD_IMAGE: message_reader(ismrmrd.Image),
constants.GADGET_MESSAGE_IMAGE_ARRAY: message_reader(ImageArray),
constants.GADGET_MESSAGE_RECON_DATA: message_reader(ReconData),
constants.GADGET_MESSAGE_BUCKET: message_reader(AcquisitionBucket)
}

@ staticmethod
@staticmethod
def _default_writers():
def create_writer(message_id, obj_type):
return lambda item: isinstance(item, obj_type), message_writer(message_id, obj_type)

return [
(lambda item: isinstance(item, ismrmrd.Acquisition), write_acquisition),
(lambda item: isinstance(item, ismrmrd.Waveform), write_waveform),
(lambda item: isinstance(item, ismrmrd.Image), write_image),
(lambda item: isinstance(item, ImageArray), write_image_array),
(lambda item: isinstance(item, ReconData), write_recon_data)
create_writer(constants.GADGET_MESSAGE_ISMRMRD_ACQUISITION, ismrmrd.Acquisition),
create_writer(constants.GADGET_MESSAGE_ISMRMRD_WAVEFORM, ismrmrd.Waveform),
create_writer(constants.GADGET_MESSAGE_ISMRMRD_IMAGE, ismrmrd.Image),
create_writer(constants.GADGET_MESSAGE_IMAGE_ARRAY, ImageArray),
create_writer(constants.GADGET_MESSAGE_RECON_DATA, ReconData)
]

@ staticmethod
@staticmethod
def stop_iteration(_):
logging.debug("Connection closed normally.")
raise StopIteration()

11 changes: 8 additions & 3 deletions gadgetron/external/listen.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@

import logging

import os
import socket
import logging

from . import connection




def wait_for_client_connection(port):

sock = socket.socket(family=socket.AF_INET6)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('', port))
sock.listen(0)
conn, address = sock.accept()
sock.close()

logging.info(f"Accepted connection from client: {address}")

Expand All @@ -30,5 +33,7 @@ def listen(port, handler, *args, **kwargs):
logging.debug(f"Starting external Python module '{handler.__name__}' in state: [PASSIVE]")
logging.debug(f"Waiting for connection from client on port: {port}")

with connection.Connection(wait_for_client_connection(port)) as conn:
storage_address = kwargs.get('storage_address', os.environ.get("GADGETRON_STORAGE_ADDRESS", None))

with connection.Connection(wait_for_client_connection(port), storage_address) as conn:
handler(conn, *args, **kwargs)
Loading