From ebb518516ad36b3d4cde7157e78a167cbd0f8dd8 Mon Sep 17 00:00:00 2001
From: Lewis Chambers <47946387+lewis-chambers@users.noreply.github.com>
Date: Thu, 20 Jun 2024 09:01:18 +0100
Subject: [PATCH] Data source expansion (#6)
This pull request covers the inclusion of running the swarm from local data files and addresses the challenges in running the code from an EC2 instance that cannot reach the COSMOS-UK database.
* NEW: A class for looping through as .csv file with pandas
* NEW: A class for looping through a SQLite database (assumes that it's been sorted first)
* NEW: Updated CLI to allow new data sources to be selected.
* NEW: Updated documentation
---
.github/workflows/test.yml | 3 +
.gitignore | 6 +-
docs/index.rst | 30 +-
docs/source/cli.rst | 6 +
docs/source/iotswarm.scripts.rst | 2 +-
pyproject.toml | 22 +-
.../__assets__/data/build_database.py | 41 ++
src/iotswarm/__init__.py | 4 +-
src/iotswarm/db.py | 291 +++++++++++--
src/iotswarm/devices.py | 67 +--
src/iotswarm/messaging/aws.py | 60 +--
src/iotswarm/messaging/core.py | 33 +-
src/iotswarm/queries.py | 104 ++---
src/iotswarm/scripts/cli.py | 387 +++++++++++++-----
src/iotswarm/scripts/common.py | 113 +++++
src/iotswarm/utils.py | 54 +++
src/tests/data/ALIC1_4_ROWS.csv | 5 +
src/tests/data/build_database.py | 22 +
src/tests/test_cli.py | 63 +++
src/tests/test_db.py | 368 +++++++++++++++--
src/tests/test_devices.py | 121 +++---
src/tests/test_messaging.py | 65 ++-
22 files changed, 1456 insertions(+), 411 deletions(-)
create mode 100644 docs/source/cli.rst
create mode 100644 src/iotswarm/__assets__/data/build_database.py
create mode 100644 src/iotswarm/scripts/common.py
create mode 100644 src/tests/data/ALIC1_4_ROWS.csv
create mode 100644 src/tests/data/build_database.py
create mode 100644 src/tests/test_cli.py
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index f2d31a0..441f649 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -24,6 +24,9 @@ jobs:
python -m pip install --upgrade pip
pip install flake8
pip install .[test]
+ - name: Build test DB
+ run: |
+ python src/tests/data/build_database.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
diff --git a/.gitignore b/.gitignore
index c88052f..091353a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,7 +8,9 @@ dist/*
config.cfg
aws-auth
*.log
-*.certs
+*.certs*
!.github/*
docs/_*
-conf.sh
\ No newline at end of file
+*.sh
+**/*.csv
+**/*.db
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
index 1cdd28a..bc97815 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -14,6 +14,7 @@
:caption: Contents:
self
+ source/cli
source/modules
genindex
modindex
@@ -57,7 +58,7 @@ intialise a swarm with data sent every 30 minutes like so:
.. code-block:: shell
iot-swarm cosmos --dsn="xxxxxx" --user="xxxxx" --password="*****" \
- mqtt "aws" LEVEL_1_SOILMET_30MIN "client_id" \
+ mqtt LEVEL_1_SOILMET_30MIN "client_id" \
--endpoint="xxxxxxx" \
--cert-path="C:\path\..." \
--key-path="C:\path\..." \
@@ -84,7 +85,7 @@ Then the CLI can be called more cleanly:
.. code-block:: shell
- iot-swarm cosmos mqtt "aws" LEVEL_1_SOILMET_30MIN "client_id" --sleep-time=1800 --swarm-name="my-swarm"
+ iot-swarm cosmos mqtt LEVEL_1_SOILMET_30MIN "client_id" --sleep-time=1800 --swarm-name="my-swarm"
------------------------
Using the Python Modules
@@ -163,6 +164,31 @@ The system expects config credentials for the MQTT endpoint and the COSMOS Oracl
.. include:: example-config.cfg
+-------------------------------------------
+Looping Through a local database / csv file
+-------------------------------------------
+
+This package now supports using a CSV file or local SQLite database as the data source.
+There are 2 modules to support it: `db.LoopingCsvDB` and `db.LoopingSQLite3`. Each of them
+initializes from a local file and loops through the data for a given site id. The database
+objects store an in memory cache of each site ID and it's current index in the database.
+Once the end is reached, it loops back to the start for that site.
+
+For use in FDRI, 6 months of data was downloaded to CSV from the COSMOS-UK database, but
+the files are too large to be included in this repo, so they are stored in the `ukceh-fdri`
+`S3` bucket on AWS. There are scripts for regenerating the `.db` file in this repo:
+
+* `./src/iotswarm/__assets__/data/build_database.py`
+* `./src/tests/data/build_database.py`
+
+To use the 'official' data, it should be downloaded from the `S3` bucket and placed in
+`./src/iotswarm/__assets__/data` before running the script. This will build a `.db` file
+sorted in datetime order that the `LoopingSQLite3` class can operate with.
+
+.. warning::
+ The looping database classes assume that their data files are sorted, and make no
+ attempt to sort it themselves.
+
Indices and tables
==================
diff --git a/docs/source/cli.rst b/docs/source/cli.rst
new file mode 100644
index 0000000..0341699
--- /dev/null
+++ b/docs/source/cli.rst
@@ -0,0 +1,6 @@
+CLI
+===
+
+.. click:: iotswarm.scripts.cli:main
+ :prog: iot-swarm
+ :nested: full
\ No newline at end of file
diff --git a/docs/source/iotswarm.scripts.rst b/docs/source/iotswarm.scripts.rst
index 42b1dd6..81dd67a 100644
--- a/docs/source/iotswarm.scripts.rst
+++ b/docs/source/iotswarm.scripts.rst
@@ -1,5 +1,5 @@
iotswarm.scripts package
-==================================
+========================
Submodules
----------
diff --git a/pyproject.toml b/pyproject.toml
index e23f65b..a69e0ae 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[build-system]
requires = ["setuptools >= 61.0", "autosemver"]
-build-backend = "setuptools.build_meta"
+# build-backend = "setuptools.build_meta"
[project]
dependencies = [
@@ -8,12 +8,14 @@ dependencies = [
"boto3",
"autosemver",
"config",
+ "click",
"docutils<0.17",
"awscli",
"awscrt",
+ "awsiotsdk",
"oracledb",
"backoff",
- "click",
+ "pandas",
]
name = "iot-swarm"
dynamic = ["version"]
@@ -26,14 +28,18 @@ docs = ["sphinx", "sphinx-copybutton", "sphinx-rtd-theme", "sphinx-click"]
[project.scripts]
iot-swarm = "iotswarm.scripts.cli:main"
+
[tool.setuptools.dynamic]
version = { attr = "iotswarm.__version__" }
+
[tool.setuptools.packages.find]
where = ["src"]
+include = ["iotswarm*"]
+exclude = ["tests*"]
[tool.setuptools.package-data]
-"*" = ["*.*"]
+"iotswarm.__assets__" = ["loggers.ini"]
[tool.pytest.ini_options]
@@ -49,4 +55,12 @@ markers = [
]
[tool.coverage.run]
-omit = ["*example.py", "*__init__.py", "queries.py", "loggers.py", "cli.py"]
+omit = [
+ "*example.py",
+ "*__init__.py",
+ "queries.py",
+ "loggers.py",
+ "**/scripts/*.py",
+ "**/build_database.py",
+ "utils.py",
+]
diff --git a/src/iotswarm/__assets__/data/build_database.py b/src/iotswarm/__assets__/data/build_database.py
new file mode 100644
index 0000000..96c5838
--- /dev/null
+++ b/src/iotswarm/__assets__/data/build_database.py
@@ -0,0 +1,41 @@
+"""This script is responsible for building an SQL data file from the CSV files used
+by the cosmos network.
+
+The files are stored in AWS S3 and should be downloaded into this directory before continuing.
+They are:
+ * LEVEL_1_NMDB_1HOUR_DATA_TABLE.csv
+ * LEVEL_1_SOILMET_30MIN_DATA_TABLE.csv
+ * LEVEL_1_PRECIP_1MIN_DATA_TABLE.csv
+ * LEVEL_1_PRECIP_RAINE_1MIN_DATA_TABLE.csv
+
+Once installed, run this script to generate the .db file.
+"""
+
+from iotswarm.utils import build_database_from_csv
+from pathlib import Path
+from glob import glob
+from iotswarm.queries import CosmosTable
+
+
+def main(
+ csv_dir: str | Path = Path(__file__).parent,
+ database_output: str | Path = Path(Path(__file__).parent, "cosmos.db"),
+):
+ """Reads exported cosmos DB files from CSV format. Assumes that the files
+ look like: LEVEL_1_SOILMET_30MIN_DATA_TABLE.csv
+
+ Args:
+ csv_dir: Directory where the csv files are stored.
+ database_output: Output destination of the csv_data.
+ """
+ csv_files = glob("*.csv", root_dir=csv_dir)
+
+ tables = [CosmosTable[x.removesuffix("_DATA_TABLE.csv")] for x in csv_files]
+
+ for table, file in zip(tables, csv_files):
+ file = Path(csv_dir, file)
+ build_database_from_csv(file, database_output, table.value, sort_by="DATE_TIME")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/iotswarm/__init__.py b/src/iotswarm/__init__.py
index fbfe123..0131fea 100644
--- a/src/iotswarm/__init__.py
+++ b/src/iotswarm/__init__.py
@@ -1,8 +1,6 @@
import autosemver
try:
- __version__ = autosemver.packaging.get_current_version(
- project_name="iot-device-simulator"
- )
+ __version__ = autosemver.packaging.get_current_version(project_name="iot-swarm")
except:
__version__ = "unknown version"
diff --git a/src/iotswarm/db.py b/src/iotswarm/db.py
index 1386ddd..e7297bc 100644
--- a/src/iotswarm/db.py
+++ b/src/iotswarm/db.py
@@ -4,7 +4,14 @@
import getpass
import logging
import abc
-from iotswarm.queries import CosmosQuery, CosmosSiteQuery
+from iotswarm.queries import (
+ CosmosQuery,
+ CosmosTable,
+)
+import pandas as pd
+from pathlib import Path
+from math import nan
+import sqlite3
logger = logging.getLogger(__name__)
@@ -44,12 +51,65 @@ def query_latest_from_site():
return []
-class Oracle(BaseDatabase):
+class CosmosDB(BaseDatabase):
+ """Base class for databases using COSMOS_UK data."""
+
+ connection: object
+ """Connection to database."""
+
+ site_data_query: CosmosQuery
+ """SQL query for retrieving a single record."""
+
+ site_id_query: CosmosQuery
+ """SQL query for retrieving list of site IDs"""
+
+ @staticmethod
+ def _validate_table(table: CosmosTable) -> None:
+ """Validates that the query is legal"""
+
+ if not isinstance(table, CosmosTable):
+ raise TypeError(
+ f"`table` must be a `{CosmosTable.__class__}` Enum, not a `{type(table)}`"
+ )
+
+ @staticmethod
+ def _fill_query(query: str, table: CosmosTable) -> str:
+ """Fills a query string with a CosmosTable enum."""
+
+ CosmosDB._validate_table(table)
+
+ return query.format(table=table.value)
+
+ @staticmethod
+ def _validate_max_sites(max_sites: int) -> int:
+ """Validates that a valid maximum sites is given:
+ Args:
+ max_sites: The maximum number of sites required.
+
+ Returns:
+ An integer 0 or more.
+ """
+
+ if max_sites is not None:
+ max_sites = int(max_sites)
+ if max_sites < 0:
+ raise ValueError(
+ f"`max_sites` must be 1 or more, or 0 for no maximum. Received: {max_sites}"
+ )
+
+ return max_sites
+
+
+class Oracle(CosmosDB):
"""Class for handling oracledb logic and retrieving values from DB."""
connection: oracledb.Connection
"""Connection to oracle database."""
+ site_data_query = CosmosQuery.ORACLE_LATEST_DATA
+
+ site_id_query = CosmosQuery.ORACLE_SITE_IDS
+
def __repr__(self):
parent_repr = (
super().__repr__().lstrip(f"{self.__class__.__name__}(").rstrip(")")
@@ -64,7 +124,14 @@ def __repr__(self):
)
@classmethod
- async def create(cls, dsn: str, user: str, password: str = None, **kwargs):
+ async def create(
+ cls,
+ dsn: str,
+ user: str,
+ password: str = None,
+ inherit_logger: logging.Logger | None = None,
+ **kwargs,
+ ):
"""Factory method for initialising the class.
Initialization is done through the `create() method`: `Oracle.create(...)`.
@@ -72,6 +139,7 @@ async def create(cls, dsn: str, user: str, password: str = None, **kwargs):
dsn: Oracle data source name.
user: Username used for query.
pw: User password for auth.
+ inherit_logger: Uses the given logger if provided
"""
if not password:
@@ -83,32 +151,31 @@ async def create(cls, dsn: str, user: str, password: str = None, **kwargs):
dsn=dsn, user=user, password=password
)
+ if inherit_logger is not None:
+ self._instance_logger = inherit_logger.getChild(self.__class__.__name__)
+ else:
+ self._instance_logger = logger.getChild(self.__class__.__name__)
+
self._instance_logger.info("Initialized Oracle connection.")
return self
- async def query_latest_from_site(self, site_id: str, query: CosmosQuery) -> dict:
+ async def query_latest_from_site(self, site_id: str, table: CosmosTable) -> dict:
"""Requests the latest data from a table for a specific site.
Args:
site_id: ID of the site to retrieve records from.
- query: Query to parse and submit.
+ table: A valid table from the database
Returns:
dict | None: A dict containing the database columns as keys, and the values as values.
Returns `None` if no data retrieved.
"""
- if not isinstance(query, CosmosQuery):
- raise TypeError(
- f"`query` must be a `CosmosQuery` Enum, not a `{type(query)}`"
- )
+ query = self._fill_query(self.site_data_query, table)
- async with self.connection.cursor() as cursor:
- await cursor.execute(
- query.value,
- mysite=site_id,
- )
+ with self.connection.cursor() as cursor:
+ await cursor.execute(query, site_id=site_id)
columns = [i[0] for i in cursor.description]
data = await cursor.fetchone()
@@ -119,23 +186,92 @@ async def query_latest_from_site(self, site_id: str, query: CosmosQuery) -> dict
return dict(zip(columns, data))
async def query_site_ids(
- self, query: CosmosSiteQuery, max_sites: int | None = None
+ self, table: CosmosTable, max_sites: int | None = None
) -> list:
"""query_site_ids returns a list of site IDs from COSMOS database
Args:
- query: The query to run.
+ table: A valid table from the database
max_sites: Maximum number of sites to retreive
Returns:
List[str]: A list of site ID strings.
"""
- if not isinstance(query, CosmosSiteQuery):
- raise TypeError(
- f"`query` must be a `CosmosSiteQuery` Enum, not a `{type(query)}`"
- )
+ max_sites = self._validate_max_sites(max_sites)
+
+ query = self._fill_query(self.site_id_query, table)
+
+ async with self.connection.cursor() as cursor:
+ await cursor.execute(query)
+
+ data = await cursor.fetchall()
+ if max_sites == 0:
+ data = [x[0] for x in data]
+ else:
+ data = [x[0] for x in data[:max_sites]]
+
+ if not data:
+ data = []
+
+ return data
+
+
+class LoopingCsvDB(BaseDatabase):
+ """A database that reads from csv files and loops through items
+ for a given table or site. The site and index is remembered via a
+ dictionary key and incremented each time data is requested."""
+
+ connection: pd.DataFrame
+ """Connection to the pd object holding data."""
+
+ cache: dict
+ """Cache object containing current index of each site queried."""
+
+ @staticmethod
+ def _get_connection(*args) -> pd.DataFrame:
+ """Gets the database connection."""
+ return pd.read_csv(*args)
+
+ def __init__(self, csv_file: str | Path):
+ """Initialises the database object.
+
+ Args:
+ csv_file: A pathlike object pointing to the datafile.
+ """
+
+ BaseDatabase.__init__(self)
+ self.connection = self._get_connection(csv_file)
+ self.cache = dict()
+
+ def query_latest_from_site(self, site_id: str) -> dict:
+ """Queries the datbase for a `SITE_ID` incrementing by 1 each time called
+ for a specific site. If the end is reached, it loops back to the start.
+
+ Args:
+ site_id: ID of the site to query for.
+ Returns:
+ A dict of the data row.
+ """
+
+ data = self.connection.query("SITE_ID == @site_id").replace({nan: None})
+
+ if site_id not in self.cache or self.cache[site_id] >= len(data):
+ self.cache[site_id] = 1
+ else:
+ self.cache[site_id] += 1
+
+ return data.iloc[self.cache[site_id] - 1].to_dict()
+
+ def query_site_ids(self, max_sites: int | None = None) -> list:
+ """query_site_ids returns a list of site IDs from the database
+
+ Args:
+ max_sites: Maximum number of sites to retreive
+ Returns:
+ List[str]: A list of site ID strings.
+ """
if max_sites is not None:
max_sites = int(max_sites)
if max_sites < 0:
@@ -143,10 +279,115 @@ async def query_site_ids(
f"`max_sites` must be 1 or more, or 0 for no maximum. Received: {max_sites}"
)
- async with self.connection.cursor() as cursor:
- await cursor.execute(query.value)
+ sites = self.connection["SITE_ID"].drop_duplicates().to_list()
- data = await cursor.fetchall()
+ if max_sites is not None and max_sites > 0:
+ sites = sites[:max_sites]
+
+ return sites
+
+
+class LoopingSQLite3(CosmosDB, LoopingCsvDB):
+ """A database that reads from .db files using sqlite3 and loops through
+ entries in sequential order. There is a script that generates the .db file
+ in the `__assets__/data` directory relative to this file. .csv datasets should
+ be downloaded from the accompanying S3 bucket before running."""
+
+ connection: sqlite3.Connection
+ """Connection to the database."""
+
+ site_data_query = CosmosQuery.SQLITE_LOOPED_DATA
+
+ site_id_query = CosmosQuery.SQLITE_SITE_IDS
+
+ @staticmethod
+ def _get_connection(*args) -> sqlite3.Connection:
+ """Gets a database connection."""
+
+ return sqlite3.connect(*args)
+
+ def __init__(self, db_file: str | Path):
+ """Initialises the database object.
+
+ Args:
+ csv_file: A pathlike object pointing to the datafile.
+ """
+ LoopingCsvDB.__init__(self, db_file)
+
+ self.cursor = self.connection.cursor()
+
+ def query_latest_from_site(self, site_id: str, table: CosmosTable) -> dict:
+ """Queries the datbase for a `SITE_ID` incrementing by 1 each time called
+ for a specific site. If the end is reached, it loops back to the start.
+
+ Args:
+ site_id: ID of the site to query for.
+ table: A valid table from the database
+ Returns:
+ A dict of the data row.
+ """
+ query = self._fill_query(self.site_data_query, table)
+
+ if site_id not in self.cache:
+ self.cache[site_id] = 0
+ else:
+ self.cache[site_id] += 1
+
+ data = self._query_latest_from_site(
+ query, {"site_id": site_id, "offset": self.cache[site_id]}
+ )
+
+ if data is None:
+ self.cache[site_id] = 0
+
+ data = self._query_latest_from_site(
+ query, {"site_id": site_id, "offset": self.cache[site_id]}
+ )
+
+ return data
+
+ def _query_latest_from_site(self, query, arg_dict: dict) -> dict:
+ """Requests the latest data from a table for a specific site.
+
+ Args:
+ table: A valid table from the database
+ arg_dict: Dictionary of query arguments.
+
+ Returns:
+ dict | None: A dict containing the database columns as keys, and the values as values.
+ Returns `None` if no data retrieved.
+ """
+
+ self.cursor.execute(query, arg_dict)
+
+ columns = [i[0] for i in self.cursor.description]
+ data = self.cursor.fetchone()
+
+ if not data:
+ return None
+
+ return dict(zip(columns, data))
+
+ def query_site_ids(self, table: CosmosTable, max_sites: int | None = None) -> list:
+ """query_site_ids returns a list of site IDs from COSMOS database
+
+ Args:
+ table: A valid table from the database
+ max_sites: Maximum number of sites to retreive
+
+ Returns:
+ List[str]: A list of site ID strings.
+ """
+
+ query = self._fill_query(self.site_id_query, table)
+
+ max_sites = self._validate_max_sites(max_sites)
+
+ try:
+ cursor = self.connection.cursor()
+ cursor.execute(query)
+
+ data = cursor.fetchall()
if max_sites == 0:
data = [x[0] for x in data]
else:
@@ -154,5 +395,7 @@ async def query_site_ids(
if not data:
data = []
+ finally:
+ cursor.close()
- return data
+ return data
diff --git a/src/iotswarm/devices.py b/src/iotswarm/devices.py
index 8c9c6ce..2b2cd1d 100644
--- a/src/iotswarm/devices.py
+++ b/src/iotswarm/devices.py
@@ -3,14 +3,14 @@
import asyncio
import logging
from iotswarm import __version__ as package_version
-from iotswarm.queries import CosmosQuery
-from iotswarm.db import BaseDatabase, Oracle
-from iotswarm.messaging.core import MessagingBaseClass
+from iotswarm.queries import CosmosTable
+from iotswarm.db import BaseDatabase, CosmosDB, Oracle, LoopingCsvDB, LoopingSQLite3
+from iotswarm.messaging.core import MessagingBaseClass, MockMessageConnection
from iotswarm.messaging.aws import IotCoreMQTTConnection
-import random
+from typing import List
from datetime import datetime
+import random
import enum
-from typing import List
logger = logging.getLogger(__name__)
@@ -45,8 +45,8 @@ class BaseDevice:
connection: MessagingBaseClass
"""Connection to the data receiver."""
- query: CosmosQuery
- """SQL query sent to database if Oracle type selected as `data_source`."""
+ table: CosmosTable
+ """SQL table used in queries if Oracle or LoopingSQLite3 selected as `data_source`."""
mqtt_base_topic: str
"""Base topic for mqtt topic."""
@@ -85,7 +85,7 @@ def __init__(
max_cycles: int | None = None,
delay_start: bool | None = None,
inherit_logger: logging.Logger | None = None,
- query: CosmosQuery | None = None,
+ table: CosmosTable | None = None,
mqtt_topic: str | None = None,
mqtt_prefix: str | None = None,
mqtt_suffix: str | None = None,
@@ -100,7 +100,7 @@ def __init__(
max_cycles: Maximum number of cycles before shutdown.
delay_start: Adds a random delay to first invocation from 0 - `sleep_time`.
inherit_logger: Override for the module logger.
- query: Sets the query used in COSMOS database. Ignored if `data_source` is not a Cosmos object.
+ table: A valid table from the database
mqtt_prefix: Prefixes the MQTT topic if MQTT messaging used.
mqtt_suffix: Suffixes the MQTT topic if MQTT messaging used.
"""
@@ -111,10 +111,18 @@ def __init__(
raise TypeError(
f"`data_source` must be a `BaseDatabase`. Received: {data_source}."
)
- if isinstance(data_source, Oracle) and query is None:
- raise ValueError(
- f"`query` must be provided if `data_source` is type `Oracle`."
- )
+ if isinstance(data_source, (CosmosDB)):
+
+ if table is None:
+ raise ValueError(
+ f"`table` must be provided if `data_source` is type `OracleDB`."
+ )
+ elif not isinstance(table, CosmosTable):
+ raise TypeError(
+ f'table must be a "{CosmosTable.__class__}", not "{type(table)}"'
+ )
+
+ self.table = table
self.data_source = data_source
if not isinstance(connection, MessagingBaseClass):
@@ -144,18 +152,11 @@ def __init__(
)
self.delay_start = delay_start
- if query is not None and isinstance(self.data_source, Oracle):
- if not isinstance(query, CosmosQuery):
- raise TypeError(
- f"`query` must be a `CosmosQuery`. Received: {type(query)}."
- )
- self.query = query
-
- if isinstance(connection, IotCoreMQTTConnection):
+ if isinstance(connection, (IotCoreMQTTConnection, MockMessageConnection)):
if mqtt_topic is not None:
self.mqtt_topic = str(mqtt_topic)
else:
- self.mqtt_topic = f"{self.device_type}/{self.device_id}"
+ self.mqtt_topic = f"{self.device_id}"
if mqtt_prefix is not None:
self.mqtt_prefix = str(mqtt_prefix)
@@ -190,16 +191,16 @@ def __repr__(self):
if self.delay_start != self.__class__.delay_start
else ""
)
- query_arg = (
- f", query={self.query.__class__.__name__}.{self.query.name}"
- if isinstance(self.data_source, Oracle)
+ table_arg = (
+ f", table={self.table.__class__.__name__}.{self.table.name}"
+ if isinstance(self.data_source, CosmosDB)
else ""
)
mqtt_topic_arg = (
f', mqtt_topic="{self.mqtt_base_topic}"'
if hasattr(self, "mqtt_base_topic")
- and self.mqtt_base_topic != f"{self.device_type}/{self.device_id}"
+ and self.mqtt_base_topic != self.device_id
else ""
)
@@ -222,7 +223,7 @@ def __repr__(self):
f"{sleep_time_arg}"
f"{max_cycles_arg}"
f"{delay_start_arg}"
- f"{query_arg}"
+ f"{table_arg}"
f"{mqtt_topic_arg}"
f"{mqtt_prefix_arg}"
f"{mqtt_suffix_arg}"
@@ -243,7 +244,7 @@ def _send_payload(self, payload: dict):
async def run(self):
"""The main invocation of the method. Expects a Oracle object to do work on
- and a query to retrieve. Runs asynchronously until `max_cycles` is reached.
+ and a table to retrieve. Runs asynchronously until `max_cycles` is reached.
Args:
message_connection: The message object to send data through
@@ -260,6 +261,9 @@ async def run(self):
if payload:
self._instance_logger.debug("Requesting payload submission.")
self._send_payload(payload)
+ self._instance_logger.info(
+ f'Message sent{f" to topic: {self.mqtt_topic}" if self.mqtt_topic else ""}'
+ )
else:
self._instance_logger.warning(f"No data found.")
@@ -273,9 +277,12 @@ async def _get_payload(self):
"""Method for grabbing the payload to send"""
if isinstance(self.data_source, Oracle):
return await self.data_source.query_latest_from_site(
- self.device_id, self.query
+ self.device_id, self.table
)
-
+ elif isinstance(self.data_source, LoopingSQLite3):
+ return self.data_source.query_latest_from_site(self.device_id, self.table)
+ elif isinstance(self.data_source, LoopingCsvDB):
+ return self.data_source.query_latest_from_site(self.device_id)
elif isinstance(self.data_source, BaseDatabase):
return self.data_source.query_latest_from_site()
diff --git a/src/iotswarm/messaging/aws.py b/src/iotswarm/messaging/aws.py
index 8205964..be1a708 100644
--- a/src/iotswarm/messaging/aws.py
+++ b/src/iotswarm/messaging/aws.py
@@ -2,6 +2,7 @@
import awscrt
from awscrt import mqtt
+from awsiot import mqtt_connection_builder
import awscrt.io
import json
from awscrt.exceptions import AwsCrtError
@@ -34,6 +35,7 @@ def __init__(
port: int | None = None,
clean_session: bool = False,
keep_alive_secs: int = 1200,
+ inherit_logger: logging.Logger | None = None,
**kwargs,
) -> None:
"""Initializes the class.
@@ -47,7 +49,7 @@ def __init__(
port: Port used by endpoint. Guesses correct port if not given.
clean_session: Builds a clean MQTT session if true. Defaults to False.
keep_alive_secs: Time to keep connection alive. Defaults to 1200.
- topic_prefix: A topic prefixed to MQTT topic, useful for attaching a "Basic Ingest" rule. Defaults to None.
+ inherit_logger: Override for the module logger.
"""
if not isinstance(endpoint, str):
@@ -88,41 +90,31 @@ def __init__(
if port < 0:
raise ValueError(f"`port` cannot be less than 0. Received: {port}.")
- socket_options = awscrt.io.SocketOptions()
- socket_options.connect_timeout_ms = 5000
- socket_options.keep_alive = False
- socket_options.keep_alive_timeout_secs = 0
- socket_options.keep_alive_interval_secs = 0
- socket_options.keep_alive_max_probes = 0
-
- client_bootstrap = awscrt.io.ClientBootstrap.get_or_create_static_default()
-
- tls_ctx = awscrt.io.ClientTlsContext(tls_ctx_options)
- mqtt_client = awscrt.mqtt.Client(client_bootstrap, tls_ctx)
-
- self.connection = awscrt.mqtt.Connection(
- client=mqtt_client,
+ self.connection = mqtt_connection_builder.mtls_from_path(
+ endpoint=endpoint,
+ port=port,
+ cert_filepath=cert_path,
+ pri_key_filepath=key_path,
+ ca_filepath=ca_cert_path,
on_connection_interrupted=self._on_connection_interrupted,
on_connection_resumed=self._on_connection_resumed,
client_id=client_id,
- host_name=endpoint,
- port=port,
+ proxy_options=None,
clean_session=clean_session,
- reconnect_min_timeout_secs=5,
- reconnect_max_timeout_secs=60,
keep_alive_secs=keep_alive_secs,
- ping_timeout_ms=3000,
- protocol_operation_timeout_ms=0,
- socket_options=socket_options,
- use_websockets=False,
on_connection_success=self._on_connection_success,
on_connection_failure=self._on_connection_failure,
on_connection_closed=self._on_connection_closed,
)
- self._instance_logger = logger.getChild(
- f"{self.__class__.__name__}.client-{client_id}"
- )
+ if inherit_logger is not None:
+ self._instance_logger = inherit_logger.getChild(
+ f"{self.__class__.__name__}.client-{client_id}"
+ )
+ else:
+ self._instance_logger = logger.getChild(
+ f"{self.__class__.__name__}.client-{client_id}"
+ )
def _on_connection_interrupted(
self, connection, error, **kwargs
@@ -182,23 +174,15 @@ def _disconnect(self): # pragma: no cover
disconnect_future = self.connection.disconnect()
disconnect_future.result()
- def send_message(
- self, message: dict, topic: str, use_logger: logging.Logger | None = None
- ) -> None:
+ def send_message(self, message: dict, topic: str) -> None:
"""Sends a message to the endpoint.
Args:
message: The message to send.
topic: MQTT topic to send message under.
- use_logger: Sends log message with requested logger.
"""
- if use_logger is not None and isinstance(use_logger, logging.Logger):
- use_logger = use_logger
- else:
- use_logger = self._instance_logger
-
if not message:
- use_logger.error(f'No message to send for topic: "{topic}".')
+ self._instance_logger.error(f'No message to send for topic: "{topic}".')
return
if self.connected_flag == False:
@@ -212,6 +196,4 @@ def send_message(
qos=mqtt.QoS.AT_LEAST_ONCE,
)
- use_logger.info(f'Sent {sys.getsizeof(payload)} bytes to "{topic}"')
-
- # self._disconnect()
+ self._instance_logger.debug(f'Sent {sys.getsizeof(payload)} bytes to "{topic}"')
diff --git a/src/iotswarm/messaging/core.py b/src/iotswarm/messaging/core.py
index d362b30..0957d93 100644
--- a/src/iotswarm/messaging/core.py
+++ b/src/iotswarm/messaging/core.py
@@ -2,8 +2,6 @@
from abc import abstractmethod
import logging
-logger = logging.getLogger(__name__)
-
class MessagingBaseClass(ABC):
"""MessagingBaseClass Base class for messaging implementation
@@ -14,9 +12,20 @@ class MessagingBaseClass(ABC):
_instance_logger: logging.Logger
"""Logger handle used by instance."""
- def __init__(self):
-
- self._instance_logger = logger.getChild(self.__class__.__name__)
+ def __init__(
+ self,
+ inherit_logger: logging.Logger | None = None,
+ ):
+ """Initialises the class.
+ Args:
+ inherit_logger: Override for the module logger.
+ """
+ if inherit_logger is not None:
+ self._instance_logger = inherit_logger.getChild(self.__class__.__name__)
+ else:
+ self._instance_logger = logging.getLogger(__name__).getChild(
+ self.__class__.__name__
+ )
@property
@abstractmethod
@@ -37,15 +46,7 @@ class MockMessageConnection(MessagingBaseClass):
connection: None = None
"""Connection object. Not needed in a mock but must be implemented"""
- def send_message(self, use_logger: logging.Logger | None = None):
- """Consumes requests to send a message but does nothing with it.
-
- Args:
- use_logger: Sends log message with requested logger."""
-
- if use_logger is not None and isinstance(use_logger, logging.Logger):
- use_logger = use_logger
- else:
- use_logger = self._instance_logger
+ def send_message(self, *_):
+ """Consumes requests to send a message but does nothing with it."""
- use_logger.info("Message was sent.")
+ self._instance_logger.debug("Message was sent.")
diff --git a/src/iotswarm/queries.py b/src/iotswarm/queries.py
index f1174b2..dfda5c8 100644
--- a/src/iotswarm/queries.py
+++ b/src/iotswarm/queries.py
@@ -5,112 +5,60 @@
@enum.unique
-class CosmosQuery(StrEnum):
- """Class containing permitted SQL queries for retrieving sensor data."""
+class CosmosTable(StrEnum):
+ """Enums of approved COSMOS database tables."""
- LEVEL_1_SOILMET_30MIN = """SELECT * FROM COSMOS.LEVEL1_SOILMET_30MIN
-WHERE site_id = :mysite
-ORDER BY date_time DESC
-FETCH NEXT 1 ROWS ONLY"""
+ LEVEL_1_SOILMET_30MIN = "LEVEL1_SOILMET_30MIN"
+ LEVEL_1_NMDB_1HOUR = "LEVEL1_NMDB_1HOUR"
+ LEVEL_1_PRECIP_1MIN = "LEVEL1_PRECIP_1MIN"
+ LEVEL_1_PRECIP_RAINE_1MIN = "LEVEL1_PRECIP_RAINE_1MIN"
- """Query for retreiving data from the LEVEL1_SOILMET_30MIN table, containing
- calibration the core telemetry from COSMOS sites.
-
- .. code-block:: sql
- SELECT * FROM COSMOS.LEVEL1_SOILMET_30MIN
- WHERE site_id = :mysite
- ORDER BY date_time DESC
- FETCH NEXT 1 ROWS ONLY
- """
+@enum.unique
+class CosmosQuery(StrEnum):
+ """Enums of common queries in each databasing language."""
- LEVEL_1_NMDB_1HOUR = """SELECT * FROM COSMOS.LEVEL1_NMDB_1HOUR
-WHERE site_id = :mysite
-ORDER BY date_time DESC
-FETCH NEXT 1 ROWS ONLY"""
+ SQLITE_LOOPED_DATA = """SELECT * FROM {table}
+WHERE site_id = :site_id
+LIMIT 1 OFFSET :offset"""
- """Query for retreiving data from the Level1_NMDB_1HOUR table, containing
- calibration data from the Neutron Monitor DataBase.
+ """Query for retreiving data from a given table in sqlite format.
.. code-block:: sql
- SELECT * FROM COSMOS.LEVEL1_NMDB_1HOUR
- WHERE site_id = :mysite
- ORDER BY date_time DESC
- FETCH NEXT 1 ROWS ONLY
+ SELECT * FROM
+ WHERE site_id = :site_id
+ LIMIT 1 OFFSET :offset
"""
- LEVEL_1_PRECIP_1MIN = """SELECT * FROM COSMOS.LEVEL1_PRECIP_1MIN
-WHERE site_id = :mysite
+ ORACLE_LATEST_DATA = """SELECT * FROM COSMOS.{table}
+WHERE site_id = :site_id
ORDER BY date_time DESC
FETCH NEXT 1 ROWS ONLY"""
- """Query for retreiving data from the LEVEL1_PRECIP_1MIN table, containing
- the standard precipitation telemetry from COSMOS sites.
+ """Query for retreiving data from a given table in oracle format.
.. code-block:: sql
- SELECT * FROM COSMOS.LEVEL1_PRECIP_1MIN
- WHERE site_id = :mysite
+ SELECT * FROM
ORDER BY date_time DESC
FETCH NEXT 1 ROWS ONLY
"""
- LEVEL_1_PRECIP_RAINE_1MIN = """SELECT * FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN
-WHERE site_id = :mysite
-ORDER BY date_time DESC
-FETCH NEXT 1 ROWS ONLY"""
-
- """Query for retreiving data from the LEVEL1_PRECIP_RAINE_1MIN table, containing
- the rain[e] precipitation telemetry from COSMOS sites.
-
- .. code-block:: sql
-
- SELECT * FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN
- WHERE site_id = :mysite
- ORDER BY date_time DESC
- FETCH NEXT 1 ROWS ONLY
- """
-
-
-@enum.unique
-class CosmosSiteQuery(StrEnum):
- """Contains permitted SQL queries for extracting site IDs from database."""
-
- LEVEL_1_SOILMET_30MIN = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_SOILMET_30MIN"
-
- """Queries unique site IDs from LEVEL1_SOILMET_30MIN.
-
- .. code-block:: sql
-
- SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_SOILMET_30MIN
- """
-
- LEVEL_1_NMDB_1HOUR = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_NMDB_1HOUR"
-
- """Queries unique site IDs from LEVEL1_NMDB_1HOUR table.
-
- .. code-block:: sql
-
- SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_NMDB_1HOUR
- """
-
- LEVEL_1_PRECIP_1MIN = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_1MIN"
+ SQLITE_SITE_IDS = "SELECT DISTINCT(site_id) FROM {table}"
- """Queries unique site IDs from the LEVEL1_PRECIP_1MIN table.
+ """Queries unique `site_id `s from a given table.
.. code-block:: sql
- SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_1MIN
+ SELECT DISTINCT(site_id) FROM
"""
- LEVEL_1_PRECIP_RAINE_1MIN = (
- "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN"
- )
+ ORACLE_SITE_IDS = "SELECT UNIQUE(site_id) FROM COSMOS.{table}"
- """Queries unique site IDs from the LEVEL1_PRECIP_RAINE_1MIN table.
+ """Queries unique `site_id `s from a given table.
.. code-block:: sql
- SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN
+ SELECT UNQIUE(site_id) FROM
"""
diff --git a/src/iotswarm/scripts/cli.py b/src/iotswarm/scripts/cli.py
index 7d79274..912a7be 100644
--- a/src/iotswarm/scripts/cli.py
+++ b/src/iotswarm/scripts/cli.py
@@ -1,24 +1,19 @@
"""CLI exposed when the package is installed."""
import click
-from iotswarm import queries
+from iotswarm import __version__ as package_version
+from iotswarm.queries import CosmosTable
from iotswarm.devices import BaseDevice, CR1000XDevice
from iotswarm.swarm import Swarm
-from iotswarm.db import Oracle
+from iotswarm.db import Oracle, LoopingCsvDB, LoopingSQLite3
from iotswarm.messaging.core import MockMessageConnection
from iotswarm.messaging.aws import IotCoreMQTTConnection
+import iotswarm.scripts.common as cli_common
import asyncio
from pathlib import Path
import logging
-TABLES = [table.name for table in queries.CosmosQuery]
-
-
-@click.command
-@click.pass_context
-def test(ctx: click.Context):
- """Enables testing of cosmos group arguments."""
- print(ctx.obj)
+TABLE_NAMES = [table.name for table in CosmosTable]
@click.group()
@@ -27,18 +22,64 @@ def test(ctx: click.Context):
"--log-config",
type=click.Path(exists=True),
help="Path to a logging config file. Uses default if not given.",
+ default=Path(Path(__file__).parents[1], "__assets__", "loggers.ini"),
)
-def main(ctx: click.Context, log_config: Path):
+@click.option(
+ "--log-level",
+ type=click.Choice(
+ ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
+ ),
+ help="Overrides the logging level.",
+ envvar="IOT_SWARM_LOG_LEVEL",
+)
+def main(ctx: click.Context, log_config: Path, log_level: str):
"""Core group of the cli."""
ctx.ensure_object(dict)
- if not log_config:
- log_config = Path(Path(__file__).parents[1], "__assets__", "loggers.ini")
-
logging.config.fileConfig(fname=log_config)
+ logger = logging.getLogger(__name__)
+
+ if log_level:
+ logger.setLevel(log_level)
+ click.echo(f"Set log level to {log_level}.")
+
+ ctx.obj["logger"] = logger
+
+
+main.add_command(cli_common.test)
+
+
+@main.command
+def get_version():
+ """Gets the package version"""
+ click.echo(package_version)
+
+@main.command()
+@click.pass_context
+@cli_common.iotcore_options
+def test_mqtt(
+ ctx,
+ message,
+ topic,
+ client_id: str,
+ endpoint: str,
+ cert_path: str,
+ key_path: str,
+ ca_cert_path: str,
+):
+ """Tests that a basic message can be sent via mqtt."""
+
+ connection = IotCoreMQTTConnection(
+ endpoint=endpoint,
+ cert_path=cert_path,
+ key_path=key_path,
+ ca_cert_path=ca_cert_path,
+ client_id=client_id,
+ inherit_logger=ctx.obj["logger"],
+ )
-main.add_command(test)
+ connection.send_message(message, topic)
@main.group()
@@ -78,113 +119,262 @@ def cosmos(ctx: click.Context, site: str, dsn: str, user: str, password: str):
ctx.obj["sites"] = site
-cosmos.add_command(test)
+cosmos.add_command(cli_common.test)
@cosmos.command()
@click.pass_context
-@click.argument("query", type=click.Choice(TABLES))
+@click.argument("table", type=click.Choice(TABLE_NAMES))
@click.option("--max-sites", type=click.IntRange(min=0), default=0)
-def list_sites(ctx, query, max_sites):
- """Lists site IDs from the database from table QUERY."""
+def list_sites(ctx, table, max_sites):
+ """Lists unique `site_id` from an oracle database table."""
- async def _list_sites(ctx, query):
+ async def _list_sites():
oracle = await Oracle.create(
dsn=ctx.obj["credentials"]["dsn"],
user=ctx.obj["credentials"]["user"],
password=ctx.obj["credentials"]["password"],
+ inherit_logger=ctx.obj["logger"],
)
- sites = await oracle.query_site_ids(
- queries.CosmosSiteQuery[query], max_sites=max_sites
- )
+ sites = await oracle.query_site_ids(CosmosTable[table], max_sites=max_sites)
return sites
- click.echo(asyncio.run(_list_sites(ctx, query)))
+ click.echo(asyncio.run(_list_sites()))
@cosmos.command()
@click.pass_context
@click.argument(
- "provider",
- type=click.Choice(["aws"]),
-)
-@click.argument(
- "query",
- type=click.Choice(TABLES),
-)
-@click.argument(
- "client-id",
- type=click.STRING,
- required=True,
+ "table",
+ type=click.Choice(TABLE_NAMES),
)
+@cli_common.device_options
+@cli_common.iotcore_options
+@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.")
+def mqtt(
+ ctx,
+ table,
+ endpoint,
+ cert_path,
+ key_path,
+ ca_cert_path,
+ client_id,
+ sleep_time,
+ max_cycles,
+ max_sites,
+ swarm_name,
+ delay_start,
+ mqtt_prefix,
+ mqtt_suffix,
+ dry,
+ device_type,
+):
+ """Sends The cosmos data via MQTT protocol using IoT Core.
+ Data is from the cosmos database TABLE and sent using CLIENT_ID.
+
+ Currently only supports sending through AWS IoT Core."""
+ table = CosmosTable[table]
+
+ async def _mqtt():
+ oracle = await Oracle.create(
+ dsn=ctx.obj["credentials"]["dsn"],
+ user=ctx.obj["credentials"]["user"],
+ password=ctx.obj["credentials"]["password"],
+ inherit_logger=ctx.obj["logger"],
+ )
+
+ sites = ctx.obj["sites"]
+ if len(sites) == 0:
+ sites = await oracle.query_site_ids(table, max_sites=max_sites)
+
+ if dry == True:
+ connection = MockMessageConnection(inherit_logger=ctx.obj["logger"])
+ else:
+ connection = IotCoreMQTTConnection(
+ endpoint=endpoint,
+ cert_path=cert_path,
+ key_path=key_path,
+ ca_cert_path=ca_cert_path,
+ client_id=client_id,
+ inherit_logger=ctx.obj["logger"],
+ )
+
+ if device_type == "basic":
+ DeviceClass = BaseDevice
+ elif device_type == "cr1000x":
+ DeviceClass = CR1000XDevice
+
+ site_devices = [
+ DeviceClass(
+ site,
+ oracle,
+ connection,
+ sleep_time=sleep_time,
+ table=table,
+ max_cycles=max_cycles,
+ delay_start=delay_start,
+ mqtt_prefix=mqtt_prefix,
+ mqtt_suffix=mqtt_suffix,
+ inherit_logger=ctx.obj["logger"],
+ )
+ for site in sites
+ ]
+
+ swarm = Swarm(site_devices, swarm_name)
+
+ await swarm.run()
+
+ asyncio.run(_mqtt())
+
+
+@main.group()
+@click.pass_context
@click.option(
- "--endpoint",
+ "--site",
type=click.STRING,
- required=True,
- envvar="IOT_SWARM_MQTT_ENDPOINT",
- help="Endpoint of the MQTT receiving host.",
+ multiple=True,
+ help="Adds a site to be initialized. Can be invoked multiple times for other sites."
+ " Grabs all sites from database query if none provided",
)
@click.option(
- "--cert-path",
+ "--file",
type=click.Path(exists=True),
required=True,
- envvar="IOT_SWARM_MQTT_CERT_PATH",
- help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.",
+ envvar="IOT_SWARM_CSV_DB",
+ help="*.csv file used to instantiate a pandas database.",
)
+def looping_csv(ctx, site, file):
+ """Instantiates a pandas dataframe from a csv file which is used as the database.
+ Responsibility falls on the user to ensure the correct file is selected."""
+
+ ctx.obj["db"] = LoopingCsvDB(file)
+ ctx.obj["sites"] = site
+
+
+looping_csv.add_command(cli_common.test)
+looping_csv.add_command(cli_common.list_sites)
+
+
+@looping_csv.command()
+@click.pass_context
+@cli_common.device_options
+@cli_common.iotcore_options
+@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.")
+def mqtt(
+ ctx,
+ endpoint,
+ cert_path,
+ key_path,
+ ca_cert_path,
+ client_id,
+ sleep_time,
+ max_cycles,
+ max_sites,
+ swarm_name,
+ delay_start,
+ mqtt_prefix,
+ mqtt_suffix,
+ dry,
+ device_type,
+):
+ """Sends The cosmos data via MQTT protocol using IoT Core.
+ Data is collected from the db using QUERY and sent using CLIENT_ID.
+
+ Currently only supports sending through AWS IoT Core."""
+
+ async def _mqtt():
+
+ sites = ctx.obj["sites"]
+ db = ctx.obj["db"]
+ if len(sites) == 0:
+ sites = db.query_site_ids(max_sites=max_sites)
+
+ if dry == True:
+ connection = MockMessageConnection(inherit_logger=ctx.obj["logger"])
+ else:
+ connection = IotCoreMQTTConnection(
+ endpoint=endpoint,
+ cert_path=cert_path,
+ key_path=key_path,
+ ca_cert_path=ca_cert_path,
+ client_id=client_id,
+ inherit_logger=ctx.obj["logger"],
+ )
+
+ if device_type == "basic":
+ DeviceClass = BaseDevice
+ elif device_type == "cr1000x":
+ DeviceClass = CR1000XDevice
+
+ site_devices = [
+ DeviceClass(
+ site,
+ db,
+ connection,
+ sleep_time=sleep_time,
+ max_cycles=max_cycles,
+ delay_start=delay_start,
+ mqtt_prefix=mqtt_prefix,
+ mqtt_suffix=mqtt_suffix,
+ inherit_logger=ctx.obj["logger"],
+ )
+ for site in sites
+ ]
+
+ swarm = Swarm(site_devices, swarm_name)
+
+ await swarm.run()
+
+ asyncio.run(_mqtt())
+
+
+@main.group()
+@click.pass_context
@click.option(
- "--key-path",
- type=click.Path(exists=True),
- required=True,
- envvar="IOT_SWARM_MQTT_KEY_PATH",
- help="Path to the private key that pairs with the `--cert-path`.",
+ "--site",
+ type=click.STRING,
+ multiple=True,
+ help="Adds a site to be initialized. Can be invoked multiple times for other sites."
+ " Grabs all sites from database query if none provided",
)
@click.option(
- "--ca-cert-path",
+ "--file",
type=click.Path(exists=True),
required=True,
- envvar="IOT_SWARM_MQTT_CA_CERT_PATH",
- help="Path to the root Certificate Authority (CA) for the MQTT host.",
-)
-@click.option(
- "--sleep-time",
- type=click.INT,
- help="The number of seconds each site goes idle after sending a message.",
-)
-@click.option(
- "--max-cycles",
- type=click.IntRange(0),
- help="Maximum number message sending cycles. Runs forever if set to 0.",
-)
-@click.option(
- "--max-sites",
- type=click.IntRange(0),
- help="Maximum number of sites allowed to initialize. No limit if set to 0.",
-)
-@click.option(
- "--swarm-name", type=click.STRING, help="Name given to swarm. Appears in the logs."
+ envvar="IOT_SWARM_LOCAL_DB",
+ help="*.db file used to instantiate a sqlite3 database.",
)
-@click.option(
- "--delay-start",
- is_flag=True,
- default=False,
- help="Adds a random delay before the first message from each site up to `--sleep-time`.",
-)
-@click.option(
- "--mqtt-prefix",
- type=click.STRING,
- help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
-)
-@click.option(
- "--mqtt-suffix",
- type=click.STRING,
- help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
+def looping_sqlite3(ctx, site, file):
+ """Instantiates a sqlite3 database as sensor source.."""
+ ctx.obj["db"] = LoopingSQLite3(file)
+ ctx.obj["sites"] = site
+
+
+looping_sqlite3.add_command(cli_common.test)
+
+
+@looping_sqlite3.command
+@click.pass_context
+@click.option("--max-sites", type=click.IntRange(min=0), default=0)
+@click.argument(
+ "table",
+ type=click.Choice(TABLE_NAMES),
)
+def list_sites(ctx, max_sites, table):
+ """Prints the sites present in database."""
+ sites = ctx.obj["db"].query_site_ids(table, max_sites=max_sites)
+ click.echo(sites)
+
+
+@looping_sqlite3.command()
+@click.pass_context
+@cli_common.device_options
+@cli_common.iotcore_options
+@click.argument("table", type=click.Choice(TABLE_NAMES))
@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.")
-@click.option("--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic")
def mqtt(
ctx,
- provider,
- query,
+ table,
endpoint,
cert_path,
key_path,
@@ -200,34 +390,30 @@ def mqtt(
dry,
device_type,
):
- """Sends The cosmos data via MQTT protocol using PROVIDER.
+ """Sends The cosmos data via MQTT protocol using IoT Core.
Data is collected from the db using QUERY and sent using CLIENT_ID.
Currently only supports sending through AWS IoT Core."""
- async def _mqtt():
- oracle = await Oracle.create(
- dsn=ctx.obj["credentials"]["dsn"],
- user=ctx.obj["credentials"]["user"],
- password=ctx.obj["credentials"]["password"],
- )
+ table = CosmosTable[table]
- data_query = queries.CosmosQuery[query]
- site_query = queries.CosmosSiteQuery[query]
+ async def _mqtt():
sites = ctx.obj["sites"]
+ db = ctx.obj["db"]
if len(sites) == 0:
- sites = await oracle.query_site_ids(site_query, max_sites=max_sites)
+ sites = db.query_site_ids(table, max_sites=max_sites)
if dry == True:
- connection = MockMessageConnection()
- elif provider == "aws":
+ connection = MockMessageConnection(inherit_logger=ctx.obj["logger"])
+ else:
connection = IotCoreMQTTConnection(
endpoint=endpoint,
cert_path=cert_path,
key_path=key_path,
ca_cert_path=ca_cert_path,
client_id=client_id,
+ inherit_logger=ctx.obj["logger"],
)
if device_type == "basic":
@@ -238,14 +424,15 @@ async def _mqtt():
site_devices = [
DeviceClass(
site,
- oracle,
+ db,
connection,
sleep_time=sleep_time,
- query=data_query,
max_cycles=max_cycles,
delay_start=delay_start,
mqtt_prefix=mqtt_prefix,
mqtt_suffix=mqtt_suffix,
+ table=table,
+ inherit_logger=ctx.obj["logger"],
)
for site in sites
]
diff --git a/src/iotswarm/scripts/common.py b/src/iotswarm/scripts/common.py
new file mode 100644
index 0000000..1b8b1e0
--- /dev/null
+++ b/src/iotswarm/scripts/common.py
@@ -0,0 +1,113 @@
+"""Location for common CLI commands"""
+
+import click
+
+
+def device_options(function):
+ click.option(
+ "--sleep-time",
+ type=click.INT,
+ help="The number of seconds each site goes idle after sending a message.",
+ )(function)
+
+ click.option(
+ "--max-cycles",
+ type=click.IntRange(0),
+ help="Maximum number message sending cycles. Runs forever if set to 0.",
+ )(function)
+
+ click.option(
+ "--max-sites",
+ type=click.IntRange(0),
+ help="Maximum number of sites allowed to initialize. No limit if set to 0.",
+ )(function)
+
+ click.option(
+ "--swarm-name",
+ type=click.STRING,
+ help="Name given to swarm. Appears in the logs.",
+ )(function)
+
+ click.option(
+ "--delay-start",
+ is_flag=True,
+ default=False,
+ help="Adds a random delay before the first message from each site up to `--sleep-time`.",
+ )(function)
+
+ click.option(
+ "--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic"
+ )(function)
+
+ return function
+
+
+def iotcore_options(function):
+ click.argument(
+ "client-id",
+ type=click.STRING,
+ required=True,
+ )(function)
+
+ click.option(
+ "--endpoint",
+ type=click.STRING,
+ required=True,
+ envvar="IOT_SWARM_MQTT_ENDPOINT",
+ help="Endpoint of the MQTT receiving host.",
+ )(function)
+
+ click.option(
+ "--cert-path",
+ type=click.Path(exists=True),
+ required=True,
+ envvar="IOT_SWARM_MQTT_CERT_PATH",
+ help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.",
+ )(function)
+
+ click.option(
+ "--key-path",
+ type=click.Path(exists=True),
+ required=True,
+ envvar="IOT_SWARM_MQTT_KEY_PATH",
+ help="Path to the private key that pairs with the `--cert-path`.",
+ )(function)
+
+ click.option(
+ "--ca-cert-path",
+ type=click.Path(exists=True),
+ required=True,
+ envvar="IOT_SWARM_MQTT_CA_CERT_PATH",
+ help="Path to the root Certificate Authority (CA) for the MQTT host.",
+ )(function)
+
+ click.option(
+ "--mqtt-prefix",
+ type=click.STRING,
+ help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
+ )(function)
+
+ click.option(
+ "--mqtt-suffix",
+ type=click.STRING,
+ help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
+ )(function)
+
+ return function
+
+
+@click.command
+@click.pass_context
+@click.option("--max-sites", type=click.IntRange(min=0), default=0)
+def list_sites(ctx, max_sites):
+ """Prints the sites present in database."""
+
+ sites = ctx.obj["db"].query_site_ids(max_sites=max_sites)
+ click.echo(sites)
+
+
+@click.command
+@click.pass_context
+def test(ctx: click.Context):
+ """Enables testing of cosmos group arguments."""
+ print(ctx.obj)
diff --git a/src/iotswarm/utils.py b/src/iotswarm/utils.py
index 6d9d988..3a935eb 100644
--- a/src/iotswarm/utils.py
+++ b/src/iotswarm/utils.py
@@ -1,6 +1,10 @@
"""Module for handling commonly reused utility functions."""
from datetime import date, datetime
+from pathlib import Path
+import pandas
+import sqlite3
+from glob import glob
def json_serial(obj: object):
@@ -13,3 +17,53 @@ def json_serial(obj: object):
return obj.__json__()
raise TypeError(f"Type {type(obj)} is not serializable.")
+
+
+def build_database_from_csv(
+ csv_file: str | Path,
+ database: str | Path,
+ table_name: str,
+ sort_by: str | None = None,
+ date_time_format: str = r"%d-%b-%y %H.%M.%S",
+) -> None:
+ """Adds a database table using a csv file with headers.
+
+ Args:
+ csv_file: A path to the csv.
+ database: Output destination of the database. File is created if not
+ existing.
+ table_name: Name of the table to add into database.
+ sort_by: Column to sort by
+ date_time_format: Format of datetime column
+ """
+
+ if not isinstance(csv_file, Path):
+ csv_file = Path(csv_file)
+
+ if not isinstance(database, Path):
+ database = Path(database)
+
+ if not csv_file.exists():
+ raise FileNotFoundError(f'csv_file does not exist: "{csv_file}"')
+
+ if not database.parent.exists():
+ raise NotADirectoryError(f'Database directory not found: "{database.parent}"')
+
+ with sqlite3.connect(database) as conn:
+ print(
+ f'Writing table: "{table_name}" from csv_file: "{csv_file}" to db: "{database}"'
+ )
+ print("Loading csv")
+ df = pandas.read_csv(csv_file)
+ print("Done")
+ print("Formatting dates")
+ df["DATE_TIME"] = pandas.to_datetime(df["DATE_TIME"], format=date_time_format)
+ print("Done")
+ if sort_by is not None:
+ print("Sorting.")
+ df = df.sort_values(by=sort_by)
+ print("Done")
+
+ print("Writing to db.")
+ df.to_sql(table_name, conn, if_exists="replace", index=False)
+ print("Writing complete.")
diff --git a/src/tests/data/ALIC1_4_ROWS.csv b/src/tests/data/ALIC1_4_ROWS.csv
new file mode 100644
index 0000000..55659cd
--- /dev/null
+++ b/src/tests/data/ALIC1_4_ROWS.csv
@@ -0,0 +1,5 @@
+"SITE_ID","DATE_TIME","WD","WD_STD","WS","WS_STD","PA","PA_STD","RH","RH_STD","TA","TA_STD","Q","Q_STD","SWIN","SWIN_STD","SWOUT","SWOUT_STD","LWIN_UNC","LWIN_UNC_STD","LWOUT_UNC","LWOUT_UNC_STD","LWIN","LWIN_STD","LWOUT","LWOUT_STD","TNR01C","TNR01C_STD","PRECIP","PRECIP_DIAG","SNOWD_DISTANCE_UNC","SNOWD_SIGNALQUALITY","CTS_MOD","CTS_BARE","CTS_MOD2","CTS_SNOW","PROFILE_VWC15","PROFILE_SOILEC15","PROFILE_VWC40","PROFILE_SOILEC40","PROFILE_VWC65","PROFILE_SOILEC65","TDT1_VWC","TDT1_TSOIL","TDT1_SOILPERM","TDT1_SOILEC","TDT2_VWC","TDT2_TSOIL","TDT2_SOILPERM","TDT2_SOILEC","STP_TSOIL50","STP_TSOIL20","STP_TSOIL10","STP_TSOIL5","STP_TSOIL2","G1","G1_STD","G1_MV","G1_CAL","G2","G2_STD","G2_MV","G2_CAL","BATTV","SCANS","RECORD","CTS_MOD_DIAG_1","CTS_MOD_DIAG_2","CTS_MOD_DIAG_3","CTS_MOD_DIAG_4","CTS_MOD_DIAG_5","UX","UZ","UY","STDEV_UX","STDEV_UY","STDEV_UZ","COV_UX_UY","COV_UX_UZ","COV_UY_UZ","CTS_SNOW_DIAG_1","CTS_SNOW_DIAG_2","CTS_SNOW_DIAG_3","CTS_SNOW_DIAG_4","CTS_SNOW_DIAG_5","SNOWD_DISTANCE_COR","WS_RES","HS","TAU","U_STAR","TS","STDEV_TS","COV_TS_UX","COV_TS_UY","COV_TS_UZ","RHO_A_MEAN","E_SAT_MEAN","METPAK_SAMPLES","CTS_MOD_RH","CTS_MOD_TEMP","TDT3_VWC","TDT3_TSOIL","TDT3_SOILEC","TDT3_SOILPERM","TDT4_VWC","TDT4_TSOIL","TDT4_SOILEC","TDT4_SOILPERM","TDT5_VWC","TDT5_TSOIL","TDT5_SOILEC","TDT5_SOILPERM","TDT6_VWC","TDT6_TSOIL","TDT6_SOILEC","TDT6_SOILPERM","TDT7_VWC","TDT7_TSOIL","TDT7_SOILEC","TDT7_SOILPERM","TDT8_VWC","TDT8_TSOIL","TDT8_SOILEC","TDT8_SOILPERM","TDT9_VWC","TDT9_TSOIL","TDT9_SOILEC","TDT9_SOILPERM","TDT10_VWC","TDT10_TSOIL","TDT10_SOILEC","TDT10_SOILPERM","CTS_SNOW_TEMP","CTS_SNOW_RH","CTS_MOD2_RH","CTS_MOD2_TEMP","SWIN_MULT","SWOUT_MULT","LWIN_MULT","LWOUT_MULT","PRECIP_TIPPING_A","PRECIP_TIPPING_B","PRECIP_TIPPING2","TBR_TIP","CTS_MOD_PERIOD","CTS_MOD2_PERIOD","CTS_SNOW_PERIOD","WM_NAN_COUNT","RH_INTERNAL"
+"ALIC1",01-JUN-24,353.4,28.38,2.397,1.309,1010.926,0.044,72.58,0.641,12.34,0.05,7.891,,-0.941,0.245,0.454,0.097,-17.14,2.736,-0.303,0.172,359.1,2.68,376,0.191,12.27,0.022,0,0,,,519,,,,,,,,,,43.84,12.72,27.98,0.48,35.03,12.61,20.29,0.24,11.91,12.37,12.45,12.45,12.44,-3.3016,0.10141,-0.15144,21.8,-3.27326,0.1105,-0.14011,23.36,13.21,180,11153,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.23,1.433,,2.1,18.9,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1800,,,99999.99999,
+"ALIC1",01-JUN-24,358,29.12,2.045,1.015,1011.057,0.049,70.59,0.34,12.23,0.139,7.623,,-2.57,1.666,0.552,0.318,-45.27,28,-1.624,1.119,330.1,28.91,373.8,2.085,12.1,0.203,0,0,,,536,,,,,,,,,,43.61,12.72,27.77,0.51,34.91,12.59,20.2,0.27,11.91,12.36,12.43,12.42,12.42,-2.78394,0.58028,-0.12777,21.77,-3.0528,0.35922,-0.13067,23.36,13.19,180,11154,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.231,1.424,,2.1,18.8,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1800,,,99999.99999,
+"ALIC1",01-JUN-24,354.8,26.46,1.531,0.902,1011.085,0.036,73,0.945,11.38,0.31,7.471,,-4.584,0.342,1.413,0.163,-76.38,4.869,-2.488,0.493,293,3.633,366.9,1.718,10.95,0.381,0,0,,,504,,,,,,,,,,43.16,12.7,27.35,0.54,34.91,12.59,20.2,0.25,11.89,12.34,12.41,12.39,12.38,-3.57444,0.69214,-0.16418,21.77,-3.62632,0.49502,-0.15522,23.36,13.22,180,11155,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.235,1.346,,2.1,18.6,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1799,,,99999.99999,
+"ALIC1",01-JUN-24,323.9,47.78,0.916,0.541,1011.082,0.038,76.2,1.033,10.42,0.285,7.34,,-4.493,0.27,1.759,0.13,-75.77,0.74,-2.514,0.463,287.7,1.127,360.9,1.949,9.8,0.328,0,0,,,461,,,,,,,,,,43.33,12.68,27.51,0.51,34.75,12.58,20.07,0.26,11.9,12.35,12.4,12.37,12.36,-6.14165,0.82193,-0.2821,21.77,-5.52521,0.57313,-0.2365,23.36,13.22,180,11156,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.239,1.262,,2.1,18.4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.025,52.2193,85.4701,90.4977,,,,,1799,,,99999.99999,
\ No newline at end of file
diff --git a/src/tests/data/build_database.py b/src/tests/data/build_database.py
new file mode 100644
index 0000000..eef80be
--- /dev/null
+++ b/src/tests/data/build_database.py
@@ -0,0 +1,22 @@
+"""This script generates a .db SQLite file from the accompanying .csv data file.
+It contains just a few rows of data and is used for testing purposes only."""
+
+from iotswarm.utils import build_database_from_csv
+from pathlib import Path
+from iotswarm.queries import CosmosTable
+
+
+def main():
+ data_dir = Path(__file__).parent
+ data_file = Path(data_dir, "ALIC1_4_ROWS.csv")
+ database_file = Path(data_dir, "database.db")
+
+ data_table = CosmosTable.LEVEL_1_SOILMET_30MIN
+
+ build_database_from_csv(
+ data_file, database_file, data_table.value, date_time_format=r"%d-%b-%y"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/tests/test_cli.py b/src/tests/test_cli.py
new file mode 100644
index 0000000..c771148
--- /dev/null
+++ b/src/tests/test_cli.py
@@ -0,0 +1,63 @@
+from click.testing import CliRunner
+from iotswarm.scripts import cli
+from parameterized import parameterized
+import re
+
+RUNNER = CliRunner()
+
+
+def test_main_ctx():
+ result = RUNNER.invoke(cli.main, ["test"])
+ assert not result.exception
+ assert result.output == "{'logger': }\n"
+
+
+def test_main_log_config():
+ with RUNNER.isolated_filesystem():
+ with open("logger.ini", "w") as f:
+ f.write(
+ """[loggers]
+keys=root
+
+[handlers]
+keys=consoleHandler
+
+[formatters]
+keys=sampleFormatter
+
+[logger_root]
+level=INFO
+handlers=consoleHandler
+
+[handler_consoleHandler]
+class=StreamHandler
+level=INFO
+formatter=sampleFormatter
+args=(sys.stdout,)
+
+[formatter_sampleFormatter]
+format=%(asctime)s - %(name)s - %(levelname)s - %(message)s"""
+ )
+ result = RUNNER.invoke(cli.main, ["--log-config", "logger.ini", "test"])
+ assert not result.exception
+ assert result.output == "{'logger': }\n"
+
+
+@parameterized.expand(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
+def test_log_level_set(log_level):
+
+ result = RUNNER.invoke(cli.main, ["--log-level", log_level, "test"])
+ assert not result.exception
+ assert (
+ result.output
+ == f"Set log level to {log_level}.\n{{'logger': }}\n"
+ )
+
+
+def test_get_version():
+ """Tests that the verison can be retrieved."""
+
+ result = RUNNER.invoke(cli.main, ["get-version"])
+ print(result.output)
+ assert not result.exception
+ assert re.match(r"^\d+\.\d+\.\d+$", result.output)
diff --git a/src/tests/test_db.py b/src/tests/test_db.py
index 0cbb27f..34e8675 100644
--- a/src/tests/test_db.py
+++ b/src/tests/test_db.py
@@ -1,21 +1,29 @@
import unittest
import pytest
import config
-import pathlib
+from pathlib import Path
from iotswarm import db
-from iotswarm.queries import CosmosQuery, CosmosSiteQuery
+from iotswarm.devices import BaseDevice
+from iotswarm.messaging.core import MockMessageConnection
+from iotswarm.queries import CosmosTable
+from iotswarm.swarm import Swarm
from parameterized import parameterized
import logging
from unittest.mock import patch
+import pandas as pd
+from glob import glob
+from math import isnan
+import sqlite3
-CONFIG_PATH = pathlib.Path(
- pathlib.Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg"
+CONFIG_PATH = Path(
+ Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg"
)
config_exists = pytest.mark.skipif(
not CONFIG_PATH.exists(),
reason="Config file `config.cfg` not found in root directory.",
)
+COSMOS_TABLES = list(CosmosTable)
class TestBaseDatabase(unittest.TestCase):
@patch.multiple(db.BaseDatabase, __abstractmethods__=set())
@@ -27,6 +35,7 @@ def test_instantiation(self):
self.assertIsNone(inst.query_latest_from_site())
+
class TestMockDB(unittest.TestCase):
def test_instantiation(self):
@@ -58,22 +67,23 @@ def test_logger_set(self, logger, expected):
self.assertEqual(inst._instance_logger.parent, expected)
- @parameterized.expand(
- [
- [None, "MockDB()"],
- [
- logging.getLogger("notroot"),
- "MockDB(inherit_logger=)",
- ],
- ]
- )
- def test__repr__(self, logger, expected):
+
+ def test__repr__no_logger(self):
- inst = db.MockDB(inherit_logger=logger)
+ inst = db.MockDB()
+
+ self.assertEqual(inst.__repr__(), "MockDB()")
- self.assertEqual(inst.__repr__(), expected)
+ def test__repr__logger_given(self):
+ logger = logging.getLogger("testdblogger")
+ logger.setLevel(logging.CRITICAL)
+ expected = "MockDB(inherit_logger=)"
+ mock = db.MockDB(inherit_logger=logger)
+ self.assertEqual(mock.__repr__(), expected)
+
+
class TestOracleDB(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
@@ -86,6 +96,7 @@ async def asyncSetUp(self):
creds["user"],
password=creds["password"],
)
+ self.table = CosmosTable.LEVEL_1_SOILMET_30MIN
async def asyncTearDown(self) -> None:
await self.oracle.connection.close()
@@ -103,25 +114,19 @@ async def test_instantiation(self):
async def test_latest_data_query(self):
site_id = "MORLY"
- query = CosmosQuery.LEVEL_1_SOILMET_30MIN
- row = await self.oracle.query_latest_from_site(site_id, query)
+ row = await self.oracle.query_latest_from_site(site_id, self.table)
self.assertEqual(row["SITE_ID"], site_id)
- @parameterized.expand([
- CosmosSiteQuery.LEVEL_1_NMDB_1HOUR,
- CosmosSiteQuery.LEVEL_1_SOILMET_30MIN,
- CosmosSiteQuery.LEVEL_1_PRECIP_1MIN,
- CosmosSiteQuery.LEVEL_1_PRECIP_RAINE_1MIN,
- ])
+ @parameterized.expand(COSMOS_TABLES)
@pytest.mark.oracle
@pytest.mark.asyncio
@pytest.mark.slow
@config_exists
- async def test_site_id_query(self,query):
+ async def test_site_id_query(self, table):
- sites = await self.oracle.query_site_ids(query)
+ sites = await self.oracle.query_site_ids(table)
self.assertIsInstance(sites, list)
@@ -137,9 +142,7 @@ async def test_site_id_query(self,query):
@config_exists
async def test_site_id_query_max_sites(self, max_sites):
- query = CosmosSiteQuery.LEVEL_1_SOILMET_30MIN
-
- sites = await self.oracle.query_site_ids(query, max_sites=max_sites)
+ sites = await self.oracle.query_site_ids(self.table, max_sites=max_sites)
self.assertEqual(len(sites), max_sites)
@@ -147,33 +150,33 @@ async def test_site_id_query_max_sites(self, max_sites):
@pytest.mark.asyncio
@pytest.mark.oracle
@config_exists
- async def test_bad_latest_data_query_type(self):
+ async def test_bad_latest_data_table_type(self):
site_id = "MORLY"
- query = "sql injection goes brr"
+ table = "sql injection goes brr"
with self.assertRaises(TypeError):
- await self.oracle.query_latest_from_site(site_id, query)
+ await self.oracle.query_latest_from_site(site_id, table)
@pytest.mark.asyncio
@pytest.mark.oracle
@config_exists
- async def test_bad_site_query_type(self):
+ async def test_bad_site_table_type(self):
- query = "sql injection goes brr"
+ table = "sql injection goes brr"
with self.assertRaises(TypeError):
- await self.oracle.query_site_ids(query)
+ await self.oracle.query_site_ids(table)
@parameterized.expand([-1, -100, "STRING"])
@pytest.mark.asyncio
@pytest.mark.oracle
@config_exists
- async def test_bad_site_query_max_sites_type(self, max_sites):
+ async def test_bad_site_table_max_sites_type(self, max_sites):
"""Tests bad values for max_sites."""
with self.assertRaises((TypeError,ValueError)):
- await self.oracle.query_site_ids(CosmosSiteQuery.LEVEL_1_SOILMET_30MIN, max_sites=max_sites)
+ await self.oracle.query_site_ids(self.table, max_sites=max_sites)
@pytest.mark.asyncio
@pytest.mark.oracle
@@ -181,7 +184,7 @@ async def test_bad_site_query_max_sites_type(self, max_sites):
async def test__repr__(self):
"""Tests string representation."""
- oracle1 = self.oracle = await db.Oracle.create(
+ oracle1 = await db.Oracle.create(
self.creds["dsn"],
self.creds["user"],
password=self.creds["password"],
@@ -192,7 +195,7 @@ async def test__repr__(self):
oracle1.__repr__(), expected1
)
- oracle2 = self.oracle = await db.Oracle.create(
+ oracle2 = await db.Oracle.create(
self.creds["dsn"],
self.creds["user"],
password=self.creds["password"],
@@ -207,6 +210,293 @@ async def test__repr__(self):
oracle2.__repr__(), expected2
)
+CSV_PATH = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data")
+CSV_DATA_FILES = [Path(x) for x in glob(str(Path(CSV_PATH, "*.csv")))]
+sqlite_db_exist = pytest.mark.skipif(not Path(CSV_PATH, "cosmos.db").exists(), reason="Local cosmos.db does not exist.")
+
+data_files_exist = pytest.mark.skipif(
+ not CSV_PATH.exists() or len(CSV_DATA_FILES) == 0,
+ reason="No data files are present"
+)
+
+class TestLoopingCsvDB(unittest.TestCase):
+ """Tests the LoopingCsvDB class."""
+
+ def setUp(self):
+ self.data_path = {v.name.removesuffix("_DATA_TABLE.csv"):v for v in CSV_DATA_FILES}
+
+ self.soilmet_table = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"])
+ self.maxDiff = None
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_instantiation(self):
+ """Tests that the database can be instantiated."""
+
+ database = self.soilmet_table
+
+ self.assertIsInstance(database, db.LoopingCsvDB)
+ self.assertIsInstance(database, db.BaseDatabase)
+
+
+ self.assertIsInstance(database.cache, dict)
+ self.assertIsInstance(database.connection, pd.DataFrame)
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_site_data_return_value(self):
+ database = self.soilmet_table
+
+ site = "MORLY"
+
+ data = database.query_latest_from_site(site)
+
+ expected_cache = {site: 1}
+ self.assertDictEqual(database.cache, expected_cache)
+
+ self.assertIsInstance(data, dict)
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_multiple_sites_added_to_cache(self):
+ sites = ["ALIC1", "MORLY", "HOLLN","EUSTN"]
+
+ database = self.soilmet_table
+
+ data = [database.query_latest_from_site(x) for x in sites]
+
+ for i, site in enumerate(sites):
+ self.assertEqual(site, data[i]["SITE_ID"])
+ self.assertIn(site, database.cache)
+ self.assertEqual(database.cache[site], 1)
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_cache_incremented_on_each_request(self):
+ database = self.soilmet_table
+
+ site = "MORLY"
+
+ expected = 1
+
+ last_data = None
+ for _ in range(10):
+ data = database.query_latest_from_site(site)
+ self.assertNotEqual(last_data, data)
+ self.assertEqual(expected, database.cache[site])
+
+ last_data = data
+ expected += 1
+
+ self.assertEqual(expected, 11)
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_cache_counter_restarts_at_end(self):
+
+ short_table_path = Path(Path(__file__).parent, "data", "ALIC1_4_ROWS.csv")
+ database = db.LoopingCsvDB(short_table_path)
+
+ site = "ALIC1"
+
+ expected = [1,2,3,4,1]
+ data = []
+ for e in expected:
+ data.append(database.query_latest_from_site(site))
+
+ self.assertEqual(database.cache[site], e)
+
+ for key in data[0].keys():
+ try:
+ self.assertEqual(data[0][key], data[-1][key])
+ except AssertionError as err:
+ if not isnan(data[0][key]) and isnan(data[-1][key]):
+ raise(err)
+
+ self.assertEqual(len(expected), len(data))
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_site_ids_can_be_retrieved(self):
+ database = self.soilmet_table
+
+ site_ids_full = database.query_site_ids()
+ site_ids_exp_full = database.query_site_ids(max_sites=0)
+
+
+ self.assertIsInstance(site_ids_full, list)
+
+ self.assertGreater(len(site_ids_full), 0)
+ for site in site_ids_full:
+ self.assertIsInstance(site, str)
+
+ self.assertEqual(len(site_ids_full), len(site_ids_exp_full))
+
+ site_ids_limit = database.query_site_ids(max_sites=5)
+
+ self.assertEqual(len(site_ids_limit), 5)
+ self.assertGreater(len(site_ids_full), len(site_ids_limit))
+
+ with self.assertRaises(ValueError):
+
+ database.query_site_ids(max_sites=-1)
+
+class TestLoopingCsvDBEndToEnd(unittest.IsolatedAsyncioTestCase):
+ """Tests the LoopingCsvDB class."""
+
+ def setUp(self):
+ self.data_path = {v.name.removesuffix("_DATA_TABLE.csv"):v for v in CSV_DATA_FILES}
+ self.maxDiff = None
+
+ @data_files_exist
+ @pytest.mark.slow
+ async def test_flow_with_device_attached(self):
+ """Tests that data is looped through with a device making requests."""
+
+ database = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"])
+ device = BaseDevice("ALIC1", database, MockMessageConnection(), sleep_time=0, max_cycles=5)
+
+ await device.run()
+
+ self.assertDictEqual(database.cache, {"ALIC1": 5})
+
+ @data_files_exist
+ @pytest.mark.slow
+ async def test_flow_with_swarm_attached(self):
+ """Tests that the database is looped through correctly with multiple sites in a swarm."""
+
+ database = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"])
+ sites = ["MORLY", "ALIC1", "EUSTN"]
+ cycles = [1, 4, 6]
+ devices = [
+ BaseDevice(s, database, MockMessageConnection(), sleep_time=0, max_cycles=c)
+ for (s,c) in zip(sites, cycles)
+ ]
+
+ swarm = Swarm(devices)
+
+ await swarm.run()
+
+ self.assertDictEqual(database.cache, {"MORLY": 1, "ALIC1": 4, "EUSTN": 6})
+
+class TestSqliteDB(unittest.TestCase):
+
+ @sqlite_db_exist
+ def setUp(self):
+ self.db_path = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data", "cosmos.db")
+ self.table = CosmosTable.LEVEL_1_SOILMET_30MIN
+
+ if self.db_path.exists():
+ self.database = db.LoopingSQLite3(self.db_path)
+ self.maxDiff = None
+
+ @sqlite_db_exist
+ def test_instantiation(self):
+ self.assertIsInstance(self.database, db.LoopingSQLite3)
+ self.assertIsInstance(self.database.connection, sqlite3.Connection)
+
+ @sqlite_db_exist
+ def test_latest_data(self):
+
+ site_id = "MORLY"
+
+ data = self.database.query_latest_from_site(site_id, self.table)
+
+ self.assertIsInstance(data, dict)
+
+ @sqlite_db_exist
+ def test_site_id_query(self):
+
+ sites = self.database.query_site_ids(self.table)
+
+ self.assertGreater(len(sites), 0)
+
+ self.assertIsInstance(sites, list)
+
+ for site in sites:
+ self.assertIsInstance(site, str)
+
+ @sqlite_db_exist
+ def test_multiple_sites_added_to_cache(self):
+ sites = ["ALIC1", "MORLY", "HOLLN","EUSTN"]
+
+ data = [self.database.query_latest_from_site(x, self.table) for x in sites]
+
+ for i, site in enumerate(sites):
+ self.assertEqual(site, data[i]["SITE_ID"])
+ self.assertIn(site, self.database.cache)
+ self.assertEqual(self.database.cache[site], 0)
+
+ @sqlite_db_exist
+ def test_cache_incremented_on_each_request(self):
+ site = "MORLY"
+
+ last_data = {}
+ for i in range(3):
+ if i == 0:
+ self.assertEqual(self.database.cache, {})
+ else:
+ self.assertEqual(i-1, self.database.cache[site])
+ data = self.database.query_latest_from_site(site, self.table)
+ self.assertNotEqual(last_data, data)
+
+ last_data = data
+
+ @sqlite_db_exist
+ def test_cache_counter_restarts_at_end(self):
+ database = db.LoopingSQLite3(Path(Path(__file__).parent, "data", "database.db"))
+
+ site = "ALIC1"
+
+ expected = [0,1,2,3,0]
+ data = []
+ for e in expected:
+ data.append(database.query_latest_from_site(site, self.table))
+
+ self.assertEqual(database.cache[site], e)
+
+ self.assertEqual(data[0], data[-1])
+
+ self.assertEqual(len(expected), len(data))
+
+class TestLoopingSQLite3DBEndToEnd(unittest.IsolatedAsyncioTestCase):
+ """Tests the LoopingCsvDB class."""
+
+ @sqlite_db_exist
+ def setUp(self):
+ self.db_path = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data", "cosmos.db")
+ if self.db_path.exists():
+ self.database = db.LoopingSQLite3(self.db_path)
+ self.maxDiff = None
+ self.table = CosmosTable.LEVEL_1_PRECIP_1MIN
+
+ @sqlite_db_exist
+ async def test_flow_with_device_attached(self):
+ """Tests that data is looped through with a device making requests."""
+
+ device = BaseDevice("ALIC1", self.database, MockMessageConnection(), table=self.table, sleep_time=0, max_cycles=5)
+
+ await device.run()
+
+ self.assertDictEqual(self.database.cache, {"ALIC1": 4})
+
+ @sqlite_db_exist
+ async def test_flow_with_swarm_attached(self):
+ """Tests that the database is looped through correctly with multiple sites in a swarm."""
+
+ sites = ["MORLY", "ALIC1", "EUSTN"]
+ cycles = [1, 2, 3]
+ devices = [
+ BaseDevice(s, self.database, MockMessageConnection(), sleep_time=0, max_cycles=c,table=self.table)
+ for (s,c) in zip(sites, cycles)
+ ]
+
+ swarm = Swarm(devices)
+
+ await swarm.run()
+
+ self.assertDictEqual(self.database.cache, {"MORLY": 0, "ALIC1": 1, "EUSTN": 2})
+
if __name__ == "__main__":
unittest.main()
diff --git a/src/tests/test_devices.py b/src/tests/test_devices.py
index 1d68a52..efb61d3 100644
--- a/src/tests/test_devices.py
+++ b/src/tests/test_devices.py
@@ -5,24 +5,28 @@
import json
from iotswarm.utils import json_serial
from iotswarm.devices import BaseDevice, CR1000XDevice, CR1000XField
-from iotswarm.db import Oracle, BaseDatabase, MockDB
-from iotswarm.queries import CosmosQuery, CosmosSiteQuery
+from iotswarm.db import Oracle, BaseDatabase, MockDB, LoopingSQLite3
+from iotswarm.queries import CosmosQuery, CosmosTable
from iotswarm.messaging.core import MockMessageConnection, MessagingBaseClass
from iotswarm.messaging.aws import IotCoreMQTTConnection
from parameterized import parameterized
from unittest.mock import patch
-import pathlib
+from pathlib import Path
import config
from datetime import datetime, timedelta
-CONFIG_PATH = pathlib.Path(
- pathlib.Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg"
+CONFIG_PATH = Path(
+ Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg"
)
config_exists = pytest.mark.skipif(
not CONFIG_PATH.exists(),
reason="Config file `config.cfg` not found in root directory.",
)
+DATA_DIR = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data")
+sqlite_db_exist = pytest.mark.skipif(not Path(DATA_DIR, "cosmos.db").exists(), reason="Local cosmos.db does not exist.")
+
+
class TestBaseClass(unittest.IsolatedAsyncioTestCase):
def setUp(self):
@@ -169,26 +173,12 @@ def test__repr__(self, data_source, kwargs, expected):
self.assertEqual(repr(instance), expected)
- def test_query_not_set_for_non_oracle_db(self):
+ def test_table_not_set_for_non_oracle_db(self):
- inst = BaseDevice("test", MockDB(), MockMessageConnection(), query=CosmosQuery.LEVEL_1_SOILMET_30MIN)
+ inst = BaseDevice("test", MockDB(), MockMessageConnection(), table=CosmosTable.LEVEL_1_NMDB_1HOUR)
with self.assertRaises(AttributeError):
- inst.query
-
- def test_prefix_suffix_not_set_for_non_mqtt(self):
- "Tests that mqtt prefix and suffix not set for non MQTT messaging."
-
- inst = BaseDevice("site-1", self.data_source, MockMessageConnection(), mqtt_prefix="prefix", mqtt_suffix="suffix")
-
- with self.assertRaises(AttributeError):
- inst.mqtt_topic
-
- with self.assertRaises(AttributeError):
- inst.mqtt_prefix
-
- with self.assertRaises(AttributeError):
- inst.mqtt_suffix
+ inst.table
@parameterized.expand(
[
@@ -204,19 +194,6 @@ def test_logger_set(self, logger, expected):
self.assertEqual(inst._instance_logger.parent, expected)
- @parameterized.expand([
- [None, None, None],
- ["topic", None, None],
- ["topic", "prefix", None],
- ["topic", "prefix", "suffix"],
- ])
- def test__repr__mqtt_opts_no_mqtt_connection(self, topic, prefix, suffix):
- """Tests that the __repr__ method returns correctly with MQTT options set."""
- expected = 'BaseDevice("site-id", MockDB(), MockMessageConnection())'
- inst = BaseDevice("site-id", MockDB(), MockMessageConnection(), mqtt_topic=topic, mqtt_prefix=prefix, mqtt_suffix=suffix)
-
- self.assertEqual(inst.__repr__(), expected)
-
class TestBaseDeviceMQTTOptions(unittest.TestCase):
def setUp(self) -> None:
@@ -244,7 +221,7 @@ def test_mqtt_prefix_set(self, topic):
inst = BaseDevice("site", self.db, self.conn, mqtt_prefix=topic)
self.assertEqual(inst.mqtt_prefix, topic)
- self.assertEqual(inst.mqtt_topic, f"{topic}/base-device/site")
+ self.assertEqual(inst.mqtt_topic, f"{topic}/site")
@parameterized.expand(["this/topic", "1/1/1", "TOPICO!"])
@config_exists
@@ -254,7 +231,7 @@ def test_mqtt_suffix_set(self, topic):
inst = BaseDevice("site", self.db, self.conn, mqtt_suffix=topic)
self.assertEqual(inst.mqtt_suffix, topic)
- self.assertEqual(inst.mqtt_topic, f"base-device/site/{topic}")
+ self.assertEqual(inst.mqtt_topic, f"site/{topic}")
@parameterized.expand([["this/prefix", "this/suffix"], ["1/1/1", "2/2/2"], ["TOPICO!", "FOUR"]])
@config_exists
@@ -265,7 +242,7 @@ def test_mqtt_prefix_and_suffix_set(self, prefix, suffix):
self.assertEqual(inst.mqtt_suffix, suffix)
self.assertEqual(inst.mqtt_prefix, prefix)
- self.assertEqual(inst.mqtt_topic, f"{prefix}/base-device/site/{suffix}")
+ self.assertEqual(inst.mqtt_topic, f"{prefix}/site/{suffix}")
@config_exists
def test_default_mqtt_topic_set(self):
@@ -273,7 +250,7 @@ def test_default_mqtt_topic_set(self):
inst = BaseDevice("site-12", self.db, self.conn)
- self.assertEqual(inst.mqtt_topic, "base-device/site-12")
+ self.assertEqual(inst.mqtt_topic, "site-12")
@parameterized.expand([
[None, None, None, ""],
@@ -289,7 +266,9 @@ def test__repr__mqtt_opts_mqtt_connection(self, topic, prefix, suffix,expected_a
self.assertEqual(inst.__repr__(), expected)
+
class TestBaseDeviceOracleUsed(unittest.IsolatedAsyncioTestCase):
+
async def asyncSetUp(self):
cred_path = str(CONFIG_PATH)
creds = config.Config(cred_path)["oracle"]
@@ -300,53 +279,92 @@ async def asyncSetUp(self):
creds["user"],
password=creds["password"],
)
- self.query = CosmosQuery.LEVEL_1_SOILMET_30MIN
+ self.table = CosmosTable.LEVEL_1_SOILMET_30MIN
async def asyncTearDown(self) -> None:
await self.oracle.connection.close()
- @parameterized.expand([-1, -423.78, CosmosSiteQuery.LEVEL_1_NMDB_1HOUR, "Four", MockDB(), {"a": 1}])
+ @parameterized.expand([-1, -423.78, CosmosQuery.ORACLE_LATEST_DATA, "Four", MockDB(), {"a": 1}])
@config_exists
- def test_query_value_check(self, query):
+ @pytest.mark.oracle
+ def test_table_value_check(self, table):
with self.assertRaises(TypeError):
BaseDevice(
- "test_id", self.oracle, MockMessageConnection(), query=query
+ "test_id", self.oracle, MockMessageConnection(), table=table
)
@pytest.mark.oracle
@config_exists
- async def test_error_if_query_not_given(self):
+ async def test_error_if_table_not_given(self):
with self.assertRaises(ValueError):
BaseDevice("site", self.oracle, MockMessageConnection())
- inst = BaseDevice("site", self.oracle, MockMessageConnection(), query=self.query)
+ inst = BaseDevice("site", self.oracle, MockMessageConnection(), table=self.table)
- self.assertEqual(inst.query, self.query)
+ self.assertEqual(inst.table, self.table)
@pytest.mark.oracle
@config_exists
async def test__repr__oracle_data(self):
- inst_oracle = BaseDevice("site", self.oracle, MockMessageConnection(), query=self.query)
- exp_oracle = f'BaseDevice("site", Oracle("{self.creds['dsn']}"), MockMessageConnection(), query=CosmosQuery.{self.query.name})'
+ inst_oracle = BaseDevice("site", self.oracle, MockMessageConnection(), table=self.table)
+ exp_oracle = f'BaseDevice("site", Oracle("{self.creds['dsn']}"), MockMessageConnection(), table=CosmosTable.{self.table.name})'
self.assertEqual(inst_oracle.__repr__(), exp_oracle)
- inst_not_oracle = BaseDevice("site", MockDB(), MockMessageConnection(), query=self.query)
+ inst_not_oracle = BaseDevice("site", MockDB(), MockMessageConnection(), table=self.table)
exp_not_oracle = 'BaseDevice("site", MockDB(), MockMessageConnection())'
self.assertEqual(inst_not_oracle.__repr__(), exp_not_oracle)
with self.assertRaises(AttributeError):
- inst_not_oracle.query
+ inst_not_oracle.table
@pytest.mark.oracle
@config_exists
async def test__get_payload(self):
"""Tests that Cosmos payload retrieved."""
- inst = BaseDevice("MORLY", self.oracle, MockMessageConnection(), query=self.query)
+ inst = BaseDevice("MORLY", self.oracle, MockMessageConnection(), table=self.table)
+ print(inst)
+ payload = await inst._get_payload()
+
+ self.assertIsInstance(payload, dict)
+
+class TestBaseDevicesSQLite3Used(unittest.IsolatedAsyncioTestCase):
+
+ def setUp(self):
+ db_path = Path(Path(__file__).parents[1], "iotswarm","__assets__", "data", "cosmos.db")
+ if db_path.exists():
+ self.db = LoopingSQLite3(db_path)
+ self.table = CosmosTable.LEVEL_1_SOILMET_30MIN
+
+ @parameterized.expand([-1, -423.78, CosmosQuery.ORACLE_LATEST_DATA, "Four", MockDB(), {"a": 1}])
+ @sqlite_db_exist
+ def test_table_value_check(self, table):
+ with self.assertRaises(TypeError):
+ BaseDevice(
+ "test_id", self.db, MockMessageConnection(), table=table
+ )
+
+ @sqlite_db_exist
+ def test_error_if_table_not_given(self):
+
+ with self.assertRaises(ValueError):
+ BaseDevice("site", self.db, MockMessageConnection())
+
+
+ inst = BaseDevice("site", self.db, MockMessageConnection(), table=self.table)
+
+ self.assertEqual(inst.table, self.table)
+
+ @sqlite_db_exist
+ async def test__get_payload(self):
+ """Tests that Cosmos payload retrieved."""
+
+ inst = BaseDevice("MORLY", self.db, MockMessageConnection(), table=self.table)
+
payload = await inst._get_payload()
self.assertIsInstance(payload, dict)
@@ -448,6 +466,7 @@ async def test__get_payload(self):
self.assertIsInstance(payload, list)
self.assertEqual(len(payload),0)
+
class TestCr1000xDevice(unittest.TestCase):
"""Test suite for the CR1000X Device."""
diff --git a/src/tests/test_messaging.py b/src/tests/test_messaging.py
index 4fe9271..49e86fe 100644
--- a/src/tests/test_messaging.py
+++ b/src/tests/test_messaging.py
@@ -11,13 +11,20 @@
from parameterized import parameterized
import logging
-CONFIG_PATH = Path(
- Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg"
-)
+
+ASSETS_PATH = Path(Path(__file__).parents[1], "iotswarm", "__assets__")
+CONFIG_PATH = Path(ASSETS_PATH, "config.cfg")
+
config_exists = pytest.mark.skipif(
not CONFIG_PATH.exists(),
reason="Config file `config.cfg` not found in root directory.",
)
+certs_exist = pytest.mark.skipif(
+ not Path(ASSETS_PATH, ".certs", "cosmos_soilmet-certificate.pem.crt").exists()
+ or not Path(ASSETS_PATH, ".certs", "cosmos_soilmet-private.pem.key").exists()
+ or not Path(ASSETS_PATH, ".certs", "AmazonRootCA1.pem").exists(),
+ reason="IotCore certificates not present.",
+)
class TestBaseClass(unittest.TestCase):
@@ -43,24 +50,20 @@ def test_instantiation(self):
self.assertIsInstance(mock, MessagingBaseClass)
- def test_logger_used(self):
+ def test_no_logger_used(self):
- mock = MockMessageConnection()
+ with self.assertNoLogs():
+ mock = MockMessageConnection()
+ mock.send_message("")
- with self.assertLogs() as cm:
- mock.send_message()
- self.assertEqual(
- cm.output,
- [
- "INFO:iotswarm.messaging.core.MockMessageConnection:Message was sent."
- ],
- )
-
- with self.assertLogs() as cm:
- mock.send_message(use_logger=logging.getLogger("mine"))
+ def test_logger_used(self):
+ logger = logging.getLogger("testlogger")
+ with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
+ mock = MockMessageConnection(inherit_logger=logger)
+ mock.send_message("")
self.assertEqual(
cm.output,
- ["INFO:mine:Message was sent."],
+ ["DEBUG:testlogger.MockMessageConnection:Message was sent."],
)
@@ -73,6 +76,7 @@ def setUp(self) -> None:
self.config = config["iot_core"]
@config_exists
+ @certs_exist
def test_instantiation(self):
instance = IotCoreMQTTConnection(**self.config, client_id="test_id")
@@ -82,6 +86,7 @@ def test_instantiation(self):
self.assertIsInstance(instance.connection, awscrt.mqtt.Connection)
@config_exists
+ @certs_exist
def test_non_string_arguments(self):
with self.assertRaises(TypeError):
@@ -130,6 +135,7 @@ def test_non_string_arguments(self):
)
@config_exists
+ @certs_exist
def test_port(self):
# Expect one of defaults if no port given
@@ -151,6 +157,7 @@ def test_port(self):
@parameterized.expand([-4, {"f": 4}, "FOUR"])
@config_exists
+ @certs_exist
def test_bad_port_type(self, port):
with self.assertRaises((TypeError, ValueError)):
@@ -164,6 +171,7 @@ def test_bad_port_type(self, port):
)
@config_exists
+ @certs_exist
def test_clean_session_set(self):
expected = False
@@ -180,6 +188,7 @@ def test_clean_session_set(self):
@parameterized.expand([0, -1, "true", None])
@config_exists
+ @certs_exist
def test_bad_clean_session_type(self, clean_session):
with self.assertRaises(TypeError):
@@ -188,6 +197,7 @@ def test_bad_clean_session_type(self, clean_session):
)
@config_exists
+ @certs_exist
def test_keep_alive_secs_set(self):
# Test defualt is not none
instance = IotCoreMQTTConnection(**self.config, client_id="test_id")
@@ -200,8 +210,9 @@ def test_keep_alive_secs_set(self):
)
self.assertEqual(instance.connection.keep_alive_secs, expected)
- @parameterized.expand(["FOURTY", "True", None])
+ @parameterized.expand(["FOURTY", "True"])
@config_exists
+ @certs_exist
def test_bad_keep_alive_secs_type(self, secs):
with self.assertRaises(TypeError):
IotCoreMQTTConnection(
@@ -209,7 +220,8 @@ def test_bad_keep_alive_secs_type(self, secs):
)
@config_exists
- def test_logger_set(self):
+ @certs_exist
+ def test_no_logger_set(self):
inst = IotCoreMQTTConnection(**self.config, client_id="test_id")
expected = 'No message to send for topic: "mytopic".'
@@ -223,12 +235,21 @@ def test_logger_set(self):
],
)
- with self.assertLogs() as cm:
- inst.send_message(None, "mytopic", use_logger=logging.getLogger("mine"))
+ @config_exists
+ @certs_exist
+ def test_logger_set(self):
+ logger = logging.getLogger("mine")
+ inst = IotCoreMQTTConnection(
+ **self.config, client_id="test_id", inherit_logger=logger
+ )
+
+ expected = 'No message to send for topic: "mytopic".'
+ with self.assertLogs(logger=logger, level=logging.INFO) as cm:
+ inst.send_message(None, "mytopic")
self.assertEqual(
cm.output,
- [f"ERROR:mine:{expected}"],
+ [f"ERROR:mine.IotCoreMQTTConnection.client-test_id:{expected}"],
)