From ebb518516ad36b3d4cde7157e78a167cbd0f8dd8 Mon Sep 17 00:00:00 2001 From: Lewis Chambers <47946387+lewis-chambers@users.noreply.github.com> Date: Thu, 20 Jun 2024 09:01:18 +0100 Subject: [PATCH] Data source expansion (#6) This pull request covers the inclusion of running the swarm from local data files and addresses the challenges in running the code from an EC2 instance that cannot reach the COSMOS-UK database. * NEW: A class for looping through as .csv file with pandas * NEW: A class for looping through a SQLite database (assumes that it's been sorted first) * NEW: Updated CLI to allow new data sources to be selected. * NEW: Updated documentation --- .github/workflows/test.yml | 3 + .gitignore | 6 +- docs/index.rst | 30 +- docs/source/cli.rst | 6 + docs/source/iotswarm.scripts.rst | 2 +- pyproject.toml | 22 +- .../__assets__/data/build_database.py | 41 ++ src/iotswarm/__init__.py | 4 +- src/iotswarm/db.py | 291 +++++++++++-- src/iotswarm/devices.py | 67 +-- src/iotswarm/messaging/aws.py | 60 +-- src/iotswarm/messaging/core.py | 33 +- src/iotswarm/queries.py | 104 ++--- src/iotswarm/scripts/cli.py | 387 +++++++++++++----- src/iotswarm/scripts/common.py | 113 +++++ src/iotswarm/utils.py | 54 +++ src/tests/data/ALIC1_4_ROWS.csv | 5 + src/tests/data/build_database.py | 22 + src/tests/test_cli.py | 63 +++ src/tests/test_db.py | 368 +++++++++++++++-- src/tests/test_devices.py | 121 +++--- src/tests/test_messaging.py | 65 ++- 22 files changed, 1456 insertions(+), 411 deletions(-) create mode 100644 docs/source/cli.rst create mode 100644 src/iotswarm/__assets__/data/build_database.py create mode 100644 src/iotswarm/scripts/common.py create mode 100644 src/tests/data/ALIC1_4_ROWS.csv create mode 100644 src/tests/data/build_database.py create mode 100644 src/tests/test_cli.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f2d31a0..441f649 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,6 +24,9 @@ jobs: python -m pip install --upgrade pip pip install flake8 pip install .[test] + - name: Build test DB + run: | + python src/tests/data/build_database.py - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/.gitignore b/.gitignore index c88052f..091353a 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,9 @@ dist/* config.cfg aws-auth *.log -*.certs +*.certs* !.github/* docs/_* -conf.sh \ No newline at end of file +*.sh +**/*.csv +**/*.db \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 1cdd28a..bc97815 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -14,6 +14,7 @@ :caption: Contents: self + source/cli source/modules genindex modindex @@ -57,7 +58,7 @@ intialise a swarm with data sent every 30 minutes like so: .. code-block:: shell iot-swarm cosmos --dsn="xxxxxx" --user="xxxxx" --password="*****" \ - mqtt "aws" LEVEL_1_SOILMET_30MIN "client_id" \ + mqtt LEVEL_1_SOILMET_30MIN "client_id" \ --endpoint="xxxxxxx" \ --cert-path="C:\path\..." \ --key-path="C:\path\..." \ @@ -84,7 +85,7 @@ Then the CLI can be called more cleanly: .. code-block:: shell - iot-swarm cosmos mqtt "aws" LEVEL_1_SOILMET_30MIN "client_id" --sleep-time=1800 --swarm-name="my-swarm" + iot-swarm cosmos mqtt LEVEL_1_SOILMET_30MIN "client_id" --sleep-time=1800 --swarm-name="my-swarm" ------------------------ Using the Python Modules @@ -163,6 +164,31 @@ The system expects config credentials for the MQTT endpoint and the COSMOS Oracl .. include:: example-config.cfg +------------------------------------------- +Looping Through a local database / csv file +------------------------------------------- + +This package now supports using a CSV file or local SQLite database as the data source. +There are 2 modules to support it: `db.LoopingCsvDB` and `db.LoopingSQLite3`. Each of them +initializes from a local file and loops through the data for a given site id. The database +objects store an in memory cache of each site ID and it's current index in the database. +Once the end is reached, it loops back to the start for that site. + +For use in FDRI, 6 months of data was downloaded to CSV from the COSMOS-UK database, but +the files are too large to be included in this repo, so they are stored in the `ukceh-fdri` +`S3` bucket on AWS. There are scripts for regenerating the `.db` file in this repo: + +* `./src/iotswarm/__assets__/data/build_database.py` +* `./src/tests/data/build_database.py` + +To use the 'official' data, it should be downloaded from the `S3` bucket and placed in +`./src/iotswarm/__assets__/data` before running the script. This will build a `.db` file +sorted in datetime order that the `LoopingSQLite3` class can operate with. + +.. warning:: + The looping database classes assume that their data files are sorted, and make no + attempt to sort it themselves. + Indices and tables ================== diff --git a/docs/source/cli.rst b/docs/source/cli.rst new file mode 100644 index 0000000..0341699 --- /dev/null +++ b/docs/source/cli.rst @@ -0,0 +1,6 @@ +CLI +=== + +.. click:: iotswarm.scripts.cli:main + :prog: iot-swarm + :nested: full \ No newline at end of file diff --git a/docs/source/iotswarm.scripts.rst b/docs/source/iotswarm.scripts.rst index 42b1dd6..81dd67a 100644 --- a/docs/source/iotswarm.scripts.rst +++ b/docs/source/iotswarm.scripts.rst @@ -1,5 +1,5 @@ iotswarm.scripts package -================================== +======================== Submodules ---------- diff --git a/pyproject.toml b/pyproject.toml index e23f65b..a69e0ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = ["setuptools >= 61.0", "autosemver"] -build-backend = "setuptools.build_meta" +# build-backend = "setuptools.build_meta" [project] dependencies = [ @@ -8,12 +8,14 @@ dependencies = [ "boto3", "autosemver", "config", + "click", "docutils<0.17", "awscli", "awscrt", + "awsiotsdk", "oracledb", "backoff", - "click", + "pandas", ] name = "iot-swarm" dynamic = ["version"] @@ -26,14 +28,18 @@ docs = ["sphinx", "sphinx-copybutton", "sphinx-rtd-theme", "sphinx-click"] [project.scripts] iot-swarm = "iotswarm.scripts.cli:main" + [tool.setuptools.dynamic] version = { attr = "iotswarm.__version__" } + [tool.setuptools.packages.find] where = ["src"] +include = ["iotswarm*"] +exclude = ["tests*"] [tool.setuptools.package-data] -"*" = ["*.*"] +"iotswarm.__assets__" = ["loggers.ini"] [tool.pytest.ini_options] @@ -49,4 +55,12 @@ markers = [ ] [tool.coverage.run] -omit = ["*example.py", "*__init__.py", "queries.py", "loggers.py", "cli.py"] +omit = [ + "*example.py", + "*__init__.py", + "queries.py", + "loggers.py", + "**/scripts/*.py", + "**/build_database.py", + "utils.py", +] diff --git a/src/iotswarm/__assets__/data/build_database.py b/src/iotswarm/__assets__/data/build_database.py new file mode 100644 index 0000000..96c5838 --- /dev/null +++ b/src/iotswarm/__assets__/data/build_database.py @@ -0,0 +1,41 @@ +"""This script is responsible for building an SQL data file from the CSV files used +by the cosmos network. + +The files are stored in AWS S3 and should be downloaded into this directory before continuing. +They are: + * LEVEL_1_NMDB_1HOUR_DATA_TABLE.csv + * LEVEL_1_SOILMET_30MIN_DATA_TABLE.csv + * LEVEL_1_PRECIP_1MIN_DATA_TABLE.csv + * LEVEL_1_PRECIP_RAINE_1MIN_DATA_TABLE.csv + +Once installed, run this script to generate the .db file. +""" + +from iotswarm.utils import build_database_from_csv +from pathlib import Path +from glob import glob +from iotswarm.queries import CosmosTable + + +def main( + csv_dir: str | Path = Path(__file__).parent, + database_output: str | Path = Path(Path(__file__).parent, "cosmos.db"), +): + """Reads exported cosmos DB files from CSV format. Assumes that the files + look like: LEVEL_1_SOILMET_30MIN_DATA_TABLE.csv + + Args: + csv_dir: Directory where the csv files are stored. + database_output: Output destination of the csv_data. + """ + csv_files = glob("*.csv", root_dir=csv_dir) + + tables = [CosmosTable[x.removesuffix("_DATA_TABLE.csv")] for x in csv_files] + + for table, file in zip(tables, csv_files): + file = Path(csv_dir, file) + build_database_from_csv(file, database_output, table.value, sort_by="DATE_TIME") + + +if __name__ == "__main__": + main() diff --git a/src/iotswarm/__init__.py b/src/iotswarm/__init__.py index fbfe123..0131fea 100644 --- a/src/iotswarm/__init__.py +++ b/src/iotswarm/__init__.py @@ -1,8 +1,6 @@ import autosemver try: - __version__ = autosemver.packaging.get_current_version( - project_name="iot-device-simulator" - ) + __version__ = autosemver.packaging.get_current_version(project_name="iot-swarm") except: __version__ = "unknown version" diff --git a/src/iotswarm/db.py b/src/iotswarm/db.py index 1386ddd..e7297bc 100644 --- a/src/iotswarm/db.py +++ b/src/iotswarm/db.py @@ -4,7 +4,14 @@ import getpass import logging import abc -from iotswarm.queries import CosmosQuery, CosmosSiteQuery +from iotswarm.queries import ( + CosmosQuery, + CosmosTable, +) +import pandas as pd +from pathlib import Path +from math import nan +import sqlite3 logger = logging.getLogger(__name__) @@ -44,12 +51,65 @@ def query_latest_from_site(): return [] -class Oracle(BaseDatabase): +class CosmosDB(BaseDatabase): + """Base class for databases using COSMOS_UK data.""" + + connection: object + """Connection to database.""" + + site_data_query: CosmosQuery + """SQL query for retrieving a single record.""" + + site_id_query: CosmosQuery + """SQL query for retrieving list of site IDs""" + + @staticmethod + def _validate_table(table: CosmosTable) -> None: + """Validates that the query is legal""" + + if not isinstance(table, CosmosTable): + raise TypeError( + f"`table` must be a `{CosmosTable.__class__}` Enum, not a `{type(table)}`" + ) + + @staticmethod + def _fill_query(query: str, table: CosmosTable) -> str: + """Fills a query string with a CosmosTable enum.""" + + CosmosDB._validate_table(table) + + return query.format(table=table.value) + + @staticmethod + def _validate_max_sites(max_sites: int) -> int: + """Validates that a valid maximum sites is given: + Args: + max_sites: The maximum number of sites required. + + Returns: + An integer 0 or more. + """ + + if max_sites is not None: + max_sites = int(max_sites) + if max_sites < 0: + raise ValueError( + f"`max_sites` must be 1 or more, or 0 for no maximum. Received: {max_sites}" + ) + + return max_sites + + +class Oracle(CosmosDB): """Class for handling oracledb logic and retrieving values from DB.""" connection: oracledb.Connection """Connection to oracle database.""" + site_data_query = CosmosQuery.ORACLE_LATEST_DATA + + site_id_query = CosmosQuery.ORACLE_SITE_IDS + def __repr__(self): parent_repr = ( super().__repr__().lstrip(f"{self.__class__.__name__}(").rstrip(")") @@ -64,7 +124,14 @@ def __repr__(self): ) @classmethod - async def create(cls, dsn: str, user: str, password: str = None, **kwargs): + async def create( + cls, + dsn: str, + user: str, + password: str = None, + inherit_logger: logging.Logger | None = None, + **kwargs, + ): """Factory method for initialising the class. Initialization is done through the `create() method`: `Oracle.create(...)`. @@ -72,6 +139,7 @@ async def create(cls, dsn: str, user: str, password: str = None, **kwargs): dsn: Oracle data source name. user: Username used for query. pw: User password for auth. + inherit_logger: Uses the given logger if provided """ if not password: @@ -83,32 +151,31 @@ async def create(cls, dsn: str, user: str, password: str = None, **kwargs): dsn=dsn, user=user, password=password ) + if inherit_logger is not None: + self._instance_logger = inherit_logger.getChild(self.__class__.__name__) + else: + self._instance_logger = logger.getChild(self.__class__.__name__) + self._instance_logger.info("Initialized Oracle connection.") return self - async def query_latest_from_site(self, site_id: str, query: CosmosQuery) -> dict: + async def query_latest_from_site(self, site_id: str, table: CosmosTable) -> dict: """Requests the latest data from a table for a specific site. Args: site_id: ID of the site to retrieve records from. - query: Query to parse and submit. + table: A valid table from the database Returns: dict | None: A dict containing the database columns as keys, and the values as values. Returns `None` if no data retrieved. """ - if not isinstance(query, CosmosQuery): - raise TypeError( - f"`query` must be a `CosmosQuery` Enum, not a `{type(query)}`" - ) + query = self._fill_query(self.site_data_query, table) - async with self.connection.cursor() as cursor: - await cursor.execute( - query.value, - mysite=site_id, - ) + with self.connection.cursor() as cursor: + await cursor.execute(query, site_id=site_id) columns = [i[0] for i in cursor.description] data = await cursor.fetchone() @@ -119,23 +186,92 @@ async def query_latest_from_site(self, site_id: str, query: CosmosQuery) -> dict return dict(zip(columns, data)) async def query_site_ids( - self, query: CosmosSiteQuery, max_sites: int | None = None + self, table: CosmosTable, max_sites: int | None = None ) -> list: """query_site_ids returns a list of site IDs from COSMOS database Args: - query: The query to run. + table: A valid table from the database max_sites: Maximum number of sites to retreive Returns: List[str]: A list of site ID strings. """ - if not isinstance(query, CosmosSiteQuery): - raise TypeError( - f"`query` must be a `CosmosSiteQuery` Enum, not a `{type(query)}`" - ) + max_sites = self._validate_max_sites(max_sites) + + query = self._fill_query(self.site_id_query, table) + + async with self.connection.cursor() as cursor: + await cursor.execute(query) + + data = await cursor.fetchall() + if max_sites == 0: + data = [x[0] for x in data] + else: + data = [x[0] for x in data[:max_sites]] + + if not data: + data = [] + + return data + + +class LoopingCsvDB(BaseDatabase): + """A database that reads from csv files and loops through items + for a given table or site. The site and index is remembered via a + dictionary key and incremented each time data is requested.""" + + connection: pd.DataFrame + """Connection to the pd object holding data.""" + + cache: dict + """Cache object containing current index of each site queried.""" + + @staticmethod + def _get_connection(*args) -> pd.DataFrame: + """Gets the database connection.""" + return pd.read_csv(*args) + + def __init__(self, csv_file: str | Path): + """Initialises the database object. + + Args: + csv_file: A pathlike object pointing to the datafile. + """ + + BaseDatabase.__init__(self) + self.connection = self._get_connection(csv_file) + self.cache = dict() + + def query_latest_from_site(self, site_id: str) -> dict: + """Queries the datbase for a `SITE_ID` incrementing by 1 each time called + for a specific site. If the end is reached, it loops back to the start. + + Args: + site_id: ID of the site to query for. + Returns: + A dict of the data row. + """ + + data = self.connection.query("SITE_ID == @site_id").replace({nan: None}) + + if site_id not in self.cache or self.cache[site_id] >= len(data): + self.cache[site_id] = 1 + else: + self.cache[site_id] += 1 + + return data.iloc[self.cache[site_id] - 1].to_dict() + + def query_site_ids(self, max_sites: int | None = None) -> list: + """query_site_ids returns a list of site IDs from the database + + Args: + max_sites: Maximum number of sites to retreive + Returns: + List[str]: A list of site ID strings. + """ if max_sites is not None: max_sites = int(max_sites) if max_sites < 0: @@ -143,10 +279,115 @@ async def query_site_ids( f"`max_sites` must be 1 or more, or 0 for no maximum. Received: {max_sites}" ) - async with self.connection.cursor() as cursor: - await cursor.execute(query.value) + sites = self.connection["SITE_ID"].drop_duplicates().to_list() - data = await cursor.fetchall() + if max_sites is not None and max_sites > 0: + sites = sites[:max_sites] + + return sites + + +class LoopingSQLite3(CosmosDB, LoopingCsvDB): + """A database that reads from .db files using sqlite3 and loops through + entries in sequential order. There is a script that generates the .db file + in the `__assets__/data` directory relative to this file. .csv datasets should + be downloaded from the accompanying S3 bucket before running.""" + + connection: sqlite3.Connection + """Connection to the database.""" + + site_data_query = CosmosQuery.SQLITE_LOOPED_DATA + + site_id_query = CosmosQuery.SQLITE_SITE_IDS + + @staticmethod + def _get_connection(*args) -> sqlite3.Connection: + """Gets a database connection.""" + + return sqlite3.connect(*args) + + def __init__(self, db_file: str | Path): + """Initialises the database object. + + Args: + csv_file: A pathlike object pointing to the datafile. + """ + LoopingCsvDB.__init__(self, db_file) + + self.cursor = self.connection.cursor() + + def query_latest_from_site(self, site_id: str, table: CosmosTable) -> dict: + """Queries the datbase for a `SITE_ID` incrementing by 1 each time called + for a specific site. If the end is reached, it loops back to the start. + + Args: + site_id: ID of the site to query for. + table: A valid table from the database + Returns: + A dict of the data row. + """ + query = self._fill_query(self.site_data_query, table) + + if site_id not in self.cache: + self.cache[site_id] = 0 + else: + self.cache[site_id] += 1 + + data = self._query_latest_from_site( + query, {"site_id": site_id, "offset": self.cache[site_id]} + ) + + if data is None: + self.cache[site_id] = 0 + + data = self._query_latest_from_site( + query, {"site_id": site_id, "offset": self.cache[site_id]} + ) + + return data + + def _query_latest_from_site(self, query, arg_dict: dict) -> dict: + """Requests the latest data from a table for a specific site. + + Args: + table: A valid table from the database + arg_dict: Dictionary of query arguments. + + Returns: + dict | None: A dict containing the database columns as keys, and the values as values. + Returns `None` if no data retrieved. + """ + + self.cursor.execute(query, arg_dict) + + columns = [i[0] for i in self.cursor.description] + data = self.cursor.fetchone() + + if not data: + return None + + return dict(zip(columns, data)) + + def query_site_ids(self, table: CosmosTable, max_sites: int | None = None) -> list: + """query_site_ids returns a list of site IDs from COSMOS database + + Args: + table: A valid table from the database + max_sites: Maximum number of sites to retreive + + Returns: + List[str]: A list of site ID strings. + """ + + query = self._fill_query(self.site_id_query, table) + + max_sites = self._validate_max_sites(max_sites) + + try: + cursor = self.connection.cursor() + cursor.execute(query) + + data = cursor.fetchall() if max_sites == 0: data = [x[0] for x in data] else: @@ -154,5 +395,7 @@ async def query_site_ids( if not data: data = [] + finally: + cursor.close() - return data + return data diff --git a/src/iotswarm/devices.py b/src/iotswarm/devices.py index 8c9c6ce..2b2cd1d 100644 --- a/src/iotswarm/devices.py +++ b/src/iotswarm/devices.py @@ -3,14 +3,14 @@ import asyncio import logging from iotswarm import __version__ as package_version -from iotswarm.queries import CosmosQuery -from iotswarm.db import BaseDatabase, Oracle -from iotswarm.messaging.core import MessagingBaseClass +from iotswarm.queries import CosmosTable +from iotswarm.db import BaseDatabase, CosmosDB, Oracle, LoopingCsvDB, LoopingSQLite3 +from iotswarm.messaging.core import MessagingBaseClass, MockMessageConnection from iotswarm.messaging.aws import IotCoreMQTTConnection -import random +from typing import List from datetime import datetime +import random import enum -from typing import List logger = logging.getLogger(__name__) @@ -45,8 +45,8 @@ class BaseDevice: connection: MessagingBaseClass """Connection to the data receiver.""" - query: CosmosQuery - """SQL query sent to database if Oracle type selected as `data_source`.""" + table: CosmosTable + """SQL table used in queries if Oracle or LoopingSQLite3 selected as `data_source`.""" mqtt_base_topic: str """Base topic for mqtt topic.""" @@ -85,7 +85,7 @@ def __init__( max_cycles: int | None = None, delay_start: bool | None = None, inherit_logger: logging.Logger | None = None, - query: CosmosQuery | None = None, + table: CosmosTable | None = None, mqtt_topic: str | None = None, mqtt_prefix: str | None = None, mqtt_suffix: str | None = None, @@ -100,7 +100,7 @@ def __init__( max_cycles: Maximum number of cycles before shutdown. delay_start: Adds a random delay to first invocation from 0 - `sleep_time`. inherit_logger: Override for the module logger. - query: Sets the query used in COSMOS database. Ignored if `data_source` is not a Cosmos object. + table: A valid table from the database mqtt_prefix: Prefixes the MQTT topic if MQTT messaging used. mqtt_suffix: Suffixes the MQTT topic if MQTT messaging used. """ @@ -111,10 +111,18 @@ def __init__( raise TypeError( f"`data_source` must be a `BaseDatabase`. Received: {data_source}." ) - if isinstance(data_source, Oracle) and query is None: - raise ValueError( - f"`query` must be provided if `data_source` is type `Oracle`." - ) + if isinstance(data_source, (CosmosDB)): + + if table is None: + raise ValueError( + f"`table` must be provided if `data_source` is type `OracleDB`." + ) + elif not isinstance(table, CosmosTable): + raise TypeError( + f'table must be a "{CosmosTable.__class__}", not "{type(table)}"' + ) + + self.table = table self.data_source = data_source if not isinstance(connection, MessagingBaseClass): @@ -144,18 +152,11 @@ def __init__( ) self.delay_start = delay_start - if query is not None and isinstance(self.data_source, Oracle): - if not isinstance(query, CosmosQuery): - raise TypeError( - f"`query` must be a `CosmosQuery`. Received: {type(query)}." - ) - self.query = query - - if isinstance(connection, IotCoreMQTTConnection): + if isinstance(connection, (IotCoreMQTTConnection, MockMessageConnection)): if mqtt_topic is not None: self.mqtt_topic = str(mqtt_topic) else: - self.mqtt_topic = f"{self.device_type}/{self.device_id}" + self.mqtt_topic = f"{self.device_id}" if mqtt_prefix is not None: self.mqtt_prefix = str(mqtt_prefix) @@ -190,16 +191,16 @@ def __repr__(self): if self.delay_start != self.__class__.delay_start else "" ) - query_arg = ( - f", query={self.query.__class__.__name__}.{self.query.name}" - if isinstance(self.data_source, Oracle) + table_arg = ( + f", table={self.table.__class__.__name__}.{self.table.name}" + if isinstance(self.data_source, CosmosDB) else "" ) mqtt_topic_arg = ( f', mqtt_topic="{self.mqtt_base_topic}"' if hasattr(self, "mqtt_base_topic") - and self.mqtt_base_topic != f"{self.device_type}/{self.device_id}" + and self.mqtt_base_topic != self.device_id else "" ) @@ -222,7 +223,7 @@ def __repr__(self): f"{sleep_time_arg}" f"{max_cycles_arg}" f"{delay_start_arg}" - f"{query_arg}" + f"{table_arg}" f"{mqtt_topic_arg}" f"{mqtt_prefix_arg}" f"{mqtt_suffix_arg}" @@ -243,7 +244,7 @@ def _send_payload(self, payload: dict): async def run(self): """The main invocation of the method. Expects a Oracle object to do work on - and a query to retrieve. Runs asynchronously until `max_cycles` is reached. + and a table to retrieve. Runs asynchronously until `max_cycles` is reached. Args: message_connection: The message object to send data through @@ -260,6 +261,9 @@ async def run(self): if payload: self._instance_logger.debug("Requesting payload submission.") self._send_payload(payload) + self._instance_logger.info( + f'Message sent{f" to topic: {self.mqtt_topic}" if self.mqtt_topic else ""}' + ) else: self._instance_logger.warning(f"No data found.") @@ -273,9 +277,12 @@ async def _get_payload(self): """Method for grabbing the payload to send""" if isinstance(self.data_source, Oracle): return await self.data_source.query_latest_from_site( - self.device_id, self.query + self.device_id, self.table ) - + elif isinstance(self.data_source, LoopingSQLite3): + return self.data_source.query_latest_from_site(self.device_id, self.table) + elif isinstance(self.data_source, LoopingCsvDB): + return self.data_source.query_latest_from_site(self.device_id) elif isinstance(self.data_source, BaseDatabase): return self.data_source.query_latest_from_site() diff --git a/src/iotswarm/messaging/aws.py b/src/iotswarm/messaging/aws.py index 8205964..be1a708 100644 --- a/src/iotswarm/messaging/aws.py +++ b/src/iotswarm/messaging/aws.py @@ -2,6 +2,7 @@ import awscrt from awscrt import mqtt +from awsiot import mqtt_connection_builder import awscrt.io import json from awscrt.exceptions import AwsCrtError @@ -34,6 +35,7 @@ def __init__( port: int | None = None, clean_session: bool = False, keep_alive_secs: int = 1200, + inherit_logger: logging.Logger | None = None, **kwargs, ) -> None: """Initializes the class. @@ -47,7 +49,7 @@ def __init__( port: Port used by endpoint. Guesses correct port if not given. clean_session: Builds a clean MQTT session if true. Defaults to False. keep_alive_secs: Time to keep connection alive. Defaults to 1200. - topic_prefix: A topic prefixed to MQTT topic, useful for attaching a "Basic Ingest" rule. Defaults to None. + inherit_logger: Override for the module logger. """ if not isinstance(endpoint, str): @@ -88,41 +90,31 @@ def __init__( if port < 0: raise ValueError(f"`port` cannot be less than 0. Received: {port}.") - socket_options = awscrt.io.SocketOptions() - socket_options.connect_timeout_ms = 5000 - socket_options.keep_alive = False - socket_options.keep_alive_timeout_secs = 0 - socket_options.keep_alive_interval_secs = 0 - socket_options.keep_alive_max_probes = 0 - - client_bootstrap = awscrt.io.ClientBootstrap.get_or_create_static_default() - - tls_ctx = awscrt.io.ClientTlsContext(tls_ctx_options) - mqtt_client = awscrt.mqtt.Client(client_bootstrap, tls_ctx) - - self.connection = awscrt.mqtt.Connection( - client=mqtt_client, + self.connection = mqtt_connection_builder.mtls_from_path( + endpoint=endpoint, + port=port, + cert_filepath=cert_path, + pri_key_filepath=key_path, + ca_filepath=ca_cert_path, on_connection_interrupted=self._on_connection_interrupted, on_connection_resumed=self._on_connection_resumed, client_id=client_id, - host_name=endpoint, - port=port, + proxy_options=None, clean_session=clean_session, - reconnect_min_timeout_secs=5, - reconnect_max_timeout_secs=60, keep_alive_secs=keep_alive_secs, - ping_timeout_ms=3000, - protocol_operation_timeout_ms=0, - socket_options=socket_options, - use_websockets=False, on_connection_success=self._on_connection_success, on_connection_failure=self._on_connection_failure, on_connection_closed=self._on_connection_closed, ) - self._instance_logger = logger.getChild( - f"{self.__class__.__name__}.client-{client_id}" - ) + if inherit_logger is not None: + self._instance_logger = inherit_logger.getChild( + f"{self.__class__.__name__}.client-{client_id}" + ) + else: + self._instance_logger = logger.getChild( + f"{self.__class__.__name__}.client-{client_id}" + ) def _on_connection_interrupted( self, connection, error, **kwargs @@ -182,23 +174,15 @@ def _disconnect(self): # pragma: no cover disconnect_future = self.connection.disconnect() disconnect_future.result() - def send_message( - self, message: dict, topic: str, use_logger: logging.Logger | None = None - ) -> None: + def send_message(self, message: dict, topic: str) -> None: """Sends a message to the endpoint. Args: message: The message to send. topic: MQTT topic to send message under. - use_logger: Sends log message with requested logger. """ - if use_logger is not None and isinstance(use_logger, logging.Logger): - use_logger = use_logger - else: - use_logger = self._instance_logger - if not message: - use_logger.error(f'No message to send for topic: "{topic}".') + self._instance_logger.error(f'No message to send for topic: "{topic}".') return if self.connected_flag == False: @@ -212,6 +196,4 @@ def send_message( qos=mqtt.QoS.AT_LEAST_ONCE, ) - use_logger.info(f'Sent {sys.getsizeof(payload)} bytes to "{topic}"') - - # self._disconnect() + self._instance_logger.debug(f'Sent {sys.getsizeof(payload)} bytes to "{topic}"') diff --git a/src/iotswarm/messaging/core.py b/src/iotswarm/messaging/core.py index d362b30..0957d93 100644 --- a/src/iotswarm/messaging/core.py +++ b/src/iotswarm/messaging/core.py @@ -2,8 +2,6 @@ from abc import abstractmethod import logging -logger = logging.getLogger(__name__) - class MessagingBaseClass(ABC): """MessagingBaseClass Base class for messaging implementation @@ -14,9 +12,20 @@ class MessagingBaseClass(ABC): _instance_logger: logging.Logger """Logger handle used by instance.""" - def __init__(self): - - self._instance_logger = logger.getChild(self.__class__.__name__) + def __init__( + self, + inherit_logger: logging.Logger | None = None, + ): + """Initialises the class. + Args: + inherit_logger: Override for the module logger. + """ + if inherit_logger is not None: + self._instance_logger = inherit_logger.getChild(self.__class__.__name__) + else: + self._instance_logger = logging.getLogger(__name__).getChild( + self.__class__.__name__ + ) @property @abstractmethod @@ -37,15 +46,7 @@ class MockMessageConnection(MessagingBaseClass): connection: None = None """Connection object. Not needed in a mock but must be implemented""" - def send_message(self, use_logger: logging.Logger | None = None): - """Consumes requests to send a message but does nothing with it. - - Args: - use_logger: Sends log message with requested logger.""" - - if use_logger is not None and isinstance(use_logger, logging.Logger): - use_logger = use_logger - else: - use_logger = self._instance_logger + def send_message(self, *_): + """Consumes requests to send a message but does nothing with it.""" - use_logger.info("Message was sent.") + self._instance_logger.debug("Message was sent.") diff --git a/src/iotswarm/queries.py b/src/iotswarm/queries.py index f1174b2..dfda5c8 100644 --- a/src/iotswarm/queries.py +++ b/src/iotswarm/queries.py @@ -5,112 +5,60 @@ @enum.unique -class CosmosQuery(StrEnum): - """Class containing permitted SQL queries for retrieving sensor data.""" +class CosmosTable(StrEnum): + """Enums of approved COSMOS database tables.""" - LEVEL_1_SOILMET_30MIN = """SELECT * FROM COSMOS.LEVEL1_SOILMET_30MIN -WHERE site_id = :mysite -ORDER BY date_time DESC -FETCH NEXT 1 ROWS ONLY""" + LEVEL_1_SOILMET_30MIN = "LEVEL1_SOILMET_30MIN" + LEVEL_1_NMDB_1HOUR = "LEVEL1_NMDB_1HOUR" + LEVEL_1_PRECIP_1MIN = "LEVEL1_PRECIP_1MIN" + LEVEL_1_PRECIP_RAINE_1MIN = "LEVEL1_PRECIP_RAINE_1MIN" - """Query for retreiving data from the LEVEL1_SOILMET_30MIN table, containing - calibration the core telemetry from COSMOS sites. - - .. code-block:: sql - SELECT * FROM COSMOS.LEVEL1_SOILMET_30MIN - WHERE site_id = :mysite - ORDER BY date_time DESC - FETCH NEXT 1 ROWS ONLY - """ +@enum.unique +class CosmosQuery(StrEnum): + """Enums of common queries in each databasing language.""" - LEVEL_1_NMDB_1HOUR = """SELECT * FROM COSMOS.LEVEL1_NMDB_1HOUR -WHERE site_id = :mysite -ORDER BY date_time DESC -FETCH NEXT 1 ROWS ONLY""" + SQLITE_LOOPED_DATA = """SELECT * FROM {table} +WHERE site_id = :site_id +LIMIT 1 OFFSET :offset""" - """Query for retreiving data from the Level1_NMDB_1HOUR table, containing - calibration data from the Neutron Monitor DataBase. + """Query for retreiving data from a given table in sqlite format. .. code-block:: sql - SELECT * FROM COSMOS.LEVEL1_NMDB_1HOUR - WHERE site_id = :mysite - ORDER BY date_time DESC - FETCH NEXT 1 ROWS ONLY + SELECT * FROM + WHERE site_id = :site_id + LIMIT 1 OFFSET :offset """ - LEVEL_1_PRECIP_1MIN = """SELECT * FROM COSMOS.LEVEL1_PRECIP_1MIN -WHERE site_id = :mysite + ORACLE_LATEST_DATA = """SELECT * FROM COSMOS.{table} +WHERE site_id = :site_id ORDER BY date_time DESC FETCH NEXT 1 ROWS ONLY""" - """Query for retreiving data from the LEVEL1_PRECIP_1MIN table, containing - the standard precipitation telemetry from COSMOS sites. + """Query for retreiving data from a given table in oracle format. .. code-block:: sql - SELECT * FROM COSMOS.LEVEL1_PRECIP_1MIN - WHERE site_id = :mysite + SELECT * FROM
ORDER BY date_time DESC FETCH NEXT 1 ROWS ONLY """ - LEVEL_1_PRECIP_RAINE_1MIN = """SELECT * FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN -WHERE site_id = :mysite -ORDER BY date_time DESC -FETCH NEXT 1 ROWS ONLY""" - - """Query for retreiving data from the LEVEL1_PRECIP_RAINE_1MIN table, containing - the rain[e] precipitation telemetry from COSMOS sites. - - .. code-block:: sql - - SELECT * FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN - WHERE site_id = :mysite - ORDER BY date_time DESC - FETCH NEXT 1 ROWS ONLY - """ - - -@enum.unique -class CosmosSiteQuery(StrEnum): - """Contains permitted SQL queries for extracting site IDs from database.""" - - LEVEL_1_SOILMET_30MIN = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_SOILMET_30MIN" - - """Queries unique site IDs from LEVEL1_SOILMET_30MIN. - - .. code-block:: sql - - SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_SOILMET_30MIN - """ - - LEVEL_1_NMDB_1HOUR = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_NMDB_1HOUR" - - """Queries unique site IDs from LEVEL1_NMDB_1HOUR table. - - .. code-block:: sql - - SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_NMDB_1HOUR - """ - - LEVEL_1_PRECIP_1MIN = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_1MIN" + SQLITE_SITE_IDS = "SELECT DISTINCT(site_id) FROM {table}" - """Queries unique site IDs from the LEVEL1_PRECIP_1MIN table. + """Queries unique `site_id `s from a given table. .. code-block:: sql - SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_1MIN + SELECT DISTINCT(site_id) FROM
""" - LEVEL_1_PRECIP_RAINE_1MIN = ( - "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN" - ) + ORACLE_SITE_IDS = "SELECT UNIQUE(site_id) FROM COSMOS.{table}" - """Queries unique site IDs from the LEVEL1_PRECIP_RAINE_1MIN table. + """Queries unique `site_id `s from a given table. .. code-block:: sql - SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN + SELECT UNQIUE(site_id) FROM
""" diff --git a/src/iotswarm/scripts/cli.py b/src/iotswarm/scripts/cli.py index 7d79274..912a7be 100644 --- a/src/iotswarm/scripts/cli.py +++ b/src/iotswarm/scripts/cli.py @@ -1,24 +1,19 @@ """CLI exposed when the package is installed.""" import click -from iotswarm import queries +from iotswarm import __version__ as package_version +from iotswarm.queries import CosmosTable from iotswarm.devices import BaseDevice, CR1000XDevice from iotswarm.swarm import Swarm -from iotswarm.db import Oracle +from iotswarm.db import Oracle, LoopingCsvDB, LoopingSQLite3 from iotswarm.messaging.core import MockMessageConnection from iotswarm.messaging.aws import IotCoreMQTTConnection +import iotswarm.scripts.common as cli_common import asyncio from pathlib import Path import logging -TABLES = [table.name for table in queries.CosmosQuery] - - -@click.command -@click.pass_context -def test(ctx: click.Context): - """Enables testing of cosmos group arguments.""" - print(ctx.obj) +TABLE_NAMES = [table.name for table in CosmosTable] @click.group() @@ -27,18 +22,64 @@ def test(ctx: click.Context): "--log-config", type=click.Path(exists=True), help="Path to a logging config file. Uses default if not given.", + default=Path(Path(__file__).parents[1], "__assets__", "loggers.ini"), ) -def main(ctx: click.Context, log_config: Path): +@click.option( + "--log-level", + type=click.Choice( + ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False + ), + help="Overrides the logging level.", + envvar="IOT_SWARM_LOG_LEVEL", +) +def main(ctx: click.Context, log_config: Path, log_level: str): """Core group of the cli.""" ctx.ensure_object(dict) - if not log_config: - log_config = Path(Path(__file__).parents[1], "__assets__", "loggers.ini") - logging.config.fileConfig(fname=log_config) + logger = logging.getLogger(__name__) + + if log_level: + logger.setLevel(log_level) + click.echo(f"Set log level to {log_level}.") + + ctx.obj["logger"] = logger + + +main.add_command(cli_common.test) + + +@main.command +def get_version(): + """Gets the package version""" + click.echo(package_version) + +@main.command() +@click.pass_context +@cli_common.iotcore_options +def test_mqtt( + ctx, + message, + topic, + client_id: str, + endpoint: str, + cert_path: str, + key_path: str, + ca_cert_path: str, +): + """Tests that a basic message can be sent via mqtt.""" + + connection = IotCoreMQTTConnection( + endpoint=endpoint, + cert_path=cert_path, + key_path=key_path, + ca_cert_path=ca_cert_path, + client_id=client_id, + inherit_logger=ctx.obj["logger"], + ) -main.add_command(test) + connection.send_message(message, topic) @main.group() @@ -78,113 +119,262 @@ def cosmos(ctx: click.Context, site: str, dsn: str, user: str, password: str): ctx.obj["sites"] = site -cosmos.add_command(test) +cosmos.add_command(cli_common.test) @cosmos.command() @click.pass_context -@click.argument("query", type=click.Choice(TABLES)) +@click.argument("table", type=click.Choice(TABLE_NAMES)) @click.option("--max-sites", type=click.IntRange(min=0), default=0) -def list_sites(ctx, query, max_sites): - """Lists site IDs from the database from table QUERY.""" +def list_sites(ctx, table, max_sites): + """Lists unique `site_id` from an oracle database table.""" - async def _list_sites(ctx, query): + async def _list_sites(): oracle = await Oracle.create( dsn=ctx.obj["credentials"]["dsn"], user=ctx.obj["credentials"]["user"], password=ctx.obj["credentials"]["password"], + inherit_logger=ctx.obj["logger"], ) - sites = await oracle.query_site_ids( - queries.CosmosSiteQuery[query], max_sites=max_sites - ) + sites = await oracle.query_site_ids(CosmosTable[table], max_sites=max_sites) return sites - click.echo(asyncio.run(_list_sites(ctx, query))) + click.echo(asyncio.run(_list_sites())) @cosmos.command() @click.pass_context @click.argument( - "provider", - type=click.Choice(["aws"]), -) -@click.argument( - "query", - type=click.Choice(TABLES), -) -@click.argument( - "client-id", - type=click.STRING, - required=True, + "table", + type=click.Choice(TABLE_NAMES), ) +@cli_common.device_options +@cli_common.iotcore_options +@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.") +def mqtt( + ctx, + table, + endpoint, + cert_path, + key_path, + ca_cert_path, + client_id, + sleep_time, + max_cycles, + max_sites, + swarm_name, + delay_start, + mqtt_prefix, + mqtt_suffix, + dry, + device_type, +): + """Sends The cosmos data via MQTT protocol using IoT Core. + Data is from the cosmos database TABLE and sent using CLIENT_ID. + + Currently only supports sending through AWS IoT Core.""" + table = CosmosTable[table] + + async def _mqtt(): + oracle = await Oracle.create( + dsn=ctx.obj["credentials"]["dsn"], + user=ctx.obj["credentials"]["user"], + password=ctx.obj["credentials"]["password"], + inherit_logger=ctx.obj["logger"], + ) + + sites = ctx.obj["sites"] + if len(sites) == 0: + sites = await oracle.query_site_ids(table, max_sites=max_sites) + + if dry == True: + connection = MockMessageConnection(inherit_logger=ctx.obj["logger"]) + else: + connection = IotCoreMQTTConnection( + endpoint=endpoint, + cert_path=cert_path, + key_path=key_path, + ca_cert_path=ca_cert_path, + client_id=client_id, + inherit_logger=ctx.obj["logger"], + ) + + if device_type == "basic": + DeviceClass = BaseDevice + elif device_type == "cr1000x": + DeviceClass = CR1000XDevice + + site_devices = [ + DeviceClass( + site, + oracle, + connection, + sleep_time=sleep_time, + table=table, + max_cycles=max_cycles, + delay_start=delay_start, + mqtt_prefix=mqtt_prefix, + mqtt_suffix=mqtt_suffix, + inherit_logger=ctx.obj["logger"], + ) + for site in sites + ] + + swarm = Swarm(site_devices, swarm_name) + + await swarm.run() + + asyncio.run(_mqtt()) + + +@main.group() +@click.pass_context @click.option( - "--endpoint", + "--site", type=click.STRING, - required=True, - envvar="IOT_SWARM_MQTT_ENDPOINT", - help="Endpoint of the MQTT receiving host.", + multiple=True, + help="Adds a site to be initialized. Can be invoked multiple times for other sites." + " Grabs all sites from database query if none provided", ) @click.option( - "--cert-path", + "--file", type=click.Path(exists=True), required=True, - envvar="IOT_SWARM_MQTT_CERT_PATH", - help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.", + envvar="IOT_SWARM_CSV_DB", + help="*.csv file used to instantiate a pandas database.", ) +def looping_csv(ctx, site, file): + """Instantiates a pandas dataframe from a csv file which is used as the database. + Responsibility falls on the user to ensure the correct file is selected.""" + + ctx.obj["db"] = LoopingCsvDB(file) + ctx.obj["sites"] = site + + +looping_csv.add_command(cli_common.test) +looping_csv.add_command(cli_common.list_sites) + + +@looping_csv.command() +@click.pass_context +@cli_common.device_options +@cli_common.iotcore_options +@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.") +def mqtt( + ctx, + endpoint, + cert_path, + key_path, + ca_cert_path, + client_id, + sleep_time, + max_cycles, + max_sites, + swarm_name, + delay_start, + mqtt_prefix, + mqtt_suffix, + dry, + device_type, +): + """Sends The cosmos data via MQTT protocol using IoT Core. + Data is collected from the db using QUERY and sent using CLIENT_ID. + + Currently only supports sending through AWS IoT Core.""" + + async def _mqtt(): + + sites = ctx.obj["sites"] + db = ctx.obj["db"] + if len(sites) == 0: + sites = db.query_site_ids(max_sites=max_sites) + + if dry == True: + connection = MockMessageConnection(inherit_logger=ctx.obj["logger"]) + else: + connection = IotCoreMQTTConnection( + endpoint=endpoint, + cert_path=cert_path, + key_path=key_path, + ca_cert_path=ca_cert_path, + client_id=client_id, + inherit_logger=ctx.obj["logger"], + ) + + if device_type == "basic": + DeviceClass = BaseDevice + elif device_type == "cr1000x": + DeviceClass = CR1000XDevice + + site_devices = [ + DeviceClass( + site, + db, + connection, + sleep_time=sleep_time, + max_cycles=max_cycles, + delay_start=delay_start, + mqtt_prefix=mqtt_prefix, + mqtt_suffix=mqtt_suffix, + inherit_logger=ctx.obj["logger"], + ) + for site in sites + ] + + swarm = Swarm(site_devices, swarm_name) + + await swarm.run() + + asyncio.run(_mqtt()) + + +@main.group() +@click.pass_context @click.option( - "--key-path", - type=click.Path(exists=True), - required=True, - envvar="IOT_SWARM_MQTT_KEY_PATH", - help="Path to the private key that pairs with the `--cert-path`.", + "--site", + type=click.STRING, + multiple=True, + help="Adds a site to be initialized. Can be invoked multiple times for other sites." + " Grabs all sites from database query if none provided", ) @click.option( - "--ca-cert-path", + "--file", type=click.Path(exists=True), required=True, - envvar="IOT_SWARM_MQTT_CA_CERT_PATH", - help="Path to the root Certificate Authority (CA) for the MQTT host.", -) -@click.option( - "--sleep-time", - type=click.INT, - help="The number of seconds each site goes idle after sending a message.", -) -@click.option( - "--max-cycles", - type=click.IntRange(0), - help="Maximum number message sending cycles. Runs forever if set to 0.", -) -@click.option( - "--max-sites", - type=click.IntRange(0), - help="Maximum number of sites allowed to initialize. No limit if set to 0.", -) -@click.option( - "--swarm-name", type=click.STRING, help="Name given to swarm. Appears in the logs." + envvar="IOT_SWARM_LOCAL_DB", + help="*.db file used to instantiate a sqlite3 database.", ) -@click.option( - "--delay-start", - is_flag=True, - default=False, - help="Adds a random delay before the first message from each site up to `--sleep-time`.", -) -@click.option( - "--mqtt-prefix", - type=click.STRING, - help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", -) -@click.option( - "--mqtt-suffix", - type=click.STRING, - help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", +def looping_sqlite3(ctx, site, file): + """Instantiates a sqlite3 database as sensor source..""" + ctx.obj["db"] = LoopingSQLite3(file) + ctx.obj["sites"] = site + + +looping_sqlite3.add_command(cli_common.test) + + +@looping_sqlite3.command +@click.pass_context +@click.option("--max-sites", type=click.IntRange(min=0), default=0) +@click.argument( + "table", + type=click.Choice(TABLE_NAMES), ) +def list_sites(ctx, max_sites, table): + """Prints the sites present in database.""" + sites = ctx.obj["db"].query_site_ids(table, max_sites=max_sites) + click.echo(sites) + + +@looping_sqlite3.command() +@click.pass_context +@cli_common.device_options +@cli_common.iotcore_options +@click.argument("table", type=click.Choice(TABLE_NAMES)) @click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.") -@click.option("--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic") def mqtt( ctx, - provider, - query, + table, endpoint, cert_path, key_path, @@ -200,34 +390,30 @@ def mqtt( dry, device_type, ): - """Sends The cosmos data via MQTT protocol using PROVIDER. + """Sends The cosmos data via MQTT protocol using IoT Core. Data is collected from the db using QUERY and sent using CLIENT_ID. Currently only supports sending through AWS IoT Core.""" - async def _mqtt(): - oracle = await Oracle.create( - dsn=ctx.obj["credentials"]["dsn"], - user=ctx.obj["credentials"]["user"], - password=ctx.obj["credentials"]["password"], - ) + table = CosmosTable[table] - data_query = queries.CosmosQuery[query] - site_query = queries.CosmosSiteQuery[query] + async def _mqtt(): sites = ctx.obj["sites"] + db = ctx.obj["db"] if len(sites) == 0: - sites = await oracle.query_site_ids(site_query, max_sites=max_sites) + sites = db.query_site_ids(table, max_sites=max_sites) if dry == True: - connection = MockMessageConnection() - elif provider == "aws": + connection = MockMessageConnection(inherit_logger=ctx.obj["logger"]) + else: connection = IotCoreMQTTConnection( endpoint=endpoint, cert_path=cert_path, key_path=key_path, ca_cert_path=ca_cert_path, client_id=client_id, + inherit_logger=ctx.obj["logger"], ) if device_type == "basic": @@ -238,14 +424,15 @@ async def _mqtt(): site_devices = [ DeviceClass( site, - oracle, + db, connection, sleep_time=sleep_time, - query=data_query, max_cycles=max_cycles, delay_start=delay_start, mqtt_prefix=mqtt_prefix, mqtt_suffix=mqtt_suffix, + table=table, + inherit_logger=ctx.obj["logger"], ) for site in sites ] diff --git a/src/iotswarm/scripts/common.py b/src/iotswarm/scripts/common.py new file mode 100644 index 0000000..1b8b1e0 --- /dev/null +++ b/src/iotswarm/scripts/common.py @@ -0,0 +1,113 @@ +"""Location for common CLI commands""" + +import click + + +def device_options(function): + click.option( + "--sleep-time", + type=click.INT, + help="The number of seconds each site goes idle after sending a message.", + )(function) + + click.option( + "--max-cycles", + type=click.IntRange(0), + help="Maximum number message sending cycles. Runs forever if set to 0.", + )(function) + + click.option( + "--max-sites", + type=click.IntRange(0), + help="Maximum number of sites allowed to initialize. No limit if set to 0.", + )(function) + + click.option( + "--swarm-name", + type=click.STRING, + help="Name given to swarm. Appears in the logs.", + )(function) + + click.option( + "--delay-start", + is_flag=True, + default=False, + help="Adds a random delay before the first message from each site up to `--sleep-time`.", + )(function) + + click.option( + "--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic" + )(function) + + return function + + +def iotcore_options(function): + click.argument( + "client-id", + type=click.STRING, + required=True, + )(function) + + click.option( + "--endpoint", + type=click.STRING, + required=True, + envvar="IOT_SWARM_MQTT_ENDPOINT", + help="Endpoint of the MQTT receiving host.", + )(function) + + click.option( + "--cert-path", + type=click.Path(exists=True), + required=True, + envvar="IOT_SWARM_MQTT_CERT_PATH", + help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.", + )(function) + + click.option( + "--key-path", + type=click.Path(exists=True), + required=True, + envvar="IOT_SWARM_MQTT_KEY_PATH", + help="Path to the private key that pairs with the `--cert-path`.", + )(function) + + click.option( + "--ca-cert-path", + type=click.Path(exists=True), + required=True, + envvar="IOT_SWARM_MQTT_CA_CERT_PATH", + help="Path to the root Certificate Authority (CA) for the MQTT host.", + )(function) + + click.option( + "--mqtt-prefix", + type=click.STRING, + help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", + )(function) + + click.option( + "--mqtt-suffix", + type=click.STRING, + help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.", + )(function) + + return function + + +@click.command +@click.pass_context +@click.option("--max-sites", type=click.IntRange(min=0), default=0) +def list_sites(ctx, max_sites): + """Prints the sites present in database.""" + + sites = ctx.obj["db"].query_site_ids(max_sites=max_sites) + click.echo(sites) + + +@click.command +@click.pass_context +def test(ctx: click.Context): + """Enables testing of cosmos group arguments.""" + print(ctx.obj) diff --git a/src/iotswarm/utils.py b/src/iotswarm/utils.py index 6d9d988..3a935eb 100644 --- a/src/iotswarm/utils.py +++ b/src/iotswarm/utils.py @@ -1,6 +1,10 @@ """Module for handling commonly reused utility functions.""" from datetime import date, datetime +from pathlib import Path +import pandas +import sqlite3 +from glob import glob def json_serial(obj: object): @@ -13,3 +17,53 @@ def json_serial(obj: object): return obj.__json__() raise TypeError(f"Type {type(obj)} is not serializable.") + + +def build_database_from_csv( + csv_file: str | Path, + database: str | Path, + table_name: str, + sort_by: str | None = None, + date_time_format: str = r"%d-%b-%y %H.%M.%S", +) -> None: + """Adds a database table using a csv file with headers. + + Args: + csv_file: A path to the csv. + database: Output destination of the database. File is created if not + existing. + table_name: Name of the table to add into database. + sort_by: Column to sort by + date_time_format: Format of datetime column + """ + + if not isinstance(csv_file, Path): + csv_file = Path(csv_file) + + if not isinstance(database, Path): + database = Path(database) + + if not csv_file.exists(): + raise FileNotFoundError(f'csv_file does not exist: "{csv_file}"') + + if not database.parent.exists(): + raise NotADirectoryError(f'Database directory not found: "{database.parent}"') + + with sqlite3.connect(database) as conn: + print( + f'Writing table: "{table_name}" from csv_file: "{csv_file}" to db: "{database}"' + ) + print("Loading csv") + df = pandas.read_csv(csv_file) + print("Done") + print("Formatting dates") + df["DATE_TIME"] = pandas.to_datetime(df["DATE_TIME"], format=date_time_format) + print("Done") + if sort_by is not None: + print("Sorting.") + df = df.sort_values(by=sort_by) + print("Done") + + print("Writing to db.") + df.to_sql(table_name, conn, if_exists="replace", index=False) + print("Writing complete.") diff --git a/src/tests/data/ALIC1_4_ROWS.csv b/src/tests/data/ALIC1_4_ROWS.csv new file mode 100644 index 0000000..55659cd --- /dev/null +++ b/src/tests/data/ALIC1_4_ROWS.csv @@ -0,0 +1,5 @@ +"SITE_ID","DATE_TIME","WD","WD_STD","WS","WS_STD","PA","PA_STD","RH","RH_STD","TA","TA_STD","Q","Q_STD","SWIN","SWIN_STD","SWOUT","SWOUT_STD","LWIN_UNC","LWIN_UNC_STD","LWOUT_UNC","LWOUT_UNC_STD","LWIN","LWIN_STD","LWOUT","LWOUT_STD","TNR01C","TNR01C_STD","PRECIP","PRECIP_DIAG","SNOWD_DISTANCE_UNC","SNOWD_SIGNALQUALITY","CTS_MOD","CTS_BARE","CTS_MOD2","CTS_SNOW","PROFILE_VWC15","PROFILE_SOILEC15","PROFILE_VWC40","PROFILE_SOILEC40","PROFILE_VWC65","PROFILE_SOILEC65","TDT1_VWC","TDT1_TSOIL","TDT1_SOILPERM","TDT1_SOILEC","TDT2_VWC","TDT2_TSOIL","TDT2_SOILPERM","TDT2_SOILEC","STP_TSOIL50","STP_TSOIL20","STP_TSOIL10","STP_TSOIL5","STP_TSOIL2","G1","G1_STD","G1_MV","G1_CAL","G2","G2_STD","G2_MV","G2_CAL","BATTV","SCANS","RECORD","CTS_MOD_DIAG_1","CTS_MOD_DIAG_2","CTS_MOD_DIAG_3","CTS_MOD_DIAG_4","CTS_MOD_DIAG_5","UX","UZ","UY","STDEV_UX","STDEV_UY","STDEV_UZ","COV_UX_UY","COV_UX_UZ","COV_UY_UZ","CTS_SNOW_DIAG_1","CTS_SNOW_DIAG_2","CTS_SNOW_DIAG_3","CTS_SNOW_DIAG_4","CTS_SNOW_DIAG_5","SNOWD_DISTANCE_COR","WS_RES","HS","TAU","U_STAR","TS","STDEV_TS","COV_TS_UX","COV_TS_UY","COV_TS_UZ","RHO_A_MEAN","E_SAT_MEAN","METPAK_SAMPLES","CTS_MOD_RH","CTS_MOD_TEMP","TDT3_VWC","TDT3_TSOIL","TDT3_SOILEC","TDT3_SOILPERM","TDT4_VWC","TDT4_TSOIL","TDT4_SOILEC","TDT4_SOILPERM","TDT5_VWC","TDT5_TSOIL","TDT5_SOILEC","TDT5_SOILPERM","TDT6_VWC","TDT6_TSOIL","TDT6_SOILEC","TDT6_SOILPERM","TDT7_VWC","TDT7_TSOIL","TDT7_SOILEC","TDT7_SOILPERM","TDT8_VWC","TDT8_TSOIL","TDT8_SOILEC","TDT8_SOILPERM","TDT9_VWC","TDT9_TSOIL","TDT9_SOILEC","TDT9_SOILPERM","TDT10_VWC","TDT10_TSOIL","TDT10_SOILEC","TDT10_SOILPERM","CTS_SNOW_TEMP","CTS_SNOW_RH","CTS_MOD2_RH","CTS_MOD2_TEMP","SWIN_MULT","SWOUT_MULT","LWIN_MULT","LWOUT_MULT","PRECIP_TIPPING_A","PRECIP_TIPPING_B","PRECIP_TIPPING2","TBR_TIP","CTS_MOD_PERIOD","CTS_MOD2_PERIOD","CTS_SNOW_PERIOD","WM_NAN_COUNT","RH_INTERNAL" +"ALIC1",01-JUN-24,353.4,28.38,2.397,1.309,1010.926,0.044,72.58,0.641,12.34,0.05,7.891,,-0.941,0.245,0.454,0.097,-17.14,2.736,-0.303,0.172,359.1,2.68,376,0.191,12.27,0.022,0,0,,,519,,,,,,,,,,43.84,12.72,27.98,0.48,35.03,12.61,20.29,0.24,11.91,12.37,12.45,12.45,12.44,-3.3016,0.10141,-0.15144,21.8,-3.27326,0.1105,-0.14011,23.36,13.21,180,11153,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.23,1.433,,2.1,18.9,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1800,,,99999.99999, +"ALIC1",01-JUN-24,358,29.12,2.045,1.015,1011.057,0.049,70.59,0.34,12.23,0.139,7.623,,-2.57,1.666,0.552,0.318,-45.27,28,-1.624,1.119,330.1,28.91,373.8,2.085,12.1,0.203,0,0,,,536,,,,,,,,,,43.61,12.72,27.77,0.51,34.91,12.59,20.2,0.27,11.91,12.36,12.43,12.42,12.42,-2.78394,0.58028,-0.12777,21.77,-3.0528,0.35922,-0.13067,23.36,13.19,180,11154,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.231,1.424,,2.1,18.8,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1800,,,99999.99999, +"ALIC1",01-JUN-24,354.8,26.46,1.531,0.902,1011.085,0.036,73,0.945,11.38,0.31,7.471,,-4.584,0.342,1.413,0.163,-76.38,4.869,-2.488,0.493,293,3.633,366.9,1.718,10.95,0.381,0,0,,,504,,,,,,,,,,43.16,12.7,27.35,0.54,34.91,12.59,20.2,0.25,11.89,12.34,12.41,12.39,12.38,-3.57444,0.69214,-0.16418,21.77,-3.62632,0.49502,-0.15522,23.36,13.22,180,11155,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.235,1.346,,2.1,18.6,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1799,,,99999.99999, +"ALIC1",01-JUN-24,323.9,47.78,0.916,0.541,1011.082,0.038,76.2,1.033,10.42,0.285,7.34,,-4.493,0.27,1.759,0.13,-75.77,0.74,-2.514,0.463,287.7,1.127,360.9,1.949,9.8,0.328,0,0,,,461,,,,,,,,,,43.33,12.68,27.51,0.51,34.75,12.58,20.07,0.26,11.9,12.35,12.4,12.37,12.36,-6.14165,0.82193,-0.2821,21.77,-5.52521,0.57313,-0.2365,23.36,13.22,180,11156,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.239,1.262,,2.1,18.4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.025,52.2193,85.4701,90.4977,,,,,1799,,,99999.99999, \ No newline at end of file diff --git a/src/tests/data/build_database.py b/src/tests/data/build_database.py new file mode 100644 index 0000000..eef80be --- /dev/null +++ b/src/tests/data/build_database.py @@ -0,0 +1,22 @@ +"""This script generates a .db SQLite file from the accompanying .csv data file. +It contains just a few rows of data and is used for testing purposes only.""" + +from iotswarm.utils import build_database_from_csv +from pathlib import Path +from iotswarm.queries import CosmosTable + + +def main(): + data_dir = Path(__file__).parent + data_file = Path(data_dir, "ALIC1_4_ROWS.csv") + database_file = Path(data_dir, "database.db") + + data_table = CosmosTable.LEVEL_1_SOILMET_30MIN + + build_database_from_csv( + data_file, database_file, data_table.value, date_time_format=r"%d-%b-%y" + ) + + +if __name__ == "__main__": + main() diff --git a/src/tests/test_cli.py b/src/tests/test_cli.py new file mode 100644 index 0000000..c771148 --- /dev/null +++ b/src/tests/test_cli.py @@ -0,0 +1,63 @@ +from click.testing import CliRunner +from iotswarm.scripts import cli +from parameterized import parameterized +import re + +RUNNER = CliRunner() + + +def test_main_ctx(): + result = RUNNER.invoke(cli.main, ["test"]) + assert not result.exception + assert result.output == "{'logger': }\n" + + +def test_main_log_config(): + with RUNNER.isolated_filesystem(): + with open("logger.ini", "w") as f: + f.write( + """[loggers] +keys=root + +[handlers] +keys=consoleHandler + +[formatters] +keys=sampleFormatter + +[logger_root] +level=INFO +handlers=consoleHandler + +[handler_consoleHandler] +class=StreamHandler +level=INFO +formatter=sampleFormatter +args=(sys.stdout,) + +[formatter_sampleFormatter] +format=%(asctime)s - %(name)s - %(levelname)s - %(message)s""" + ) + result = RUNNER.invoke(cli.main, ["--log-config", "logger.ini", "test"]) + assert not result.exception + assert result.output == "{'logger': }\n" + + +@parameterized.expand(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) +def test_log_level_set(log_level): + + result = RUNNER.invoke(cli.main, ["--log-level", log_level, "test"]) + assert not result.exception + assert ( + result.output + == f"Set log level to {log_level}.\n{{'logger': }}\n" + ) + + +def test_get_version(): + """Tests that the verison can be retrieved.""" + + result = RUNNER.invoke(cli.main, ["get-version"]) + print(result.output) + assert not result.exception + assert re.match(r"^\d+\.\d+\.\d+$", result.output) diff --git a/src/tests/test_db.py b/src/tests/test_db.py index 0cbb27f..34e8675 100644 --- a/src/tests/test_db.py +++ b/src/tests/test_db.py @@ -1,21 +1,29 @@ import unittest import pytest import config -import pathlib +from pathlib import Path from iotswarm import db -from iotswarm.queries import CosmosQuery, CosmosSiteQuery +from iotswarm.devices import BaseDevice +from iotswarm.messaging.core import MockMessageConnection +from iotswarm.queries import CosmosTable +from iotswarm.swarm import Swarm from parameterized import parameterized import logging from unittest.mock import patch +import pandas as pd +from glob import glob +from math import isnan +import sqlite3 -CONFIG_PATH = pathlib.Path( - pathlib.Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg" +CONFIG_PATH = Path( + Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg" ) config_exists = pytest.mark.skipif( not CONFIG_PATH.exists(), reason="Config file `config.cfg` not found in root directory.", ) +COSMOS_TABLES = list(CosmosTable) class TestBaseDatabase(unittest.TestCase): @patch.multiple(db.BaseDatabase, __abstractmethods__=set()) @@ -27,6 +35,7 @@ def test_instantiation(self): self.assertIsNone(inst.query_latest_from_site()) + class TestMockDB(unittest.TestCase): def test_instantiation(self): @@ -58,22 +67,23 @@ def test_logger_set(self, logger, expected): self.assertEqual(inst._instance_logger.parent, expected) - @parameterized.expand( - [ - [None, "MockDB()"], - [ - logging.getLogger("notroot"), - "MockDB(inherit_logger=)", - ], - ] - ) - def test__repr__(self, logger, expected): + + def test__repr__no_logger(self): - inst = db.MockDB(inherit_logger=logger) + inst = db.MockDB() + + self.assertEqual(inst.__repr__(), "MockDB()") - self.assertEqual(inst.__repr__(), expected) + def test__repr__logger_given(self): + logger = logging.getLogger("testdblogger") + logger.setLevel(logging.CRITICAL) + expected = "MockDB(inherit_logger=)" + mock = db.MockDB(inherit_logger=logger) + self.assertEqual(mock.__repr__(), expected) + + class TestOracleDB(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): @@ -86,6 +96,7 @@ async def asyncSetUp(self): creds["user"], password=creds["password"], ) + self.table = CosmosTable.LEVEL_1_SOILMET_30MIN async def asyncTearDown(self) -> None: await self.oracle.connection.close() @@ -103,25 +114,19 @@ async def test_instantiation(self): async def test_latest_data_query(self): site_id = "MORLY" - query = CosmosQuery.LEVEL_1_SOILMET_30MIN - row = await self.oracle.query_latest_from_site(site_id, query) + row = await self.oracle.query_latest_from_site(site_id, self.table) self.assertEqual(row["SITE_ID"], site_id) - @parameterized.expand([ - CosmosSiteQuery.LEVEL_1_NMDB_1HOUR, - CosmosSiteQuery.LEVEL_1_SOILMET_30MIN, - CosmosSiteQuery.LEVEL_1_PRECIP_1MIN, - CosmosSiteQuery.LEVEL_1_PRECIP_RAINE_1MIN, - ]) + @parameterized.expand(COSMOS_TABLES) @pytest.mark.oracle @pytest.mark.asyncio @pytest.mark.slow @config_exists - async def test_site_id_query(self,query): + async def test_site_id_query(self, table): - sites = await self.oracle.query_site_ids(query) + sites = await self.oracle.query_site_ids(table) self.assertIsInstance(sites, list) @@ -137,9 +142,7 @@ async def test_site_id_query(self,query): @config_exists async def test_site_id_query_max_sites(self, max_sites): - query = CosmosSiteQuery.LEVEL_1_SOILMET_30MIN - - sites = await self.oracle.query_site_ids(query, max_sites=max_sites) + sites = await self.oracle.query_site_ids(self.table, max_sites=max_sites) self.assertEqual(len(sites), max_sites) @@ -147,33 +150,33 @@ async def test_site_id_query_max_sites(self, max_sites): @pytest.mark.asyncio @pytest.mark.oracle @config_exists - async def test_bad_latest_data_query_type(self): + async def test_bad_latest_data_table_type(self): site_id = "MORLY" - query = "sql injection goes brr" + table = "sql injection goes brr" with self.assertRaises(TypeError): - await self.oracle.query_latest_from_site(site_id, query) + await self.oracle.query_latest_from_site(site_id, table) @pytest.mark.asyncio @pytest.mark.oracle @config_exists - async def test_bad_site_query_type(self): + async def test_bad_site_table_type(self): - query = "sql injection goes brr" + table = "sql injection goes brr" with self.assertRaises(TypeError): - await self.oracle.query_site_ids(query) + await self.oracle.query_site_ids(table) @parameterized.expand([-1, -100, "STRING"]) @pytest.mark.asyncio @pytest.mark.oracle @config_exists - async def test_bad_site_query_max_sites_type(self, max_sites): + async def test_bad_site_table_max_sites_type(self, max_sites): """Tests bad values for max_sites.""" with self.assertRaises((TypeError,ValueError)): - await self.oracle.query_site_ids(CosmosSiteQuery.LEVEL_1_SOILMET_30MIN, max_sites=max_sites) + await self.oracle.query_site_ids(self.table, max_sites=max_sites) @pytest.mark.asyncio @pytest.mark.oracle @@ -181,7 +184,7 @@ async def test_bad_site_query_max_sites_type(self, max_sites): async def test__repr__(self): """Tests string representation.""" - oracle1 = self.oracle = await db.Oracle.create( + oracle1 = await db.Oracle.create( self.creds["dsn"], self.creds["user"], password=self.creds["password"], @@ -192,7 +195,7 @@ async def test__repr__(self): oracle1.__repr__(), expected1 ) - oracle2 = self.oracle = await db.Oracle.create( + oracle2 = await db.Oracle.create( self.creds["dsn"], self.creds["user"], password=self.creds["password"], @@ -207,6 +210,293 @@ async def test__repr__(self): oracle2.__repr__(), expected2 ) +CSV_PATH = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data") +CSV_DATA_FILES = [Path(x) for x in glob(str(Path(CSV_PATH, "*.csv")))] +sqlite_db_exist = pytest.mark.skipif(not Path(CSV_PATH, "cosmos.db").exists(), reason="Local cosmos.db does not exist.") + +data_files_exist = pytest.mark.skipif( + not CSV_PATH.exists() or len(CSV_DATA_FILES) == 0, + reason="No data files are present" +) + +class TestLoopingCsvDB(unittest.TestCase): + """Tests the LoopingCsvDB class.""" + + def setUp(self): + self.data_path = {v.name.removesuffix("_DATA_TABLE.csv"):v for v in CSV_DATA_FILES} + + self.soilmet_table = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"]) + self.maxDiff = None + + @data_files_exist + @pytest.mark.slow + def test_instantiation(self): + """Tests that the database can be instantiated.""" + + database = self.soilmet_table + + self.assertIsInstance(database, db.LoopingCsvDB) + self.assertIsInstance(database, db.BaseDatabase) + + + self.assertIsInstance(database.cache, dict) + self.assertIsInstance(database.connection, pd.DataFrame) + + @data_files_exist + @pytest.mark.slow + def test_site_data_return_value(self): + database = self.soilmet_table + + site = "MORLY" + + data = database.query_latest_from_site(site) + + expected_cache = {site: 1} + self.assertDictEqual(database.cache, expected_cache) + + self.assertIsInstance(data, dict) + + @data_files_exist + @pytest.mark.slow + def test_multiple_sites_added_to_cache(self): + sites = ["ALIC1", "MORLY", "HOLLN","EUSTN"] + + database = self.soilmet_table + + data = [database.query_latest_from_site(x) for x in sites] + + for i, site in enumerate(sites): + self.assertEqual(site, data[i]["SITE_ID"]) + self.assertIn(site, database.cache) + self.assertEqual(database.cache[site], 1) + + @data_files_exist + @pytest.mark.slow + def test_cache_incremented_on_each_request(self): + database = self.soilmet_table + + site = "MORLY" + + expected = 1 + + last_data = None + for _ in range(10): + data = database.query_latest_from_site(site) + self.assertNotEqual(last_data, data) + self.assertEqual(expected, database.cache[site]) + + last_data = data + expected += 1 + + self.assertEqual(expected, 11) + + @data_files_exist + @pytest.mark.slow + def test_cache_counter_restarts_at_end(self): + + short_table_path = Path(Path(__file__).parent, "data", "ALIC1_4_ROWS.csv") + database = db.LoopingCsvDB(short_table_path) + + site = "ALIC1" + + expected = [1,2,3,4,1] + data = [] + for e in expected: + data.append(database.query_latest_from_site(site)) + + self.assertEqual(database.cache[site], e) + + for key in data[0].keys(): + try: + self.assertEqual(data[0][key], data[-1][key]) + except AssertionError as err: + if not isnan(data[0][key]) and isnan(data[-1][key]): + raise(err) + + self.assertEqual(len(expected), len(data)) + + @data_files_exist + @pytest.mark.slow + def test_site_ids_can_be_retrieved(self): + database = self.soilmet_table + + site_ids_full = database.query_site_ids() + site_ids_exp_full = database.query_site_ids(max_sites=0) + + + self.assertIsInstance(site_ids_full, list) + + self.assertGreater(len(site_ids_full), 0) + for site in site_ids_full: + self.assertIsInstance(site, str) + + self.assertEqual(len(site_ids_full), len(site_ids_exp_full)) + + site_ids_limit = database.query_site_ids(max_sites=5) + + self.assertEqual(len(site_ids_limit), 5) + self.assertGreater(len(site_ids_full), len(site_ids_limit)) + + with self.assertRaises(ValueError): + + database.query_site_ids(max_sites=-1) + +class TestLoopingCsvDBEndToEnd(unittest.IsolatedAsyncioTestCase): + """Tests the LoopingCsvDB class.""" + + def setUp(self): + self.data_path = {v.name.removesuffix("_DATA_TABLE.csv"):v for v in CSV_DATA_FILES} + self.maxDiff = None + + @data_files_exist + @pytest.mark.slow + async def test_flow_with_device_attached(self): + """Tests that data is looped through with a device making requests.""" + + database = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"]) + device = BaseDevice("ALIC1", database, MockMessageConnection(), sleep_time=0, max_cycles=5) + + await device.run() + + self.assertDictEqual(database.cache, {"ALIC1": 5}) + + @data_files_exist + @pytest.mark.slow + async def test_flow_with_swarm_attached(self): + """Tests that the database is looped through correctly with multiple sites in a swarm.""" + + database = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"]) + sites = ["MORLY", "ALIC1", "EUSTN"] + cycles = [1, 4, 6] + devices = [ + BaseDevice(s, database, MockMessageConnection(), sleep_time=0, max_cycles=c) + for (s,c) in zip(sites, cycles) + ] + + swarm = Swarm(devices) + + await swarm.run() + + self.assertDictEqual(database.cache, {"MORLY": 1, "ALIC1": 4, "EUSTN": 6}) + +class TestSqliteDB(unittest.TestCase): + + @sqlite_db_exist + def setUp(self): + self.db_path = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data", "cosmos.db") + self.table = CosmosTable.LEVEL_1_SOILMET_30MIN + + if self.db_path.exists(): + self.database = db.LoopingSQLite3(self.db_path) + self.maxDiff = None + + @sqlite_db_exist + def test_instantiation(self): + self.assertIsInstance(self.database, db.LoopingSQLite3) + self.assertIsInstance(self.database.connection, sqlite3.Connection) + + @sqlite_db_exist + def test_latest_data(self): + + site_id = "MORLY" + + data = self.database.query_latest_from_site(site_id, self.table) + + self.assertIsInstance(data, dict) + + @sqlite_db_exist + def test_site_id_query(self): + + sites = self.database.query_site_ids(self.table) + + self.assertGreater(len(sites), 0) + + self.assertIsInstance(sites, list) + + for site in sites: + self.assertIsInstance(site, str) + + @sqlite_db_exist + def test_multiple_sites_added_to_cache(self): + sites = ["ALIC1", "MORLY", "HOLLN","EUSTN"] + + data = [self.database.query_latest_from_site(x, self.table) for x in sites] + + for i, site in enumerate(sites): + self.assertEqual(site, data[i]["SITE_ID"]) + self.assertIn(site, self.database.cache) + self.assertEqual(self.database.cache[site], 0) + + @sqlite_db_exist + def test_cache_incremented_on_each_request(self): + site = "MORLY" + + last_data = {} + for i in range(3): + if i == 0: + self.assertEqual(self.database.cache, {}) + else: + self.assertEqual(i-1, self.database.cache[site]) + data = self.database.query_latest_from_site(site, self.table) + self.assertNotEqual(last_data, data) + + last_data = data + + @sqlite_db_exist + def test_cache_counter_restarts_at_end(self): + database = db.LoopingSQLite3(Path(Path(__file__).parent, "data", "database.db")) + + site = "ALIC1" + + expected = [0,1,2,3,0] + data = [] + for e in expected: + data.append(database.query_latest_from_site(site, self.table)) + + self.assertEqual(database.cache[site], e) + + self.assertEqual(data[0], data[-1]) + + self.assertEqual(len(expected), len(data)) + +class TestLoopingSQLite3DBEndToEnd(unittest.IsolatedAsyncioTestCase): + """Tests the LoopingCsvDB class.""" + + @sqlite_db_exist + def setUp(self): + self.db_path = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data", "cosmos.db") + if self.db_path.exists(): + self.database = db.LoopingSQLite3(self.db_path) + self.maxDiff = None + self.table = CosmosTable.LEVEL_1_PRECIP_1MIN + + @sqlite_db_exist + async def test_flow_with_device_attached(self): + """Tests that data is looped through with a device making requests.""" + + device = BaseDevice("ALIC1", self.database, MockMessageConnection(), table=self.table, sleep_time=0, max_cycles=5) + + await device.run() + + self.assertDictEqual(self.database.cache, {"ALIC1": 4}) + + @sqlite_db_exist + async def test_flow_with_swarm_attached(self): + """Tests that the database is looped through correctly with multiple sites in a swarm.""" + + sites = ["MORLY", "ALIC1", "EUSTN"] + cycles = [1, 2, 3] + devices = [ + BaseDevice(s, self.database, MockMessageConnection(), sleep_time=0, max_cycles=c,table=self.table) + for (s,c) in zip(sites, cycles) + ] + + swarm = Swarm(devices) + + await swarm.run() + + self.assertDictEqual(self.database.cache, {"MORLY": 0, "ALIC1": 1, "EUSTN": 2}) + if __name__ == "__main__": unittest.main() diff --git a/src/tests/test_devices.py b/src/tests/test_devices.py index 1d68a52..efb61d3 100644 --- a/src/tests/test_devices.py +++ b/src/tests/test_devices.py @@ -5,24 +5,28 @@ import json from iotswarm.utils import json_serial from iotswarm.devices import BaseDevice, CR1000XDevice, CR1000XField -from iotswarm.db import Oracle, BaseDatabase, MockDB -from iotswarm.queries import CosmosQuery, CosmosSiteQuery +from iotswarm.db import Oracle, BaseDatabase, MockDB, LoopingSQLite3 +from iotswarm.queries import CosmosQuery, CosmosTable from iotswarm.messaging.core import MockMessageConnection, MessagingBaseClass from iotswarm.messaging.aws import IotCoreMQTTConnection from parameterized import parameterized from unittest.mock import patch -import pathlib +from pathlib import Path import config from datetime import datetime, timedelta -CONFIG_PATH = pathlib.Path( - pathlib.Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg" +CONFIG_PATH = Path( + Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg" ) config_exists = pytest.mark.skipif( not CONFIG_PATH.exists(), reason="Config file `config.cfg` not found in root directory.", ) +DATA_DIR = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data") +sqlite_db_exist = pytest.mark.skipif(not Path(DATA_DIR, "cosmos.db").exists(), reason="Local cosmos.db does not exist.") + + class TestBaseClass(unittest.IsolatedAsyncioTestCase): def setUp(self): @@ -169,26 +173,12 @@ def test__repr__(self, data_source, kwargs, expected): self.assertEqual(repr(instance), expected) - def test_query_not_set_for_non_oracle_db(self): + def test_table_not_set_for_non_oracle_db(self): - inst = BaseDevice("test", MockDB(), MockMessageConnection(), query=CosmosQuery.LEVEL_1_SOILMET_30MIN) + inst = BaseDevice("test", MockDB(), MockMessageConnection(), table=CosmosTable.LEVEL_1_NMDB_1HOUR) with self.assertRaises(AttributeError): - inst.query - - def test_prefix_suffix_not_set_for_non_mqtt(self): - "Tests that mqtt prefix and suffix not set for non MQTT messaging." - - inst = BaseDevice("site-1", self.data_source, MockMessageConnection(), mqtt_prefix="prefix", mqtt_suffix="suffix") - - with self.assertRaises(AttributeError): - inst.mqtt_topic - - with self.assertRaises(AttributeError): - inst.mqtt_prefix - - with self.assertRaises(AttributeError): - inst.mqtt_suffix + inst.table @parameterized.expand( [ @@ -204,19 +194,6 @@ def test_logger_set(self, logger, expected): self.assertEqual(inst._instance_logger.parent, expected) - @parameterized.expand([ - [None, None, None], - ["topic", None, None], - ["topic", "prefix", None], - ["topic", "prefix", "suffix"], - ]) - def test__repr__mqtt_opts_no_mqtt_connection(self, topic, prefix, suffix): - """Tests that the __repr__ method returns correctly with MQTT options set.""" - expected = 'BaseDevice("site-id", MockDB(), MockMessageConnection())' - inst = BaseDevice("site-id", MockDB(), MockMessageConnection(), mqtt_topic=topic, mqtt_prefix=prefix, mqtt_suffix=suffix) - - self.assertEqual(inst.__repr__(), expected) - class TestBaseDeviceMQTTOptions(unittest.TestCase): def setUp(self) -> None: @@ -244,7 +221,7 @@ def test_mqtt_prefix_set(self, topic): inst = BaseDevice("site", self.db, self.conn, mqtt_prefix=topic) self.assertEqual(inst.mqtt_prefix, topic) - self.assertEqual(inst.mqtt_topic, f"{topic}/base-device/site") + self.assertEqual(inst.mqtt_topic, f"{topic}/site") @parameterized.expand(["this/topic", "1/1/1", "TOPICO!"]) @config_exists @@ -254,7 +231,7 @@ def test_mqtt_suffix_set(self, topic): inst = BaseDevice("site", self.db, self.conn, mqtt_suffix=topic) self.assertEqual(inst.mqtt_suffix, topic) - self.assertEqual(inst.mqtt_topic, f"base-device/site/{topic}") + self.assertEqual(inst.mqtt_topic, f"site/{topic}") @parameterized.expand([["this/prefix", "this/suffix"], ["1/1/1", "2/2/2"], ["TOPICO!", "FOUR"]]) @config_exists @@ -265,7 +242,7 @@ def test_mqtt_prefix_and_suffix_set(self, prefix, suffix): self.assertEqual(inst.mqtt_suffix, suffix) self.assertEqual(inst.mqtt_prefix, prefix) - self.assertEqual(inst.mqtt_topic, f"{prefix}/base-device/site/{suffix}") + self.assertEqual(inst.mqtt_topic, f"{prefix}/site/{suffix}") @config_exists def test_default_mqtt_topic_set(self): @@ -273,7 +250,7 @@ def test_default_mqtt_topic_set(self): inst = BaseDevice("site-12", self.db, self.conn) - self.assertEqual(inst.mqtt_topic, "base-device/site-12") + self.assertEqual(inst.mqtt_topic, "site-12") @parameterized.expand([ [None, None, None, ""], @@ -289,7 +266,9 @@ def test__repr__mqtt_opts_mqtt_connection(self, topic, prefix, suffix,expected_a self.assertEqual(inst.__repr__(), expected) + class TestBaseDeviceOracleUsed(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): cred_path = str(CONFIG_PATH) creds = config.Config(cred_path)["oracle"] @@ -300,53 +279,92 @@ async def asyncSetUp(self): creds["user"], password=creds["password"], ) - self.query = CosmosQuery.LEVEL_1_SOILMET_30MIN + self.table = CosmosTable.LEVEL_1_SOILMET_30MIN async def asyncTearDown(self) -> None: await self.oracle.connection.close() - @parameterized.expand([-1, -423.78, CosmosSiteQuery.LEVEL_1_NMDB_1HOUR, "Four", MockDB(), {"a": 1}]) + @parameterized.expand([-1, -423.78, CosmosQuery.ORACLE_LATEST_DATA, "Four", MockDB(), {"a": 1}]) @config_exists - def test_query_value_check(self, query): + @pytest.mark.oracle + def test_table_value_check(self, table): with self.assertRaises(TypeError): BaseDevice( - "test_id", self.oracle, MockMessageConnection(), query=query + "test_id", self.oracle, MockMessageConnection(), table=table ) @pytest.mark.oracle @config_exists - async def test_error_if_query_not_given(self): + async def test_error_if_table_not_given(self): with self.assertRaises(ValueError): BaseDevice("site", self.oracle, MockMessageConnection()) - inst = BaseDevice("site", self.oracle, MockMessageConnection(), query=self.query) + inst = BaseDevice("site", self.oracle, MockMessageConnection(), table=self.table) - self.assertEqual(inst.query, self.query) + self.assertEqual(inst.table, self.table) @pytest.mark.oracle @config_exists async def test__repr__oracle_data(self): - inst_oracle = BaseDevice("site", self.oracle, MockMessageConnection(), query=self.query) - exp_oracle = f'BaseDevice("site", Oracle("{self.creds['dsn']}"), MockMessageConnection(), query=CosmosQuery.{self.query.name})' + inst_oracle = BaseDevice("site", self.oracle, MockMessageConnection(), table=self.table) + exp_oracle = f'BaseDevice("site", Oracle("{self.creds['dsn']}"), MockMessageConnection(), table=CosmosTable.{self.table.name})' self.assertEqual(inst_oracle.__repr__(), exp_oracle) - inst_not_oracle = BaseDevice("site", MockDB(), MockMessageConnection(), query=self.query) + inst_not_oracle = BaseDevice("site", MockDB(), MockMessageConnection(), table=self.table) exp_not_oracle = 'BaseDevice("site", MockDB(), MockMessageConnection())' self.assertEqual(inst_not_oracle.__repr__(), exp_not_oracle) with self.assertRaises(AttributeError): - inst_not_oracle.query + inst_not_oracle.table @pytest.mark.oracle @config_exists async def test__get_payload(self): """Tests that Cosmos payload retrieved.""" - inst = BaseDevice("MORLY", self.oracle, MockMessageConnection(), query=self.query) + inst = BaseDevice("MORLY", self.oracle, MockMessageConnection(), table=self.table) + print(inst) + payload = await inst._get_payload() + + self.assertIsInstance(payload, dict) + +class TestBaseDevicesSQLite3Used(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + db_path = Path(Path(__file__).parents[1], "iotswarm","__assets__", "data", "cosmos.db") + if db_path.exists(): + self.db = LoopingSQLite3(db_path) + self.table = CosmosTable.LEVEL_1_SOILMET_30MIN + + @parameterized.expand([-1, -423.78, CosmosQuery.ORACLE_LATEST_DATA, "Four", MockDB(), {"a": 1}]) + @sqlite_db_exist + def test_table_value_check(self, table): + with self.assertRaises(TypeError): + BaseDevice( + "test_id", self.db, MockMessageConnection(), table=table + ) + + @sqlite_db_exist + def test_error_if_table_not_given(self): + + with self.assertRaises(ValueError): + BaseDevice("site", self.db, MockMessageConnection()) + + + inst = BaseDevice("site", self.db, MockMessageConnection(), table=self.table) + + self.assertEqual(inst.table, self.table) + + @sqlite_db_exist + async def test__get_payload(self): + """Tests that Cosmos payload retrieved.""" + + inst = BaseDevice("MORLY", self.db, MockMessageConnection(), table=self.table) + payload = await inst._get_payload() self.assertIsInstance(payload, dict) @@ -448,6 +466,7 @@ async def test__get_payload(self): self.assertIsInstance(payload, list) self.assertEqual(len(payload),0) + class TestCr1000xDevice(unittest.TestCase): """Test suite for the CR1000X Device.""" diff --git a/src/tests/test_messaging.py b/src/tests/test_messaging.py index 4fe9271..49e86fe 100644 --- a/src/tests/test_messaging.py +++ b/src/tests/test_messaging.py @@ -11,13 +11,20 @@ from parameterized import parameterized import logging -CONFIG_PATH = Path( - Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg" -) + +ASSETS_PATH = Path(Path(__file__).parents[1], "iotswarm", "__assets__") +CONFIG_PATH = Path(ASSETS_PATH, "config.cfg") + config_exists = pytest.mark.skipif( not CONFIG_PATH.exists(), reason="Config file `config.cfg` not found in root directory.", ) +certs_exist = pytest.mark.skipif( + not Path(ASSETS_PATH, ".certs", "cosmos_soilmet-certificate.pem.crt").exists() + or not Path(ASSETS_PATH, ".certs", "cosmos_soilmet-private.pem.key").exists() + or not Path(ASSETS_PATH, ".certs", "AmazonRootCA1.pem").exists(), + reason="IotCore certificates not present.", +) class TestBaseClass(unittest.TestCase): @@ -43,24 +50,20 @@ def test_instantiation(self): self.assertIsInstance(mock, MessagingBaseClass) - def test_logger_used(self): + def test_no_logger_used(self): - mock = MockMessageConnection() + with self.assertNoLogs(): + mock = MockMessageConnection() + mock.send_message("") - with self.assertLogs() as cm: - mock.send_message() - self.assertEqual( - cm.output, - [ - "INFO:iotswarm.messaging.core.MockMessageConnection:Message was sent." - ], - ) - - with self.assertLogs() as cm: - mock.send_message(use_logger=logging.getLogger("mine")) + def test_logger_used(self): + logger = logging.getLogger("testlogger") + with self.assertLogs(logger=logger, level=logging.DEBUG) as cm: + mock = MockMessageConnection(inherit_logger=logger) + mock.send_message("") self.assertEqual( cm.output, - ["INFO:mine:Message was sent."], + ["DEBUG:testlogger.MockMessageConnection:Message was sent."], ) @@ -73,6 +76,7 @@ def setUp(self) -> None: self.config = config["iot_core"] @config_exists + @certs_exist def test_instantiation(self): instance = IotCoreMQTTConnection(**self.config, client_id="test_id") @@ -82,6 +86,7 @@ def test_instantiation(self): self.assertIsInstance(instance.connection, awscrt.mqtt.Connection) @config_exists + @certs_exist def test_non_string_arguments(self): with self.assertRaises(TypeError): @@ -130,6 +135,7 @@ def test_non_string_arguments(self): ) @config_exists + @certs_exist def test_port(self): # Expect one of defaults if no port given @@ -151,6 +157,7 @@ def test_port(self): @parameterized.expand([-4, {"f": 4}, "FOUR"]) @config_exists + @certs_exist def test_bad_port_type(self, port): with self.assertRaises((TypeError, ValueError)): @@ -164,6 +171,7 @@ def test_bad_port_type(self, port): ) @config_exists + @certs_exist def test_clean_session_set(self): expected = False @@ -180,6 +188,7 @@ def test_clean_session_set(self): @parameterized.expand([0, -1, "true", None]) @config_exists + @certs_exist def test_bad_clean_session_type(self, clean_session): with self.assertRaises(TypeError): @@ -188,6 +197,7 @@ def test_bad_clean_session_type(self, clean_session): ) @config_exists + @certs_exist def test_keep_alive_secs_set(self): # Test defualt is not none instance = IotCoreMQTTConnection(**self.config, client_id="test_id") @@ -200,8 +210,9 @@ def test_keep_alive_secs_set(self): ) self.assertEqual(instance.connection.keep_alive_secs, expected) - @parameterized.expand(["FOURTY", "True", None]) + @parameterized.expand(["FOURTY", "True"]) @config_exists + @certs_exist def test_bad_keep_alive_secs_type(self, secs): with self.assertRaises(TypeError): IotCoreMQTTConnection( @@ -209,7 +220,8 @@ def test_bad_keep_alive_secs_type(self, secs): ) @config_exists - def test_logger_set(self): + @certs_exist + def test_no_logger_set(self): inst = IotCoreMQTTConnection(**self.config, client_id="test_id") expected = 'No message to send for topic: "mytopic".' @@ -223,12 +235,21 @@ def test_logger_set(self): ], ) - with self.assertLogs() as cm: - inst.send_message(None, "mytopic", use_logger=logging.getLogger("mine")) + @config_exists + @certs_exist + def test_logger_set(self): + logger = logging.getLogger("mine") + inst = IotCoreMQTTConnection( + **self.config, client_id="test_id", inherit_logger=logger + ) + + expected = 'No message to send for topic: "mytopic".' + with self.assertLogs(logger=logger, level=logging.INFO) as cm: + inst.send_message(None, "mytopic") self.assertEqual( cm.output, - [f"ERROR:mine:{expected}"], + [f"ERROR:mine.IotCoreMQTTConnection.client-test_id:{expected}"], )