diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f2d31a0..441f649 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,6 +24,9 @@ jobs: python -m pip install --upgrade pip pip install flake8 pip install .[test] + - name: Build test DB + run: | + python src/tests/data/build_database.py - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/.gitignore b/.gitignore index c88052f..091353a 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,9 @@ dist/* config.cfg aws-auth *.log -*.certs +*.certs* !.github/* docs/_* -conf.sh \ No newline at end of file +*.sh +**/*.csv +**/*.db \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 1cdd28a..bc97815 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -14,6 +14,7 @@ :caption: Contents: self + source/cli source/modules genindex modindex @@ -57,7 +58,7 @@ intialise a swarm with data sent every 30 minutes like so: .. code-block:: shell iot-swarm cosmos --dsn="xxxxxx" --user="xxxxx" --password="*****" \ - mqtt "aws" LEVEL_1_SOILMET_30MIN "client_id" \ + mqtt LEVEL_1_SOILMET_30MIN "client_id" \ --endpoint="xxxxxxx" \ --cert-path="C:\path\..." \ --key-path="C:\path\..." \ @@ -84,7 +85,7 @@ Then the CLI can be called more cleanly: .. code-block:: shell - iot-swarm cosmos mqtt "aws" LEVEL_1_SOILMET_30MIN "client_id" --sleep-time=1800 --swarm-name="my-swarm" + iot-swarm cosmos mqtt LEVEL_1_SOILMET_30MIN "client_id" --sleep-time=1800 --swarm-name="my-swarm" ------------------------ Using the Python Modules @@ -163,6 +164,31 @@ The system expects config credentials for the MQTT endpoint and the COSMOS Oracl .. include:: example-config.cfg +------------------------------------------- +Looping Through a local database / csv file +------------------------------------------- + +This package now supports using a CSV file or local SQLite database as the data source. +There are 2 modules to support it: `db.LoopingCsvDB` and `db.LoopingSQLite3`. Each of them +initializes from a local file and loops through the data for a given site id. The database +objects store an in memory cache of each site ID and it's current index in the database. +Once the end is reached, it loops back to the start for that site. + +For use in FDRI, 6 months of data was downloaded to CSV from the COSMOS-UK database, but +the files are too large to be included in this repo, so they are stored in the `ukceh-fdri` +`S3` bucket on AWS. There are scripts for regenerating the `.db` file in this repo: + +* `./src/iotswarm/__assets__/data/build_database.py` +* `./src/tests/data/build_database.py` + +To use the 'official' data, it should be downloaded from the `S3` bucket and placed in +`./src/iotswarm/__assets__/data` before running the script. This will build a `.db` file +sorted in datetime order that the `LoopingSQLite3` class can operate with. + +.. warning:: + The looping database classes assume that their data files are sorted, and make no + attempt to sort it themselves. + Indices and tables ================== diff --git a/docs/source/cli.rst b/docs/source/cli.rst new file mode 100644 index 0000000..0341699 --- /dev/null +++ b/docs/source/cli.rst @@ -0,0 +1,6 @@ +CLI +=== + +.. click:: iotswarm.scripts.cli:main + :prog: iot-swarm + :nested: full \ No newline at end of file diff --git a/docs/source/iotswarm.scripts.rst b/docs/source/iotswarm.scripts.rst index 42b1dd6..81dd67a 100644 --- a/docs/source/iotswarm.scripts.rst +++ b/docs/source/iotswarm.scripts.rst @@ -1,5 +1,5 @@ iotswarm.scripts package -================================== +======================== Submodules ---------- diff --git a/pyproject.toml b/pyproject.toml index e23f65b..a69e0ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = ["setuptools >= 61.0", "autosemver"] -build-backend = "setuptools.build_meta" +# build-backend = "setuptools.build_meta" [project] dependencies = [ @@ -8,12 +8,14 @@ dependencies = [ "boto3", "autosemver", "config", + "click", "docutils<0.17", "awscli", "awscrt", + "awsiotsdk", "oracledb", "backoff", - "click", + "pandas", ] name = "iot-swarm" dynamic = ["version"] @@ -26,14 +28,18 @@ docs = ["sphinx", "sphinx-copybutton", "sphinx-rtd-theme", "sphinx-click"] [project.scripts] iot-swarm = "iotswarm.scripts.cli:main" + [tool.setuptools.dynamic] version = { attr = "iotswarm.__version__" } + [tool.setuptools.packages.find] where = ["src"] +include = ["iotswarm*"] +exclude = ["tests*"] [tool.setuptools.package-data] -"*" = ["*.*"] +"iotswarm.__assets__" = ["loggers.ini"] [tool.pytest.ini_options] @@ -49,4 +55,12 @@ markers = [ ] [tool.coverage.run] -omit = ["*example.py", "*__init__.py", "queries.py", "loggers.py", "cli.py"] +omit = [ + "*example.py", + "*__init__.py", + "queries.py", + "loggers.py", + "**/scripts/*.py", + "**/build_database.py", + "utils.py", +] diff --git a/src/iotswarm/__assets__/data/build_database.py b/src/iotswarm/__assets__/data/build_database.py new file mode 100644 index 0000000..96c5838 --- /dev/null +++ b/src/iotswarm/__assets__/data/build_database.py @@ -0,0 +1,41 @@ +"""This script is responsible for building an SQL data file from the CSV files used +by the cosmos network. + +The files are stored in AWS S3 and should be downloaded into this directory before continuing. +They are: + * LEVEL_1_NMDB_1HOUR_DATA_TABLE.csv + * LEVEL_1_SOILMET_30MIN_DATA_TABLE.csv + * LEVEL_1_PRECIP_1MIN_DATA_TABLE.csv + * LEVEL_1_PRECIP_RAINE_1MIN_DATA_TABLE.csv + +Once installed, run this script to generate the .db file. +""" + +from iotswarm.utils import build_database_from_csv +from pathlib import Path +from glob import glob +from iotswarm.queries import CosmosTable + + +def main( + csv_dir: str | Path = Path(__file__).parent, + database_output: str | Path = Path(Path(__file__).parent, "cosmos.db"), +): + """Reads exported cosmos DB files from CSV format. Assumes that the files + look like: LEVEL_1_SOILMET_30MIN_DATA_TABLE.csv + + Args: + csv_dir: Directory where the csv files are stored. + database_output: Output destination of the csv_data. + """ + csv_files = glob("*.csv", root_dir=csv_dir) + + tables = [CosmosTable[x.removesuffix("_DATA_TABLE.csv")] for x in csv_files] + + for table, file in zip(tables, csv_files): + file = Path(csv_dir, file) + build_database_from_csv(file, database_output, table.value, sort_by="DATE_TIME") + + +if __name__ == "__main__": + main() diff --git a/src/iotswarm/__init__.py b/src/iotswarm/__init__.py index fbfe123..0131fea 100644 --- a/src/iotswarm/__init__.py +++ b/src/iotswarm/__init__.py @@ -1,8 +1,6 @@ import autosemver try: - __version__ = autosemver.packaging.get_current_version( - project_name="iot-device-simulator" - ) + __version__ = autosemver.packaging.get_current_version(project_name="iot-swarm") except: __version__ = "unknown version" diff --git a/src/iotswarm/db.py b/src/iotswarm/db.py index 1386ddd..e7297bc 100644 --- a/src/iotswarm/db.py +++ b/src/iotswarm/db.py @@ -4,7 +4,14 @@ import getpass import logging import abc -from iotswarm.queries import CosmosQuery, CosmosSiteQuery +from iotswarm.queries import ( + CosmosQuery, + CosmosTable, +) +import pandas as pd +from pathlib import Path +from math import nan +import sqlite3 logger = logging.getLogger(__name__) @@ -44,12 +51,65 @@ def query_latest_from_site(): return [] -class Oracle(BaseDatabase): +class CosmosDB(BaseDatabase): + """Base class for databases using COSMOS_UK data.""" + + connection: object + """Connection to database.""" + + site_data_query: CosmosQuery + """SQL query for retrieving a single record.""" + + site_id_query: CosmosQuery + """SQL query for retrieving list of site IDs""" + + @staticmethod + def _validate_table(table: CosmosTable) -> None: + """Validates that the query is legal""" + + if not isinstance(table, CosmosTable): + raise TypeError( + f"`table` must be a `{CosmosTable.__class__}` Enum, not a `{type(table)}`" + ) + + @staticmethod + def _fill_query(query: str, table: CosmosTable) -> str: + """Fills a query string with a CosmosTable enum.""" + + CosmosDB._validate_table(table) + + return query.format(table=table.value) + + @staticmethod + def _validate_max_sites(max_sites: int) -> int: + """Validates that a valid maximum sites is given: + Args: + max_sites: The maximum number of sites required. + + Returns: + An integer 0 or more. + """ + + if max_sites is not None: + max_sites = int(max_sites) + if max_sites < 0: + raise ValueError( + f"`max_sites` must be 1 or more, or 0 for no maximum. Received: {max_sites}" + ) + + return max_sites + + +class Oracle(CosmosDB): """Class for handling oracledb logic and retrieving values from DB.""" connection: oracledb.Connection """Connection to oracle database.""" + site_data_query = CosmosQuery.ORACLE_LATEST_DATA + + site_id_query = CosmosQuery.ORACLE_SITE_IDS + def __repr__(self): parent_repr = ( super().__repr__().lstrip(f"{self.__class__.__name__}(").rstrip(")") @@ -64,7 +124,14 @@ def __repr__(self): ) @classmethod - async def create(cls, dsn: str, user: str, password: str = None, **kwargs): + async def create( + cls, + dsn: str, + user: str, + password: str = None, + inherit_logger: logging.Logger | None = None, + **kwargs, + ): """Factory method for initialising the class. Initialization is done through the `create() method`: `Oracle.create(...)`. @@ -72,6 +139,7 @@ async def create(cls, dsn: str, user: str, password: str = None, **kwargs): dsn: Oracle data source name. user: Username used for query. pw: User password for auth. + inherit_logger: Uses the given logger if provided """ if not password: @@ -83,32 +151,31 @@ async def create(cls, dsn: str, user: str, password: str = None, **kwargs): dsn=dsn, user=user, password=password ) + if inherit_logger is not None: + self._instance_logger = inherit_logger.getChild(self.__class__.__name__) + else: + self._instance_logger = logger.getChild(self.__class__.__name__) + self._instance_logger.info("Initialized Oracle connection.") return self - async def query_latest_from_site(self, site_id: str, query: CosmosQuery) -> dict: + async def query_latest_from_site(self, site_id: str, table: CosmosTable) -> dict: """Requests the latest data from a table for a specific site. Args: site_id: ID of the site to retrieve records from. - query: Query to parse and submit. + table: A valid table from the database Returns: dict | None: A dict containing the database columns as keys, and the values as values. Returns `None` if no data retrieved. """ - if not isinstance(query, CosmosQuery): - raise TypeError( - f"`query` must be a `CosmosQuery` Enum, not a `{type(query)}`" - ) + query = self._fill_query(self.site_data_query, table) - async with self.connection.cursor() as cursor: - await cursor.execute( - query.value, - mysite=site_id, - ) + with self.connection.cursor() as cursor: + await cursor.execute(query, site_id=site_id) columns = [i[0] for i in cursor.description] data = await cursor.fetchone() @@ -119,23 +186,92 @@ async def query_latest_from_site(self, site_id: str, query: CosmosQuery) -> dict return dict(zip(columns, data)) async def query_site_ids( - self, query: CosmosSiteQuery, max_sites: int | None = None + self, table: CosmosTable, max_sites: int | None = None ) -> list: """query_site_ids returns a list of site IDs from COSMOS database Args: - query: The query to run. + table: A valid table from the database max_sites: Maximum number of sites to retreive Returns: List[str]: A list of site ID strings. """ - if not isinstance(query, CosmosSiteQuery): - raise TypeError( - f"`query` must be a `CosmosSiteQuery` Enum, not a `{type(query)}`" - ) + max_sites = self._validate_max_sites(max_sites) + + query = self._fill_query(self.site_id_query, table) + + async with self.connection.cursor() as cursor: + await cursor.execute(query) + + data = await cursor.fetchall() + if max_sites == 0: + data = [x[0] for x in data] + else: + data = [x[0] for x in data[:max_sites]] + + if not data: + data = [] + + return data + + +class LoopingCsvDB(BaseDatabase): + """A database that reads from csv files and loops through items + for a given table or site. The site and index is remembered via a + dictionary key and incremented each time data is requested.""" + + connection: pd.DataFrame + """Connection to the pd object holding data.""" + + cache: dict + """Cache object containing current index of each site queried.""" + + @staticmethod + def _get_connection(*args) -> pd.DataFrame: + """Gets the database connection.""" + return pd.read_csv(*args) + + def __init__(self, csv_file: str | Path): + """Initialises the database object. + + Args: + csv_file: A pathlike object pointing to the datafile. + """ + + BaseDatabase.__init__(self) + self.connection = self._get_connection(csv_file) + self.cache = dict() + + def query_latest_from_site(self, site_id: str) -> dict: + """Queries the datbase for a `SITE_ID` incrementing by 1 each time called + for a specific site. If the end is reached, it loops back to the start. + + Args: + site_id: ID of the site to query for. + Returns: + A dict of the data row. + """ + + data = self.connection.query("SITE_ID == @site_id").replace({nan: None}) + + if site_id not in self.cache or self.cache[site_id] >= len(data): + self.cache[site_id] = 1 + else: + self.cache[site_id] += 1 + + return data.iloc[self.cache[site_id] - 1].to_dict() + + def query_site_ids(self, max_sites: int | None = None) -> list: + """query_site_ids returns a list of site IDs from the database + + Args: + max_sites: Maximum number of sites to retreive + Returns: + List[str]: A list of site ID strings. + """ if max_sites is not None: max_sites = int(max_sites) if max_sites < 0: @@ -143,10 +279,115 @@ async def query_site_ids( f"`max_sites` must be 1 or more, or 0 for no maximum. Received: {max_sites}" ) - async with self.connection.cursor() as cursor: - await cursor.execute(query.value) + sites = self.connection["SITE_ID"].drop_duplicates().to_list() - data = await cursor.fetchall() + if max_sites is not None and max_sites > 0: + sites = sites[:max_sites] + + return sites + + +class LoopingSQLite3(CosmosDB, LoopingCsvDB): + """A database that reads from .db files using sqlite3 and loops through + entries in sequential order. There is a script that generates the .db file + in the `__assets__/data` directory relative to this file. .csv datasets should + be downloaded from the accompanying S3 bucket before running.""" + + connection: sqlite3.Connection + """Connection to the database.""" + + site_data_query = CosmosQuery.SQLITE_LOOPED_DATA + + site_id_query = CosmosQuery.SQLITE_SITE_IDS + + @staticmethod + def _get_connection(*args) -> sqlite3.Connection: + """Gets a database connection.""" + + return sqlite3.connect(*args) + + def __init__(self, db_file: str | Path): + """Initialises the database object. + + Args: + csv_file: A pathlike object pointing to the datafile. + """ + LoopingCsvDB.__init__(self, db_file) + + self.cursor = self.connection.cursor() + + def query_latest_from_site(self, site_id: str, table: CosmosTable) -> dict: + """Queries the datbase for a `SITE_ID` incrementing by 1 each time called + for a specific site. If the end is reached, it loops back to the start. + + Args: + site_id: ID of the site to query for. + table: A valid table from the database + Returns: + A dict of the data row. + """ + query = self._fill_query(self.site_data_query, table) + + if site_id not in self.cache: + self.cache[site_id] = 0 + else: + self.cache[site_id] += 1 + + data = self._query_latest_from_site( + query, {"site_id": site_id, "offset": self.cache[site_id]} + ) + + if data is None: + self.cache[site_id] = 0 + + data = self._query_latest_from_site( + query, {"site_id": site_id, "offset": self.cache[site_id]} + ) + + return data + + def _query_latest_from_site(self, query, arg_dict: dict) -> dict: + """Requests the latest data from a table for a specific site. + + Args: + table: A valid table from the database + arg_dict: Dictionary of query arguments. + + Returns: + dict | None: A dict containing the database columns as keys, and the values as values. + Returns `None` if no data retrieved. + """ + + self.cursor.execute(query, arg_dict) + + columns = [i[0] for i in self.cursor.description] + data = self.cursor.fetchone() + + if not data: + return None + + return dict(zip(columns, data)) + + def query_site_ids(self, table: CosmosTable, max_sites: int | None = None) -> list: + """query_site_ids returns a list of site IDs from COSMOS database + + Args: + table: A valid table from the database + max_sites: Maximum number of sites to retreive + + Returns: + List[str]: A list of site ID strings. + """ + + query = self._fill_query(self.site_id_query, table) + + max_sites = self._validate_max_sites(max_sites) + + try: + cursor = self.connection.cursor() + cursor.execute(query) + + data = cursor.fetchall() if max_sites == 0: data = [x[0] for x in data] else: @@ -154,5 +395,7 @@ async def query_site_ids( if not data: data = [] + finally: + cursor.close() - return data + return data diff --git a/src/iotswarm/devices.py b/src/iotswarm/devices.py index 8c9c6ce..2b2cd1d 100644 --- a/src/iotswarm/devices.py +++ b/src/iotswarm/devices.py @@ -3,14 +3,14 @@ import asyncio import logging from iotswarm import __version__ as package_version -from iotswarm.queries import CosmosQuery -from iotswarm.db import BaseDatabase, Oracle -from iotswarm.messaging.core import MessagingBaseClass +from iotswarm.queries import CosmosTable +from iotswarm.db import BaseDatabase, CosmosDB, Oracle, LoopingCsvDB, LoopingSQLite3 +from iotswarm.messaging.core import MessagingBaseClass, MockMessageConnection from iotswarm.messaging.aws import IotCoreMQTTConnection -import random +from typing import List from datetime import datetime +import random import enum -from typing import List logger = logging.getLogger(__name__) @@ -45,8 +45,8 @@ class BaseDevice: connection: MessagingBaseClass """Connection to the data receiver.""" - query: CosmosQuery - """SQL query sent to database if Oracle type selected as `data_source`.""" + table: CosmosTable + """SQL table used in queries if Oracle or LoopingSQLite3 selected as `data_source`.""" mqtt_base_topic: str """Base topic for mqtt topic.""" @@ -85,7 +85,7 @@ def __init__( max_cycles: int | None = None, delay_start: bool | None = None, inherit_logger: logging.Logger | None = None, - query: CosmosQuery | None = None, + table: CosmosTable | None = None, mqtt_topic: str | None = None, mqtt_prefix: str | None = None, mqtt_suffix: str | None = None, @@ -100,7 +100,7 @@ def __init__( max_cycles: Maximum number of cycles before shutdown. delay_start: Adds a random delay to first invocation from 0 - `sleep_time`. inherit_logger: Override for the module logger. - query: Sets the query used in COSMOS database. Ignored if `data_source` is not a Cosmos object. + table: A valid table from the database mqtt_prefix: Prefixes the MQTT topic if MQTT messaging used. mqtt_suffix: Suffixes the MQTT topic if MQTT messaging used. """ @@ -111,10 +111,18 @@ def __init__( raise TypeError( f"`data_source` must be a `BaseDatabase`. Received: {data_source}." ) - if isinstance(data_source, Oracle) and query is None: - raise ValueError( - f"`query` must be provided if `data_source` is type `Oracle`." - ) + if isinstance(data_source, (CosmosDB)): + + if table is None: + raise ValueError( + f"`table` must be provided if `data_source` is type `OracleDB`." + ) + elif not isinstance(table, CosmosTable): + raise TypeError( + f'table must be a "{CosmosTable.__class__}", not "{type(table)}"' + ) + + self.table = table self.data_source = data_source if not isinstance(connection, MessagingBaseClass): @@ -144,18 +152,11 @@ def __init__( ) self.delay_start = delay_start - if query is not None and isinstance(self.data_source, Oracle): - if not isinstance(query, CosmosQuery): - raise TypeError( - f"`query` must be a `CosmosQuery`. Received: {type(query)}." - ) - self.query = query - - if isinstance(connection, IotCoreMQTTConnection): + if isinstance(connection, (IotCoreMQTTConnection, MockMessageConnection)): if mqtt_topic is not None: self.mqtt_topic = str(mqtt_topic) else: - self.mqtt_topic = f"{self.device_type}/{self.device_id}" + self.mqtt_topic = f"{self.device_id}" if mqtt_prefix is not None: self.mqtt_prefix = str(mqtt_prefix) @@ -190,16 +191,16 @@ def __repr__(self): if self.delay_start != self.__class__.delay_start else "" ) - query_arg = ( - f", query={self.query.__class__.__name__}.{self.query.name}" - if isinstance(self.data_source, Oracle) + table_arg = ( + f", table={self.table.__class__.__name__}.{self.table.name}" + if isinstance(self.data_source, CosmosDB) else "" ) mqtt_topic_arg = ( f', mqtt_topic="{self.mqtt_base_topic}"' if hasattr(self, "mqtt_base_topic") - and self.mqtt_base_topic != f"{self.device_type}/{self.device_id}" + and self.mqtt_base_topic != self.device_id else "" ) @@ -222,7 +223,7 @@ def __repr__(self): f"{sleep_time_arg}" f"{max_cycles_arg}" f"{delay_start_arg}" - f"{query_arg}" + f"{table_arg}" f"{mqtt_topic_arg}" f"{mqtt_prefix_arg}" f"{mqtt_suffix_arg}" @@ -243,7 +244,7 @@ def _send_payload(self, payload: dict): async def run(self): """The main invocation of the method. Expects a Oracle object to do work on - and a query to retrieve. Runs asynchronously until `max_cycles` is reached. + and a table to retrieve. Runs asynchronously until `max_cycles` is reached. Args: message_connection: The message object to send data through @@ -260,6 +261,9 @@ async def run(self): if payload: self._instance_logger.debug("Requesting payload submission.") self._send_payload(payload) + self._instance_logger.info( + f'Message sent{f" to topic: {self.mqtt_topic}" if self.mqtt_topic else ""}' + ) else: self._instance_logger.warning(f"No data found.") @@ -273,9 +277,12 @@ async def _get_payload(self): """Method for grabbing the payload to send""" if isinstance(self.data_source, Oracle): return await self.data_source.query_latest_from_site( - self.device_id, self.query + self.device_id, self.table ) - + elif isinstance(self.data_source, LoopingSQLite3): + return self.data_source.query_latest_from_site(self.device_id, self.table) + elif isinstance(self.data_source, LoopingCsvDB): + return self.data_source.query_latest_from_site(self.device_id) elif isinstance(self.data_source, BaseDatabase): return self.data_source.query_latest_from_site() diff --git a/src/iotswarm/messaging/aws.py b/src/iotswarm/messaging/aws.py index 8205964..be1a708 100644 --- a/src/iotswarm/messaging/aws.py +++ b/src/iotswarm/messaging/aws.py @@ -2,6 +2,7 @@ import awscrt from awscrt import mqtt +from awsiot import mqtt_connection_builder import awscrt.io import json from awscrt.exceptions import AwsCrtError @@ -34,6 +35,7 @@ def __init__( port: int | None = None, clean_session: bool = False, keep_alive_secs: int = 1200, + inherit_logger: logging.Logger | None = None, **kwargs, ) -> None: """Initializes the class. @@ -47,7 +49,7 @@ def __init__( port: Port used by endpoint. Guesses correct port if not given. clean_session: Builds a clean MQTT session if true. Defaults to False. keep_alive_secs: Time to keep connection alive. Defaults to 1200. - topic_prefix: A topic prefixed to MQTT topic, useful for attaching a "Basic Ingest" rule. Defaults to None. + inherit_logger: Override for the module logger. """ if not isinstance(endpoint, str): @@ -88,41 +90,31 @@ def __init__( if port < 0: raise ValueError(f"`port` cannot be less than 0. Received: {port}.") - socket_options = awscrt.io.SocketOptions() - socket_options.connect_timeout_ms = 5000 - socket_options.keep_alive = False - socket_options.keep_alive_timeout_secs = 0 - socket_options.keep_alive_interval_secs = 0 - socket_options.keep_alive_max_probes = 0 - - client_bootstrap = awscrt.io.ClientBootstrap.get_or_create_static_default() - - tls_ctx = awscrt.io.ClientTlsContext(tls_ctx_options) - mqtt_client = awscrt.mqtt.Client(client_bootstrap, tls_ctx) - - self.connection = awscrt.mqtt.Connection( - client=mqtt_client, + self.connection = mqtt_connection_builder.mtls_from_path( + endpoint=endpoint, + port=port, + cert_filepath=cert_path, + pri_key_filepath=key_path, + ca_filepath=ca_cert_path, on_connection_interrupted=self._on_connection_interrupted, on_connection_resumed=self._on_connection_resumed, client_id=client_id, - host_name=endpoint, - port=port, + proxy_options=None, clean_session=clean_session, - reconnect_min_timeout_secs=5, - reconnect_max_timeout_secs=60, keep_alive_secs=keep_alive_secs, - ping_timeout_ms=3000, - protocol_operation_timeout_ms=0, - socket_options=socket_options, - use_websockets=False, on_connection_success=self._on_connection_success, on_connection_failure=self._on_connection_failure, on_connection_closed=self._on_connection_closed, ) - self._instance_logger = logger.getChild( - f"{self.__class__.__name__}.client-{client_id}" - ) + if inherit_logger is not None: + self._instance_logger = inherit_logger.getChild( + f"{self.__class__.__name__}.client-{client_id}" + ) + else: + self._instance_logger = logger.getChild( + f"{self.__class__.__name__}.client-{client_id}" + ) def _on_connection_interrupted( self, connection, error, **kwargs @@ -182,23 +174,15 @@ def _disconnect(self): # pragma: no cover disconnect_future = self.connection.disconnect() disconnect_future.result() - def send_message( - self, message: dict, topic: str, use_logger: logging.Logger | None = None - ) -> None: + def send_message(self, message: dict, topic: str) -> None: """Sends a message to the endpoint. Args: message: The message to send. topic: MQTT topic to send message under. - use_logger: Sends log message with requested logger. """ - if use_logger is not None and isinstance(use_logger, logging.Logger): - use_logger = use_logger - else: - use_logger = self._instance_logger - if not message: - use_logger.error(f'No message to send for topic: "{topic}".') + self._instance_logger.error(f'No message to send for topic: "{topic}".') return if self.connected_flag == False: @@ -212,6 +196,4 @@ def send_message( qos=mqtt.QoS.AT_LEAST_ONCE, ) - use_logger.info(f'Sent {sys.getsizeof(payload)} bytes to "{topic}"') - - # self._disconnect() + self._instance_logger.debug(f'Sent {sys.getsizeof(payload)} bytes to "{topic}"') diff --git a/src/iotswarm/messaging/core.py b/src/iotswarm/messaging/core.py index d362b30..0957d93 100644 --- a/src/iotswarm/messaging/core.py +++ b/src/iotswarm/messaging/core.py @@ -2,8 +2,6 @@ from abc import abstractmethod import logging -logger = logging.getLogger(__name__) - class MessagingBaseClass(ABC): """MessagingBaseClass Base class for messaging implementation @@ -14,9 +12,20 @@ class MessagingBaseClass(ABC): _instance_logger: logging.Logger """Logger handle used by instance.""" - def __init__(self): - - self._instance_logger = logger.getChild(self.__class__.__name__) + def __init__( + self, + inherit_logger: logging.Logger | None = None, + ): + """Initialises the class. + Args: + inherit_logger: Override for the module logger. + """ + if inherit_logger is not None: + self._instance_logger = inherit_logger.getChild(self.__class__.__name__) + else: + self._instance_logger = logging.getLogger(__name__).getChild( + self.__class__.__name__ + ) @property @abstractmethod @@ -37,15 +46,7 @@ class MockMessageConnection(MessagingBaseClass): connection: None = None """Connection object. Not needed in a mock but must be implemented""" - def send_message(self, use_logger: logging.Logger | None = None): - """Consumes requests to send a message but does nothing with it. - - Args: - use_logger: Sends log message with requested logger.""" - - if use_logger is not None and isinstance(use_logger, logging.Logger): - use_logger = use_logger - else: - use_logger = self._instance_logger + def send_message(self, *_): + """Consumes requests to send a message but does nothing with it.""" - use_logger.info("Message was sent.") + self._instance_logger.debug("Message was sent.") diff --git a/src/iotswarm/queries.py b/src/iotswarm/queries.py index f1174b2..dfda5c8 100644 --- a/src/iotswarm/queries.py +++ b/src/iotswarm/queries.py @@ -5,112 +5,60 @@ @enum.unique -class CosmosQuery(StrEnum): - """Class containing permitted SQL queries for retrieving sensor data.""" +class CosmosTable(StrEnum): + """Enums of approved COSMOS database tables.""" - LEVEL_1_SOILMET_30MIN = """SELECT * FROM COSMOS.LEVEL1_SOILMET_30MIN -WHERE site_id = :mysite -ORDER BY date_time DESC -FETCH NEXT 1 ROWS ONLY""" + LEVEL_1_SOILMET_30MIN = "LEVEL1_SOILMET_30MIN" + LEVEL_1_NMDB_1HOUR = "LEVEL1_NMDB_1HOUR" + LEVEL_1_PRECIP_1MIN = "LEVEL1_PRECIP_1MIN" + LEVEL_1_PRECIP_RAINE_1MIN = "LEVEL1_PRECIP_RAINE_1MIN" - """Query for retreiving data from the LEVEL1_SOILMET_30MIN table, containing - calibration the core telemetry from COSMOS sites. - - .. code-block:: sql - SELECT * FROM COSMOS.LEVEL1_SOILMET_30MIN - WHERE site_id = :mysite - ORDER BY date_time DESC - FETCH NEXT 1 ROWS ONLY - """ +@enum.unique +class CosmosQuery(StrEnum): + """Enums of common queries in each databasing language.""" - LEVEL_1_NMDB_1HOUR = """SELECT * FROM COSMOS.LEVEL1_NMDB_1HOUR -WHERE site_id = :mysite -ORDER BY date_time DESC -FETCH NEXT 1 ROWS ONLY""" + SQLITE_LOOPED_DATA = """SELECT * FROM {table} +WHERE site_id = :site_id +LIMIT 1 OFFSET :offset""" - """Query for retreiving data from the Level1_NMDB_1HOUR table, containing - calibration data from the Neutron Monitor DataBase. + """Query for retreiving data from a given table in sqlite format. .. code-block:: sql - SELECT * FROM COSMOS.LEVEL1_NMDB_1HOUR - WHERE site_id = :mysite - ORDER BY date_time DESC - FETCH NEXT 1 ROWS ONLY + SELECT * FROM + WHERE site_id = :site_id + LIMIT 1 OFFSET :offset """ - LEVEL_1_PRECIP_1MIN = """SELECT * FROM COSMOS.LEVEL1_PRECIP_1MIN -WHERE site_id = :mysite + ORACLE_LATEST_DATA = """SELECT * FROM COSMOS.{table} +WHERE site_id = :site_id ORDER BY date_time DESC FETCH NEXT 1 ROWS ONLY""" - """Query for retreiving data from the LEVEL1_PRECIP_1MIN table, containing - the standard precipitation telemetry from COSMOS sites. + """Query for retreiving data from a given table in oracle format. .. code-block:: sql - SELECT * FROM COSMOS.LEVEL1_PRECIP_1MIN - WHERE site_id = :mysite + SELECT * FROM
ORDER BY date_time DESC FETCH NEXT 1 ROWS ONLY """ - LEVEL_1_PRECIP_RAINE_1MIN = """SELECT * FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN -WHERE site_id = :mysite -ORDER BY date_time DESC -FETCH NEXT 1 ROWS ONLY""" - - """Query for retreiving data from the LEVEL1_PRECIP_RAINE_1MIN table, containing - the rain[e] precipitation telemetry from COSMOS sites. - - .. code-block:: sql - - SELECT * FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN - WHERE site_id = :mysite - ORDER BY date_time DESC - FETCH NEXT 1 ROWS ONLY - """ - - -@enum.unique -class CosmosSiteQuery(StrEnum): - """Contains permitted SQL queries for extracting site IDs from database.""" - - LEVEL_1_SOILMET_30MIN = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_SOILMET_30MIN" - - """Queries unique site IDs from LEVEL1_SOILMET_30MIN. - - .. code-block:: sql - - SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_SOILMET_30MIN - """ - - LEVEL_1_NMDB_1HOUR = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_NMDB_1HOUR" - - """Queries unique site IDs from LEVEL1_NMDB_1HOUR table. - - .. code-block:: sql - - SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_NMDB_1HOUR - """ - - LEVEL_1_PRECIP_1MIN = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_1MIN" + SQLITE_SITE_IDS = "SELECT DISTINCT(site_id) FROM {table}" - """Queries unique site IDs from the LEVEL1_PRECIP_1MIN table. + """Queries unique `site_id `s from a given table. .. code-block:: sql - SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_1MIN + SELECT DISTINCT(site_id) FROM
""" - LEVEL_1_PRECIP_RAINE_1MIN = ( - "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN" - ) + ORACLE_SITE_IDS = "SELECT UNIQUE(site_id) FROM COSMOS.{table}" - """Queries unique site IDs from the LEVEL1_PRECIP_RAINE_1MIN table. + """Queries unique `site_id `s from a given table. .. code-block:: sql - SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN + SELECT UNQIUE(site_id) FROM
""" diff --git a/src/iotswarm/scripts/cli.py b/src/iotswarm/scripts/cli.py index 7d79274..912a7be 100644 --- a/src/iotswarm/scripts/cli.py +++ b/src/iotswarm/scripts/cli.py @@ -1,24 +1,19 @@ """CLI exposed when the package is installed.""" import click -from iotswarm import queries +from iotswarm import __version__ as package_version +from iotswarm.queries import CosmosTable from iotswarm.devices import BaseDevice, CR1000XDevice from iotswarm.swarm import Swarm -from iotswarm.db import Oracle +from iotswarm.db import Oracle, LoopingCsvDB, LoopingSQLite3 from iotswarm.messaging.core import MockMessageConnection from iotswarm.messaging.aws import IotCoreMQTTConnection +import iotswarm.scripts.common as cli_common import asyncio from pathlib import Path import logging -TABLES = [table.name for table in queries.CosmosQuery] - - -@click.command -@click.pass_context -def test(ctx: click.Context): - """Enables testing of cosmos group arguments.""" - print(ctx.obj) +TABLE_NAMES = [table.name for table in CosmosTable] @click.group() @@ -27,18 +22,64 @@ def test(ctx: click.Context): "--log-config", type=click.Path(exists=True), help="Path to a logging config file. Uses default if not given.", + default=Path(Path(__file__).parents[1], "__assets__", "loggers.ini"), ) -def main(ctx: click.Context, log_config: Path): +@click.option( + "--log-level", + type=click.Choice( + ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False + ), + help="Overrides the logging level.", + envvar="IOT_SWARM_LOG_LEVEL", +) +def main(ctx: click.Context, log_config: Path, log_level: str): """Core group of the cli.""" ctx.ensure_object(dict) - if not log_config: - log_config = Path(Path(__file__).parents[1], "__assets__", "loggers.ini") - logging.config.fileConfig(fname=log_config) + logger = logging.getLogger(__name__) + + if log_level: + logger.setLevel(log_level) + click.echo(f"Set log level to {log_level}.") + + ctx.obj["logger"] = logger + + +main.add_command(cli_common.test) + + +@main.command +def get_version(): + """Gets the package version""" + click.echo(package_version) + +@main.command() +@click.pass_context +@cli_common.iotcore_options +def test_mqtt( + ctx, + message, + topic, + client_id: str, + endpoint: str, + cert_path: str, + key_path: str, + ca_cert_path: str, +): + """Tests that a basic message can be sent via mqtt.""" + + connection = IotCoreMQTTConnection( + endpoint=endpoint, + cert_path=cert_path, + key_path=key_path, + ca_cert_path=ca_cert_path, + client_id=client_id, + inherit_logger=ctx.obj["logger"], + ) -main.add_command(test) + connection.send_message(message, topic) @main.group() @@ -78,113 +119,262 @@ def cosmos(ctx: click.Context, site: str, dsn: str, user: str, password: str): ctx.obj["sites"] = site -cosmos.add_command(test) +cosmos.add_command(cli_common.test) @cosmos.command() @click.pass_context -@click.argument("query", type=click.Choice(TABLES)) +@click.argument("table", type=click.Choice(TABLE_NAMES)) @click.option("--max-sites", type=click.IntRange(min=0), default=0) -def list_sites(ctx, query, max_sites): - """Lists site IDs from the database from table QUERY.""" +def list_sites(ctx, table, max_sites): + """Lists unique `site_id` from an oracle database table.""" - async def _list_sites(ctx, query): + async def _list_sites(): oracle = await Oracle.create( dsn=ctx.obj["credentials"]["dsn"], user=ctx.obj["credentials"]["user"], password=ctx.obj["credentials"]["password"], + inherit_logger=ctx.obj["logger"], ) - sites = await oracle.query_site_ids( - queries.CosmosSiteQuery[query], max_sites=max_sites - ) + sites = await oracle.query_site_ids(CosmosTable[table], max_sites=max_sites) return sites - click.echo(asyncio.run(_list_sites(ctx, query))) + click.echo(asyncio.run(_list_sites())) @cosmos.command() @click.pass_context @click.argument( - "provider", - type=click.Choice(["aws"]), -) -@click.argument( - "query", - type=click.Choice(TABLES), -) -@click.argument( - "client-id", - type=click.STRING, - required=True, + "table", + type=click.Choice(TABLE_NAMES), ) +@cli_common.device_options +@cli_common.iotcore_options +@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.") +def mqtt( + ctx, + table, + endpoint, + cert_path, + key_path, + ca_cert_path, + client_id, + sleep_time, + max_cycles, + max_sites, + swarm_name, + delay_start, + mqtt_prefix, + mqtt_suffix, + dry, + device_type, +): + """Sends The cosmos data via MQTT protocol using IoT Core. + Data is from the cosmos database TABLE and sent using CLIENT_ID. + + Currently only supports sending through AWS IoT Core.""" + table = CosmosTable[table] + + async def _mqtt(): + oracle = await Oracle.create( + dsn=ctx.obj["credentials"]["dsn"], + user=ctx.obj["credentials"]["user"], + password=ctx.obj["credentials"]["password"], + inherit_logger=ctx.obj["logger"], + ) + + sites = ctx.obj["sites"] + if len(sites) == 0: + sites = await oracle.query_site_ids(table, max_sites=max_sites) + + if dry == True: + connection = MockMessageConnection(inherit_logger=ctx.obj["logger"]) + else: + connection = IotCoreMQTTConnection( + endpoint=endpoint, + cert_path=cert_path, + key_path=key_path, + ca_cert_path=ca_cert_path, + client_id=client_id, + inherit_logger=ctx.obj["logger"], + ) + + if device_type == "basic": + DeviceClass = BaseDevice + elif device_type == "cr1000x": + DeviceClass = CR1000XDevice + + site_devices = [ + DeviceClass( + site, + oracle, + connection, + sleep_time=sleep_time, + table=table, + max_cycles=max_cycles, + delay_start=delay_start, + mqtt_prefix=mqtt_prefix, + mqtt_suffix=mqtt_suffix, + inherit_logger=ctx.obj["logger"], + ) + for site in sites + ] + + swarm = Swarm(site_devices, swarm_name) + + await swarm.run() + + asyncio.run(_mqtt()) + + +@main.group() +@click.pass_context @click.option( - "--endpoint", + "--site", type=click.STRING, - required=True, - envvar="IOT_SWARM_MQTT_ENDPOINT", - help="Endpoint of the MQTT receiving host.", + multiple=True, + help="Adds a site to be initialized. Can be invoked multiple times for other sites." + " Grabs all sites from database query if none provided", ) @click.option( - "--cert-path", + "--file", type=click.Path(exists=True), required=True, - envvar="IOT_SWARM_MQTT_CERT_PATH", - help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.", + envvar="IOT_SWARM_CSV_DB", + help="*.csv file used to instantiate a pandas database.", ) +def looping_csv(ctx, site, file): + """Instantiates a pandas dataframe from a csv file which is used as the database. + Responsibility falls on the user to ensure the correct file is selected.""" + + ctx.obj["db"] = LoopingCsvDB(file) + ctx.obj["sites"] = site + + +looping_csv.add_command(cli_common.test) +looping_csv.add_command(cli_common.list_sites) + + +@looping_csv.command() +@click.pass_context +@cli_common.device_options +@cli_common.iotcore_options +@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.") +def mqtt( + ctx, + endpoint, + cert_path, + key_path, + ca_cert_path, + client_id, + sleep_time, + max_cycles, + max_sites, + swarm_name, + delay_start, + mqtt_prefix, + mqtt_suffix, + dry, + device_type, +): + """Sends The cosmos data via MQTT protocol using IoT Core. + Data is collected from the db using QUERY and sent using CLIENT_ID. + + Currently only supports sending through AWS IoT Core.""" + + async def _mqtt(): + + sites = ctx.obj["sites"] + db = ctx.obj["db"] + if len(sites) == 0: + sites = db.query_site_ids(max_sites=max_sites) + + if dry == True: + connection = MockMessageConnection(inherit_logger=ctx.obj["logger"]) + else: + connection = IotCoreMQTTConnection( + endpoint=endpoint, + cert_path=cert_path, + key_path=key_path, + ca_cert_path=ca_cert_path, + client_id=client_id, + inherit_logger=ctx.obj["logger"], + ) + + if device_type == "basic": + DeviceClass = BaseDevice + elif device_type == "cr1000x": + DeviceClass = CR1000XDevice + + site_devices = [ + DeviceClass( + site, + db, + connection, + sleep_time=sleep_time, + max_cycles=max_cycles, + delay_start=delay_start, + mqtt_prefix=mqtt_prefix, + mqtt_suffix=mqtt_suffix, + inherit_logger=ctx.obj["logger"], + ) + for site in sites + ] + + swarm = Swarm(site_devices, swarm_name) + + await swarm.run() + + asyncio.run(_mqtt()) + + +@main.group() +@click.pass_context @click.option( - "--key-path", - type=click.Path(exists=True), - required=True, - envvar="IOT_SWARM_MQTT_KEY_PATH", - help="Path to the private key that pairs with the `--cert-path`.", + "--site", + type=click.STRING, + multiple=True, + help="Adds a site to be initialized. Can be invoked multiple times for other sites." + " Grabs all sites from database query if none provided", ) @click.option( - "--ca-cert-path", + "--file", type=click.Path(exists=True), required=True, - envvar="IOT_SWARM_MQTT_CA_CERT_PATH", - help="Path to the root Certificate Authority (CA) for the MQTT host.", -) -@click.option( - "--sleep-time", - type=click.INT, - help="The number of seconds each site goes idle after sending a message.", -) -@click.option( - "--max-cycles", - type=click.IntRange(0), - help="Maximum number message sending cycles. Runs forever if set to 0.", -) -@click.option( - "--max-sites", - type=click.IntRange(0), - help="Maximum number of sites allowed to initialize. No limit if set to 0.", -) -@click.option( - "--swarm-name", type=click.STRING, help="Name given to swarm. Appears in the logs." + envvar="IOT_SWARM_LOCAL_DB", + help="*.db file used to instantiate a sqlite3 database.", ) -@click.option( - "--delay-start", - is_flag=True, - default=False, - help="Adds a random delay before the first message from each site up to `--sleep-time`.", -) -@click.option( - "--mqtt-prefix", - type=click.STRING, - help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", -) -@click.option( - "--mqtt-suffix", - type=click.STRING, - help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", +def looping_sqlite3(ctx, site, file): + """Instantiates a sqlite3 database as sensor source..""" + ctx.obj["db"] = LoopingSQLite3(file) + ctx.obj["sites"] = site + + +looping_sqlite3.add_command(cli_common.test) + + +@looping_sqlite3.command +@click.pass_context +@click.option("--max-sites", type=click.IntRange(min=0), default=0) +@click.argument( + "table", + type=click.Choice(TABLE_NAMES), ) +def list_sites(ctx, max_sites, table): + """Prints the sites present in database.""" + sites = ctx.obj["db"].query_site_ids(table, max_sites=max_sites) + click.echo(sites) + + +@looping_sqlite3.command() +@click.pass_context +@cli_common.device_options +@cli_common.iotcore_options +@click.argument("table", type=click.Choice(TABLE_NAMES)) @click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.") -@click.option("--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic") def mqtt( ctx, - provider, - query, + table, endpoint, cert_path, key_path, @@ -200,34 +390,30 @@ def mqtt( dry, device_type, ): - """Sends The cosmos data via MQTT protocol using PROVIDER. + """Sends The cosmos data via MQTT protocol using IoT Core. Data is collected from the db using QUERY and sent using CLIENT_ID. Currently only supports sending through AWS IoT Core.""" - async def _mqtt(): - oracle = await Oracle.create( - dsn=ctx.obj["credentials"]["dsn"], - user=ctx.obj["credentials"]["user"], - password=ctx.obj["credentials"]["password"], - ) + table = CosmosTable[table] - data_query = queries.CosmosQuery[query] - site_query = queries.CosmosSiteQuery[query] + async def _mqtt(): sites = ctx.obj["sites"] + db = ctx.obj["db"] if len(sites) == 0: - sites = await oracle.query_site_ids(site_query, max_sites=max_sites) + sites = db.query_site_ids(table, max_sites=max_sites) if dry == True: - connection = MockMessageConnection() - elif provider == "aws": + connection = MockMessageConnection(inherit_logger=ctx.obj["logger"]) + else: connection = IotCoreMQTTConnection( endpoint=endpoint, cert_path=cert_path, key_path=key_path, ca_cert_path=ca_cert_path, client_id=client_id, + inherit_logger=ctx.obj["logger"], ) if device_type == "basic": @@ -238,14 +424,15 @@ async def _mqtt(): site_devices = [ DeviceClass( site, - oracle, + db, connection, sleep_time=sleep_time, - query=data_query, max_cycles=max_cycles, delay_start=delay_start, mqtt_prefix=mqtt_prefix, mqtt_suffix=mqtt_suffix, + table=table, + inherit_logger=ctx.obj["logger"], ) for site in sites ] diff --git a/src/iotswarm/scripts/common.py b/src/iotswarm/scripts/common.py new file mode 100644 index 0000000..1b8b1e0 --- /dev/null +++ b/src/iotswarm/scripts/common.py @@ -0,0 +1,113 @@ +"""Location for common CLI commands""" + +import click + + +def device_options(function): + click.option( + "--sleep-time", + type=click.INT, + help="The number of seconds each site goes idle after sending a message.", + )(function) + + click.option( + "--max-cycles", + type=click.IntRange(0), + help="Maximum number message sending cycles. Runs forever if set to 0.", + )(function) + + click.option( + "--max-sites", + type=click.IntRange(0), + help="Maximum number of sites allowed to initialize. No limit if set to 0.", + )(function) + + click.option( + "--swarm-name", + type=click.STRING, + help="Name given to swarm. Appears in the logs.", + )(function) + + click.option( + "--delay-start", + is_flag=True, + default=False, + help="Adds a random delay before the first message from each site up to `--sleep-time`.", + )(function) + + click.option( + "--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic" + )(function) + + return function + + +def iotcore_options(function): + click.argument( + "client-id", + type=click.STRING, + required=True, + )(function) + + click.option( + "--endpoint", + type=click.STRING, + required=True, + envvar="IOT_SWARM_MQTT_ENDPOINT", + help="Endpoint of the MQTT receiving host.", + )(function) + + click.option( + "--cert-path", + type=click.Path(exists=True), + required=True, + envvar="IOT_SWARM_MQTT_CERT_PATH", + help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.", + )(function) + + click.option( + "--key-path", + type=click.Path(exists=True), + required=True, + envvar="IOT_SWARM_MQTT_KEY_PATH", + help="Path to the private key that pairs with the `--cert-path`.", + )(function) + + click.option( + "--ca-cert-path", + type=click.Path(exists=True), + required=True, + envvar="IOT_SWARM_MQTT_CA_CERT_PATH", + help="Path to the root Certificate Authority (CA) for the MQTT host.", + )(function) + + click.option( + "--mqtt-prefix", + type=click.STRING, + help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", + )(function) + + click.option( + "--mqtt-suffix", + type=click.STRING, + help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", + )(function) + + return function + + +@click.command +@click.pass_context +@click.option("--max-sites", type=click.IntRange(min=0), default=0) +def list_sites(ctx, max_sites): + """Prints the sites present in database.""" + + sites = ctx.obj["db"].query_site_ids(max_sites=max_sites) + click.echo(sites) + + +@click.command +@click.pass_context +def test(ctx: click.Context): + """Enables testing of cosmos group arguments.""" + print(ctx.obj) diff --git a/src/iotswarm/utils.py b/src/iotswarm/utils.py index 6d9d988..3a935eb 100644 --- a/src/iotswarm/utils.py +++ b/src/iotswarm/utils.py @@ -1,6 +1,10 @@ """Module for handling commonly reused utility functions.""" from datetime import date, datetime +from pathlib import Path +import pandas +import sqlite3 +from glob import glob def json_serial(obj: object): @@ -13,3 +17,53 @@ def json_serial(obj: object): return obj.__json__() raise TypeError(f"Type {type(obj)} is not serializable.") + + +def build_database_from_csv( + csv_file: str | Path, + database: str | Path, + table_name: str, + sort_by: str | None = None, + date_time_format: str = r"%d-%b-%y %H.%M.%S", +) -> None: + """Adds a database table using a csv file with headers. + + Args: + csv_file: A path to the csv. + database: Output destination of the database. File is created if not + existing. + table_name: Name of the table to add into database. + sort_by: Column to sort by + date_time_format: Format of datetime column + """ + + if not isinstance(csv_file, Path): + csv_file = Path(csv_file) + + if not isinstance(database, Path): + database = Path(database) + + if not csv_file.exists(): + raise FileNotFoundError(f'csv_file does not exist: "{csv_file}"') + + if not database.parent.exists(): + raise NotADirectoryError(f'Database directory not found: "{database.parent}"') + + with sqlite3.connect(database) as conn: + print( + f'Writing table: "{table_name}" from csv_file: "{csv_file}" to db: "{database}"' + ) + print("Loading csv") + df = pandas.read_csv(csv_file) + print("Done") + print("Formatting dates") + df["DATE_TIME"] = pandas.to_datetime(df["DATE_TIME"], format=date_time_format) + print("Done") + if sort_by is not None: + print("Sorting.") + df = df.sort_values(by=sort_by) + print("Done") + + print("Writing to db.") + df.to_sql(table_name, conn, if_exists="replace", index=False) + print("Writing complete.") diff --git a/src/tests/data/ALIC1_4_ROWS.csv b/src/tests/data/ALIC1_4_ROWS.csv new file mode 100644 index 0000000..55659cd --- /dev/null +++ b/src/tests/data/ALIC1_4_ROWS.csv @@ -0,0 +1,5 @@ +"SITE_ID","DATE_TIME","WD","WD_STD","WS","WS_STD","PA","PA_STD","RH","RH_STD","TA","TA_STD","Q","Q_STD","SWIN","SWIN_STD","SWOUT","SWOUT_STD","LWIN_UNC","LWIN_UNC_STD","LWOUT_UNC","LWOUT_UNC_STD","LWIN","LWIN_STD","LWOUT","LWOUT_STD","TNR01C","TNR01C_STD","PRECIP","PRECIP_DIAG","SNOWD_DISTANCE_UNC","SNOWD_SIGNALQUALITY","CTS_MOD","CTS_BARE","CTS_MOD2","CTS_SNOW","PROFILE_VWC15","PROFILE_SOILEC15","PROFILE_VWC40","PROFILE_SOILEC40","PROFILE_VWC65","PROFILE_SOILEC65","TDT1_VWC","TDT1_TSOIL","TDT1_SOILPERM","TDT1_SOILEC","TDT2_VWC","TDT2_TSOIL","TDT2_SOILPERM","TDT2_SOILEC","STP_TSOIL50","STP_TSOIL20","STP_TSOIL10","STP_TSOIL5","STP_TSOIL2","G1","G1_STD","G1_MV","G1_CAL","G2","G2_STD","G2_MV","G2_CAL","BATTV","SCANS","RECORD","CTS_MOD_DIAG_1","CTS_MOD_DIAG_2","CTS_MOD_DIAG_3","CTS_MOD_DIAG_4","CTS_MOD_DIAG_5","UX","UZ","UY","STDEV_UX","STDEV_UY","STDEV_UZ","COV_UX_UY","COV_UX_UZ","COV_UY_UZ","CTS_SNOW_DIAG_1","CTS_SNOW_DIAG_2","CTS_SNOW_DIAG_3","CTS_SNOW_DIAG_4","CTS_SNOW_DIAG_5","SNOWD_DISTANCE_COR","WS_RES","HS","TAU","U_STAR","TS","STDEV_TS","COV_TS_UX","COV_TS_UY","COV_TS_UZ","RHO_A_MEAN","E_SAT_MEAN","METPAK_SAMPLES","CTS_MOD_RH","CTS_MOD_TEMP","TDT3_VWC","TDT3_TSOIL","TDT3_SOILEC","TDT3_SOILPERM","TDT4_VWC","TDT4_TSOIL","TDT4_SOILEC","TDT4_SOILPERM","TDT5_VWC","TDT5_TSOIL","TDT5_SOILEC","TDT5_SOILPERM","TDT6_VWC","TDT6_TSOIL","TDT6_SOILEC","TDT6_SOILPERM","TDT7_VWC","TDT7_TSOIL","TDT7_SOILEC","TDT7_SOILPERM","TDT8_VWC","TDT8_TSOIL","TDT8_SOILEC","TDT8_SOILPERM","TDT9_VWC","TDT9_TSOIL","TDT9_SOILEC","TDT9_SOILPERM","TDT10_VWC","TDT10_TSOIL","TDT10_SOILEC","TDT10_SOILPERM","CTS_SNOW_TEMP","CTS_SNOW_RH","CTS_MOD2_RH","CTS_MOD2_TEMP","SWIN_MULT","SWOUT_MULT","LWIN_MULT","LWOUT_MULT","PRECIP_TIPPING_A","PRECIP_TIPPING_B","PRECIP_TIPPING2","TBR_TIP","CTS_MOD_PERIOD","CTS_MOD2_PERIOD","CTS_SNOW_PERIOD","WM_NAN_COUNT","RH_INTERNAL" +"ALIC1",01-JUN-24,353.4,28.38,2.397,1.309,1010.926,0.044,72.58,0.641,12.34,0.05,7.891,,-0.941,0.245,0.454,0.097,-17.14,2.736,-0.303,0.172,359.1,2.68,376,0.191,12.27,0.022,0,0,,,519,,,,,,,,,,43.84,12.72,27.98,0.48,35.03,12.61,20.29,0.24,11.91,12.37,12.45,12.45,12.44,-3.3016,0.10141,-0.15144,21.8,-3.27326,0.1105,-0.14011,23.36,13.21,180,11153,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.23,1.433,,2.1,18.9,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1800,,,99999.99999, +"ALIC1",01-JUN-24,358,29.12,2.045,1.015,1011.057,0.049,70.59,0.34,12.23,0.139,7.623,,-2.57,1.666,0.552,0.318,-45.27,28,-1.624,1.119,330.1,28.91,373.8,2.085,12.1,0.203,0,0,,,536,,,,,,,,,,43.61,12.72,27.77,0.51,34.91,12.59,20.2,0.27,11.91,12.36,12.43,12.42,12.42,-2.78394,0.58028,-0.12777,21.77,-3.0528,0.35922,-0.13067,23.36,13.19,180,11154,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.231,1.424,,2.1,18.8,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1800,,,99999.99999, +"ALIC1",01-JUN-24,354.8,26.46,1.531,0.902,1011.085,0.036,73,0.945,11.38,0.31,7.471,,-4.584,0.342,1.413,0.163,-76.38,4.869,-2.488,0.493,293,3.633,366.9,1.718,10.95,0.381,0,0,,,504,,,,,,,,,,43.16,12.7,27.35,0.54,34.91,12.59,20.2,0.25,11.89,12.34,12.41,12.39,12.38,-3.57444,0.69214,-0.16418,21.77,-3.62632,0.49502,-0.15522,23.36,13.22,180,11155,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.235,1.346,,2.1,18.6,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1799,,,99999.99999, +"ALIC1",01-JUN-24,323.9,47.78,0.916,0.541,1011.082,0.038,76.2,1.033,10.42,0.285,7.34,,-4.493,0.27,1.759,0.13,-75.77,0.74,-2.514,0.463,287.7,1.127,360.9,1.949,9.8,0.328,0,0,,,461,,,,,,,,,,43.33,12.68,27.51,0.51,34.75,12.58,20.07,0.26,11.9,12.35,12.4,12.37,12.36,-6.14165,0.82193,-0.2821,21.77,-5.52521,0.57313,-0.2365,23.36,13.22,180,11156,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.239,1.262,,2.1,18.4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.025,52.2193,85.4701,90.4977,,,,,1799,,,99999.99999, \ No newline at end of file diff --git a/src/tests/data/build_database.py b/src/tests/data/build_database.py new file mode 100644 index 0000000..eef80be --- /dev/null +++ b/src/tests/data/build_database.py @@ -0,0 +1,22 @@ +"""This script generates a .db SQLite file from the accompanying .csv data file. +It contains just a few rows of data and is used for testing purposes only.""" + +from iotswarm.utils import build_database_from_csv +from pathlib import Path +from iotswarm.queries import CosmosTable + + +def main(): + data_dir = Path(__file__).parent + data_file = Path(data_dir, "ALIC1_4_ROWS.csv") + database_file = Path(data_dir, "database.db") + + data_table = CosmosTable.LEVEL_1_SOILMET_30MIN + + build_database_from_csv( + data_file, database_file, data_table.value, date_time_format=r"%d-%b-%y" + ) + + +if __name__ == "__main__": + main() diff --git a/src/tests/test_cli.py b/src/tests/test_cli.py new file mode 100644 index 0000000..c771148 --- /dev/null +++ b/src/tests/test_cli.py @@ -0,0 +1,63 @@ +from click.testing import CliRunner +from iotswarm.scripts import cli +from parameterized import parameterized +import re + +RUNNER = CliRunner() + + +def test_main_ctx(): + result = RUNNER.invoke(cli.main, ["test"]) + assert not result.exception + assert result.output == "{'logger': }\n" + + +def test_main_log_config(): + with RUNNER.isolated_filesystem(): + with open("logger.ini", "w") as f: + f.write( + """[loggers] +keys=root + +[handlers] +keys=consoleHandler + +[formatters] +keys=sampleFormatter + +[logger_root] +level=INFO +handlers=consoleHandler + +[handler_consoleHandler] +class=StreamHandler +level=INFO +formatter=sampleFormatter +args=(sys.stdout,) + +[formatter_sampleFormatter] +format=%(asctime)s - %(name)s - %(levelname)s - %(message)s""" + ) + result = RUNNER.invoke(cli.main, ["--log-config", "logger.ini", "test"]) + assert not result.exception + assert result.output == "{'logger': }\n" + + +@parameterized.expand(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) +def test_log_level_set(log_level): + + result = RUNNER.invoke(cli.main, ["--log-level", log_level, "test"]) + assert not result.exception + assert ( + result.output + == f"Set log level to {log_level}.\n{{'logger': }}\n" + ) + + +def test_get_version(): + """Tests that the verison can be retrieved.""" + + result = RUNNER.invoke(cli.main, ["get-version"]) + print(result.output) + assert not result.exception + assert re.match(r"^\d+\.\d+\.\d+$", result.output) diff --git a/src/tests/test_db.py b/src/tests/test_db.py index 0cbb27f..34e8675 100644 --- a/src/tests/test_db.py +++ b/src/tests/test_db.py @@ -1,21 +1,29 @@ import unittest import pytest import config -import pathlib +from pathlib import Path from iotswarm import db -from iotswarm.queries import CosmosQuery, CosmosSiteQuery +from iotswarm.devices import BaseDevice +from iotswarm.messaging.core import MockMessageConnection +from iotswarm.queries import CosmosTable +from iotswarm.swarm import Swarm from parameterized import parameterized import logging from unittest.mock import patch +import pandas as pd +from glob import glob +from math import isnan +import sqlite3 -CONFIG_PATH = pathlib.Path( - pathlib.Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg" +CONFIG_PATH = Path( + Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg" ) config_exists = pytest.mark.skipif( not CONFIG_PATH.exists(), reason="Config file `config.cfg` not found in root directory.", ) +COSMOS_TABLES = list(CosmosTable) class TestBaseDatabase(unittest.TestCase): @patch.multiple(db.BaseDatabase, __abstractmethods__=set()) @@ -27,6 +35,7 @@ def test_instantiation(self): self.assertIsNone(inst.query_latest_from_site()) + class TestMockDB(unittest.TestCase): def test_instantiation(self): @@ -58,22 +67,23 @@ def test_logger_set(self, logger, expected): self.assertEqual(inst._instance_logger.parent, expected) - @parameterized.expand( - [ - [None, "MockDB()"], - [ - logging.getLogger("notroot"), - "MockDB(inherit_logger=)", - ], - ] - ) - def test__repr__(self, logger, expected): + + def test__repr__no_logger(self): - inst = db.MockDB(inherit_logger=logger) + inst = db.MockDB() + + self.assertEqual(inst.__repr__(), "MockDB()") - self.assertEqual(inst.__repr__(), expected) + def test__repr__logger_given(self): + logger = logging.getLogger("testdblogger") + logger.setLevel(logging.CRITICAL) + expected = "MockDB(inherit_logger=)" + mock = db.MockDB(inherit_logger=logger) + self.assertEqual(mock.__repr__(), expected) + + class TestOracleDB(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): @@ -86,6 +96,7 @@ async def asyncSetUp(self): creds["user"], password=creds["password"], ) + self.table = CosmosTable.LEVEL_1_SOILMET_30MIN async def asyncTearDown(self) -> None: await self.oracle.connection.close() @@ -103,25 +114,19 @@ async def test_instantiation(self): async def test_latest_data_query(self): site_id = "MORLY" - query = CosmosQuery.LEVEL_1_SOILMET_30MIN - row = await self.oracle.query_latest_from_site(site_id, query) + row = await self.oracle.query_latest_from_site(site_id, self.table) self.assertEqual(row["SITE_ID"], site_id) - @parameterized.expand([ - CosmosSiteQuery.LEVEL_1_NMDB_1HOUR, - CosmosSiteQuery.LEVEL_1_SOILMET_30MIN, - CosmosSiteQuery.LEVEL_1_PRECIP_1MIN, - CosmosSiteQuery.LEVEL_1_PRECIP_RAINE_1MIN, - ]) + @parameterized.expand(COSMOS_TABLES) @pytest.mark.oracle @pytest.mark.asyncio @pytest.mark.slow @config_exists - async def test_site_id_query(self,query): + async def test_site_id_query(self, table): - sites = await self.oracle.query_site_ids(query) + sites = await self.oracle.query_site_ids(table) self.assertIsInstance(sites, list) @@ -137,9 +142,7 @@ async def test_site_id_query(self,query): @config_exists async def test_site_id_query_max_sites(self, max_sites): - query = CosmosSiteQuery.LEVEL_1_SOILMET_30MIN - - sites = await self.oracle.query_site_ids(query, max_sites=max_sites) + sites = await self.oracle.query_site_ids(self.table, max_sites=max_sites) self.assertEqual(len(sites), max_sites) @@ -147,33 +150,33 @@ async def test_site_id_query_max_sites(self, max_sites): @pytest.mark.asyncio @pytest.mark.oracle @config_exists - async def test_bad_latest_data_query_type(self): + async def test_bad_latest_data_table_type(self): site_id = "MORLY" - query = "sql injection goes brr" + table = "sql injection goes brr" with self.assertRaises(TypeError): - await self.oracle.query_latest_from_site(site_id, query) + await self.oracle.query_latest_from_site(site_id, table) @pytest.mark.asyncio @pytest.mark.oracle @config_exists - async def test_bad_site_query_type(self): + async def test_bad_site_table_type(self): - query = "sql injection goes brr" + table = "sql injection goes brr" with self.assertRaises(TypeError): - await self.oracle.query_site_ids(query) + await self.oracle.query_site_ids(table) @parameterized.expand([-1, -100, "STRING"]) @pytest.mark.asyncio @pytest.mark.oracle @config_exists - async def test_bad_site_query_max_sites_type(self, max_sites): + async def test_bad_site_table_max_sites_type(self, max_sites): """Tests bad values for max_sites.""" with self.assertRaises((TypeError,ValueError)): - await self.oracle.query_site_ids(CosmosSiteQuery.LEVEL_1_SOILMET_30MIN, max_sites=max_sites) + await self.oracle.query_site_ids(self.table, max_sites=max_sites) @pytest.mark.asyncio @pytest.mark.oracle @@ -181,7 +184,7 @@ async def test_bad_site_query_max_sites_type(self, max_sites): async def test__repr__(self): """Tests string representation.""" - oracle1 = self.oracle = await db.Oracle.create( + oracle1 = await db.Oracle.create( self.creds["dsn"], self.creds["user"], password=self.creds["password"], @@ -192,7 +195,7 @@ async def test__repr__(self): oracle1.__repr__(), expected1 ) - oracle2 = self.oracle = await db.Oracle.create( + oracle2 = await db.Oracle.create( self.creds["dsn"], self.creds["user"], password=self.creds["password"], @@ -207,6 +210,293 @@ async def test__repr__(self): oracle2.__repr__(), expected2 ) +CSV_PATH = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data") +CSV_DATA_FILES = [Path(x) for x in glob(str(Path(CSV_PATH, "*.csv")))] +sqlite_db_exist = pytest.mark.skipif(not Path(CSV_PATH, "cosmos.db").exists(), reason="Local cosmos.db does not exist.") + +data_files_exist = pytest.mark.skipif( + not CSV_PATH.exists() or len(CSV_DATA_FILES) == 0, + reason="No data files are present" +) + +class TestLoopingCsvDB(unittest.TestCase): + """Tests the LoopingCsvDB class.""" + + def setUp(self): + self.data_path = {v.name.removesuffix("_DATA_TABLE.csv"):v for v in CSV_DATA_FILES} + + self.soilmet_table = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"]) + self.maxDiff = None + + @data_files_exist + @pytest.mark.slow + def test_instantiation(self): + """Tests that the database can be instantiated.""" + + database = self.soilmet_table + + self.assertIsInstance(database, db.LoopingCsvDB) + self.assertIsInstance(database, db.BaseDatabase) + + + self.assertIsInstance(database.cache, dict) + self.assertIsInstance(database.connection, pd.DataFrame) + + @data_files_exist + @pytest.mark.slow + def test_site_data_return_value(self): + database = self.soilmet_table + + site = "MORLY" + + data = database.query_latest_from_site(site) + + expected_cache = {site: 1} + self.assertDictEqual(database.cache, expected_cache) + + self.assertIsInstance(data, dict) + + @data_files_exist + @pytest.mark.slow + def test_multiple_sites_added_to_cache(self): + sites = ["ALIC1", "MORLY", "HOLLN","EUSTN"] + + database = self.soilmet_table + + data = [database.query_latest_from_site(x) for x in sites] + + for i, site in enumerate(sites): + self.assertEqual(site, data[i]["SITE_ID"]) + self.assertIn(site, database.cache) + self.assertEqual(database.cache[site], 1) + + @data_files_exist + @pytest.mark.slow + def test_cache_incremented_on_each_request(self): + database = self.soilmet_table + + site = "MORLY" + + expected = 1 + + last_data = None + for _ in range(10): + data = database.query_latest_from_site(site) + self.assertNotEqual(last_data, data) + self.assertEqual(expected, database.cache[site]) + + last_data = data + expected += 1 + + self.assertEqual(expected, 11) + + @data_files_exist + @pytest.mark.slow + def test_cache_counter_restarts_at_end(self): + + short_table_path = Path(Path(__file__).parent, "data", "ALIC1_4_ROWS.csv") + database = db.LoopingCsvDB(short_table_path) + + site = "ALIC1" + + expected = [1,2,3,4,1] + data = [] + for e in expected: + data.append(database.query_latest_from_site(site)) + + self.assertEqual(database.cache[site], e) + + for key in data[0].keys(): + try: + self.assertEqual(data[0][key], data[-1][key]) + except AssertionError as err: + if not isnan(data[0][key]) and isnan(data[-1][key]): + raise(err) + + self.assertEqual(len(expected), len(data)) + + @data_files_exist + @pytest.mark.slow + def test_site_ids_can_be_retrieved(self): + database = self.soilmet_table + + site_ids_full = database.query_site_ids() + site_ids_exp_full = database.query_site_ids(max_sites=0) + + + self.assertIsInstance(site_ids_full, list) + + self.assertGreater(len(site_ids_full), 0) + for site in site_ids_full: + self.assertIsInstance(site, str) + + self.assertEqual(len(site_ids_full), len(site_ids_exp_full)) + + site_ids_limit = database.query_site_ids(max_sites=5) + + self.assertEqual(len(site_ids_limit), 5) + self.assertGreater(len(site_ids_full), len(site_ids_limit)) + + with self.assertRaises(ValueError): + + database.query_site_ids(max_sites=-1) + +class TestLoopingCsvDBEndToEnd(unittest.IsolatedAsyncioTestCase): + """Tests the LoopingCsvDB class.""" + + def setUp(self): + self.data_path = {v.name.removesuffix("_DATA_TABLE.csv"):v for v in CSV_DATA_FILES} + self.maxDiff = None + + @data_files_exist + @pytest.mark.slow + async def test_flow_with_device_attached(self): + """Tests that data is looped through with a device making requests.""" + + database = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"]) + device = BaseDevice("ALIC1", database, MockMessageConnection(), sleep_time=0, max_cycles=5) + + await device.run() + + self.assertDictEqual(database.cache, {"ALIC1": 5}) + + @data_files_exist + @pytest.mark.slow + async def test_flow_with_swarm_attached(self): + """Tests that the database is looped through correctly with multiple sites in a swarm.""" + + database = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"]) + sites = ["MORLY", "ALIC1", "EUSTN"] + cycles = [1, 4, 6] + devices = [ + BaseDevice(s, database, MockMessageConnection(), sleep_time=0, max_cycles=c) + for (s,c) in zip(sites, cycles) + ] + + swarm = Swarm(devices) + + await swarm.run() + + self.assertDictEqual(database.cache, {"MORLY": 1, "ALIC1": 4, "EUSTN": 6}) + +class TestSqliteDB(unittest.TestCase): + + @sqlite_db_exist + def setUp(self): + self.db_path = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data", "cosmos.db") + self.table = CosmosTable.LEVEL_1_SOILMET_30MIN + + if self.db_path.exists(): + self.database = db.LoopingSQLite3(self.db_path) + self.maxDiff = None + + @sqlite_db_exist + def test_instantiation(self): + self.assertIsInstance(self.database, db.LoopingSQLite3) + self.assertIsInstance(self.database.connection, sqlite3.Connection) + + @sqlite_db_exist + def test_latest_data(self): + + site_id = "MORLY" + + data = self.database.query_latest_from_site(site_id, self.table) + + self.assertIsInstance(data, dict) + + @sqlite_db_exist + def test_site_id_query(self): + + sites = self.database.query_site_ids(self.table) + + self.assertGreater(len(sites), 0) + + self.assertIsInstance(sites, list) + + for site in sites: + self.assertIsInstance(site, str) + + @sqlite_db_exist + def test_multiple_sites_added_to_cache(self): + sites = ["ALIC1", "MORLY", "HOLLN","EUSTN"] + + data = [self.database.query_latest_from_site(x, self.table) for x in sites] + + for i, site in enumerate(sites): + self.assertEqual(site, data[i]["SITE_ID"]) + self.assertIn(site, self.database.cache) + self.assertEqual(self.database.cache[site], 0) + + @sqlite_db_exist + def test_cache_incremented_on_each_request(self): + site = "MORLY" + + last_data = {} + for i in range(3): + if i == 0: + self.assertEqual(self.database.cache, {}) + else: + self.assertEqual(i-1, self.database.cache[site]) + data = self.database.query_latest_from_site(site, self.table) + self.assertNotEqual(last_data, data) + + last_data = data + + @sqlite_db_exist + def test_cache_counter_restarts_at_end(self): + database = db.LoopingSQLite3(Path(Path(__file__).parent, "data", "database.db")) + + site = "ALIC1" + + expected = [0,1,2,3,0] + data = [] + for e in expected: + data.append(database.query_latest_from_site(site, self.table)) + + self.assertEqual(database.cache[site], e) + + self.assertEqual(data[0], data[-1]) + + self.assertEqual(len(expected), len(data)) + +class TestLoopingSQLite3DBEndToEnd(unittest.IsolatedAsyncioTestCase): + """Tests the LoopingCsvDB class.""" + + @sqlite_db_exist + def setUp(self): + self.db_path = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data", "cosmos.db") + if self.db_path.exists(): + self.database = db.LoopingSQLite3(self.db_path) + self.maxDiff = None + self.table = CosmosTable.LEVEL_1_PRECIP_1MIN + + @sqlite_db_exist + async def test_flow_with_device_attached(self): + """Tests that data is looped through with a device making requests.""" + + device = BaseDevice("ALIC1", self.database, MockMessageConnection(), table=self.table, sleep_time=0, max_cycles=5) + + await device.run() + + self.assertDictEqual(self.database.cache, {"ALIC1": 4}) + + @sqlite_db_exist + async def test_flow_with_swarm_attached(self): + """Tests that the database is looped through correctly with multiple sites in a swarm.""" + + sites = ["MORLY", "ALIC1", "EUSTN"] + cycles = [1, 2, 3] + devices = [ + BaseDevice(s, self.database, MockMessageConnection(), sleep_time=0, max_cycles=c,table=self.table) + for (s,c) in zip(sites, cycles) + ] + + swarm = Swarm(devices) + + await swarm.run() + + self.assertDictEqual(self.database.cache, {"MORLY": 0, "ALIC1": 1, "EUSTN": 2}) + if __name__ == "__main__": unittest.main() diff --git a/src/tests/test_devices.py b/src/tests/test_devices.py index 1d68a52..efb61d3 100644 --- a/src/tests/test_devices.py +++ b/src/tests/test_devices.py @@ -5,24 +5,28 @@ import json from iotswarm.utils import json_serial from iotswarm.devices import BaseDevice, CR1000XDevice, CR1000XField -from iotswarm.db import Oracle, BaseDatabase, MockDB -from iotswarm.queries import CosmosQuery, CosmosSiteQuery +from iotswarm.db import Oracle, BaseDatabase, MockDB, LoopingSQLite3 +from iotswarm.queries import CosmosQuery, CosmosTable from iotswarm.messaging.core import MockMessageConnection, MessagingBaseClass from iotswarm.messaging.aws import IotCoreMQTTConnection from parameterized import parameterized from unittest.mock import patch -import pathlib +from pathlib import Path import config from datetime import datetime, timedelta -CONFIG_PATH = pathlib.Path( - pathlib.Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg" +CONFIG_PATH = Path( + Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg" ) config_exists = pytest.mark.skipif( not CONFIG_PATH.exists(), reason="Config file `config.cfg` not found in root directory.", ) +DATA_DIR = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data") +sqlite_db_exist = pytest.mark.skipif(not Path(DATA_DIR, "cosmos.db").exists(), reason="Local cosmos.db does not exist.") + + class TestBaseClass(unittest.IsolatedAsyncioTestCase): def setUp(self): @@ -169,26 +173,12 @@ def test__repr__(self, data_source, kwargs, expected): self.assertEqual(repr(instance), expected) - def test_query_not_set_for_non_oracle_db(self): + def test_table_not_set_for_non_oracle_db(self): - inst = BaseDevice("test", MockDB(), MockMessageConnection(), query=CosmosQuery.LEVEL_1_SOILMET_30MIN) + inst = BaseDevice("test", MockDB(), MockMessageConnection(), table=CosmosTable.LEVEL_1_NMDB_1HOUR) with self.assertRaises(AttributeError): - inst.query - - def test_prefix_suffix_not_set_for_non_mqtt(self): - "Tests that mqtt prefix and suffix not set for non MQTT messaging." - - inst = BaseDevice("site-1", self.data_source, MockMessageConnection(), mqtt_prefix="prefix", mqtt_suffix="suffix") - - with self.assertRaises(AttributeError): - inst.mqtt_topic - - with self.assertRaises(AttributeError): - inst.mqtt_prefix - - with self.assertRaises(AttributeError): - inst.mqtt_suffix + inst.table @parameterized.expand( [ @@ -204,19 +194,6 @@ def test_logger_set(self, logger, expected): self.assertEqual(inst._instance_logger.parent, expected) - @parameterized.expand([ - [None, None, None], - ["topic", None, None], - ["topic", "prefix", None], - ["topic", "prefix", "suffix"], - ]) - def test__repr__mqtt_opts_no_mqtt_connection(self, topic, prefix, suffix): - """Tests that the __repr__ method returns correctly with MQTT options set.""" - expected = 'BaseDevice("site-id", MockDB(), MockMessageConnection())' - inst = BaseDevice("site-id", MockDB(), MockMessageConnection(), mqtt_topic=topic, mqtt_prefix=prefix, mqtt_suffix=suffix) - - self.assertEqual(inst.__repr__(), expected) - class TestBaseDeviceMQTTOptions(unittest.TestCase): def setUp(self) -> None: @@ -244,7 +221,7 @@ def test_mqtt_prefix_set(self, topic): inst = BaseDevice("site", self.db, self.conn, mqtt_prefix=topic) self.assertEqual(inst.mqtt_prefix, topic) - self.assertEqual(inst.mqtt_topic, f"{topic}/base-device/site") + self.assertEqual(inst.mqtt_topic, f"{topic}/site") @parameterized.expand(["this/topic", "1/1/1", "TOPICO!"]) @config_exists @@ -254,7 +231,7 @@ def test_mqtt_suffix_set(self, topic): inst = BaseDevice("site", self.db, self.conn, mqtt_suffix=topic) self.assertEqual(inst.mqtt_suffix, topic) - self.assertEqual(inst.mqtt_topic, f"base-device/site/{topic}") + self.assertEqual(inst.mqtt_topic, f"site/{topic}") @parameterized.expand([["this/prefix", "this/suffix"], ["1/1/1", "2/2/2"], ["TOPICO!", "FOUR"]]) @config_exists @@ -265,7 +242,7 @@ def test_mqtt_prefix_and_suffix_set(self, prefix, suffix): self.assertEqual(inst.mqtt_suffix, suffix) self.assertEqual(inst.mqtt_prefix, prefix) - self.assertEqual(inst.mqtt_topic, f"{prefix}/base-device/site/{suffix}") + self.assertEqual(inst.mqtt_topic, f"{prefix}/site/{suffix}") @config_exists def test_default_mqtt_topic_set(self): @@ -273,7 +250,7 @@ def test_default_mqtt_topic_set(self): inst = BaseDevice("site-12", self.db, self.conn) - self.assertEqual(inst.mqtt_topic, "base-device/site-12") + self.assertEqual(inst.mqtt_topic, "site-12") @parameterized.expand([ [None, None, None, ""], @@ -289,7 +266,9 @@ def test__repr__mqtt_opts_mqtt_connection(self, topic, prefix, suffix,expected_a self.assertEqual(inst.__repr__(), expected) + class TestBaseDeviceOracleUsed(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): cred_path = str(CONFIG_PATH) creds = config.Config(cred_path)["oracle"] @@ -300,53 +279,92 @@ async def asyncSetUp(self): creds["user"], password=creds["password"], ) - self.query = CosmosQuery.LEVEL_1_SOILMET_30MIN + self.table = CosmosTable.LEVEL_1_SOILMET_30MIN async def asyncTearDown(self) -> None: await self.oracle.connection.close() - @parameterized.expand([-1, -423.78, CosmosSiteQuery.LEVEL_1_NMDB_1HOUR, "Four", MockDB(), {"a": 1}]) + @parameterized.expand([-1, -423.78, CosmosQuery.ORACLE_LATEST_DATA, "Four", MockDB(), {"a": 1}]) @config_exists - def test_query_value_check(self, query): + @pytest.mark.oracle + def test_table_value_check(self, table): with self.assertRaises(TypeError): BaseDevice( - "test_id", self.oracle, MockMessageConnection(), query=query + "test_id", self.oracle, MockMessageConnection(), table=table ) @pytest.mark.oracle @config_exists - async def test_error_if_query_not_given(self): + async def test_error_if_table_not_given(self): with self.assertRaises(ValueError): BaseDevice("site", self.oracle, MockMessageConnection()) - inst = BaseDevice("site", self.oracle, MockMessageConnection(), query=self.query) + inst = BaseDevice("site", self.oracle, MockMessageConnection(), table=self.table) - self.assertEqual(inst.query, self.query) + self.assertEqual(inst.table, self.table) @pytest.mark.oracle @config_exists async def test__repr__oracle_data(self): - inst_oracle = BaseDevice("site", self.oracle, MockMessageConnection(), query=self.query) - exp_oracle = f'BaseDevice("site", Oracle("{self.creds['dsn']}"), MockMessageConnection(), query=CosmosQuery.{self.query.name})' + inst_oracle = BaseDevice("site", self.oracle, MockMessageConnection(), table=self.table) + exp_oracle = f'BaseDevice("site", Oracle("{self.creds['dsn']}"), MockMessageConnection(), table=CosmosTable.{self.table.name})' self.assertEqual(inst_oracle.__repr__(), exp_oracle) - inst_not_oracle = BaseDevice("site", MockDB(), MockMessageConnection(), query=self.query) + inst_not_oracle = BaseDevice("site", MockDB(), MockMessageConnection(), table=self.table) exp_not_oracle = 'BaseDevice("site", MockDB(), MockMessageConnection())' self.assertEqual(inst_not_oracle.__repr__(), exp_not_oracle) with self.assertRaises(AttributeError): - inst_not_oracle.query + inst_not_oracle.table @pytest.mark.oracle @config_exists async def test__get_payload(self): """Tests that Cosmos payload retrieved.""" - inst = BaseDevice("MORLY", self.oracle, MockMessageConnection(), query=self.query) + inst = BaseDevice("MORLY", self.oracle, MockMessageConnection(), table=self.table) + print(inst) + payload = await inst._get_payload() + + self.assertIsInstance(payload, dict) + +class TestBaseDevicesSQLite3Used(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + db_path = Path(Path(__file__).parents[1], "iotswarm","__assets__", "data", "cosmos.db") + if db_path.exists(): + self.db = LoopingSQLite3(db_path) + self.table = CosmosTable.LEVEL_1_SOILMET_30MIN + + @parameterized.expand([-1, -423.78, CosmosQuery.ORACLE_LATEST_DATA, "Four", MockDB(), {"a": 1}]) + @sqlite_db_exist + def test_table_value_check(self, table): + with self.assertRaises(TypeError): + BaseDevice( + "test_id", self.db, MockMessageConnection(), table=table + ) + + @sqlite_db_exist + def test_error_if_table_not_given(self): + + with self.assertRaises(ValueError): + BaseDevice("site", self.db, MockMessageConnection()) + + + inst = BaseDevice("site", self.db, MockMessageConnection(), table=self.table) + + self.assertEqual(inst.table, self.table) + + @sqlite_db_exist + async def test__get_payload(self): + """Tests that Cosmos payload retrieved.""" + + inst = BaseDevice("MORLY", self.db, MockMessageConnection(), table=self.table) + payload = await inst._get_payload() self.assertIsInstance(payload, dict) @@ -448,6 +466,7 @@ async def test__get_payload(self): self.assertIsInstance(payload, list) self.assertEqual(len(payload),0) + class TestCr1000xDevice(unittest.TestCase): """Test suite for the CR1000X Device.""" diff --git a/src/tests/test_messaging.py b/src/tests/test_messaging.py index 4fe9271..49e86fe 100644 --- a/src/tests/test_messaging.py +++ b/src/tests/test_messaging.py @@ -11,13 +11,20 @@ from parameterized import parameterized import logging -CONFIG_PATH = Path( - Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg" -) + +ASSETS_PATH = Path(Path(__file__).parents[1], "iotswarm", "__assets__") +CONFIG_PATH = Path(ASSETS_PATH, "config.cfg") + config_exists = pytest.mark.skipif( not CONFIG_PATH.exists(), reason="Config file `config.cfg` not found in root directory.", ) +certs_exist = pytest.mark.skipif( + not Path(ASSETS_PATH, ".certs", "cosmos_soilmet-certificate.pem.crt").exists() + or not Path(ASSETS_PATH, ".certs", "cosmos_soilmet-private.pem.key").exists() + or not Path(ASSETS_PATH, ".certs", "AmazonRootCA1.pem").exists(), + reason="IotCore certificates not present.", +) class TestBaseClass(unittest.TestCase): @@ -43,24 +50,20 @@ def test_instantiation(self): self.assertIsInstance(mock, MessagingBaseClass) - def test_logger_used(self): + def test_no_logger_used(self): - mock = MockMessageConnection() + with self.assertNoLogs(): + mock = MockMessageConnection() + mock.send_message("") - with self.assertLogs() as cm: - mock.send_message() - self.assertEqual( - cm.output, - [ - "INFO:iotswarm.messaging.core.MockMessageConnection:Message was sent." - ], - ) - - with self.assertLogs() as cm: - mock.send_message(use_logger=logging.getLogger("mine")) + def test_logger_used(self): + logger = logging.getLogger("testlogger") + with self.assertLogs(logger=logger, level=logging.DEBUG) as cm: + mock = MockMessageConnection(inherit_logger=logger) + mock.send_message("") self.assertEqual( cm.output, - ["INFO:mine:Message was sent."], + ["DEBUG:testlogger.MockMessageConnection:Message was sent."], ) @@ -73,6 +76,7 @@ def setUp(self) -> None: self.config = config["iot_core"] @config_exists + @certs_exist def test_instantiation(self): instance = IotCoreMQTTConnection(**self.config, client_id="test_id") @@ -82,6 +86,7 @@ def test_instantiation(self): self.assertIsInstance(instance.connection, awscrt.mqtt.Connection) @config_exists + @certs_exist def test_non_string_arguments(self): with self.assertRaises(TypeError): @@ -130,6 +135,7 @@ def test_non_string_arguments(self): ) @config_exists + @certs_exist def test_port(self): # Expect one of defaults if no port given @@ -151,6 +157,7 @@ def test_port(self): @parameterized.expand([-4, {"f": 4}, "FOUR"]) @config_exists + @certs_exist def test_bad_port_type(self, port): with self.assertRaises((TypeError, ValueError)): @@ -164,6 +171,7 @@ def test_bad_port_type(self, port): ) @config_exists + @certs_exist def test_clean_session_set(self): expected = False @@ -180,6 +188,7 @@ def test_clean_session_set(self): @parameterized.expand([0, -1, "true", None]) @config_exists + @certs_exist def test_bad_clean_session_type(self, clean_session): with self.assertRaises(TypeError): @@ -188,6 +197,7 @@ def test_bad_clean_session_type(self, clean_session): ) @config_exists + @certs_exist def test_keep_alive_secs_set(self): # Test defualt is not none instance = IotCoreMQTTConnection(**self.config, client_id="test_id") @@ -200,8 +210,9 @@ def test_keep_alive_secs_set(self): ) self.assertEqual(instance.connection.keep_alive_secs, expected) - @parameterized.expand(["FOURTY", "True", None]) + @parameterized.expand(["FOURTY", "True"]) @config_exists + @certs_exist def test_bad_keep_alive_secs_type(self, secs): with self.assertRaises(TypeError): IotCoreMQTTConnection( @@ -209,7 +220,8 @@ def test_bad_keep_alive_secs_type(self, secs): ) @config_exists - def test_logger_set(self): + @certs_exist + def test_no_logger_set(self): inst = IotCoreMQTTConnection(**self.config, client_id="test_id") expected = 'No message to send for topic: "mytopic".' @@ -223,12 +235,21 @@ def test_logger_set(self): ], ) - with self.assertLogs() as cm: - inst.send_message(None, "mytopic", use_logger=logging.getLogger("mine")) + @config_exists + @certs_exist + def test_logger_set(self): + logger = logging.getLogger("mine") + inst = IotCoreMQTTConnection( + **self.config, client_id="test_id", inherit_logger=logger + ) + + expected = 'No message to send for topic: "mytopic".' + with self.assertLogs(logger=logger, level=logging.INFO) as cm: + inst.send_message(None, "mytopic") self.assertEqual( cm.output, - [f"ERROR:mine:{expected}"], + [f"ERROR:mine.IotCoreMQTTConnection.client-test_id:{expected}"], )