diff --git a/secret-store/secret_store.py b/secret-store/secret_store.py index b24be410..a5b4d23f 100644 --- a/secret-store/secret_store.py +++ b/secret-store/secret_store.py @@ -3,10 +3,10 @@ import logging import os from typing import Optional, Union +from inspect import cleandoc _logger = logging.getLogger(__name__) -# logging.basicConfig(level=logging.INFO, format="%(message)s") @dataclass(frozen=True) @@ -25,6 +25,10 @@ class Credentials: password: str +class PotentiallyIncorrectMasterPassword(Exception): + """Signal potentially incorrect master password.""" + + class Secrets: def __init__(self, db_file: str, master_password: str) -> None: self.db_file = db_file @@ -42,13 +46,27 @@ def close(self) -> None: def connection(self) -> sqlcipher.Connection: if self._con is None: - if not os.path.exists(self.db_file): + db_file_found = os.path.exists(self.db_file) + if not db_file_found: _logger.info(f"Creating file {self.db_file}") self._con = sqlcipher.connect(self.db_file) self._use_master_password() - self.create_tables() + self._initialize(db_file_found) return self._con + def _initialize(self, db_file_found: bool) -> None: + def create_table(table: Table) -> None: + _logger.info(f'Creating table "{table.name}".') + columns = " ,".join(table.columns) + self._cursor().execute(f"CREATE TABLE {table.name} (key, {columns})") + self.connection().commit() + + if db_file_found: + self._verify_access() + return + for table in (SECRETS_TABLE, CONFIG_ITEMS_TABLE): + create_table(table) + def _use_master_password(self) -> None: """ If database is unencrypted then this method encrypts it. @@ -56,32 +74,32 @@ def _use_master_password(self) -> None: """ if self.master_password is not None: sanitized = self.master_password.replace("'", "\\'") - self.cursor().execute(f"PRAGMA key = '{sanitized}'") + self._cursor().execute(f"PRAGMA key = '{sanitized}'") + + def _verify_access(self): + try: + self._cursor().execute("SELECT * FROM sqlite_master") + except sqlcipher.DatabaseError as ex: + print(f'exception {ex}') + if str(ex) == "file is not a database": + raise PotentiallyIncorrectMasterPassword( + cleandoc( + f""" + Cannot access + database file {self.db_file}. + This also happens if master password is incorrect. + """) + ) + else: + raise ex def _cursor(self) -> sqlcipher.Cursor: if self._cur is None: self._cur = self.connection().cursor() return self._cur - def has_table(self, name: str) -> bool: - res = self.cursor().execute("SELECT * FROM sqlite_master where name = ?", [name]) - return True if res and res.fetchone() else False - - # key represents a system, service, or application - def create_table(self, table: Table) -> None: - if self.has_table(table.name): - return - _logger.info(f'Creating table "{table.name}".') - columns = " ,".join(table.columns) - self.cursor().execute(f"CREATE TABLE {table.name} (key, {columns})") - self.connection().commit() - - def create_tables(self) -> None: - for table in (SECRETS_TABLE, CONFIG_ITEMS_TABLE): - self.create_table(table) - def _save_data(self, table: Table, key: str, data: list[str]) -> "Secrets": - cur = self.cursor() + cur = self._cursor() res = cur.execute(f"SELECT * FROM {table.name} WHERE key=?", [key]) if res and res.fetchone(): columns = ", ".join(f"{c}=?" for c in table.columns) @@ -89,9 +107,10 @@ def _save_data(self, table: Table, key: str, data: list[str]) -> "Secrets": f"UPDATE {table.name} SET {columns} WHERE key=?", data + [key]) else: - columns = ", ".join("?" for c in table.columns) + columns = ",".join(table.columns) + value_slots = ", ".join("?" for c in table.columns) cur.execute( - f"INSERT INTO {table.name} VALUES (?, {columns})", + f"INSERT INTO {table.name} (key,{columns}) VALUES (?, {value_slots})", [key] + data) self.connection().commit() return self @@ -103,6 +122,7 @@ def _save_data(self, table: Table, key: str, data: list[str]) -> "Secrets": # self._save_data(SECRETS_TABLE, key, [user, password]) def save(self, key: str, data: Union[str, Credentials]) -> "Secrets": + """key represents a system, service, or application""" if isinstance(data, str): return self._save_data(CONFIG_ITEMS_TABLE, key, [data]) if isinstance(data, Credentials): @@ -111,15 +131,15 @@ def save(self, key: str, data: Union[str, Credentials]) -> "Secrets": def _data(self, table: Table, key: str) -> Optional[list[str]]: columns = ", ".join(table.columns) - res = self.cursor().execute( + res = self._cursor().execute( f"SELECT {columns} FROM {table.name} WHERE key=?", [key]) return res.fetchone() if res else None - def get_credentials(self, key: str) -> Optional[Credentials]: - row = self._get_data(SECRETS_TABLE, key) + def credentials(self, key: str) -> Optional[Credentials]: + row = self._data(SECRETS_TABLE, key) return Credentials(row[0], row[1]) if row else None - def get_config_item(self, key: str) -> Optional[str]: - row = self._get_data(CONFIG_ITEMS_TABLE, key) + def config(self, key: str) -> Optional[str]: + row = self._data(CONFIG_ITEMS_TABLE, key) return row[0] if row else None diff --git a/secret-store/test/test_secret_store.py b/secret-store/test/test_secret_store.py index 6e802b8b..48f14ee2 100644 --- a/secret-store/test/test_secret_store.py +++ b/secret-store/test/test_secret_store.py @@ -1,7 +1,7 @@ import os import pytest from pathlib import Path -from secret_store import Secrets, Credentials +from secret_store import Credentials, PotentiallyIncorrectMasterPassword, Secrets from sqlcipher3 import dbapi2 as sqlcipher @@ -20,25 +20,25 @@ def test_no_database_file(secrets): def test_database_file_from_credentials(secrets): - assert secrets.get_credentials("a") is None + assert secrets.credentials("a") is None assert os.path.exists(secrets.db_file) def test_database_file_from_config_item(secrets): - assert secrets.get_config_item("a") is None + assert secrets.config("a") is None assert os.path.exists(secrets.db_file) def test_credentials(secrets): credentials = Credentials("user", "password") secrets.save("key", credentials).close() - assert secrets.get_credentials("key") == credentials + assert secrets.credentials("key") == credentials def test_config_item(secrets): config_item = "some configuration" secrets.save("key", config_item).close() - assert secrets.get_config_item("key") == config_item + assert secrets.config("key") == config_item def test_update_credentials(secrets): @@ -47,7 +47,7 @@ def test_update_credentials(secrets): other = Credentials("other", "changed") secrets.save("key", other) secrets.close() - assert secrets.get_credentials("key") == other + assert secrets.credentials("key") == other def test_update_config_item(secrets): @@ -55,13 +55,13 @@ def test_update_config_item(secrets): secrets.save("key", initial).close() other = "other value" secrets.save("key", other).close() - assert secrets.get_config_item("key") == other + assert secrets.config("key") == other def test_wrong_password(sample_file): secrets = Secrets(sample_file, "correct password") secrets.save("key", Credentials("usr", "pass")).close() invalid = Secrets(sample_file, "wrong password") - with pytest.raises(sqlcipher.DatabaseError) as ex: - invalid.get_credentials("key") - assert "file is not a database" == str(ex.value) + with pytest.raises(PotentiallyIncorrectMasterPassword) as ex: + invalid.credentials("key") + assert "master password is incorrect" in str(ex.value)