Skip to content

Commit

Permalink
Patched test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
lewis-chambers committed May 30, 2024
1 parent d4cc6f5 commit f89636e
Show file tree
Hide file tree
Showing 10 changed files with 386 additions and 83 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"awscli",
"awscrt",
"oracledb",
"backoff",
]
name = "iot-device-simulator"
dynamic = ["version"]
Expand Down Expand Up @@ -43,3 +44,6 @@ markers = [
"asyncio: Tests asynchronous functions.",
"oracle: Requires oracle connection and required config credentials",
]

[tool.coverage.run]
omit = ["*example.py", "*__init__.py"]
2 changes: 1 addition & 1 deletion src/iotdevicesimulator/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def create(

self = cls()

if inherit_logger:
if inherit_logger is not None:
self._instance_logger = inherit_logger.getChild("db")
else:
self._instance_logger = logger
Expand Down
11 changes: 6 additions & 5 deletions src/iotdevicesimulator/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,26 @@ def __init__(self, site_id: str,*, sleep_time: int|None=None, max_cycles: int|No

self.site_id = str(site_id)

if inherit_logger:
if inherit_logger is not None:
self._instance_logger = inherit_logger.getChild(f"site-{self.site_id}")
else:
self._instance_logger = logger.getChild(self.site_id)

if max_cycles:
if max_cycles is not None:
max_cycles = int(max_cycles)
if max_cycles <= 0 and max_cycles != -1:
raise ValueError(f"`max_cycles` must be 1 or more, or -1 for no maximum. Received: {max_cycles}")

self.max_cycles = max_cycles

if sleep_time:
if sleep_time is not None:
sleep_time = int(sleep_time)
if sleep_time < 0:
raise ValueError(f"`sleep_time` must be 0 or more. Received: {sleep_time}")

self.sleep_time = sleep_time

if delay_first_cycle:
if delay_first_cycle is not None:
if not isinstance(delay_first_cycle, bool):
raise TypeError(
f"`delay_first_cycle` must be a bool. Received: {delay_first_cycle}."
Expand All @@ -77,7 +77,8 @@ def __repr__(self):
def __str__(self):
return f"Site ID: \"{self.site_id}\", Sleep Time: {self.sleep_time}, Max Cycles: {self.max_cycles}, Cycle: {self.cycle}"

async def run(self, oracle: Oracle, query: CosmosQuery, message_connection: IotCoreMQTTConnection):
async def run(self, oracle: Oracle, query: CosmosQuery,
message_connection: IotCoreMQTTConnection):
"""The main invocation of the method. Expects a Oracle object to do work on
and a query to retrieve. Runs asynchronously until `max_cycles` is reached.
Expand Down
102 changes: 55 additions & 47 deletions src/iotdevicesimulator/mqtt/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,22 @@
import awscrt
from awscrt import mqtt
import awscrt.io
import sys
import time
import json
import config
from pathlib import Path
from awscrt.exceptions import AwsCrtError
from iotdevicesimulator.mqtt.core import MessagingBaseClass
import backoff
import logging

logger = logging.getLogger(__name__)

class IotCoreMQTTConnection:

class IotCoreMQTTConnection(MessagingBaseClass):
"""Handles MQTT communication to AWS IoT Core."""

connection: awscrt.mqtt.Connection
connection: awscrt.mqtt.Connection | None = None
"""A connection to the MQTT endpoint."""

topic_prefix: str | None = None
"""Prefix attached to the send topic. Can attach \"Basic Ingest\" rules this way."""

def __init__(
self,
endpoint: str,
Expand All @@ -31,7 +30,6 @@ def __init__(
port: int | None = None,
clean_session: bool = False,
keep_alive_secs: int = 1200,
topic_prefix: str | None = None,
**kwargs,
) -> None:
"""Initializes the class.
Expand All @@ -48,21 +46,43 @@ def __init__(
topic_prefix: A topic prefixed to MQTT topic, useful for attaching a "Basic Ingest" rule. Defaults to None.
"""

if topic_prefix:
self.topic_prefix = str(topic_prefix)
if not isinstance(endpoint, str):
raise TypeError(f"`endpoint` must be a `str`, not {type(endpoint)}")

if not isinstance(cert_path, str):
raise TypeError(f"`cert_path` must be a `str`, not {type(cert_path)}")

if not isinstance(key_path, str):
raise TypeError(f"`key_path` must be a `str`, not {type(key_path)}")

if not isinstance(ca_cert_path, str):
raise TypeError(f"`ca_cert_path` must be a `str`, not {type(ca_cert_path)}")

if not isinstance(client_id, str):
raise TypeError(f"`client_id` must be a `str`, not {type(client_id)}")

if not isinstance(clean_session, bool):
raise TypeError(
f"`clean_session` must be a bool, not {type(clean_session)}."
)

tls_ctx_options = awscrt.io.TlsContextOptions.create_client_with_mtls_from_path(
cert_path, key_path
)

tls_ctx_options.override_default_trust_store_from_path(ca_cert_path)

if not port:
if port is None:
if awscrt.io.is_alpn_available():
port = 443
tls_ctx_options.alpn_list = ["x-amzn-mqtt-ca"]
else:
port = 8883
else:
port = int(port)

if port < 0:
raise ValueError(f"`port` cannot be less than 0. Received: {port}.")

socket_options = awscrt.io.SocketOptions()
socket_options.connect_timeout_ms = 5000
Expand All @@ -71,8 +91,6 @@ def __init__(
socket_options.keep_alive_interval_secs = 0
socket_options.keep_alive_max_probes = 0

username = None

client_bootstrap = awscrt.io.ClientBootstrap.get_or_create_static_default()

tls_ctx = awscrt.io.ClientTlsContext(tls_ctx_options)
Expand All @@ -91,25 +109,22 @@ def __init__(
keep_alive_secs=keep_alive_secs,
ping_timeout_ms=3000,
protocol_operation_timeout_ms=0,
will=None,
username=username,
password=None,
socket_options=socket_options,
use_websockets=False,
websocket_handshake_transform=None,
proxy_options=None,
on_connection_success=self._on_connection_success,
on_connection_failure=self._on_connection_failure,
on_connection_closed=self._on_connection_closed,
)

@staticmethod
def _on_connection_interrupted(connection, error, **kwargs):
def _on_connection_interrupted(connection, error, **kwargs): # pragma: no cover
"""Callback when connection accidentally lost."""
print("Connection interrupted. error: {}".format(error))

@staticmethod
def _on_connection_resumed(connection, return_code, session_present, **kwargs):
def _on_connection_resumed(
connection, return_code, session_present, **kwargs
): # pragma: no cover
"""Callback when an interrupted connection is re-established."""

print(
Expand All @@ -119,7 +134,7 @@ def _on_connection_resumed(connection, return_code, session_present, **kwargs):
)

@staticmethod
def _on_connection_success(connection, callback_data):
def _on_connection_success(connection, callback_data): # pragma: no cover
"""Callback when the connection successfully connects."""

assert isinstance(callback_data, mqtt.OnConnectionSuccessData)
Expand All @@ -130,44 +145,39 @@ def _on_connection_success(connection, callback_data):
)

@staticmethod
def _on_connection_failure(connection, callback_data):
def _on_connection_failure(connection, callback_data): # pragma: no cover
"""Callback when a connection attempt fails."""

assert isinstance(callback_data, mqtt.OnConnectionFailureData)
print("Connection failed with error code: {}".format(callback_data.error))

@staticmethod
def _on_connection_closed(connection, callback_data):
def _on_connection_closed(connection, callback_data): # pragma: no cover
"""Callback when a connection has been disconnected or shutdown successfully"""
print("Connection closed")

def send_message(self, message: str, topic: str, count: int = 1):
@backoff.on_exception(backoff.expo, exception=AwsCrtError, logger=logger)
def _connect(self):
connect_future = self.connection.connect()
connect_future.result()
print("Connected!")

@backoff.on_exception(backoff.expo, exception=AwsCrtError, logger=logger)
def _disconnect(self):
print("Disconnecting...")
disconnect_future = self.connection.disconnect()
disconnect_future.result()

def send_message(self, message: str, topic: str, count: int = 1) -> None:
"""Sends a message to the endpoint.
Args:
message: The message to send.
topic: MQTT topic to send message under.
cound: How many times to repeat the message. If 0, it sends forever.
count: How many times to repeat the message. If 0, it sends forever.
"""

if self.topic_prefix:
topic = f"{self.topic_prefix}/{topic}"

retry_count = 0

while retry_count < 10:
connect_future = self.connection.connect()

# Future.result() waits until a result is available
try:
connect_future.result()
break
except AwsCrtError:
print(f"Could not connect. Attempt {retry_count+1}/10")
retry_count += 1
time.sleep(2 * retry_count)

print("Connected!")
self._connect()

# Publish message to server desired number of times.
# This step is skipped if message is blank.
Expand Down Expand Up @@ -195,6 +205,4 @@ def send_message(self, message: str, topic: str, count: int = 1):
time.sleep(1)
publish_count += 1

print("Disconnecting...")
disconnect_future = self.connection.disconnect()
disconnect_future.result()
self._disconnect()
30 changes: 30 additions & 0 deletions src/iotdevicesimulator/mqtt/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from abc import ABC
from abc import abstractmethod


class MessagingBaseClass(ABC):
"""MessagingBaseClass Base class for messaging implementation
All messaging classes implement this interface.
"""

@property
@abstractmethod
def connection(self):
"""A property for the connection object where messages are sent."""

@abstractmethod
def send_message(self):
"""Method for sending the message."""


class MockMessageConnection(MessagingBaseClass):
"""Mock implementation of base class. Consumes `send_message` calls but does no work."""

connection: None = None
"""Connection object. Not needed in a mock but must be implemented"""

@staticmethod
def send_message(*args, **kwargs):
"""Consumes requests to send a message but does nothing with it."""
pass
Loading

0 comments on commit f89636e

Please sign in to comment.