From 2cb41639d395aca69f3565cce958ffff9ad2807c Mon Sep 17 00:00:00 2001 From: Cameron Jackson Date: Fri, 8 Mar 2024 07:35:02 -0800 Subject: [PATCH] Issue #517 - TCP input support - Adds a TCP client and server option to input streams - Updates docunmentation with new input specifications - Adds client specific tests - Adds additional stream tests - Formatting --- .pre-commit-config.yaml | 4 +- ait/core/server/client.py | 212 +++++++++++++++++++++++-- ait/core/server/server.py | 121 +++++++------- ait/core/server/stream.py | 119 +++++++++++++- doc/source/server_architecture.rst | 39 ++++- tests/ait/core/server/test_client.py | 86 ++++++++++ tests/ait/core/server/test_stream.py | 228 ++++++++++++++++++++++----- 7 files changed, 694 insertions(+), 115 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ab8d5986..d2cbdaf6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: rev: v2.6.0 hooks: - id: reorder-python-imports - files: ^src/|test/ + files: ^ait/|tests/ - repo: local hooks: @@ -27,7 +27,7 @@ repos: - id: black name: black entry: black - files: ^src/|test/ + files: ^ait/|tests/ language: system types: [python] diff --git a/ait/core/server/client.py b/ait/core/server/client.py index 6c673caf..0dbbcb2f 100644 --- a/ait/core/server/client.py +++ b/ait/core/server/client.py @@ -1,7 +1,6 @@ -import gevent -import gevent.socket -import gevent.server as gs import gevent.monkey +import gevent.server as gs +import gevent.socket gevent.monkey.patch_all() @@ -27,13 +26,12 @@ def __init__( zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, **kwargs, ): - self.context = zmq_context # open PUB socket & connect to broker self.pub = self.context.socket(zmq.PUB) self.pub.connect(zmq_proxy_xsub_url.replace("*", "localhost")) - if 'listener' in kwargs and isinstance(kwargs['listener'], int) : - kwargs['listener'] = "127.0.0.1:"+str(kwargs['listener']) + if "listener" in kwargs and isinstance(kwargs["listener"], int): + kwargs["listener"] = "127.0.0.1:" + str(kwargs["listener"]) # calls gevent.Greenlet or gs.DatagramServer __init__ super(ZMQClient, self).__init__(**kwargs) @@ -89,7 +87,6 @@ def __init__( zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, **kwargs, ): - super(ZMQInputClient, self).__init__( zmq_context, zmq_proxy_xsub_url, zmq_proxy_xpub_url ) @@ -134,7 +131,6 @@ def __init__( zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, **kwargs, ): - super(PortOutputClient, self).__init__( zmq_context, zmq_proxy_xsub_url, zmq_proxy_xpub_url ) @@ -162,7 +158,6 @@ def __init__( zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, **kwargs, ): - if "input" in kwargs and type(kwargs["input"][0]) is int: super(PortInputClient, self).__init__( zmq_context, @@ -180,3 +175,202 @@ def handle(self, packet, address): # This function provided for gs.DatagramServer class log.debug("{} received message from port {}".format(self, address)) self.process(packet) + + +class TCPInputServer(ZMQClient, gs.StreamServer): + """ + This class is similar to PortInputClient except its TCP instead of UDP. + """ + + def __init__( + self, + zmq_context, + zmq_proxy_xsub_url=ait.SERVER_DEFAULT_XSUB_URL, + zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, + buffer=1024, + **kwargs, + ): + self.cur_socket = None + self.buffer = buffer + if "input" in kwargs: + if ( + type(kwargs["input"]) not in [tuple, list] + or kwargs["input"][0].lower() != "server" + or type(kwargs["input"][1]) != int + ): + raise ( + ValueError( + "TCPInputServer input must be tuple|list of (str,int) e.g. ('server',1234)" + ) + ) + + self.sub = gevent.socket.socket( + gevent.socket.AF_INET, gevent.socket.SOCK_STREAM + ) + super(TCPInputServer, self).__init__( + zmq_context, + zmq_proxy_xsub_url, + zmq_proxy_xpub_url, + listener=("127.0.0.1", kwargs["input"][1]), + ) + else: + raise ( + ValueError( + "TCPInputServer input must be tuple|list of (str,int) e.g. ('server',1234)" + ) + ) + + def handle(self, socket, address): + self.cur_socket = socket + with socket: + while True: + data = socket.recv(self.buffer) + if not data: + break + log.debug("{} received message from port {}".format(self, address)) + self.process(data) + + +class TCPInputClient(ZMQClient): + """ + This class creates a TCP input client. Unlike TCPInputServer and PortInputClient, + this class will proactively initiate a connection with an input source and begin + receiving data from that source. This class does not inherit directly from gevent + servers and thus implements its own housekeeping functions. It also implements a + start function that spawns a process to stay consistent with the behavior of + TCPInputServer and PortInputClient. + + """ + + def __init__( + self, + zmq_context, + zmq_proxy_xsub_url=ait.SERVER_DEFAULT_XSUB_URL, + zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, + connection_reattempts=5, + buffer=1024, + **kwargs, + ): + self.connection_reattempts = connection_reattempts + self.buffer = buffer + self.connection_status = -1 + self.proc = None + + if "buffer" in kwargs and type(kwargs["buffer"]) == int: + self.buffer = kwargs["buffer"] + + if "input" in kwargs: + super(TCPInputClient, self).__init__( + zmq_context, zmq_proxy_xsub_url, zmq_proxy_xpub_url + ) + if ( + type(kwargs["input"]) not in [tuple, list] + or type(kwargs["input"][0]) != str + or type(kwargs["input"][1]) != int + ): + raise ( + ValueError( + "TCPInputClient 'input' must be tuple|list of (str,int) e.g. ('127.0.0.1',1234)" + ) + ) + self.sub = gevent.socket.socket( + gevent.socket.AF_INET, gevent.socket.SOCK_STREAM + ) + + self.host = kwargs["input"][0] + self.port = kwargs["input"][1] + self.address = tuple(kwargs["input"]) + + else: + raise ( + ValueError( + "TCPInputClient 'input' must be tuple of (str,int) e.g. ('127.0.0.1',1234)" + ) + ) + + def __exit__(self): + try: + if self.sub: + self.sub.close() + if self.proc: + self.proc.kill() + except Exception as e: + log.error(e) + + def __del__(self): + try: + if self.sub: + self.sub.close() + if self.proc: + self.proc.kill() + except Exception as e: + log.error(e) + + def __repr__(self): + return "<%s at %s %s>" % ( + type(self).__name__, + hex(id(self)), + self._formatinfo(), + ) + + def __str__(self): + return "<%s %s>" % (type(self).__name__, self._formatinfo()) + + def start(self): + self.proc = gevent.spawn(self._client).join() + + def _connect(self): + while self.connection_reattempts: + try: + res = self.sub.connect_ex((self.host, self.port)) + if res == 0: + self.connection_reattempts = 5 + return res + else: + self.connection_reattempts -= 1 + gevent.sleep(1) + except Exception as e: + log.error(e) + self.connection_reattempts -= 1 + gevent.sleep(1) + + def _exit(self): + try: + if self.sub: + self.sub.close() + if self.proc: + self.proc.kill() + except Exception as e: + log.error(e) + + def _client(self): + self.connection_status = self._connect() + if self.connection_status != 0: + log.error( + f"Unable to connect to client: {self.address[0]}:{self.address[1]}" + ) + self._exit() + while True: + packet = self.sub.recv(self.buffer) + if not packet: + gevent.sleep(1) + log.info( + f"Trying to reconnect to client: {self.address[0]}:{self.address[1]}" + ) + if self._connect() != 0: + log.error( + f"Unable to connect to client: {self.address[0]}:{self.address[1]}" + ) + self._exit() + self.process(packet) + + def _formatinfo(self): + result = "" + try: + if isinstance(self.address, tuple) and len(self.address) == 2: + result += "address=%s:%s" % self.address + else: + result += "address=%s" % (self.address,) + except Exception as ex: + result += str(ex) or "" + return result diff --git a/ait/core/server/server.py b/ait/core/server/server.py index 29bb8e6d..cd1e47da 100644 --- a/ait/core/server/server.py +++ b/ait/core/server/server.py @@ -1,17 +1,24 @@ -import gevent -import gevent.monkey - -from importlib import import_module import sys import traceback +from importlib import import_module + +import gevent.monkey import ait.core.server -from .stream import PortInputStream, ZMQStream, PortOutputStream -from .config import ZmqConfig from .broker import Broker -from .plugin import PluginType, Plugin, PluginConfig +from .config import ZmqConfig +from .plugin import Plugin +from .plugin import PluginConfig +from .plugin import PluginType from .process import PluginsProcess -from ait.core import log, cfg +from .stream import input_stream_factory +from .stream import PortInputStream +from .stream import PortOutputStream +from .stream import TCPInputClientStream +from .stream import TCPInputServerStream +from .stream import ZMQStream +from ait.core import cfg +from ait.core import log gevent.monkey.patch_all() @@ -121,7 +128,11 @@ def _load_streams(self): try: if stream_type == "inbound": strm = self._create_inbound_stream(s["stream"]) - if type(strm) == PortInputStream: + if ( + type(strm) == PortInputStream + or type(strm) == TCPInputClientStream + or type(strm) == TCPInputServerStream + ): self.servers.append(strm) else: self.inbound_streams.append(strm) @@ -131,8 +142,10 @@ def _load_streams(self): log.info(f"Added {stream_type} stream {strm}") except Exception: exc_type, value, tb = sys.exc_info() - log.error(f"{exc_type} creating {stream_type} stream " - f"{index}: {value}") + log.error( + f"{exc_type} creating {stream_type} stream " + f"{index}: {value}" + ) if not self.inbound_streams and not self.servers: log.warn(err_msgs["inbound"]) @@ -222,8 +235,10 @@ def _get_stream_name(self, config): + self.plugins ) ]: - raise ValueError(f"Duplicate stream name '{name}' encountered. " - "Stream names must be unique.") + raise ValueError( + f"Duplicate stream name '{name}' encountered. " + "Stream names must be unique." + ) return name @@ -234,8 +249,9 @@ def _get_stream_handlers(self, config, name): for handler in config["handlers"]: hndlr = self._create_handler(handler) stream_handlers.append(hndlr) - log.info(f"Created handler {type(hndlr).__name__} for " - f"stream {name}") + log.info( + f"Created handler {type(hndlr).__name__} for " f"stream {name}" + ) else: log.warn(f"No handlers specified for stream {name}") @@ -264,20 +280,12 @@ def _create_inbound_stream(self, config=None): # Create ZMQ args re-using the Broker's context zmq_args_dict = self._create_zmq_args(True) - if type(stream_input[0]) is int: - return PortInputStream( - name, - stream_input, - stream_handlers, - zmq_args=zmq_args_dict, - ) - else: - return ZMQStream( - name, - stream_input, - stream_handlers, - zmq_args=zmq_args_dict, - ) + return input_stream_factory( + name, + stream_input, + stream_handlers, + zmq_args=zmq_args_dict, + ) def _create_outbound_stream(self, config=None): """ @@ -317,8 +325,10 @@ def _create_outbound_stream(self, config=None): ) else: if stream_output is not None: - log.warn(f"Output of stream {name} is not an integer port. " - "Stream outputs can only be ports.") + log.warn( + f"Output of stream {name} is not an integer port. " + "Stream outputs can only be ports." + ) ostream = ZMQStream( name, stream_input, @@ -378,46 +388,52 @@ def _load_plugins(self): # that indicates that plugin will run in a separate process # with that id. Multiple plugins can specify the same value # which allows them to all run within a process together - process_namespace = ait_cfg_plugin.pop('process_id', None) - plugin_type = PluginType.STANDARD if process_namespace is \ - None else PluginType.PROCESS + process_namespace = ait_cfg_plugin.pop("process_id", None) + plugin_type = ( + PluginType.STANDARD + if process_namespace is None + else PluginType.PROCESS + ) if plugin_type == PluginType.PROCESS: - # Plugin will run in a separate process (possibly with other # plugins) try: # Check if the namespace has already been created plugins_process = self.plugin_process_dict.get( - process_namespace, None) + process_namespace, None + ) # If not, then create it and add to managed dict if plugins_process is None: plugins_process = PluginsProcess(process_namespace) - self.plugin_process_dict[process_namespace] = \ - plugins_process + self.plugin_process_dict[ + process_namespace + ] = plugins_process # Convert ait config section to PluginConfig instance - plugin_info = self._create_plugin_info(ait_cfg_plugin, - False) + plugin_info = self._create_plugin_info(ait_cfg_plugin, False) # If successful, then add it to the process if plugin_info is not None: plugins_process.add_plugin_info(plugin_info) - log.info("Added config for deferred plugin " - f"{plugin_info.name} to plugin-process " - f"'{process_namespace}'") + log.info( + "Added config for deferred plugin " + f"{plugin_info.name} to plugin-process " + f"'{process_namespace}'" + ) except Exception: exc_type, exc_msg, tb = sys.exc_info() - log.error(f"{exc_type} creating plugin config {index} " - f"for process-id '{process_namespace}': " - f"{exc_msg}") + log.error( + f"{exc_type} creating plugin config {index} " + f"for process-id '{process_namespace}': " + f"{exc_msg}" + ) log.error(traceback.format_exc()) else: - # Plugin will run in current process's greenlet set try: plugin = self._create_plugin(ait_cfg_plugin) @@ -427,12 +443,12 @@ def _load_plugins(self): except Exception: exc_type, value, tb = sys.exc_info() - log.error(f"{exc_type} creating plugin {index}: " - f"{value}") + log.error(f"{exc_type} creating plugin {index}: " f"{value}") if not self.plugins and not self.plugin_process_dict: - log.warn("No valid plugin configurations found. No plugins" - " will be added.") + log.warn( + "No valid plugin configurations found. No plugins" " will be added." + ) def _create_zmq_args(self, reuse_broker_context): """ @@ -490,7 +506,6 @@ def _create_plugin_info(self, ait_plugin_config, reuse_broker_context): zmq_args = self._create_zmq_args(reuse_broker_context) # Create Plugin config (which checks for required args) - plugin_config = PluginConfig.build_from_ait_config(ait_plugin_config, - zmq_args) + plugin_config = PluginConfig.build_from_ait_config(ait_plugin_config, zmq_args) return plugin_config diff --git a/ait/core/server/stream.py b/ait/core/server/stream.py index dd789953..cc1cdda8 100644 --- a/ait/core/server/stream.py +++ b/ait/core/server/stream.py @@ -1,8 +1,12 @@ import ait.core.log -from .client import ZMQInputClient, PortInputClient, PortOutputClient +from .client import PortInputClient +from .client import PortOutputClient +from .client import TCPInputClient +from .client import TCPInputServer +from .client import ZMQInputClient -class Stream(): +class Stream: """ This is the base Stream class that all streams will inherit from. It calls its handlers to execute on all input messages sequentially, @@ -69,7 +73,6 @@ def process(self, input_data, topic=None): """ for handler in self.handlers: output = handler.handle(input_data) - if output: input_data = output else: @@ -79,7 +82,6 @@ def process(self, input_data, topic=None): ) ait.core.log.info(msg) return - self.publish(input_data) def valid_workflow(self): @@ -99,6 +101,97 @@ def valid_workflow(self): return True +def input_stream_factory(name, inputs, handlers, zmq_args=None): + """ + This factory preempts the creating of streams directly. It accepts + the same args as any given stream class and then based primarily on the + values in 'inputs' decides on the appropriate stream to instantiate and + then returns it. + """ + stream = None + + # Stream specs in the form: + # - stream: + # name: telem_stream_udp_server + # input: + # - 24000 + # - stream: + # name: telem_stream_zmq_server + # input: + # - foo_zmq + if len(inputs) == 1: + if type(inputs[0]) is int and 1024 <= inputs[0] <= 65535: + stream = PortInputStream(name, inputs, handlers, zmq_args=zmq_args) + elif type(inputs[0]) is str: + stream = ZMQStream(name, inputs, handlers, zmq_args=zmq_args) + else: + raise ValueError( + "Input stream specification with 1 arg must be [ {port_num|str} ]" + ) + + # Stream specs in the form: + # - stream: + # name: telem_stream_udp_server + # input: + # - 'UDP' + # - 'server' + # - 24000 + # - stream: + # name: telem_stream_tcp_server + # input: + # - 'TCP' + # - 'server' + # - 24000 + # - stream: + # name: telem_stream_tcp_client + # input: + # - 'TCP' + # - '1.2.3.4' + # - 24000 + elif len(inputs) == 3: + if type(inputs[0]) is str and inputs[0].upper() == "TCP": + if type(inputs[1]) is str and inputs[1].lower() == "server": + if type(inputs[2]) is int and 1024 <= inputs[2] <= 65535: + stream = TCPInputServerStream(name, inputs[1:], handlers, zmq_args) + else: + raise ValueError( + "Input stream specification with 3 args must be [ {'TCP'|'UDP'}, {'server'|ip_address}, {port_num} ]" + ) + elif type(inputs[1]) is str and inputs[1].lower() != "server": + if type(inputs[2]) is int and 1024 <= inputs[2] <= 65535: + stream = TCPInputClientStream(name, inputs[1:], handlers, zmq_args) + else: + raise ValueError( + "Input stream specification with 3 args must be [ {'TCP'|'UDP'}, {'server'|ip_address}, {port_num} ]" + ) + else: + raise ValueError( + "Input stream specification with 3 args must be [ {'TCP'|'UDP'}, {'server'|ip_address}, {port_num} ]" + ) + elif type(inputs[0]) is str and inputs[0].upper() == "UDP": + if type(inputs[1]) is str and inputs[1].lower() == "server": + if type(inputs[2]) is int and 1024 <= inputs[2] <= 65535: + stream = PortInputStream( + name, inputs[2:], handlers, zmq_args=zmq_args + ) + else: + raise ValueError( + "Input stream specification with 3 args must be [ {'TCP'|'UDP'}, {'server'|ip_address}, {port_num} ]" + ) + else: + raise NotImplementedError("UDP client not supported atm") + else: + raise ValueError( + "Input stream specification with 3 args must be [ {'TCP'|'UDP'}, {'server'|ip_address}, {port_num} ]" + ) + else: + raise ValueError("Input stream specification must contain either 1 or 3 args") + + if stream is None: + raise ValueError("Input stream specification invalid") + return stream + + class PortInputStream(Stream, PortInputClient): """ This stream type listens for messages from a UDP port and publishes to a ZMQ socket. @@ -108,6 +201,24 @@ def __init__(self, name, inputs, handlers, zmq_args=None): super(PortInputStream, self).__init__(name, inputs, handlers, zmq_args) +class TCPInputServerStream(Stream, TCPInputServer): + """ + This stream type listens for messages from a TCP port and publishes to a ZMQ socket. + """ + + def __init__(self, name, inputs, handlers, zmq_args=None): + super(TCPInputServerStream, self).__init__(name, inputs, handlers, zmq_args) + + +class TCPInputClientStream(Stream, TCPInputClient): + """ + This stream type connects to a TCP server and publishes to a ZMQ socket. + """ + + def __init__(self, name, inputs, handlers, zmq_args=None): + super(TCPInputClientStream, self).__init__(name, inputs, handlers, zmq_args) + + class ZMQStream(Stream, ZMQInputClient): """ This stream type listens for messages from another stream or plugin and publishes diff --git a/doc/source/server_architecture.rst b/doc/source/server_architecture.rst index 895d1ef1..ac6e0e72 100644 --- a/doc/source/server_architecture.rst +++ b/doc/source/server_architecture.rst @@ -61,7 +61,7 @@ AIT provides a number of default plugins. Check the `Plugins API documentation < Streams ^^^^^^^ - Streams must be listed under either **inbound-streams** or **outbound-streams**, and must have a **name**. -- **Inbound streams** can have an integer port or inbound streams as their **input**. Inbound streams can have multiple inputs. A port input should always be listed as the first input to an inbound stream. +- **Inbound streams** can have an address specification or inbound streams as their **input**. Inbound streams can have multiple inputs. - The server sets up an input stream that emits properly formed telemetry packet messages over a globally configured topic. This is used internally by the ground script API for telemetry monitoring. The input streams that pass data to this stream must output data in the Packet UID annotated format that the core packet handlers use. The input streams used can be configured via the **server.api-telemetry-streams** field. If no configuration is provided the server will default to all valid input streams if possible. See :ref:`the Ground Script API documentation ` for additional information. @@ -72,7 +72,7 @@ Streams - The server exposes an entry point for commands submitted by other processes. During initialization, this entry point will be connected to a single outbound stream, either explicitly declared by the stream (by setting the **command-subscriber** field; see :ref:`example config below `), or decided by the server (select the first outbound stream in the configuration file). - Streams can have any number of **handlers**. A stream passes each received *packet* through its handlers in order and publishes the result. -- There are several stream classes that inherit from the base stream class. These child classes exist for handling the input and output of streams differently based on whether the inputs/output are ports or other streams and plugins. The appropriate stream type will be instantiated based on whether the stream is an inbound or outbound stream and based on the inputs/output specified in the stream's configs. If the input type of an inbound stream is an integer, it will be assumed to be a port. If it is a string, it will be assumed to be another stream name or plugin. Only outbound streams can have an output, and the output must be a port, not another stream or plugin. +- There are several stream classes that inherit from the base stream class. These child classes exist for handling the input and output of streams differently based on whether the inputs/output are remote hosts, ports or other streams and plugins. The appropriate stream type will be instantiated based on whether the stream is an inbound or outbound stream and based on the inputs/output specified in the stream's configs. Only outbound streams can have an output, and the output must be a port, not another stream or plugin. .. _Stream_config: @@ -86,17 +86,48 @@ Example configuration: input: - 3077 + # UDP Input Server - stream: - name: telem_port_in_stream + name: telem_port_in_stream_1 input: - 3076 handlers: - my_custom_handlers.TestbedTelemHandler + # UDP Input Server + - stream: + name: telem_port_in_stream_2 + input: + - "UDP" + - "server" + - 3077 + handlers: + - my_custom_handlers.TestbedTelemHandler + + # TCP Input Server + - stream: + name: telem_port_in_stream_3 + input: + - "TCP" + - "server" + - 3078 + handlers: + - my_custom_handlers.TestbedTelemHandler + + # TCP Input Client + - stream: + name: telem_port_in_stream_4 + input: + - "TCP" + - "1.2.3.4" + - 3079 + handlers: + - my_custom_handlers.TestbedTelemHandler + - stream: name: telem_testbed_stream input: - - telem_port_in_stream + - telem_port_in_stream_1 handlers: - name: ait.server.handlers.PacketHandler packet: 1553_HS_Packet diff --git a/tests/ait/core/server/test_client.py b/tests/ait/core/server/test_client.py index e69de29b..29671ea6 100644 --- a/tests/ait/core/server/test_client.py +++ b/tests/ait/core/server/test_client.py @@ -0,0 +1,86 @@ +import gevent + +from ait.core.server.broker import Broker +from ait.core.server.client import TCPInputClient +from ait.core.server.client import TCPInputServer + +broker = Broker() +TEST_BYTES = "Howdy".encode() +TEST_PORT = 6666 + + +class SimpleServer(gevent.server.StreamServer): + def handle(self, socket, address): + socket.sendall(TEST_BYTES) + + +class TCPServer(TCPInputServer): + def __init__(self, name, inputs, **kwargs): + super(TCPServer, self).__init__(broker.context, input=inputs) + + def process(self, input_data): + self.cur_socket.sendall(input_data) + + +class TCPClient(TCPInputClient): + def __init__(self, name, inputs, **kwargs): + super(TCPClient, self).__init__(broker.context, input=inputs) + self.input_data = None + + def process(self, input_data): + self.input_data = input_data + self._exit() + + +class TestTCPServer: + def setup_method(self): + self.server = TCPServer("test_tcp_server", inputs=["server", TEST_PORT]) + self.server.start() + self.client = gevent.socket.create_connection(("127.0.0.1", TEST_PORT)) + + def teardown_method(self): + self.server.stop() + self.client.close() + + def test_TCP_server(self): + nbytes = self.client.send(TEST_BYTES) + response = self.client.recv(len(TEST_BYTES)) + assert nbytes == len(TEST_BYTES) + assert response == TEST_BYTES + + def test_null_send(self): + nbytes1 = self.client.send(b"") + nbytes2 = self.client.send(TEST_BYTES) + response = self.client.recv(len(TEST_BYTES)) + assert nbytes1 == 0 + assert nbytes2 == len(TEST_BYTES) + assert response == TEST_BYTES + + def test_weird_buffer(self): + self.server.buffer = len(TEST_BYTES) - 1 + nbytes2 = self.client.send(TEST_BYTES) + response = self.client.recv(len(TEST_BYTES)) + assert nbytes2 == len(TEST_BYTES) + assert response == TEST_BYTES + + +class TestTCPClient: + def setup_method(self): + self.server = SimpleServer(("127.0.0.1", 0)) + self.server.start() + self.client = TCPClient( + "test_tcp_client", inputs=["127.0.0.1", self.server.server_port] + ) + + def teardown_method(self): + self.server.stop() + + def test_TCP_client(self): + self.client.start() + assert self.client.input_data == TEST_BYTES + + def test_bad_connection(self): + self.client.port = 1 + self.client.connection_reattempts = 2 + self.client.start() + assert self.client.connection_status != 0 diff --git a/tests/ait/core/server/test_stream.py b/tests/ait/core/server/test_stream.py index 6d89a190..ec39d5d1 100644 --- a/tests/ait/core/server/test_stream.py +++ b/tests/ait/core/server/test_stream.py @@ -1,69 +1,211 @@ from unittest import mock +import gevent import pytest import zmq.green -import ait.core from ait.core.server.broker import Broker from ait.core.server.handlers import PacketHandler +from ait.core.server.stream import input_stream_factory +from ait.core.server.stream import PortInputStream +from ait.core.server.stream import TCPInputClientStream +from ait.core.server.stream import TCPInputServerStream from ait.core.server.stream import ZMQStream +broker = Broker() + + class TestStream: + invalid_stream_args = [ + "some_stream", + "input_stream", + [ + PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER"), + PacketHandler(input_type=int, packet="CCSDS_HEADER"), + ], + {"zmq_context": broker}, + ] + test_data = [ + ( + "zmq", + { + "name": "some_zmq_stream", + "inputs": ["input_stream"], + "handlers_len": 1, + "handler_type": PacketHandler, + "broker_context": broker.context, + "sub_type": zmq.green.core._Socket, + "pub_type": zmq.green.core._Socket, + "repr": "", + }, + ), + ( + "udp_server", + { + "name": "some_udp_stream", + "inputs": [1234], + "handlers_len": 1, + "handler_type": PacketHandler, + "broker_context": broker.context, + "sub_type": gevent._socket3.socket, + "pub_type": zmq.green.core._Socket, + "repr": "", + }, + ), + ( + "tcp_server", + { + "name": "some_tcp_stream_server", + "inputs": ["server", 1234], + "handlers_len": 1, + "handler_type": PacketHandler, + "broker_context": broker.context, + "sub_type": gevent._socket3.socket, + "pub_type": zmq.green.core._Socket, + "repr": "", + }, + ), + ( + "tcp_client", + { + "name": "some_tcp_stream_client", + "inputs": ["127.0.0.1", 1234], + "handlers_len": 1, + "handler_type": PacketHandler, + "broker_context": broker.context, + "sub_type": gevent._socket3.socket, + "pub_type": zmq.green.core._Socket, + "repr": "", + }, + ), + ] + def setup_method(self): - self.broker = Broker() - self.stream = ZMQStream( - "some_stream", - ["input_stream"], - [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], - zmq_args={"zmq_context": self.broker.context}, - ) - self.stream.handlers = [ - PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER") - ] + self.streams = { + "zmq": ZMQStream( + "some_zmq_stream", + ["input_stream"], + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + zmq_args={"zmq_context": broker.context}, + ), + "udp_server": PortInputStream( + "some_udp_stream", + [1234], + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + zmq_args={"zmq_context": broker.context}, + ), + "tcp_server": TCPInputServerStream( + "some_tcp_stream_server", + ["server", 1234], + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + zmq_args={"zmq_context": broker.context}, + ), + "tcp_client": TCPInputClientStream( + "some_tcp_stream_client", + ["127.0.0.1", 1234], + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + zmq_args={"zmq_context": broker.context}, + ), + } + for stream in self.streams.values(): + stream.handlers = [ + PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER") + ] - def test_stream_creation(self): - assert self.stream.name is "some_stream" - assert self.stream.inputs == ["input_stream"] - assert len(self.stream.handlers) == 1 - assert type(self.stream.handlers[0]) == PacketHandler - assert self.stream.context == self.broker.context - assert type(self.stream.pub) == zmq.green.core._Socket - assert type(self.stream.sub) == zmq.green.core._Socket + @pytest.mark.parametrize("stream,expected", test_data) + def test_stream_creation(self, stream, expected): + assert self.streams[stream].name is expected["name"] + assert self.streams[stream].inputs == expected["inputs"] + assert len(self.streams[stream].handlers) == expected["handlers_len"] + assert type(self.streams[stream].handlers[0]) == expected["handler_type"] + assert self.streams[stream].context == expected["broker_context"] + assert type(self.streams[stream].pub) == expected["pub_type"] + assert type(self.streams[stream].sub) == expected["sub_type"] - def test_repr(self): - assert self.stream.__repr__() == "" + @pytest.mark.parametrize("stream,expected", test_data) + def test_repr(self, stream, expected): + assert self.streams[stream].__repr__() == expected["repr"] + @pytest.mark.parametrize("stream,_", test_data) @mock.patch.object(PacketHandler, "handle") - def test_process(self, execute_handler_mock): - self.stream.process("input_data") + def test_process(self, execute_handler_mock, stream, _): + self.streams[stream].process("input_data") execute_handler_mock.assert_called_with("input_data") - def test_valid_workflow_one_handler(self): - assert self.stream.valid_workflow() is True + @pytest.mark.parametrize("stream,_", test_data) + def test_valid_workflow_one_handler(self, stream, _): + assert self.streams[stream].valid_workflow() is True - def test_valid_workflow_more_handlers(self): - self.stream.handlers.append( + @pytest.mark.parametrize("stream,_", test_data) + def test_valid_workflow_more_handlers(self, stream, _): + self.streams[stream].handlers.append( PacketHandler(input_type=str, packet="CCSDS_HEADER") ) - assert self.stream.valid_workflow() is True + assert self.streams[stream].valid_workflow() is True - def test_invalid_workflow_more_handlers(self): - self.stream.handlers.append( + @pytest.mark.parametrize("stream,_", test_data) + def test_invalid_workflow_more_handlers(self, stream, _): + self.streams[stream].handlers.append( PacketHandler(input_type=int, packet="CCSDS_HEADER") ) - assert self.stream.valid_workflow() is False + assert self.streams[stream].valid_workflow() is False - def test_stream_creation_invalid_workflow(self): + @pytest.mark.parametrize( + "stream,args", + [ + (ZMQStream, invalid_stream_args), + (PortInputStream, invalid_stream_args), + (TCPInputServerStream, invalid_stream_args), + (TCPInputClientStream, invalid_stream_args), + ], + ) + def test_stream_creation_invalid_workflow(self, stream, args): with pytest.raises(ValueError): - ZMQStream( - "some_stream", - "input_stream", - [ - PacketHandler( - input_type=int, output_type=str, packet="CCSDS_HEADER" - ), - PacketHandler(input_type=int, packet="CCSDS_HEADER"), - ], - zmq_args={"zmq_context": self.broker.context}, - ) + stream(*args) + + @pytest.mark.parametrize( + "args,expected", + [ + (["TCP", "127.0.0.1", 1234], TCPInputClientStream), + (["TCP", "server", 1234], TCPInputServerStream), + ([1234], PortInputStream), + (["UDP", "server", 1234], PortInputStream), + (["FOO"], ZMQStream), + ], + ) + def test_valid_stream_factory(self, args, expected): + full_args = [ + "foo", + args, + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + {"zmq_context": broker.context}, + ] + stream = input_stream_factory(*full_args) + assert isinstance(stream, expected) + + @pytest.mark.parametrize( + "args,expected", + [ + (["TCP", "127.0.0.1", "1234"], ValueError), + (["TCP", "127.0.0.1", 1], ValueError), + (["TCP", "server", "1234"], ValueError), + (["TCP", "server", 1], ValueError), + (["TCP", 1, 1024], ValueError), + (["UDP", "127.0.0.1", "1234"], NotImplementedError), + (["UDP", "server", "1234"], ValueError), + (["UDP", "server", 1], ValueError), + (["FOO", "server", 1024], ValueError), + (["server", 1234], ValueError), + ([1], ValueError), + ], + ) + def test_invalid_stream_factory(self, args, expected): + full_args = [ + "foo", + args, + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + {"zmq_context": broker.context}, + ] + with pytest.raises(expected): + _ = input_stream_factory(*full_args)