diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index f2d31a0..441f649 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -24,6 +24,9 @@ jobs:
python -m pip install --upgrade pip
pip install flake8
pip install .[test]
+ - name: Build test DB
+ run: |
+ python src/tests/data/build_database.py
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
diff --git a/.gitignore b/.gitignore
index c88052f..091353a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,7 +8,9 @@ dist/*
config.cfg
aws-auth
*.log
-*.certs
+*.certs*
!.github/*
docs/_*
-conf.sh
\ No newline at end of file
+*.sh
+**/*.csv
+**/*.db
\ No newline at end of file
diff --git a/docs/index.rst b/docs/index.rst
index 1cdd28a..bc97815 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -14,6 +14,7 @@
:caption: Contents:
self
+ source/cli
source/modules
genindex
modindex
@@ -57,7 +58,7 @@ intialise a swarm with data sent every 30 minutes like so:
.. code-block:: shell
iot-swarm cosmos --dsn="xxxxxx" --user="xxxxx" --password="*****" \
- mqtt "aws" LEVEL_1_SOILMET_30MIN "client_id" \
+ mqtt LEVEL_1_SOILMET_30MIN "client_id" \
--endpoint="xxxxxxx" \
--cert-path="C:\path\..." \
--key-path="C:\path\..." \
@@ -84,7 +85,7 @@ Then the CLI can be called more cleanly:
.. code-block:: shell
- iot-swarm cosmos mqtt "aws" LEVEL_1_SOILMET_30MIN "client_id" --sleep-time=1800 --swarm-name="my-swarm"
+ iot-swarm cosmos mqtt LEVEL_1_SOILMET_30MIN "client_id" --sleep-time=1800 --swarm-name="my-swarm"
------------------------
Using the Python Modules
@@ -163,6 +164,31 @@ The system expects config credentials for the MQTT endpoint and the COSMOS Oracl
.. include:: example-config.cfg
+-------------------------------------------
+Looping Through a local database / csv file
+-------------------------------------------
+
+This package now supports using a CSV file or local SQLite database as the data source.
+There are 2 modules to support it: `db.LoopingCsvDB` and `db.LoopingSQLite3`. Each of them
+initializes from a local file and loops through the data for a given site id. The database
+objects store an in memory cache of each site ID and it's current index in the database.
+Once the end is reached, it loops back to the start for that site.
+
+For use in FDRI, 6 months of data was downloaded to CSV from the COSMOS-UK database, but
+the files are too large to be included in this repo, so they are stored in the `ukceh-fdri`
+`S3` bucket on AWS. There are scripts for regenerating the `.db` file in this repo:
+
+* `./src/iotswarm/__assets__/data/build_database.py`
+* `./src/tests/data/build_database.py`
+
+To use the 'official' data, it should be downloaded from the `S3` bucket and placed in
+`./src/iotswarm/__assets__/data` before running the script. This will build a `.db` file
+sorted in datetime order that the `LoopingSQLite3` class can operate with.
+
+.. warning::
+ The looping database classes assume that their data files are sorted, and make no
+ attempt to sort it themselves.
+
Indices and tables
==================
diff --git a/docs/source/cli.rst b/docs/source/cli.rst
new file mode 100644
index 0000000..0341699
--- /dev/null
+++ b/docs/source/cli.rst
@@ -0,0 +1,6 @@
+CLI
+===
+
+.. click:: iotswarm.scripts.cli:main
+ :prog: iot-swarm
+ :nested: full
\ No newline at end of file
diff --git a/docs/source/iotswarm.scripts.rst b/docs/source/iotswarm.scripts.rst
index 42b1dd6..81dd67a 100644
--- a/docs/source/iotswarm.scripts.rst
+++ b/docs/source/iotswarm.scripts.rst
@@ -1,5 +1,5 @@
iotswarm.scripts package
-==================================
+========================
Submodules
----------
diff --git a/pyproject.toml b/pyproject.toml
index e23f65b..a69e0ae 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[build-system]
requires = ["setuptools >= 61.0", "autosemver"]
-build-backend = "setuptools.build_meta"
+# build-backend = "setuptools.build_meta"
[project]
dependencies = [
@@ -8,12 +8,14 @@ dependencies = [
"boto3",
"autosemver",
"config",
+ "click",
"docutils<0.17",
"awscli",
"awscrt",
+ "awsiotsdk",
"oracledb",
"backoff",
- "click",
+ "pandas",
]
name = "iot-swarm"
dynamic = ["version"]
@@ -26,14 +28,18 @@ docs = ["sphinx", "sphinx-copybutton", "sphinx-rtd-theme", "sphinx-click"]
[project.scripts]
iot-swarm = "iotswarm.scripts.cli:main"
+
[tool.setuptools.dynamic]
version = { attr = "iotswarm.__version__" }
+
[tool.setuptools.packages.find]
where = ["src"]
+include = ["iotswarm*"]
+exclude = ["tests*"]
[tool.setuptools.package-data]
-"*" = ["*.*"]
+"iotswarm.__assets__" = ["loggers.ini"]
[tool.pytest.ini_options]
@@ -49,4 +55,12 @@ markers = [
]
[tool.coverage.run]
-omit = ["*example.py", "*__init__.py", "queries.py", "loggers.py", "cli.py"]
+omit = [
+ "*example.py",
+ "*__init__.py",
+ "queries.py",
+ "loggers.py",
+ "**/scripts/*.py",
+ "**/build_database.py",
+ "utils.py",
+]
diff --git a/src/iotswarm/__assets__/data/build_database.py b/src/iotswarm/__assets__/data/build_database.py
new file mode 100644
index 0000000..96c5838
--- /dev/null
+++ b/src/iotswarm/__assets__/data/build_database.py
@@ -0,0 +1,41 @@
+"""This script is responsible for building an SQL data file from the CSV files used
+by the cosmos network.
+
+The files are stored in AWS S3 and should be downloaded into this directory before continuing.
+They are:
+ * LEVEL_1_NMDB_1HOUR_DATA_TABLE.csv
+ * LEVEL_1_SOILMET_30MIN_DATA_TABLE.csv
+ * LEVEL_1_PRECIP_1MIN_DATA_TABLE.csv
+ * LEVEL_1_PRECIP_RAINE_1MIN_DATA_TABLE.csv
+
+Once installed, run this script to generate the .db file.
+"""
+
+from iotswarm.utils import build_database_from_csv
+from pathlib import Path
+from glob import glob
+from iotswarm.queries import CosmosTable
+
+
+def main(
+ csv_dir: str | Path = Path(__file__).parent,
+ database_output: str | Path = Path(Path(__file__).parent, "cosmos.db"),
+):
+ """Reads exported cosmos DB files from CSV format. Assumes that the files
+ look like: LEVEL_1_SOILMET_30MIN_DATA_TABLE.csv
+
+ Args:
+ csv_dir: Directory where the csv files are stored.
+ database_output: Output destination of the csv_data.
+ """
+ csv_files = glob("*.csv", root_dir=csv_dir)
+
+ tables = [CosmosTable[x.removesuffix("_DATA_TABLE.csv")] for x in csv_files]
+
+ for table, file in zip(tables, csv_files):
+ file = Path(csv_dir, file)
+ build_database_from_csv(file, database_output, table.value, sort_by="DATE_TIME")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/iotswarm/__init__.py b/src/iotswarm/__init__.py
index fbfe123..0131fea 100644
--- a/src/iotswarm/__init__.py
+++ b/src/iotswarm/__init__.py
@@ -1,8 +1,6 @@
import autosemver
try:
- __version__ = autosemver.packaging.get_current_version(
- project_name="iot-device-simulator"
- )
+ __version__ = autosemver.packaging.get_current_version(project_name="iot-swarm")
except:
__version__ = "unknown version"
diff --git a/src/iotswarm/db.py b/src/iotswarm/db.py
index 1386ddd..e7297bc 100644
--- a/src/iotswarm/db.py
+++ b/src/iotswarm/db.py
@@ -4,7 +4,14 @@
import getpass
import logging
import abc
-from iotswarm.queries import CosmosQuery, CosmosSiteQuery
+from iotswarm.queries import (
+ CosmosQuery,
+ CosmosTable,
+)
+import pandas as pd
+from pathlib import Path
+from math import nan
+import sqlite3
logger = logging.getLogger(__name__)
@@ -44,12 +51,65 @@ def query_latest_from_site():
return []
-class Oracle(BaseDatabase):
+class CosmosDB(BaseDatabase):
+ """Base class for databases using COSMOS_UK data."""
+
+ connection: object
+ """Connection to database."""
+
+ site_data_query: CosmosQuery
+ """SQL query for retrieving a single record."""
+
+ site_id_query: CosmosQuery
+ """SQL query for retrieving list of site IDs"""
+
+ @staticmethod
+ def _validate_table(table: CosmosTable) -> None:
+ """Validates that the query is legal"""
+
+ if not isinstance(table, CosmosTable):
+ raise TypeError(
+ f"`table` must be a `{CosmosTable.__class__}` Enum, not a `{type(table)}`"
+ )
+
+ @staticmethod
+ def _fill_query(query: str, table: CosmosTable) -> str:
+ """Fills a query string with a CosmosTable enum."""
+
+ CosmosDB._validate_table(table)
+
+ return query.format(table=table.value)
+
+ @staticmethod
+ def _validate_max_sites(max_sites: int) -> int:
+ """Validates that a valid maximum sites is given:
+ Args:
+ max_sites: The maximum number of sites required.
+
+ Returns:
+ An integer 0 or more.
+ """
+
+ if max_sites is not None:
+ max_sites = int(max_sites)
+ if max_sites < 0:
+ raise ValueError(
+ f"`max_sites` must be 1 or more, or 0 for no maximum. Received: {max_sites}"
+ )
+
+ return max_sites
+
+
+class Oracle(CosmosDB):
"""Class for handling oracledb logic and retrieving values from DB."""
connection: oracledb.Connection
"""Connection to oracle database."""
+ site_data_query = CosmosQuery.ORACLE_LATEST_DATA
+
+ site_id_query = CosmosQuery.ORACLE_SITE_IDS
+
def __repr__(self):
parent_repr = (
super().__repr__().lstrip(f"{self.__class__.__name__}(").rstrip(")")
@@ -64,7 +124,14 @@ def __repr__(self):
)
@classmethod
- async def create(cls, dsn: str, user: str, password: str = None, **kwargs):
+ async def create(
+ cls,
+ dsn: str,
+ user: str,
+ password: str = None,
+ inherit_logger: logging.Logger | None = None,
+ **kwargs,
+ ):
"""Factory method for initialising the class.
Initialization is done through the `create() method`: `Oracle.create(...)`.
@@ -72,6 +139,7 @@ async def create(cls, dsn: str, user: str, password: str = None, **kwargs):
dsn: Oracle data source name.
user: Username used for query.
pw: User password for auth.
+ inherit_logger: Uses the given logger if provided
"""
if not password:
@@ -83,32 +151,31 @@ async def create(cls, dsn: str, user: str, password: str = None, **kwargs):
dsn=dsn, user=user, password=password
)
+ if inherit_logger is not None:
+ self._instance_logger = inherit_logger.getChild(self.__class__.__name__)
+ else:
+ self._instance_logger = logger.getChild(self.__class__.__name__)
+
self._instance_logger.info("Initialized Oracle connection.")
return self
- async def query_latest_from_site(self, site_id: str, query: CosmosQuery) -> dict:
+ async def query_latest_from_site(self, site_id: str, table: CosmosTable) -> dict:
"""Requests the latest data from a table for a specific site.
Args:
site_id: ID of the site to retrieve records from.
- query: Query to parse and submit.
+ table: A valid table from the database
Returns:
dict | None: A dict containing the database columns as keys, and the values as values.
Returns `None` if no data retrieved.
"""
- if not isinstance(query, CosmosQuery):
- raise TypeError(
- f"`query` must be a `CosmosQuery` Enum, not a `{type(query)}`"
- )
+ query = self._fill_query(self.site_data_query, table)
- async with self.connection.cursor() as cursor:
- await cursor.execute(
- query.value,
- mysite=site_id,
- )
+ with self.connection.cursor() as cursor:
+ await cursor.execute(query, site_id=site_id)
columns = [i[0] for i in cursor.description]
data = await cursor.fetchone()
@@ -119,23 +186,92 @@ async def query_latest_from_site(self, site_id: str, query: CosmosQuery) -> dict
return dict(zip(columns, data))
async def query_site_ids(
- self, query: CosmosSiteQuery, max_sites: int | None = None
+ self, table: CosmosTable, max_sites: int | None = None
) -> list:
"""query_site_ids returns a list of site IDs from COSMOS database
Args:
- query: The query to run.
+ table: A valid table from the database
max_sites: Maximum number of sites to retreive
Returns:
List[str]: A list of site ID strings.
"""
- if not isinstance(query, CosmosSiteQuery):
- raise TypeError(
- f"`query` must be a `CosmosSiteQuery` Enum, not a `{type(query)}`"
- )
+ max_sites = self._validate_max_sites(max_sites)
+
+ query = self._fill_query(self.site_id_query, table)
+
+ async with self.connection.cursor() as cursor:
+ await cursor.execute(query)
+
+ data = await cursor.fetchall()
+ if max_sites == 0:
+ data = [x[0] for x in data]
+ else:
+ data = [x[0] for x in data[:max_sites]]
+
+ if not data:
+ data = []
+
+ return data
+
+
+class LoopingCsvDB(BaseDatabase):
+ """A database that reads from csv files and loops through items
+ for a given table or site. The site and index is remembered via a
+ dictionary key and incremented each time data is requested."""
+
+ connection: pd.DataFrame
+ """Connection to the pd object holding data."""
+
+ cache: dict
+ """Cache object containing current index of each site queried."""
+
+ @staticmethod
+ def _get_connection(*args) -> pd.DataFrame:
+ """Gets the database connection."""
+ return pd.read_csv(*args)
+
+ def __init__(self, csv_file: str | Path):
+ """Initialises the database object.
+
+ Args:
+ csv_file: A pathlike object pointing to the datafile.
+ """
+
+ BaseDatabase.__init__(self)
+ self.connection = self._get_connection(csv_file)
+ self.cache = dict()
+
+ def query_latest_from_site(self, site_id: str) -> dict:
+ """Queries the datbase for a `SITE_ID` incrementing by 1 each time called
+ for a specific site. If the end is reached, it loops back to the start.
+
+ Args:
+ site_id: ID of the site to query for.
+ Returns:
+ A dict of the data row.
+ """
+
+ data = self.connection.query("SITE_ID == @site_id").replace({nan: None})
+
+ if site_id not in self.cache or self.cache[site_id] >= len(data):
+ self.cache[site_id] = 1
+ else:
+ self.cache[site_id] += 1
+
+ return data.iloc[self.cache[site_id] - 1].to_dict()
+
+ def query_site_ids(self, max_sites: int | None = None) -> list:
+ """query_site_ids returns a list of site IDs from the database
+
+ Args:
+ max_sites: Maximum number of sites to retreive
+ Returns:
+ List[str]: A list of site ID strings.
+ """
if max_sites is not None:
max_sites = int(max_sites)
if max_sites < 0:
@@ -143,10 +279,115 @@ async def query_site_ids(
f"`max_sites` must be 1 or more, or 0 for no maximum. Received: {max_sites}"
)
- async with self.connection.cursor() as cursor:
- await cursor.execute(query.value)
+ sites = self.connection["SITE_ID"].drop_duplicates().to_list()
- data = await cursor.fetchall()
+ if max_sites is not None and max_sites > 0:
+ sites = sites[:max_sites]
+
+ return sites
+
+
+class LoopingSQLite3(CosmosDB, LoopingCsvDB):
+ """A database that reads from .db files using sqlite3 and loops through
+ entries in sequential order. There is a script that generates the .db file
+ in the `__assets__/data` directory relative to this file. .csv datasets should
+ be downloaded from the accompanying S3 bucket before running."""
+
+ connection: sqlite3.Connection
+ """Connection to the database."""
+
+ site_data_query = CosmosQuery.SQLITE_LOOPED_DATA
+
+ site_id_query = CosmosQuery.SQLITE_SITE_IDS
+
+ @staticmethod
+ def _get_connection(*args) -> sqlite3.Connection:
+ """Gets a database connection."""
+
+ return sqlite3.connect(*args)
+
+ def __init__(self, db_file: str | Path):
+ """Initialises the database object.
+
+ Args:
+ csv_file: A pathlike object pointing to the datafile.
+ """
+ LoopingCsvDB.__init__(self, db_file)
+
+ self.cursor = self.connection.cursor()
+
+ def query_latest_from_site(self, site_id: str, table: CosmosTable) -> dict:
+ """Queries the datbase for a `SITE_ID` incrementing by 1 each time called
+ for a specific site. If the end is reached, it loops back to the start.
+
+ Args:
+ site_id: ID of the site to query for.
+ table: A valid table from the database
+ Returns:
+ A dict of the data row.
+ """
+ query = self._fill_query(self.site_data_query, table)
+
+ if site_id not in self.cache:
+ self.cache[site_id] = 0
+ else:
+ self.cache[site_id] += 1
+
+ data = self._query_latest_from_site(
+ query, {"site_id": site_id, "offset": self.cache[site_id]}
+ )
+
+ if data is None:
+ self.cache[site_id] = 0
+
+ data = self._query_latest_from_site(
+ query, {"site_id": site_id, "offset": self.cache[site_id]}
+ )
+
+ return data
+
+ def _query_latest_from_site(self, query, arg_dict: dict) -> dict:
+ """Requests the latest data from a table for a specific site.
+
+ Args:
+ table: A valid table from the database
+ arg_dict: Dictionary of query arguments.
+
+ Returns:
+ dict | None: A dict containing the database columns as keys, and the values as values.
+ Returns `None` if no data retrieved.
+ """
+
+ self.cursor.execute(query, arg_dict)
+
+ columns = [i[0] for i in self.cursor.description]
+ data = self.cursor.fetchone()
+
+ if not data:
+ return None
+
+ return dict(zip(columns, data))
+
+ def query_site_ids(self, table: CosmosTable, max_sites: int | None = None) -> list:
+ """query_site_ids returns a list of site IDs from COSMOS database
+
+ Args:
+ table: A valid table from the database
+ max_sites: Maximum number of sites to retreive
+
+ Returns:
+ List[str]: A list of site ID strings.
+ """
+
+ query = self._fill_query(self.site_id_query, table)
+
+ max_sites = self._validate_max_sites(max_sites)
+
+ try:
+ cursor = self.connection.cursor()
+ cursor.execute(query)
+
+ data = cursor.fetchall()
if max_sites == 0:
data = [x[0] for x in data]
else:
@@ -154,5 +395,7 @@ async def query_site_ids(
if not data:
data = []
+ finally:
+ cursor.close()
- return data
+ return data
diff --git a/src/iotswarm/devices.py b/src/iotswarm/devices.py
index 8c9c6ce..2b2cd1d 100644
--- a/src/iotswarm/devices.py
+++ b/src/iotswarm/devices.py
@@ -3,14 +3,14 @@
import asyncio
import logging
from iotswarm import __version__ as package_version
-from iotswarm.queries import CosmosQuery
-from iotswarm.db import BaseDatabase, Oracle
-from iotswarm.messaging.core import MessagingBaseClass
+from iotswarm.queries import CosmosTable
+from iotswarm.db import BaseDatabase, CosmosDB, Oracle, LoopingCsvDB, LoopingSQLite3
+from iotswarm.messaging.core import MessagingBaseClass, MockMessageConnection
from iotswarm.messaging.aws import IotCoreMQTTConnection
-import random
+from typing import List
from datetime import datetime
+import random
import enum
-from typing import List
logger = logging.getLogger(__name__)
@@ -45,8 +45,8 @@ class BaseDevice:
connection: MessagingBaseClass
"""Connection to the data receiver."""
- query: CosmosQuery
- """SQL query sent to database if Oracle type selected as `data_source`."""
+ table: CosmosTable
+ """SQL table used in queries if Oracle or LoopingSQLite3 selected as `data_source`."""
mqtt_base_topic: str
"""Base topic for mqtt topic."""
@@ -85,7 +85,7 @@ def __init__(
max_cycles: int | None = None,
delay_start: bool | None = None,
inherit_logger: logging.Logger | None = None,
- query: CosmosQuery | None = None,
+ table: CosmosTable | None = None,
mqtt_topic: str | None = None,
mqtt_prefix: str | None = None,
mqtt_suffix: str | None = None,
@@ -100,7 +100,7 @@ def __init__(
max_cycles: Maximum number of cycles before shutdown.
delay_start: Adds a random delay to first invocation from 0 - `sleep_time`.
inherit_logger: Override for the module logger.
- query: Sets the query used in COSMOS database. Ignored if `data_source` is not a Cosmos object.
+ table: A valid table from the database
mqtt_prefix: Prefixes the MQTT topic if MQTT messaging used.
mqtt_suffix: Suffixes the MQTT topic if MQTT messaging used.
"""
@@ -111,10 +111,18 @@ def __init__(
raise TypeError(
f"`data_source` must be a `BaseDatabase`. Received: {data_source}."
)
- if isinstance(data_source, Oracle) and query is None:
- raise ValueError(
- f"`query` must be provided if `data_source` is type `Oracle`."
- )
+ if isinstance(data_source, (CosmosDB)):
+
+ if table is None:
+ raise ValueError(
+ f"`table` must be provided if `data_source` is type `OracleDB`."
+ )
+ elif not isinstance(table, CosmosTable):
+ raise TypeError(
+ f'table must be a "{CosmosTable.__class__}", not "{type(table)}"'
+ )
+
+ self.table = table
self.data_source = data_source
if not isinstance(connection, MessagingBaseClass):
@@ -144,18 +152,11 @@ def __init__(
)
self.delay_start = delay_start
- if query is not None and isinstance(self.data_source, Oracle):
- if not isinstance(query, CosmosQuery):
- raise TypeError(
- f"`query` must be a `CosmosQuery`. Received: {type(query)}."
- )
- self.query = query
-
- if isinstance(connection, IotCoreMQTTConnection):
+ if isinstance(connection, (IotCoreMQTTConnection, MockMessageConnection)):
if mqtt_topic is not None:
self.mqtt_topic = str(mqtt_topic)
else:
- self.mqtt_topic = f"{self.device_type}/{self.device_id}"
+ self.mqtt_topic = f"{self.device_id}"
if mqtt_prefix is not None:
self.mqtt_prefix = str(mqtt_prefix)
@@ -190,16 +191,16 @@ def __repr__(self):
if self.delay_start != self.__class__.delay_start
else ""
)
- query_arg = (
- f", query={self.query.__class__.__name__}.{self.query.name}"
- if isinstance(self.data_source, Oracle)
+ table_arg = (
+ f", table={self.table.__class__.__name__}.{self.table.name}"
+ if isinstance(self.data_source, CosmosDB)
else ""
)
mqtt_topic_arg = (
f', mqtt_topic="{self.mqtt_base_topic}"'
if hasattr(self, "mqtt_base_topic")
- and self.mqtt_base_topic != f"{self.device_type}/{self.device_id}"
+ and self.mqtt_base_topic != self.device_id
else ""
)
@@ -222,7 +223,7 @@ def __repr__(self):
f"{sleep_time_arg}"
f"{max_cycles_arg}"
f"{delay_start_arg}"
- f"{query_arg}"
+ f"{table_arg}"
f"{mqtt_topic_arg}"
f"{mqtt_prefix_arg}"
f"{mqtt_suffix_arg}"
@@ -243,7 +244,7 @@ def _send_payload(self, payload: dict):
async def run(self):
"""The main invocation of the method. Expects a Oracle object to do work on
- and a query to retrieve. Runs asynchronously until `max_cycles` is reached.
+ and a table to retrieve. Runs asynchronously until `max_cycles` is reached.
Args:
message_connection: The message object to send data through
@@ -260,6 +261,9 @@ async def run(self):
if payload:
self._instance_logger.debug("Requesting payload submission.")
self._send_payload(payload)
+ self._instance_logger.info(
+ f'Message sent{f" to topic: {self.mqtt_topic}" if self.mqtt_topic else ""}'
+ )
else:
self._instance_logger.warning(f"No data found.")
@@ -273,9 +277,12 @@ async def _get_payload(self):
"""Method for grabbing the payload to send"""
if isinstance(self.data_source, Oracle):
return await self.data_source.query_latest_from_site(
- self.device_id, self.query
+ self.device_id, self.table
)
-
+ elif isinstance(self.data_source, LoopingSQLite3):
+ return self.data_source.query_latest_from_site(self.device_id, self.table)
+ elif isinstance(self.data_source, LoopingCsvDB):
+ return self.data_source.query_latest_from_site(self.device_id)
elif isinstance(self.data_source, BaseDatabase):
return self.data_source.query_latest_from_site()
diff --git a/src/iotswarm/messaging/aws.py b/src/iotswarm/messaging/aws.py
index 8205964..be1a708 100644
--- a/src/iotswarm/messaging/aws.py
+++ b/src/iotswarm/messaging/aws.py
@@ -2,6 +2,7 @@
import awscrt
from awscrt import mqtt
+from awsiot import mqtt_connection_builder
import awscrt.io
import json
from awscrt.exceptions import AwsCrtError
@@ -34,6 +35,7 @@ def __init__(
port: int | None = None,
clean_session: bool = False,
keep_alive_secs: int = 1200,
+ inherit_logger: logging.Logger | None = None,
**kwargs,
) -> None:
"""Initializes the class.
@@ -47,7 +49,7 @@ def __init__(
port: Port used by endpoint. Guesses correct port if not given.
clean_session: Builds a clean MQTT session if true. Defaults to False.
keep_alive_secs: Time to keep connection alive. Defaults to 1200.
- topic_prefix: A topic prefixed to MQTT topic, useful for attaching a "Basic Ingest" rule. Defaults to None.
+ inherit_logger: Override for the module logger.
"""
if not isinstance(endpoint, str):
@@ -88,41 +90,31 @@ def __init__(
if port < 0:
raise ValueError(f"`port` cannot be less than 0. Received: {port}.")
- socket_options = awscrt.io.SocketOptions()
- socket_options.connect_timeout_ms = 5000
- socket_options.keep_alive = False
- socket_options.keep_alive_timeout_secs = 0
- socket_options.keep_alive_interval_secs = 0
- socket_options.keep_alive_max_probes = 0
-
- client_bootstrap = awscrt.io.ClientBootstrap.get_or_create_static_default()
-
- tls_ctx = awscrt.io.ClientTlsContext(tls_ctx_options)
- mqtt_client = awscrt.mqtt.Client(client_bootstrap, tls_ctx)
-
- self.connection = awscrt.mqtt.Connection(
- client=mqtt_client,
+ self.connection = mqtt_connection_builder.mtls_from_path(
+ endpoint=endpoint,
+ port=port,
+ cert_filepath=cert_path,
+ pri_key_filepath=key_path,
+ ca_filepath=ca_cert_path,
on_connection_interrupted=self._on_connection_interrupted,
on_connection_resumed=self._on_connection_resumed,
client_id=client_id,
- host_name=endpoint,
- port=port,
+ proxy_options=None,
clean_session=clean_session,
- reconnect_min_timeout_secs=5,
- reconnect_max_timeout_secs=60,
keep_alive_secs=keep_alive_secs,
- ping_timeout_ms=3000,
- protocol_operation_timeout_ms=0,
- socket_options=socket_options,
- use_websockets=False,
on_connection_success=self._on_connection_success,
on_connection_failure=self._on_connection_failure,
on_connection_closed=self._on_connection_closed,
)
- self._instance_logger = logger.getChild(
- f"{self.__class__.__name__}.client-{client_id}"
- )
+ if inherit_logger is not None:
+ self._instance_logger = inherit_logger.getChild(
+ f"{self.__class__.__name__}.client-{client_id}"
+ )
+ else:
+ self._instance_logger = logger.getChild(
+ f"{self.__class__.__name__}.client-{client_id}"
+ )
def _on_connection_interrupted(
self, connection, error, **kwargs
@@ -182,23 +174,15 @@ def _disconnect(self): # pragma: no cover
disconnect_future = self.connection.disconnect()
disconnect_future.result()
- def send_message(
- self, message: dict, topic: str, use_logger: logging.Logger | None = None
- ) -> None:
+ def send_message(self, message: dict, topic: str) -> None:
"""Sends a message to the endpoint.
Args:
message: The message to send.
topic: MQTT topic to send message under.
- use_logger: Sends log message with requested logger.
"""
- if use_logger is not None and isinstance(use_logger, logging.Logger):
- use_logger = use_logger
- else:
- use_logger = self._instance_logger
-
if not message:
- use_logger.error(f'No message to send for topic: "{topic}".')
+ self._instance_logger.error(f'No message to send for topic: "{topic}".')
return
if self.connected_flag == False:
@@ -212,6 +196,4 @@ def send_message(
qos=mqtt.QoS.AT_LEAST_ONCE,
)
- use_logger.info(f'Sent {sys.getsizeof(payload)} bytes to "{topic}"')
-
- # self._disconnect()
+ self._instance_logger.debug(f'Sent {sys.getsizeof(payload)} bytes to "{topic}"')
diff --git a/src/iotswarm/messaging/core.py b/src/iotswarm/messaging/core.py
index d362b30..0957d93 100644
--- a/src/iotswarm/messaging/core.py
+++ b/src/iotswarm/messaging/core.py
@@ -2,8 +2,6 @@
from abc import abstractmethod
import logging
-logger = logging.getLogger(__name__)
-
class MessagingBaseClass(ABC):
"""MessagingBaseClass Base class for messaging implementation
@@ -14,9 +12,20 @@ class MessagingBaseClass(ABC):
_instance_logger: logging.Logger
"""Logger handle used by instance."""
- def __init__(self):
-
- self._instance_logger = logger.getChild(self.__class__.__name__)
+ def __init__(
+ self,
+ inherit_logger: logging.Logger | None = None,
+ ):
+ """Initialises the class.
+ Args:
+ inherit_logger: Override for the module logger.
+ """
+ if inherit_logger is not None:
+ self._instance_logger = inherit_logger.getChild(self.__class__.__name__)
+ else:
+ self._instance_logger = logging.getLogger(__name__).getChild(
+ self.__class__.__name__
+ )
@property
@abstractmethod
@@ -37,15 +46,7 @@ class MockMessageConnection(MessagingBaseClass):
connection: None = None
"""Connection object. Not needed in a mock but must be implemented"""
- def send_message(self, use_logger: logging.Logger | None = None):
- """Consumes requests to send a message but does nothing with it.
-
- Args:
- use_logger: Sends log message with requested logger."""
-
- if use_logger is not None and isinstance(use_logger, logging.Logger):
- use_logger = use_logger
- else:
- use_logger = self._instance_logger
+ def send_message(self, *_):
+ """Consumes requests to send a message but does nothing with it."""
- use_logger.info("Message was sent.")
+ self._instance_logger.debug("Message was sent.")
diff --git a/src/iotswarm/queries.py b/src/iotswarm/queries.py
index f1174b2..dfda5c8 100644
--- a/src/iotswarm/queries.py
+++ b/src/iotswarm/queries.py
@@ -5,112 +5,60 @@
@enum.unique
-class CosmosQuery(StrEnum):
- """Class containing permitted SQL queries for retrieving sensor data."""
+class CosmosTable(StrEnum):
+ """Enums of approved COSMOS database tables."""
- LEVEL_1_SOILMET_30MIN = """SELECT * FROM COSMOS.LEVEL1_SOILMET_30MIN
-WHERE site_id = :mysite
-ORDER BY date_time DESC
-FETCH NEXT 1 ROWS ONLY"""
+ LEVEL_1_SOILMET_30MIN = "LEVEL1_SOILMET_30MIN"
+ LEVEL_1_NMDB_1HOUR = "LEVEL1_NMDB_1HOUR"
+ LEVEL_1_PRECIP_1MIN = "LEVEL1_PRECIP_1MIN"
+ LEVEL_1_PRECIP_RAINE_1MIN = "LEVEL1_PRECIP_RAINE_1MIN"
- """Query for retreiving data from the LEVEL1_SOILMET_30MIN table, containing
- calibration the core telemetry from COSMOS sites.
-
- .. code-block:: sql
- SELECT * FROM COSMOS.LEVEL1_SOILMET_30MIN
- WHERE site_id = :mysite
- ORDER BY date_time DESC
- FETCH NEXT 1 ROWS ONLY
- """
+@enum.unique
+class CosmosQuery(StrEnum):
+ """Enums of common queries in each databasing language."""
- LEVEL_1_NMDB_1HOUR = """SELECT * FROM COSMOS.LEVEL1_NMDB_1HOUR
-WHERE site_id = :mysite
-ORDER BY date_time DESC
-FETCH NEXT 1 ROWS ONLY"""
+ SQLITE_LOOPED_DATA = """SELECT * FROM {table}
+WHERE site_id = :site_id
+LIMIT 1 OFFSET :offset"""
- """Query for retreiving data from the Level1_NMDB_1HOUR table, containing
- calibration data from the Neutron Monitor DataBase.
+ """Query for retreiving data from a given table in sqlite format.
.. code-block:: sql
- SELECT * FROM COSMOS.LEVEL1_NMDB_1HOUR
- WHERE site_id = :mysite
- ORDER BY date_time DESC
- FETCH NEXT 1 ROWS ONLY
+ SELECT * FROM
+ WHERE site_id = :site_id
+ LIMIT 1 OFFSET :offset
"""
- LEVEL_1_PRECIP_1MIN = """SELECT * FROM COSMOS.LEVEL1_PRECIP_1MIN
-WHERE site_id = :mysite
+ ORACLE_LATEST_DATA = """SELECT * FROM COSMOS.{table}
+WHERE site_id = :site_id
ORDER BY date_time DESC
FETCH NEXT 1 ROWS ONLY"""
- """Query for retreiving data from the LEVEL1_PRECIP_1MIN table, containing
- the standard precipitation telemetry from COSMOS sites.
+ """Query for retreiving data from a given table in oracle format.
.. code-block:: sql
- SELECT * FROM COSMOS.LEVEL1_PRECIP_1MIN
- WHERE site_id = :mysite
+ SELECT * FROM
ORDER BY date_time DESC
FETCH NEXT 1 ROWS ONLY
"""
- LEVEL_1_PRECIP_RAINE_1MIN = """SELECT * FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN
-WHERE site_id = :mysite
-ORDER BY date_time DESC
-FETCH NEXT 1 ROWS ONLY"""
-
- """Query for retreiving data from the LEVEL1_PRECIP_RAINE_1MIN table, containing
- the rain[e] precipitation telemetry from COSMOS sites.
-
- .. code-block:: sql
-
- SELECT * FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN
- WHERE site_id = :mysite
- ORDER BY date_time DESC
- FETCH NEXT 1 ROWS ONLY
- """
-
-
-@enum.unique
-class CosmosSiteQuery(StrEnum):
- """Contains permitted SQL queries for extracting site IDs from database."""
-
- LEVEL_1_SOILMET_30MIN = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_SOILMET_30MIN"
-
- """Queries unique site IDs from LEVEL1_SOILMET_30MIN.
-
- .. code-block:: sql
-
- SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_SOILMET_30MIN
- """
-
- LEVEL_1_NMDB_1HOUR = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_NMDB_1HOUR"
-
- """Queries unique site IDs from LEVEL1_NMDB_1HOUR table.
-
- .. code-block:: sql
-
- SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_NMDB_1HOUR
- """
-
- LEVEL_1_PRECIP_1MIN = "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_1MIN"
+ SQLITE_SITE_IDS = "SELECT DISTINCT(site_id) FROM {table}"
- """Queries unique site IDs from the LEVEL1_PRECIP_1MIN table.
+ """Queries unique `site_id `s from a given table.
.. code-block:: sql
- SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_1MIN
+ SELECT DISTINCT(site_id) FROM
"""
- LEVEL_1_PRECIP_RAINE_1MIN = (
- "SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN"
- )
+ ORACLE_SITE_IDS = "SELECT UNIQUE(site_id) FROM COSMOS.{table}"
- """Queries unique site IDs from the LEVEL1_PRECIP_RAINE_1MIN table.
+ """Queries unique `site_id `s from a given table.
.. code-block:: sql
- SELECT UNIQUE(site_id) FROM COSMOS.LEVEL1_PRECIP_RAINE_1MIN
+ SELECT UNQIUE(site_id) FROM
"""
diff --git a/src/iotswarm/scripts/cli.py b/src/iotswarm/scripts/cli.py
index 7d79274..912a7be 100644
--- a/src/iotswarm/scripts/cli.py
+++ b/src/iotswarm/scripts/cli.py
@@ -1,24 +1,19 @@
"""CLI exposed when the package is installed."""
import click
-from iotswarm import queries
+from iotswarm import __version__ as package_version
+from iotswarm.queries import CosmosTable
from iotswarm.devices import BaseDevice, CR1000XDevice
from iotswarm.swarm import Swarm
-from iotswarm.db import Oracle
+from iotswarm.db import Oracle, LoopingCsvDB, LoopingSQLite3
from iotswarm.messaging.core import MockMessageConnection
from iotswarm.messaging.aws import IotCoreMQTTConnection
+import iotswarm.scripts.common as cli_common
import asyncio
from pathlib import Path
import logging
-TABLES = [table.name for table in queries.CosmosQuery]
-
-
-@click.command
-@click.pass_context
-def test(ctx: click.Context):
- """Enables testing of cosmos group arguments."""
- print(ctx.obj)
+TABLE_NAMES = [table.name for table in CosmosTable]
@click.group()
@@ -27,18 +22,64 @@ def test(ctx: click.Context):
"--log-config",
type=click.Path(exists=True),
help="Path to a logging config file. Uses default if not given.",
+ default=Path(Path(__file__).parents[1], "__assets__", "loggers.ini"),
)
-def main(ctx: click.Context, log_config: Path):
+@click.option(
+ "--log-level",
+ type=click.Choice(
+ ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
+ ),
+ help="Overrides the logging level.",
+ envvar="IOT_SWARM_LOG_LEVEL",
+)
+def main(ctx: click.Context, log_config: Path, log_level: str):
"""Core group of the cli."""
ctx.ensure_object(dict)
- if not log_config:
- log_config = Path(Path(__file__).parents[1], "__assets__", "loggers.ini")
-
logging.config.fileConfig(fname=log_config)
+ logger = logging.getLogger(__name__)
+
+ if log_level:
+ logger.setLevel(log_level)
+ click.echo(f"Set log level to {log_level}.")
+
+ ctx.obj["logger"] = logger
+
+
+main.add_command(cli_common.test)
+
+
+@main.command
+def get_version():
+ """Gets the package version"""
+ click.echo(package_version)
+
+@main.command()
+@click.pass_context
+@cli_common.iotcore_options
+def test_mqtt(
+ ctx,
+ message,
+ topic,
+ client_id: str,
+ endpoint: str,
+ cert_path: str,
+ key_path: str,
+ ca_cert_path: str,
+):
+ """Tests that a basic message can be sent via mqtt."""
+
+ connection = IotCoreMQTTConnection(
+ endpoint=endpoint,
+ cert_path=cert_path,
+ key_path=key_path,
+ ca_cert_path=ca_cert_path,
+ client_id=client_id,
+ inherit_logger=ctx.obj["logger"],
+ )
-main.add_command(test)
+ connection.send_message(message, topic)
@main.group()
@@ -78,113 +119,262 @@ def cosmos(ctx: click.Context, site: str, dsn: str, user: str, password: str):
ctx.obj["sites"] = site
-cosmos.add_command(test)
+cosmos.add_command(cli_common.test)
@cosmos.command()
@click.pass_context
-@click.argument("query", type=click.Choice(TABLES))
+@click.argument("table", type=click.Choice(TABLE_NAMES))
@click.option("--max-sites", type=click.IntRange(min=0), default=0)
-def list_sites(ctx, query, max_sites):
- """Lists site IDs from the database from table QUERY."""
+def list_sites(ctx, table, max_sites):
+ """Lists unique `site_id` from an oracle database table."""
- async def _list_sites(ctx, query):
+ async def _list_sites():
oracle = await Oracle.create(
dsn=ctx.obj["credentials"]["dsn"],
user=ctx.obj["credentials"]["user"],
password=ctx.obj["credentials"]["password"],
+ inherit_logger=ctx.obj["logger"],
)
- sites = await oracle.query_site_ids(
- queries.CosmosSiteQuery[query], max_sites=max_sites
- )
+ sites = await oracle.query_site_ids(CosmosTable[table], max_sites=max_sites)
return sites
- click.echo(asyncio.run(_list_sites(ctx, query)))
+ click.echo(asyncio.run(_list_sites()))
@cosmos.command()
@click.pass_context
@click.argument(
- "provider",
- type=click.Choice(["aws"]),
-)
-@click.argument(
- "query",
- type=click.Choice(TABLES),
-)
-@click.argument(
- "client-id",
- type=click.STRING,
- required=True,
+ "table",
+ type=click.Choice(TABLE_NAMES),
)
+@cli_common.device_options
+@cli_common.iotcore_options
+@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.")
+def mqtt(
+ ctx,
+ table,
+ endpoint,
+ cert_path,
+ key_path,
+ ca_cert_path,
+ client_id,
+ sleep_time,
+ max_cycles,
+ max_sites,
+ swarm_name,
+ delay_start,
+ mqtt_prefix,
+ mqtt_suffix,
+ dry,
+ device_type,
+):
+ """Sends The cosmos data via MQTT protocol using IoT Core.
+ Data is from the cosmos database TABLE and sent using CLIENT_ID.
+
+ Currently only supports sending through AWS IoT Core."""
+ table = CosmosTable[table]
+
+ async def _mqtt():
+ oracle = await Oracle.create(
+ dsn=ctx.obj["credentials"]["dsn"],
+ user=ctx.obj["credentials"]["user"],
+ password=ctx.obj["credentials"]["password"],
+ inherit_logger=ctx.obj["logger"],
+ )
+
+ sites = ctx.obj["sites"]
+ if len(sites) == 0:
+ sites = await oracle.query_site_ids(table, max_sites=max_sites)
+
+ if dry == True:
+ connection = MockMessageConnection(inherit_logger=ctx.obj["logger"])
+ else:
+ connection = IotCoreMQTTConnection(
+ endpoint=endpoint,
+ cert_path=cert_path,
+ key_path=key_path,
+ ca_cert_path=ca_cert_path,
+ client_id=client_id,
+ inherit_logger=ctx.obj["logger"],
+ )
+
+ if device_type == "basic":
+ DeviceClass = BaseDevice
+ elif device_type == "cr1000x":
+ DeviceClass = CR1000XDevice
+
+ site_devices = [
+ DeviceClass(
+ site,
+ oracle,
+ connection,
+ sleep_time=sleep_time,
+ table=table,
+ max_cycles=max_cycles,
+ delay_start=delay_start,
+ mqtt_prefix=mqtt_prefix,
+ mqtt_suffix=mqtt_suffix,
+ inherit_logger=ctx.obj["logger"],
+ )
+ for site in sites
+ ]
+
+ swarm = Swarm(site_devices, swarm_name)
+
+ await swarm.run()
+
+ asyncio.run(_mqtt())
+
+
+@main.group()
+@click.pass_context
@click.option(
- "--endpoint",
+ "--site",
type=click.STRING,
- required=True,
- envvar="IOT_SWARM_MQTT_ENDPOINT",
- help="Endpoint of the MQTT receiving host.",
+ multiple=True,
+ help="Adds a site to be initialized. Can be invoked multiple times for other sites."
+ " Grabs all sites from database query if none provided",
)
@click.option(
- "--cert-path",
+ "--file",
type=click.Path(exists=True),
required=True,
- envvar="IOT_SWARM_MQTT_CERT_PATH",
- help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.",
+ envvar="IOT_SWARM_CSV_DB",
+ help="*.csv file used to instantiate a pandas database.",
)
+def looping_csv(ctx, site, file):
+ """Instantiates a pandas dataframe from a csv file which is used as the database.
+ Responsibility falls on the user to ensure the correct file is selected."""
+
+ ctx.obj["db"] = LoopingCsvDB(file)
+ ctx.obj["sites"] = site
+
+
+looping_csv.add_command(cli_common.test)
+looping_csv.add_command(cli_common.list_sites)
+
+
+@looping_csv.command()
+@click.pass_context
+@cli_common.device_options
+@cli_common.iotcore_options
+@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.")
+def mqtt(
+ ctx,
+ endpoint,
+ cert_path,
+ key_path,
+ ca_cert_path,
+ client_id,
+ sleep_time,
+ max_cycles,
+ max_sites,
+ swarm_name,
+ delay_start,
+ mqtt_prefix,
+ mqtt_suffix,
+ dry,
+ device_type,
+):
+ """Sends The cosmos data via MQTT protocol using IoT Core.
+ Data is collected from the db using QUERY and sent using CLIENT_ID.
+
+ Currently only supports sending through AWS IoT Core."""
+
+ async def _mqtt():
+
+ sites = ctx.obj["sites"]
+ db = ctx.obj["db"]
+ if len(sites) == 0:
+ sites = db.query_site_ids(max_sites=max_sites)
+
+ if dry == True:
+ connection = MockMessageConnection(inherit_logger=ctx.obj["logger"])
+ else:
+ connection = IotCoreMQTTConnection(
+ endpoint=endpoint,
+ cert_path=cert_path,
+ key_path=key_path,
+ ca_cert_path=ca_cert_path,
+ client_id=client_id,
+ inherit_logger=ctx.obj["logger"],
+ )
+
+ if device_type == "basic":
+ DeviceClass = BaseDevice
+ elif device_type == "cr1000x":
+ DeviceClass = CR1000XDevice
+
+ site_devices = [
+ DeviceClass(
+ site,
+ db,
+ connection,
+ sleep_time=sleep_time,
+ max_cycles=max_cycles,
+ delay_start=delay_start,
+ mqtt_prefix=mqtt_prefix,
+ mqtt_suffix=mqtt_suffix,
+ inherit_logger=ctx.obj["logger"],
+ )
+ for site in sites
+ ]
+
+ swarm = Swarm(site_devices, swarm_name)
+
+ await swarm.run()
+
+ asyncio.run(_mqtt())
+
+
+@main.group()
+@click.pass_context
@click.option(
- "--key-path",
- type=click.Path(exists=True),
- required=True,
- envvar="IOT_SWARM_MQTT_KEY_PATH",
- help="Path to the private key that pairs with the `--cert-path`.",
+ "--site",
+ type=click.STRING,
+ multiple=True,
+ help="Adds a site to be initialized. Can be invoked multiple times for other sites."
+ " Grabs all sites from database query if none provided",
)
@click.option(
- "--ca-cert-path",
+ "--file",
type=click.Path(exists=True),
required=True,
- envvar="IOT_SWARM_MQTT_CA_CERT_PATH",
- help="Path to the root Certificate Authority (CA) for the MQTT host.",
-)
-@click.option(
- "--sleep-time",
- type=click.INT,
- help="The number of seconds each site goes idle after sending a message.",
-)
-@click.option(
- "--max-cycles",
- type=click.IntRange(0),
- help="Maximum number message sending cycles. Runs forever if set to 0.",
-)
-@click.option(
- "--max-sites",
- type=click.IntRange(0),
- help="Maximum number of sites allowed to initialize. No limit if set to 0.",
-)
-@click.option(
- "--swarm-name", type=click.STRING, help="Name given to swarm. Appears in the logs."
+ envvar="IOT_SWARM_LOCAL_DB",
+ help="*.db file used to instantiate a sqlite3 database.",
)
-@click.option(
- "--delay-start",
- is_flag=True,
- default=False,
- help="Adds a random delay before the first message from each site up to `--sleep-time`.",
-)
-@click.option(
- "--mqtt-prefix",
- type=click.STRING,
- help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
-)
-@click.option(
- "--mqtt-suffix",
- type=click.STRING,
- help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
+def looping_sqlite3(ctx, site, file):
+ """Instantiates a sqlite3 database as sensor source.."""
+ ctx.obj["db"] = LoopingSQLite3(file)
+ ctx.obj["sites"] = site
+
+
+looping_sqlite3.add_command(cli_common.test)
+
+
+@looping_sqlite3.command
+@click.pass_context
+@click.option("--max-sites", type=click.IntRange(min=0), default=0)
+@click.argument(
+ "table",
+ type=click.Choice(TABLE_NAMES),
)
+def list_sites(ctx, max_sites, table):
+ """Prints the sites present in database."""
+ sites = ctx.obj["db"].query_site_ids(table, max_sites=max_sites)
+ click.echo(sites)
+
+
+@looping_sqlite3.command()
+@click.pass_context
+@cli_common.device_options
+@cli_common.iotcore_options
+@click.argument("table", type=click.Choice(TABLE_NAMES))
@click.option("--dry", is_flag=True, default=False, help="Doesn't send out any data.")
-@click.option("--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic")
def mqtt(
ctx,
- provider,
- query,
+ table,
endpoint,
cert_path,
key_path,
@@ -200,34 +390,30 @@ def mqtt(
dry,
device_type,
):
- """Sends The cosmos data via MQTT protocol using PROVIDER.
+ """Sends The cosmos data via MQTT protocol using IoT Core.
Data is collected from the db using QUERY and sent using CLIENT_ID.
Currently only supports sending through AWS IoT Core."""
- async def _mqtt():
- oracle = await Oracle.create(
- dsn=ctx.obj["credentials"]["dsn"],
- user=ctx.obj["credentials"]["user"],
- password=ctx.obj["credentials"]["password"],
- )
+ table = CosmosTable[table]
- data_query = queries.CosmosQuery[query]
- site_query = queries.CosmosSiteQuery[query]
+ async def _mqtt():
sites = ctx.obj["sites"]
+ db = ctx.obj["db"]
if len(sites) == 0:
- sites = await oracle.query_site_ids(site_query, max_sites=max_sites)
+ sites = db.query_site_ids(table, max_sites=max_sites)
if dry == True:
- connection = MockMessageConnection()
- elif provider == "aws":
+ connection = MockMessageConnection(inherit_logger=ctx.obj["logger"])
+ else:
connection = IotCoreMQTTConnection(
endpoint=endpoint,
cert_path=cert_path,
key_path=key_path,
ca_cert_path=ca_cert_path,
client_id=client_id,
+ inherit_logger=ctx.obj["logger"],
)
if device_type == "basic":
@@ -238,14 +424,15 @@ async def _mqtt():
site_devices = [
DeviceClass(
site,
- oracle,
+ db,
connection,
sleep_time=sleep_time,
- query=data_query,
max_cycles=max_cycles,
delay_start=delay_start,
mqtt_prefix=mqtt_prefix,
mqtt_suffix=mqtt_suffix,
+ table=table,
+ inherit_logger=ctx.obj["logger"],
)
for site in sites
]
diff --git a/src/iotswarm/scripts/common.py b/src/iotswarm/scripts/common.py
new file mode 100644
index 0000000..1b8b1e0
--- /dev/null
+++ b/src/iotswarm/scripts/common.py
@@ -0,0 +1,113 @@
+"""Location for common CLI commands"""
+
+import click
+
+
+def device_options(function):
+ click.option(
+ "--sleep-time",
+ type=click.INT,
+ help="The number of seconds each site goes idle after sending a message.",
+ )(function)
+
+ click.option(
+ "--max-cycles",
+ type=click.IntRange(0),
+ help="Maximum number message sending cycles. Runs forever if set to 0.",
+ )(function)
+
+ click.option(
+ "--max-sites",
+ type=click.IntRange(0),
+ help="Maximum number of sites allowed to initialize. No limit if set to 0.",
+ )(function)
+
+ click.option(
+ "--swarm-name",
+ type=click.STRING,
+ help="Name given to swarm. Appears in the logs.",
+ )(function)
+
+ click.option(
+ "--delay-start",
+ is_flag=True,
+ default=False,
+ help="Adds a random delay before the first message from each site up to `--sleep-time`.",
+ )(function)
+
+ click.option(
+ "--device-type", type=click.Choice(["basic", "cr1000x"]), default="basic"
+ )(function)
+
+ return function
+
+
+def iotcore_options(function):
+ click.argument(
+ "client-id",
+ type=click.STRING,
+ required=True,
+ )(function)
+
+ click.option(
+ "--endpoint",
+ type=click.STRING,
+ required=True,
+ envvar="IOT_SWARM_MQTT_ENDPOINT",
+ help="Endpoint of the MQTT receiving host.",
+ )(function)
+
+ click.option(
+ "--cert-path",
+ type=click.Path(exists=True),
+ required=True,
+ envvar="IOT_SWARM_MQTT_CERT_PATH",
+ help="Path to public key certificate for the device. Must match key assigned to the `--client-id` in the cloud provider.",
+ )(function)
+
+ click.option(
+ "--key-path",
+ type=click.Path(exists=True),
+ required=True,
+ envvar="IOT_SWARM_MQTT_KEY_PATH",
+ help="Path to the private key that pairs with the `--cert-path`.",
+ )(function)
+
+ click.option(
+ "--ca-cert-path",
+ type=click.Path(exists=True),
+ required=True,
+ envvar="IOT_SWARM_MQTT_CA_CERT_PATH",
+ help="Path to the root Certificate Authority (CA) for the MQTT host.",
+ )(function)
+
+ click.option(
+ "--mqtt-prefix",
+ type=click.STRING,
+ help="Prefixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
+ )(function)
+
+ click.option(
+ "--mqtt-suffix",
+ type=click.STRING,
+ help="Suffixes the MQTT topic with a string. Can augment the calculated MQTT topic returned by each site.",
+ )(function)
+
+ return function
+
+
+@click.command
+@click.pass_context
+@click.option("--max-sites", type=click.IntRange(min=0), default=0)
+def list_sites(ctx, max_sites):
+ """Prints the sites present in database."""
+
+ sites = ctx.obj["db"].query_site_ids(max_sites=max_sites)
+ click.echo(sites)
+
+
+@click.command
+@click.pass_context
+def test(ctx: click.Context):
+ """Enables testing of cosmos group arguments."""
+ print(ctx.obj)
diff --git a/src/iotswarm/utils.py b/src/iotswarm/utils.py
index 6d9d988..3a935eb 100644
--- a/src/iotswarm/utils.py
+++ b/src/iotswarm/utils.py
@@ -1,6 +1,10 @@
"""Module for handling commonly reused utility functions."""
from datetime import date, datetime
+from pathlib import Path
+import pandas
+import sqlite3
+from glob import glob
def json_serial(obj: object):
@@ -13,3 +17,53 @@ def json_serial(obj: object):
return obj.__json__()
raise TypeError(f"Type {type(obj)} is not serializable.")
+
+
+def build_database_from_csv(
+ csv_file: str | Path,
+ database: str | Path,
+ table_name: str,
+ sort_by: str | None = None,
+ date_time_format: str = r"%d-%b-%y %H.%M.%S",
+) -> None:
+ """Adds a database table using a csv file with headers.
+
+ Args:
+ csv_file: A path to the csv.
+ database: Output destination of the database. File is created if not
+ existing.
+ table_name: Name of the table to add into database.
+ sort_by: Column to sort by
+ date_time_format: Format of datetime column
+ """
+
+ if not isinstance(csv_file, Path):
+ csv_file = Path(csv_file)
+
+ if not isinstance(database, Path):
+ database = Path(database)
+
+ if not csv_file.exists():
+ raise FileNotFoundError(f'csv_file does not exist: "{csv_file}"')
+
+ if not database.parent.exists():
+ raise NotADirectoryError(f'Database directory not found: "{database.parent}"')
+
+ with sqlite3.connect(database) as conn:
+ print(
+ f'Writing table: "{table_name}" from csv_file: "{csv_file}" to db: "{database}"'
+ )
+ print("Loading csv")
+ df = pandas.read_csv(csv_file)
+ print("Done")
+ print("Formatting dates")
+ df["DATE_TIME"] = pandas.to_datetime(df["DATE_TIME"], format=date_time_format)
+ print("Done")
+ if sort_by is not None:
+ print("Sorting.")
+ df = df.sort_values(by=sort_by)
+ print("Done")
+
+ print("Writing to db.")
+ df.to_sql(table_name, conn, if_exists="replace", index=False)
+ print("Writing complete.")
diff --git a/src/tests/data/ALIC1_4_ROWS.csv b/src/tests/data/ALIC1_4_ROWS.csv
new file mode 100644
index 0000000..55659cd
--- /dev/null
+++ b/src/tests/data/ALIC1_4_ROWS.csv
@@ -0,0 +1,5 @@
+"SITE_ID","DATE_TIME","WD","WD_STD","WS","WS_STD","PA","PA_STD","RH","RH_STD","TA","TA_STD","Q","Q_STD","SWIN","SWIN_STD","SWOUT","SWOUT_STD","LWIN_UNC","LWIN_UNC_STD","LWOUT_UNC","LWOUT_UNC_STD","LWIN","LWIN_STD","LWOUT","LWOUT_STD","TNR01C","TNR01C_STD","PRECIP","PRECIP_DIAG","SNOWD_DISTANCE_UNC","SNOWD_SIGNALQUALITY","CTS_MOD","CTS_BARE","CTS_MOD2","CTS_SNOW","PROFILE_VWC15","PROFILE_SOILEC15","PROFILE_VWC40","PROFILE_SOILEC40","PROFILE_VWC65","PROFILE_SOILEC65","TDT1_VWC","TDT1_TSOIL","TDT1_SOILPERM","TDT1_SOILEC","TDT2_VWC","TDT2_TSOIL","TDT2_SOILPERM","TDT2_SOILEC","STP_TSOIL50","STP_TSOIL20","STP_TSOIL10","STP_TSOIL5","STP_TSOIL2","G1","G1_STD","G1_MV","G1_CAL","G2","G2_STD","G2_MV","G2_CAL","BATTV","SCANS","RECORD","CTS_MOD_DIAG_1","CTS_MOD_DIAG_2","CTS_MOD_DIAG_3","CTS_MOD_DIAG_4","CTS_MOD_DIAG_5","UX","UZ","UY","STDEV_UX","STDEV_UY","STDEV_UZ","COV_UX_UY","COV_UX_UZ","COV_UY_UZ","CTS_SNOW_DIAG_1","CTS_SNOW_DIAG_2","CTS_SNOW_DIAG_3","CTS_SNOW_DIAG_4","CTS_SNOW_DIAG_5","SNOWD_DISTANCE_COR","WS_RES","HS","TAU","U_STAR","TS","STDEV_TS","COV_TS_UX","COV_TS_UY","COV_TS_UZ","RHO_A_MEAN","E_SAT_MEAN","METPAK_SAMPLES","CTS_MOD_RH","CTS_MOD_TEMP","TDT3_VWC","TDT3_TSOIL","TDT3_SOILEC","TDT3_SOILPERM","TDT4_VWC","TDT4_TSOIL","TDT4_SOILEC","TDT4_SOILPERM","TDT5_VWC","TDT5_TSOIL","TDT5_SOILEC","TDT5_SOILPERM","TDT6_VWC","TDT6_TSOIL","TDT6_SOILEC","TDT6_SOILPERM","TDT7_VWC","TDT7_TSOIL","TDT7_SOILEC","TDT7_SOILPERM","TDT8_VWC","TDT8_TSOIL","TDT8_SOILEC","TDT8_SOILPERM","TDT9_VWC","TDT9_TSOIL","TDT9_SOILEC","TDT9_SOILPERM","TDT10_VWC","TDT10_TSOIL","TDT10_SOILEC","TDT10_SOILPERM","CTS_SNOW_TEMP","CTS_SNOW_RH","CTS_MOD2_RH","CTS_MOD2_TEMP","SWIN_MULT","SWOUT_MULT","LWIN_MULT","LWOUT_MULT","PRECIP_TIPPING_A","PRECIP_TIPPING_B","PRECIP_TIPPING2","TBR_TIP","CTS_MOD_PERIOD","CTS_MOD2_PERIOD","CTS_SNOW_PERIOD","WM_NAN_COUNT","RH_INTERNAL"
+"ALIC1",01-JUN-24,353.4,28.38,2.397,1.309,1010.926,0.044,72.58,0.641,12.34,0.05,7.891,,-0.941,0.245,0.454,0.097,-17.14,2.736,-0.303,0.172,359.1,2.68,376,0.191,12.27,0.022,0,0,,,519,,,,,,,,,,43.84,12.72,27.98,0.48,35.03,12.61,20.29,0.24,11.91,12.37,12.45,12.45,12.44,-3.3016,0.10141,-0.15144,21.8,-3.27326,0.1105,-0.14011,23.36,13.21,180,11153,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.23,1.433,,2.1,18.9,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1800,,,99999.99999,
+"ALIC1",01-JUN-24,358,29.12,2.045,1.015,1011.057,0.049,70.59,0.34,12.23,0.139,7.623,,-2.57,1.666,0.552,0.318,-45.27,28,-1.624,1.119,330.1,28.91,373.8,2.085,12.1,0.203,0,0,,,536,,,,,,,,,,43.61,12.72,27.77,0.51,34.91,12.59,20.2,0.27,11.91,12.36,12.43,12.42,12.42,-2.78394,0.58028,-0.12777,21.77,-3.0528,0.35922,-0.13067,23.36,13.19,180,11154,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.231,1.424,,2.1,18.8,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1800,,,99999.99999,
+"ALIC1",01-JUN-24,354.8,26.46,1.531,0.902,1011.085,0.036,73,0.945,11.38,0.31,7.471,,-4.584,0.342,1.413,0.163,-76.38,4.869,-2.488,0.493,293,3.633,366.9,1.718,10.95,0.381,0,0,,,504,,,,,,,,,,43.16,12.7,27.35,0.54,34.91,12.59,20.2,0.25,11.89,12.34,12.41,12.39,12.38,-3.57444,0.69214,-0.16418,21.77,-3.62632,0.49502,-0.15522,23.36,13.22,180,11155,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.235,1.346,,2.1,18.6,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.024,52.2193,85.4701,90.4977,,,,,1799,,,99999.99999,
+"ALIC1",01-JUN-24,323.9,47.78,0.916,0.541,1011.082,0.038,76.2,1.033,10.42,0.285,7.34,,-4.493,0.27,1.759,0.13,-75.77,0.74,-2.514,0.463,287.7,1.127,360.9,1.949,9.8,0.328,0,0,,,461,,,,,,,,,,43.33,12.68,27.51,0.51,34.75,12.58,20.07,0.26,11.9,12.35,12.4,12.37,12.36,-6.14165,0.82193,-0.2821,21.77,-5.52521,0.57313,-0.2365,23.36,13.22,180,11156,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.239,1.262,,2.1,18.4,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,60.025,52.2193,85.4701,90.4977,,,,,1799,,,99999.99999,
\ No newline at end of file
diff --git a/src/tests/data/build_database.py b/src/tests/data/build_database.py
new file mode 100644
index 0000000..eef80be
--- /dev/null
+++ b/src/tests/data/build_database.py
@@ -0,0 +1,22 @@
+"""This script generates a .db SQLite file from the accompanying .csv data file.
+It contains just a few rows of data and is used for testing purposes only."""
+
+from iotswarm.utils import build_database_from_csv
+from pathlib import Path
+from iotswarm.queries import CosmosTable
+
+
+def main():
+ data_dir = Path(__file__).parent
+ data_file = Path(data_dir, "ALIC1_4_ROWS.csv")
+ database_file = Path(data_dir, "database.db")
+
+ data_table = CosmosTable.LEVEL_1_SOILMET_30MIN
+
+ build_database_from_csv(
+ data_file, database_file, data_table.value, date_time_format=r"%d-%b-%y"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/tests/test_cli.py b/src/tests/test_cli.py
new file mode 100644
index 0000000..c771148
--- /dev/null
+++ b/src/tests/test_cli.py
@@ -0,0 +1,63 @@
+from click.testing import CliRunner
+from iotswarm.scripts import cli
+from parameterized import parameterized
+import re
+
+RUNNER = CliRunner()
+
+
+def test_main_ctx():
+ result = RUNNER.invoke(cli.main, ["test"])
+ assert not result.exception
+ assert result.output == "{'logger': }\n"
+
+
+def test_main_log_config():
+ with RUNNER.isolated_filesystem():
+ with open("logger.ini", "w") as f:
+ f.write(
+ """[loggers]
+keys=root
+
+[handlers]
+keys=consoleHandler
+
+[formatters]
+keys=sampleFormatter
+
+[logger_root]
+level=INFO
+handlers=consoleHandler
+
+[handler_consoleHandler]
+class=StreamHandler
+level=INFO
+formatter=sampleFormatter
+args=(sys.stdout,)
+
+[formatter_sampleFormatter]
+format=%(asctime)s - %(name)s - %(levelname)s - %(message)s"""
+ )
+ result = RUNNER.invoke(cli.main, ["--log-config", "logger.ini", "test"])
+ assert not result.exception
+ assert result.output == "{'logger': }\n"
+
+
+@parameterized.expand(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
+def test_log_level_set(log_level):
+
+ result = RUNNER.invoke(cli.main, ["--log-level", log_level, "test"])
+ assert not result.exception
+ assert (
+ result.output
+ == f"Set log level to {log_level}.\n{{'logger': }}\n"
+ )
+
+
+def test_get_version():
+ """Tests that the verison can be retrieved."""
+
+ result = RUNNER.invoke(cli.main, ["get-version"])
+ print(result.output)
+ assert not result.exception
+ assert re.match(r"^\d+\.\d+\.\d+$", result.output)
diff --git a/src/tests/test_db.py b/src/tests/test_db.py
index 0cbb27f..34e8675 100644
--- a/src/tests/test_db.py
+++ b/src/tests/test_db.py
@@ -1,21 +1,29 @@
import unittest
import pytest
import config
-import pathlib
+from pathlib import Path
from iotswarm import db
-from iotswarm.queries import CosmosQuery, CosmosSiteQuery
+from iotswarm.devices import BaseDevice
+from iotswarm.messaging.core import MockMessageConnection
+from iotswarm.queries import CosmosTable
+from iotswarm.swarm import Swarm
from parameterized import parameterized
import logging
from unittest.mock import patch
+import pandas as pd
+from glob import glob
+from math import isnan
+import sqlite3
-CONFIG_PATH = pathlib.Path(
- pathlib.Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg"
+CONFIG_PATH = Path(
+ Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg"
)
config_exists = pytest.mark.skipif(
not CONFIG_PATH.exists(),
reason="Config file `config.cfg` not found in root directory.",
)
+COSMOS_TABLES = list(CosmosTable)
class TestBaseDatabase(unittest.TestCase):
@patch.multiple(db.BaseDatabase, __abstractmethods__=set())
@@ -27,6 +35,7 @@ def test_instantiation(self):
self.assertIsNone(inst.query_latest_from_site())
+
class TestMockDB(unittest.TestCase):
def test_instantiation(self):
@@ -58,22 +67,23 @@ def test_logger_set(self, logger, expected):
self.assertEqual(inst._instance_logger.parent, expected)
- @parameterized.expand(
- [
- [None, "MockDB()"],
- [
- logging.getLogger("notroot"),
- "MockDB(inherit_logger=)",
- ],
- ]
- )
- def test__repr__(self, logger, expected):
+
+ def test__repr__no_logger(self):
- inst = db.MockDB(inherit_logger=logger)
+ inst = db.MockDB()
+
+ self.assertEqual(inst.__repr__(), "MockDB()")
- self.assertEqual(inst.__repr__(), expected)
+ def test__repr__logger_given(self):
+ logger = logging.getLogger("testdblogger")
+ logger.setLevel(logging.CRITICAL)
+ expected = "MockDB(inherit_logger=)"
+ mock = db.MockDB(inherit_logger=logger)
+ self.assertEqual(mock.__repr__(), expected)
+
+
class TestOracleDB(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
@@ -86,6 +96,7 @@ async def asyncSetUp(self):
creds["user"],
password=creds["password"],
)
+ self.table = CosmosTable.LEVEL_1_SOILMET_30MIN
async def asyncTearDown(self) -> None:
await self.oracle.connection.close()
@@ -103,25 +114,19 @@ async def test_instantiation(self):
async def test_latest_data_query(self):
site_id = "MORLY"
- query = CosmosQuery.LEVEL_1_SOILMET_30MIN
- row = await self.oracle.query_latest_from_site(site_id, query)
+ row = await self.oracle.query_latest_from_site(site_id, self.table)
self.assertEqual(row["SITE_ID"], site_id)
- @parameterized.expand([
- CosmosSiteQuery.LEVEL_1_NMDB_1HOUR,
- CosmosSiteQuery.LEVEL_1_SOILMET_30MIN,
- CosmosSiteQuery.LEVEL_1_PRECIP_1MIN,
- CosmosSiteQuery.LEVEL_1_PRECIP_RAINE_1MIN,
- ])
+ @parameterized.expand(COSMOS_TABLES)
@pytest.mark.oracle
@pytest.mark.asyncio
@pytest.mark.slow
@config_exists
- async def test_site_id_query(self,query):
+ async def test_site_id_query(self, table):
- sites = await self.oracle.query_site_ids(query)
+ sites = await self.oracle.query_site_ids(table)
self.assertIsInstance(sites, list)
@@ -137,9 +142,7 @@ async def test_site_id_query(self,query):
@config_exists
async def test_site_id_query_max_sites(self, max_sites):
- query = CosmosSiteQuery.LEVEL_1_SOILMET_30MIN
-
- sites = await self.oracle.query_site_ids(query, max_sites=max_sites)
+ sites = await self.oracle.query_site_ids(self.table, max_sites=max_sites)
self.assertEqual(len(sites), max_sites)
@@ -147,33 +150,33 @@ async def test_site_id_query_max_sites(self, max_sites):
@pytest.mark.asyncio
@pytest.mark.oracle
@config_exists
- async def test_bad_latest_data_query_type(self):
+ async def test_bad_latest_data_table_type(self):
site_id = "MORLY"
- query = "sql injection goes brr"
+ table = "sql injection goes brr"
with self.assertRaises(TypeError):
- await self.oracle.query_latest_from_site(site_id, query)
+ await self.oracle.query_latest_from_site(site_id, table)
@pytest.mark.asyncio
@pytest.mark.oracle
@config_exists
- async def test_bad_site_query_type(self):
+ async def test_bad_site_table_type(self):
- query = "sql injection goes brr"
+ table = "sql injection goes brr"
with self.assertRaises(TypeError):
- await self.oracle.query_site_ids(query)
+ await self.oracle.query_site_ids(table)
@parameterized.expand([-1, -100, "STRING"])
@pytest.mark.asyncio
@pytest.mark.oracle
@config_exists
- async def test_bad_site_query_max_sites_type(self, max_sites):
+ async def test_bad_site_table_max_sites_type(self, max_sites):
"""Tests bad values for max_sites."""
with self.assertRaises((TypeError,ValueError)):
- await self.oracle.query_site_ids(CosmosSiteQuery.LEVEL_1_SOILMET_30MIN, max_sites=max_sites)
+ await self.oracle.query_site_ids(self.table, max_sites=max_sites)
@pytest.mark.asyncio
@pytest.mark.oracle
@@ -181,7 +184,7 @@ async def test_bad_site_query_max_sites_type(self, max_sites):
async def test__repr__(self):
"""Tests string representation."""
- oracle1 = self.oracle = await db.Oracle.create(
+ oracle1 = await db.Oracle.create(
self.creds["dsn"],
self.creds["user"],
password=self.creds["password"],
@@ -192,7 +195,7 @@ async def test__repr__(self):
oracle1.__repr__(), expected1
)
- oracle2 = self.oracle = await db.Oracle.create(
+ oracle2 = await db.Oracle.create(
self.creds["dsn"],
self.creds["user"],
password=self.creds["password"],
@@ -207,6 +210,293 @@ async def test__repr__(self):
oracle2.__repr__(), expected2
)
+CSV_PATH = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data")
+CSV_DATA_FILES = [Path(x) for x in glob(str(Path(CSV_PATH, "*.csv")))]
+sqlite_db_exist = pytest.mark.skipif(not Path(CSV_PATH, "cosmos.db").exists(), reason="Local cosmos.db does not exist.")
+
+data_files_exist = pytest.mark.skipif(
+ not CSV_PATH.exists() or len(CSV_DATA_FILES) == 0,
+ reason="No data files are present"
+)
+
+class TestLoopingCsvDB(unittest.TestCase):
+ """Tests the LoopingCsvDB class."""
+
+ def setUp(self):
+ self.data_path = {v.name.removesuffix("_DATA_TABLE.csv"):v for v in CSV_DATA_FILES}
+
+ self.soilmet_table = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"])
+ self.maxDiff = None
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_instantiation(self):
+ """Tests that the database can be instantiated."""
+
+ database = self.soilmet_table
+
+ self.assertIsInstance(database, db.LoopingCsvDB)
+ self.assertIsInstance(database, db.BaseDatabase)
+
+
+ self.assertIsInstance(database.cache, dict)
+ self.assertIsInstance(database.connection, pd.DataFrame)
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_site_data_return_value(self):
+ database = self.soilmet_table
+
+ site = "MORLY"
+
+ data = database.query_latest_from_site(site)
+
+ expected_cache = {site: 1}
+ self.assertDictEqual(database.cache, expected_cache)
+
+ self.assertIsInstance(data, dict)
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_multiple_sites_added_to_cache(self):
+ sites = ["ALIC1", "MORLY", "HOLLN","EUSTN"]
+
+ database = self.soilmet_table
+
+ data = [database.query_latest_from_site(x) for x in sites]
+
+ for i, site in enumerate(sites):
+ self.assertEqual(site, data[i]["SITE_ID"])
+ self.assertIn(site, database.cache)
+ self.assertEqual(database.cache[site], 1)
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_cache_incremented_on_each_request(self):
+ database = self.soilmet_table
+
+ site = "MORLY"
+
+ expected = 1
+
+ last_data = None
+ for _ in range(10):
+ data = database.query_latest_from_site(site)
+ self.assertNotEqual(last_data, data)
+ self.assertEqual(expected, database.cache[site])
+
+ last_data = data
+ expected += 1
+
+ self.assertEqual(expected, 11)
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_cache_counter_restarts_at_end(self):
+
+ short_table_path = Path(Path(__file__).parent, "data", "ALIC1_4_ROWS.csv")
+ database = db.LoopingCsvDB(short_table_path)
+
+ site = "ALIC1"
+
+ expected = [1,2,3,4,1]
+ data = []
+ for e in expected:
+ data.append(database.query_latest_from_site(site))
+
+ self.assertEqual(database.cache[site], e)
+
+ for key in data[0].keys():
+ try:
+ self.assertEqual(data[0][key], data[-1][key])
+ except AssertionError as err:
+ if not isnan(data[0][key]) and isnan(data[-1][key]):
+ raise(err)
+
+ self.assertEqual(len(expected), len(data))
+
+ @data_files_exist
+ @pytest.mark.slow
+ def test_site_ids_can_be_retrieved(self):
+ database = self.soilmet_table
+
+ site_ids_full = database.query_site_ids()
+ site_ids_exp_full = database.query_site_ids(max_sites=0)
+
+
+ self.assertIsInstance(site_ids_full, list)
+
+ self.assertGreater(len(site_ids_full), 0)
+ for site in site_ids_full:
+ self.assertIsInstance(site, str)
+
+ self.assertEqual(len(site_ids_full), len(site_ids_exp_full))
+
+ site_ids_limit = database.query_site_ids(max_sites=5)
+
+ self.assertEqual(len(site_ids_limit), 5)
+ self.assertGreater(len(site_ids_full), len(site_ids_limit))
+
+ with self.assertRaises(ValueError):
+
+ database.query_site_ids(max_sites=-1)
+
+class TestLoopingCsvDBEndToEnd(unittest.IsolatedAsyncioTestCase):
+ """Tests the LoopingCsvDB class."""
+
+ def setUp(self):
+ self.data_path = {v.name.removesuffix("_DATA_TABLE.csv"):v for v in CSV_DATA_FILES}
+ self.maxDiff = None
+
+ @data_files_exist
+ @pytest.mark.slow
+ async def test_flow_with_device_attached(self):
+ """Tests that data is looped through with a device making requests."""
+
+ database = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"])
+ device = BaseDevice("ALIC1", database, MockMessageConnection(), sleep_time=0, max_cycles=5)
+
+ await device.run()
+
+ self.assertDictEqual(database.cache, {"ALIC1": 5})
+
+ @data_files_exist
+ @pytest.mark.slow
+ async def test_flow_with_swarm_attached(self):
+ """Tests that the database is looped through correctly with multiple sites in a swarm."""
+
+ database = db.LoopingCsvDB(self.data_path["LEVEL_1_SOILMET_30MIN"])
+ sites = ["MORLY", "ALIC1", "EUSTN"]
+ cycles = [1, 4, 6]
+ devices = [
+ BaseDevice(s, database, MockMessageConnection(), sleep_time=0, max_cycles=c)
+ for (s,c) in zip(sites, cycles)
+ ]
+
+ swarm = Swarm(devices)
+
+ await swarm.run()
+
+ self.assertDictEqual(database.cache, {"MORLY": 1, "ALIC1": 4, "EUSTN": 6})
+
+class TestSqliteDB(unittest.TestCase):
+
+ @sqlite_db_exist
+ def setUp(self):
+ self.db_path = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data", "cosmos.db")
+ self.table = CosmosTable.LEVEL_1_SOILMET_30MIN
+
+ if self.db_path.exists():
+ self.database = db.LoopingSQLite3(self.db_path)
+ self.maxDiff = None
+
+ @sqlite_db_exist
+ def test_instantiation(self):
+ self.assertIsInstance(self.database, db.LoopingSQLite3)
+ self.assertIsInstance(self.database.connection, sqlite3.Connection)
+
+ @sqlite_db_exist
+ def test_latest_data(self):
+
+ site_id = "MORLY"
+
+ data = self.database.query_latest_from_site(site_id, self.table)
+
+ self.assertIsInstance(data, dict)
+
+ @sqlite_db_exist
+ def test_site_id_query(self):
+
+ sites = self.database.query_site_ids(self.table)
+
+ self.assertGreater(len(sites), 0)
+
+ self.assertIsInstance(sites, list)
+
+ for site in sites:
+ self.assertIsInstance(site, str)
+
+ @sqlite_db_exist
+ def test_multiple_sites_added_to_cache(self):
+ sites = ["ALIC1", "MORLY", "HOLLN","EUSTN"]
+
+ data = [self.database.query_latest_from_site(x, self.table) for x in sites]
+
+ for i, site in enumerate(sites):
+ self.assertEqual(site, data[i]["SITE_ID"])
+ self.assertIn(site, self.database.cache)
+ self.assertEqual(self.database.cache[site], 0)
+
+ @sqlite_db_exist
+ def test_cache_incremented_on_each_request(self):
+ site = "MORLY"
+
+ last_data = {}
+ for i in range(3):
+ if i == 0:
+ self.assertEqual(self.database.cache, {})
+ else:
+ self.assertEqual(i-1, self.database.cache[site])
+ data = self.database.query_latest_from_site(site, self.table)
+ self.assertNotEqual(last_data, data)
+
+ last_data = data
+
+ @sqlite_db_exist
+ def test_cache_counter_restarts_at_end(self):
+ database = db.LoopingSQLite3(Path(Path(__file__).parent, "data", "database.db"))
+
+ site = "ALIC1"
+
+ expected = [0,1,2,3,0]
+ data = []
+ for e in expected:
+ data.append(database.query_latest_from_site(site, self.table))
+
+ self.assertEqual(database.cache[site], e)
+
+ self.assertEqual(data[0], data[-1])
+
+ self.assertEqual(len(expected), len(data))
+
+class TestLoopingSQLite3DBEndToEnd(unittest.IsolatedAsyncioTestCase):
+ """Tests the LoopingCsvDB class."""
+
+ @sqlite_db_exist
+ def setUp(self):
+ self.db_path = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data", "cosmos.db")
+ if self.db_path.exists():
+ self.database = db.LoopingSQLite3(self.db_path)
+ self.maxDiff = None
+ self.table = CosmosTable.LEVEL_1_PRECIP_1MIN
+
+ @sqlite_db_exist
+ async def test_flow_with_device_attached(self):
+ """Tests that data is looped through with a device making requests."""
+
+ device = BaseDevice("ALIC1", self.database, MockMessageConnection(), table=self.table, sleep_time=0, max_cycles=5)
+
+ await device.run()
+
+ self.assertDictEqual(self.database.cache, {"ALIC1": 4})
+
+ @sqlite_db_exist
+ async def test_flow_with_swarm_attached(self):
+ """Tests that the database is looped through correctly with multiple sites in a swarm."""
+
+ sites = ["MORLY", "ALIC1", "EUSTN"]
+ cycles = [1, 2, 3]
+ devices = [
+ BaseDevice(s, self.database, MockMessageConnection(), sleep_time=0, max_cycles=c,table=self.table)
+ for (s,c) in zip(sites, cycles)
+ ]
+
+ swarm = Swarm(devices)
+
+ await swarm.run()
+
+ self.assertDictEqual(self.database.cache, {"MORLY": 0, "ALIC1": 1, "EUSTN": 2})
+
if __name__ == "__main__":
unittest.main()
diff --git a/src/tests/test_devices.py b/src/tests/test_devices.py
index 1d68a52..efb61d3 100644
--- a/src/tests/test_devices.py
+++ b/src/tests/test_devices.py
@@ -5,24 +5,28 @@
import json
from iotswarm.utils import json_serial
from iotswarm.devices import BaseDevice, CR1000XDevice, CR1000XField
-from iotswarm.db import Oracle, BaseDatabase, MockDB
-from iotswarm.queries import CosmosQuery, CosmosSiteQuery
+from iotswarm.db import Oracle, BaseDatabase, MockDB, LoopingSQLite3
+from iotswarm.queries import CosmosQuery, CosmosTable
from iotswarm.messaging.core import MockMessageConnection, MessagingBaseClass
from iotswarm.messaging.aws import IotCoreMQTTConnection
from parameterized import parameterized
from unittest.mock import patch
-import pathlib
+from pathlib import Path
import config
from datetime import datetime, timedelta
-CONFIG_PATH = pathlib.Path(
- pathlib.Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg"
+CONFIG_PATH = Path(
+ Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg"
)
config_exists = pytest.mark.skipif(
not CONFIG_PATH.exists(),
reason="Config file `config.cfg` not found in root directory.",
)
+DATA_DIR = Path(Path(__file__).parents[1], "iotswarm", "__assets__", "data")
+sqlite_db_exist = pytest.mark.skipif(not Path(DATA_DIR, "cosmos.db").exists(), reason="Local cosmos.db does not exist.")
+
+
class TestBaseClass(unittest.IsolatedAsyncioTestCase):
def setUp(self):
@@ -169,26 +173,12 @@ def test__repr__(self, data_source, kwargs, expected):
self.assertEqual(repr(instance), expected)
- def test_query_not_set_for_non_oracle_db(self):
+ def test_table_not_set_for_non_oracle_db(self):
- inst = BaseDevice("test", MockDB(), MockMessageConnection(), query=CosmosQuery.LEVEL_1_SOILMET_30MIN)
+ inst = BaseDevice("test", MockDB(), MockMessageConnection(), table=CosmosTable.LEVEL_1_NMDB_1HOUR)
with self.assertRaises(AttributeError):
- inst.query
-
- def test_prefix_suffix_not_set_for_non_mqtt(self):
- "Tests that mqtt prefix and suffix not set for non MQTT messaging."
-
- inst = BaseDevice("site-1", self.data_source, MockMessageConnection(), mqtt_prefix="prefix", mqtt_suffix="suffix")
-
- with self.assertRaises(AttributeError):
- inst.mqtt_topic
-
- with self.assertRaises(AttributeError):
- inst.mqtt_prefix
-
- with self.assertRaises(AttributeError):
- inst.mqtt_suffix
+ inst.table
@parameterized.expand(
[
@@ -204,19 +194,6 @@ def test_logger_set(self, logger, expected):
self.assertEqual(inst._instance_logger.parent, expected)
- @parameterized.expand([
- [None, None, None],
- ["topic", None, None],
- ["topic", "prefix", None],
- ["topic", "prefix", "suffix"],
- ])
- def test__repr__mqtt_opts_no_mqtt_connection(self, topic, prefix, suffix):
- """Tests that the __repr__ method returns correctly with MQTT options set."""
- expected = 'BaseDevice("site-id", MockDB(), MockMessageConnection())'
- inst = BaseDevice("site-id", MockDB(), MockMessageConnection(), mqtt_topic=topic, mqtt_prefix=prefix, mqtt_suffix=suffix)
-
- self.assertEqual(inst.__repr__(), expected)
-
class TestBaseDeviceMQTTOptions(unittest.TestCase):
def setUp(self) -> None:
@@ -244,7 +221,7 @@ def test_mqtt_prefix_set(self, topic):
inst = BaseDevice("site", self.db, self.conn, mqtt_prefix=topic)
self.assertEqual(inst.mqtt_prefix, topic)
- self.assertEqual(inst.mqtt_topic, f"{topic}/base-device/site")
+ self.assertEqual(inst.mqtt_topic, f"{topic}/site")
@parameterized.expand(["this/topic", "1/1/1", "TOPICO!"])
@config_exists
@@ -254,7 +231,7 @@ def test_mqtt_suffix_set(self, topic):
inst = BaseDevice("site", self.db, self.conn, mqtt_suffix=topic)
self.assertEqual(inst.mqtt_suffix, topic)
- self.assertEqual(inst.mqtt_topic, f"base-device/site/{topic}")
+ self.assertEqual(inst.mqtt_topic, f"site/{topic}")
@parameterized.expand([["this/prefix", "this/suffix"], ["1/1/1", "2/2/2"], ["TOPICO!", "FOUR"]])
@config_exists
@@ -265,7 +242,7 @@ def test_mqtt_prefix_and_suffix_set(self, prefix, suffix):
self.assertEqual(inst.mqtt_suffix, suffix)
self.assertEqual(inst.mqtt_prefix, prefix)
- self.assertEqual(inst.mqtt_topic, f"{prefix}/base-device/site/{suffix}")
+ self.assertEqual(inst.mqtt_topic, f"{prefix}/site/{suffix}")
@config_exists
def test_default_mqtt_topic_set(self):
@@ -273,7 +250,7 @@ def test_default_mqtt_topic_set(self):
inst = BaseDevice("site-12", self.db, self.conn)
- self.assertEqual(inst.mqtt_topic, "base-device/site-12")
+ self.assertEqual(inst.mqtt_topic, "site-12")
@parameterized.expand([
[None, None, None, ""],
@@ -289,7 +266,9 @@ def test__repr__mqtt_opts_mqtt_connection(self, topic, prefix, suffix,expected_a
self.assertEqual(inst.__repr__(), expected)
+
class TestBaseDeviceOracleUsed(unittest.IsolatedAsyncioTestCase):
+
async def asyncSetUp(self):
cred_path = str(CONFIG_PATH)
creds = config.Config(cred_path)["oracle"]
@@ -300,53 +279,92 @@ async def asyncSetUp(self):
creds["user"],
password=creds["password"],
)
- self.query = CosmosQuery.LEVEL_1_SOILMET_30MIN
+ self.table = CosmosTable.LEVEL_1_SOILMET_30MIN
async def asyncTearDown(self) -> None:
await self.oracle.connection.close()
- @parameterized.expand([-1, -423.78, CosmosSiteQuery.LEVEL_1_NMDB_1HOUR, "Four", MockDB(), {"a": 1}])
+ @parameterized.expand([-1, -423.78, CosmosQuery.ORACLE_LATEST_DATA, "Four", MockDB(), {"a": 1}])
@config_exists
- def test_query_value_check(self, query):
+ @pytest.mark.oracle
+ def test_table_value_check(self, table):
with self.assertRaises(TypeError):
BaseDevice(
- "test_id", self.oracle, MockMessageConnection(), query=query
+ "test_id", self.oracle, MockMessageConnection(), table=table
)
@pytest.mark.oracle
@config_exists
- async def test_error_if_query_not_given(self):
+ async def test_error_if_table_not_given(self):
with self.assertRaises(ValueError):
BaseDevice("site", self.oracle, MockMessageConnection())
- inst = BaseDevice("site", self.oracle, MockMessageConnection(), query=self.query)
+ inst = BaseDevice("site", self.oracle, MockMessageConnection(), table=self.table)
- self.assertEqual(inst.query, self.query)
+ self.assertEqual(inst.table, self.table)
@pytest.mark.oracle
@config_exists
async def test__repr__oracle_data(self):
- inst_oracle = BaseDevice("site", self.oracle, MockMessageConnection(), query=self.query)
- exp_oracle = f'BaseDevice("site", Oracle("{self.creds['dsn']}"), MockMessageConnection(), query=CosmosQuery.{self.query.name})'
+ inst_oracle = BaseDevice("site", self.oracle, MockMessageConnection(), table=self.table)
+ exp_oracle = f'BaseDevice("site", Oracle("{self.creds['dsn']}"), MockMessageConnection(), table=CosmosTable.{self.table.name})'
self.assertEqual(inst_oracle.__repr__(), exp_oracle)
- inst_not_oracle = BaseDevice("site", MockDB(), MockMessageConnection(), query=self.query)
+ inst_not_oracle = BaseDevice("site", MockDB(), MockMessageConnection(), table=self.table)
exp_not_oracle = 'BaseDevice("site", MockDB(), MockMessageConnection())'
self.assertEqual(inst_not_oracle.__repr__(), exp_not_oracle)
with self.assertRaises(AttributeError):
- inst_not_oracle.query
+ inst_not_oracle.table
@pytest.mark.oracle
@config_exists
async def test__get_payload(self):
"""Tests that Cosmos payload retrieved."""
- inst = BaseDevice("MORLY", self.oracle, MockMessageConnection(), query=self.query)
+ inst = BaseDevice("MORLY", self.oracle, MockMessageConnection(), table=self.table)
+ print(inst)
+ payload = await inst._get_payload()
+
+ self.assertIsInstance(payload, dict)
+
+class TestBaseDevicesSQLite3Used(unittest.IsolatedAsyncioTestCase):
+
+ def setUp(self):
+ db_path = Path(Path(__file__).parents[1], "iotswarm","__assets__", "data", "cosmos.db")
+ if db_path.exists():
+ self.db = LoopingSQLite3(db_path)
+ self.table = CosmosTable.LEVEL_1_SOILMET_30MIN
+
+ @parameterized.expand([-1, -423.78, CosmosQuery.ORACLE_LATEST_DATA, "Four", MockDB(), {"a": 1}])
+ @sqlite_db_exist
+ def test_table_value_check(self, table):
+ with self.assertRaises(TypeError):
+ BaseDevice(
+ "test_id", self.db, MockMessageConnection(), table=table
+ )
+
+ @sqlite_db_exist
+ def test_error_if_table_not_given(self):
+
+ with self.assertRaises(ValueError):
+ BaseDevice("site", self.db, MockMessageConnection())
+
+
+ inst = BaseDevice("site", self.db, MockMessageConnection(), table=self.table)
+
+ self.assertEqual(inst.table, self.table)
+
+ @sqlite_db_exist
+ async def test__get_payload(self):
+ """Tests that Cosmos payload retrieved."""
+
+ inst = BaseDevice("MORLY", self.db, MockMessageConnection(), table=self.table)
+
payload = await inst._get_payload()
self.assertIsInstance(payload, dict)
@@ -448,6 +466,7 @@ async def test__get_payload(self):
self.assertIsInstance(payload, list)
self.assertEqual(len(payload),0)
+
class TestCr1000xDevice(unittest.TestCase):
"""Test suite for the CR1000X Device."""
diff --git a/src/tests/test_messaging.py b/src/tests/test_messaging.py
index 4fe9271..49e86fe 100644
--- a/src/tests/test_messaging.py
+++ b/src/tests/test_messaging.py
@@ -11,13 +11,20 @@
from parameterized import parameterized
import logging
-CONFIG_PATH = Path(
- Path(__file__).parents[1], "iotswarm", "__assets__", "config.cfg"
-)
+
+ASSETS_PATH = Path(Path(__file__).parents[1], "iotswarm", "__assets__")
+CONFIG_PATH = Path(ASSETS_PATH, "config.cfg")
+
config_exists = pytest.mark.skipif(
not CONFIG_PATH.exists(),
reason="Config file `config.cfg` not found in root directory.",
)
+certs_exist = pytest.mark.skipif(
+ not Path(ASSETS_PATH, ".certs", "cosmos_soilmet-certificate.pem.crt").exists()
+ or not Path(ASSETS_PATH, ".certs", "cosmos_soilmet-private.pem.key").exists()
+ or not Path(ASSETS_PATH, ".certs", "AmazonRootCA1.pem").exists(),
+ reason="IotCore certificates not present.",
+)
class TestBaseClass(unittest.TestCase):
@@ -43,24 +50,20 @@ def test_instantiation(self):
self.assertIsInstance(mock, MessagingBaseClass)
- def test_logger_used(self):
+ def test_no_logger_used(self):
- mock = MockMessageConnection()
+ with self.assertNoLogs():
+ mock = MockMessageConnection()
+ mock.send_message("")
- with self.assertLogs() as cm:
- mock.send_message()
- self.assertEqual(
- cm.output,
- [
- "INFO:iotswarm.messaging.core.MockMessageConnection:Message was sent."
- ],
- )
-
- with self.assertLogs() as cm:
- mock.send_message(use_logger=logging.getLogger("mine"))
+ def test_logger_used(self):
+ logger = logging.getLogger("testlogger")
+ with self.assertLogs(logger=logger, level=logging.DEBUG) as cm:
+ mock = MockMessageConnection(inherit_logger=logger)
+ mock.send_message("")
self.assertEqual(
cm.output,
- ["INFO:mine:Message was sent."],
+ ["DEBUG:testlogger.MockMessageConnection:Message was sent."],
)
@@ -73,6 +76,7 @@ def setUp(self) -> None:
self.config = config["iot_core"]
@config_exists
+ @certs_exist
def test_instantiation(self):
instance = IotCoreMQTTConnection(**self.config, client_id="test_id")
@@ -82,6 +86,7 @@ def test_instantiation(self):
self.assertIsInstance(instance.connection, awscrt.mqtt.Connection)
@config_exists
+ @certs_exist
def test_non_string_arguments(self):
with self.assertRaises(TypeError):
@@ -130,6 +135,7 @@ def test_non_string_arguments(self):
)
@config_exists
+ @certs_exist
def test_port(self):
# Expect one of defaults if no port given
@@ -151,6 +157,7 @@ def test_port(self):
@parameterized.expand([-4, {"f": 4}, "FOUR"])
@config_exists
+ @certs_exist
def test_bad_port_type(self, port):
with self.assertRaises((TypeError, ValueError)):
@@ -164,6 +171,7 @@ def test_bad_port_type(self, port):
)
@config_exists
+ @certs_exist
def test_clean_session_set(self):
expected = False
@@ -180,6 +188,7 @@ def test_clean_session_set(self):
@parameterized.expand([0, -1, "true", None])
@config_exists
+ @certs_exist
def test_bad_clean_session_type(self, clean_session):
with self.assertRaises(TypeError):
@@ -188,6 +197,7 @@ def test_bad_clean_session_type(self, clean_session):
)
@config_exists
+ @certs_exist
def test_keep_alive_secs_set(self):
# Test defualt is not none
instance = IotCoreMQTTConnection(**self.config, client_id="test_id")
@@ -200,8 +210,9 @@ def test_keep_alive_secs_set(self):
)
self.assertEqual(instance.connection.keep_alive_secs, expected)
- @parameterized.expand(["FOURTY", "True", None])
+ @parameterized.expand(["FOURTY", "True"])
@config_exists
+ @certs_exist
def test_bad_keep_alive_secs_type(self, secs):
with self.assertRaises(TypeError):
IotCoreMQTTConnection(
@@ -209,7 +220,8 @@ def test_bad_keep_alive_secs_type(self, secs):
)
@config_exists
- def test_logger_set(self):
+ @certs_exist
+ def test_no_logger_set(self):
inst = IotCoreMQTTConnection(**self.config, client_id="test_id")
expected = 'No message to send for topic: "mytopic".'
@@ -223,12 +235,21 @@ def test_logger_set(self):
],
)
- with self.assertLogs() as cm:
- inst.send_message(None, "mytopic", use_logger=logging.getLogger("mine"))
+ @config_exists
+ @certs_exist
+ def test_logger_set(self):
+ logger = logging.getLogger("mine")
+ inst = IotCoreMQTTConnection(
+ **self.config, client_id="test_id", inherit_logger=logger
+ )
+
+ expected = 'No message to send for topic: "mytopic".'
+ with self.assertLogs(logger=logger, level=logging.INFO) as cm:
+ inst.send_message(None, "mytopic")
self.assertEqual(
cm.output,
- [f"ERROR:mine:{expected}"],
+ [f"ERROR:mine.IotCoreMQTTConnection.client-test_id:{expected}"],
)