diff --git a/secret-store/secret_store.py b/secret-store/secret_store.py index a5b4d23f..4a582eaa 100644 --- a/secret-store/secret_store.py +++ b/secret-store/secret_store.py @@ -1,7 +1,8 @@ -from dataclasses import dataclass -from sqlcipher3 import dbapi2 as sqlcipher +import contextlib import logging import os +from dataclasses import dataclass +from sqlcipher3 import dbapi2 as sqlcipher from typing import Optional, Union from inspect import cleandoc @@ -25,21 +26,17 @@ class Credentials: password: str -class PotentiallyIncorrectMasterPassword(Exception): +class InvalidPassword(Exception): """Signal potentially incorrect master password.""" class Secrets: def __init__(self, db_file: str, master_password: str) -> None: self.db_file = db_file - self.master_password = master_password + self._master_password = master_password self._con = None - self._cur = None def close(self) -> None: - if self._cur is not None: - self._cur.close() - self._cur = None if self._con is not None: self._con.close() self._con = None @@ -58,8 +55,8 @@ def _initialize(self, db_file_found: bool) -> None: def create_table(table: Table) -> None: _logger.info(f'Creating table "{table.name}".') columns = " ,".join(table.columns) - self._cursor().execute(f"CREATE TABLE {table.name} (key, {columns})") - self.connection().commit() + with self._cursor() as cur: + cur.execute(f"CREATE TABLE {table.name} (key, {columns})") if db_file_found: self._verify_access() @@ -72,54 +69,68 @@ def _use_master_password(self) -> None: If database is unencrypted then this method encrypts it. If database is already encrypted then this method enables to access the data. """ - if self.master_password is not None: - sanitized = self.master_password.replace("'", "\\'") - self._cursor().execute(f"PRAGMA key = '{sanitized}'") + if self._master_password is not None: + sanitized = self._master_password.replace("'", "\\'") + with self._cursor() as cur: + cur.execute(f"PRAGMA key = '{sanitized}'") def _verify_access(self): try: - self._cursor().execute("SELECT * FROM sqlite_master") + with self._cursor() as cur: + cur.execute("SELECT * FROM sqlite_master") except sqlcipher.DatabaseError as ex: print(f'exception {ex}') if str(ex) == "file is not a database": - raise PotentiallyIncorrectMasterPassword( + raise InvalidPassword( cleandoc( f""" Cannot access database file {self.db_file}. This also happens if master password is incorrect. """) - ) + ) from ex else: raise ex + @contextlib.contextmanager def _cursor(self) -> sqlcipher.Cursor: - if self._cur is None: - self._cur = self.connection().cursor() - return self._cur + cur = self.connection().cursor() + try: + yield cur + finally: + self.connection().commit() + cur.close() def _save_data(self, table: Table, key: str, data: list[str]) -> "Secrets": - cur = self._cursor() - res = cur.execute(f"SELECT * FROM {table.name} WHERE key=?", [key]) - if res and res.fetchone(): + def entry_exists(cur) -> None: + res = cur.execute( + f"SELECT * FROM {table.name} WHERE key=?", + [key]) + return res and res.fetchone() + + def update(cur) -> None: columns = ", ".join(f"{c}=?" for c in table.columns) cur.execute( f"UPDATE {table.name} SET {columns} WHERE key=?", data + [key]) - else: + + def insert(cur) -> None: columns = ",".join(table.columns) value_slots = ", ".join("?" for c in table.columns) cur.execute( - f"INSERT INTO {table.name} (key,{columns}) VALUES (?, {value_slots})", + ( + f"INSERT INTO {table.name}" + f" (key,{columns})" + f" VALUES (?, {value_slots})" + ), [key] + data) - self.connection().commit() - return self - # def save_config_item(self, key: str, item: str) -> None: - # self._save_data(CONFIG_ITEMS_TABLE, key, [item]) - # - # def save_credentials(self, key: str, user: str, password: str) -> None: - # self._save_data(SECRETS_TABLE, key, [user, password]) + with self._cursor() as cur: + if entry_exists(cur): + update(cur) + else: + insert(cur) + return self def save(self, key: str, data: Union[str, Credentials]) -> "Secrets": """key represents a system, service, or application""" @@ -131,10 +142,11 @@ def save(self, key: str, data: Union[str, Credentials]) -> "Secrets": def _data(self, table: Table, key: str) -> Optional[list[str]]: columns = ", ".join(table.columns) - res = self._cursor().execute( - f"SELECT {columns} FROM {table.name} WHERE key=?", - [key]) - return res.fetchone() if res else None + with self._cursor() as cur: + res = cur.execute( + f"SELECT {columns} FROM {table.name} WHERE key=?", + [key]) + return res.fetchone() if res else None def credentials(self, key: str) -> Optional[Credentials]: row = self._data(SECRETS_TABLE, key) diff --git a/secret-store/test/test_secret_store.py b/secret-store/test/test_secret_store.py index 48f14ee2..246e83ea 100644 --- a/secret-store/test/test_secret_store.py +++ b/secret-store/test/test_secret_store.py @@ -1,7 +1,7 @@ import os import pytest from pathlib import Path -from secret_store import Credentials, PotentiallyIncorrectMasterPassword, Secrets +from secret_store import Credentials, InvalidPassword, Secrets from sqlcipher3 import dbapi2 as sqlcipher @@ -62,6 +62,6 @@ def test_wrong_password(sample_file): secrets = Secrets(sample_file, "correct password") secrets.save("key", Credentials("usr", "pass")).close() invalid = Secrets(sample_file, "wrong password") - with pytest.raises(PotentiallyIncorrectMasterPassword) as ex: + with pytest.raises(InvalidPassword) as ex: invalid.credentials("key") assert "master password is incorrect" in str(ex.value)