Skip to content

Commit

Permalink
Fixed more review findings.
Browse files Browse the repository at this point in the history
  • Loading branch information
ckunki committed Oct 6, 2023
1 parent 1f412c8 commit 77caf8b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 39 deletions.
78 changes: 49 additions & 29 deletions secret-store/secret_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import logging
import os
from typing import Optional, Union
from inspect import cleandoc


_logger = logging.getLogger(__name__)
# logging.basicConfig(level=logging.INFO, format="%(message)s")


@dataclass(frozen=True)
Expand All @@ -25,6 +25,10 @@ class Credentials:
password: str


class PotentiallyIncorrectMasterPassword(Exception):
"""Signal potentially incorrect master password."""


class Secrets:
def __init__(self, db_file: str, master_password: str) -> None:
self.db_file = db_file
Expand All @@ -42,56 +46,71 @@ def close(self) -> None:

def connection(self) -> sqlcipher.Connection:
if self._con is None:
if not os.path.exists(self.db_file):
db_file_found = os.path.exists(self.db_file)
if not db_file_found:
_logger.info(f"Creating file {self.db_file}")
self._con = sqlcipher.connect(self.db_file)
self._use_master_password()
self.create_tables()
self._initialize(db_file_found)
return self._con

def _initialize(self, db_file_found: bool) -> None:
def create_table(table: Table) -> None:
_logger.info(f'Creating table "{table.name}".')
columns = " ,".join(table.columns)
self._cursor().execute(f"CREATE TABLE {table.name} (key, {columns})")
self.connection().commit()

if db_file_found:
self._verify_access()
return
for table in (SECRETS_TABLE, CONFIG_ITEMS_TABLE):
create_table(table)

def _use_master_password(self) -> None:
"""
If database is unencrypted then this method encrypts it.
If database is already encrypted then this method enables to access the data.
"""
if self.master_password is not None:
sanitized = self.master_password.replace("'", "\\'")
self.cursor().execute(f"PRAGMA key = '{sanitized}'")
self._cursor().execute(f"PRAGMA key = '{sanitized}'")

def _verify_access(self):
try:
self._cursor().execute("SELECT * FROM sqlite_master")
except sqlcipher.DatabaseError as ex:
print(f'exception {ex}')
if str(ex) == "file is not a database":
raise PotentiallyIncorrectMasterPassword(
cleandoc(
f"""
Cannot access
database file {self.db_file}.
This also happens if master password is incorrect.
""")
)
else:
raise ex

def _cursor(self) -> sqlcipher.Cursor:
if self._cur is None:
self._cur = self.connection().cursor()
return self._cur

def has_table(self, name: str) -> bool:
res = self.cursor().execute("SELECT * FROM sqlite_master where name = ?", [name])
return True if res and res.fetchone() else False

# key represents a system, service, or application
def create_table(self, table: Table) -> None:
if self.has_table(table.name):
return
_logger.info(f'Creating table "{table.name}".')
columns = " ,".join(table.columns)
self.cursor().execute(f"CREATE TABLE {table.name} (key, {columns})")
self.connection().commit()

def create_tables(self) -> None:
for table in (SECRETS_TABLE, CONFIG_ITEMS_TABLE):
self.create_table(table)

def _save_data(self, table: Table, key: str, data: list[str]) -> "Secrets":
cur = self.cursor()
cur = self._cursor()
res = cur.execute(f"SELECT * FROM {table.name} WHERE key=?", [key])
if res and res.fetchone():
columns = ", ".join(f"{c}=?" for c in table.columns)
cur.execute(
f"UPDATE {table.name} SET {columns} WHERE key=?",
data + [key])
else:
columns = ", ".join("?" for c in table.columns)
columns = ",".join(table.columns)
value_slots = ", ".join("?" for c in table.columns)
cur.execute(
f"INSERT INTO {table.name} VALUES (?, {columns})",
f"INSERT INTO {table.name} (key,{columns}) VALUES (?, {value_slots})",
[key] + data)
self.connection().commit()
return self
Expand All @@ -103,6 +122,7 @@ def _save_data(self, table: Table, key: str, data: list[str]) -> "Secrets":
# self._save_data(SECRETS_TABLE, key, [user, password])

def save(self, key: str, data: Union[str, Credentials]) -> "Secrets":
"""key represents a system, service, or application"""
if isinstance(data, str):
return self._save_data(CONFIG_ITEMS_TABLE, key, [data])
if isinstance(data, Credentials):
Expand All @@ -111,15 +131,15 @@ def save(self, key: str, data: Union[str, Credentials]) -> "Secrets":

def _data(self, table: Table, key: str) -> Optional[list[str]]:
columns = ", ".join(table.columns)
res = self.cursor().execute(
res = self._cursor().execute(
f"SELECT {columns} FROM {table.name} WHERE key=?",
[key])
return res.fetchone() if res else None

def get_credentials(self, key: str) -> Optional[Credentials]:
row = self._get_data(SECRETS_TABLE, key)
def credentials(self, key: str) -> Optional[Credentials]:
row = self._data(SECRETS_TABLE, key)
return Credentials(row[0], row[1]) if row else None

def get_config_item(self, key: str) -> Optional[str]:
row = self._get_data(CONFIG_ITEMS_TABLE, key)
def config(self, key: str) -> Optional[str]:
row = self._data(CONFIG_ITEMS_TABLE, key)
return row[0] if row else None
20 changes: 10 additions & 10 deletions secret-store/test/test_secret_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pytest
from pathlib import Path
from secret_store import Secrets, Credentials
from secret_store import Credentials, PotentiallyIncorrectMasterPassword, Secrets
from sqlcipher3 import dbapi2 as sqlcipher


Expand All @@ -20,25 +20,25 @@ def test_no_database_file(secrets):


def test_database_file_from_credentials(secrets):
assert secrets.get_credentials("a") is None
assert secrets.credentials("a") is None
assert os.path.exists(secrets.db_file)


def test_database_file_from_config_item(secrets):
assert secrets.get_config_item("a") is None
assert secrets.config("a") is None
assert os.path.exists(secrets.db_file)


def test_credentials(secrets):
credentials = Credentials("user", "password")
secrets.save("key", credentials).close()
assert secrets.get_credentials("key") == credentials
assert secrets.credentials("key") == credentials


def test_config_item(secrets):
config_item = "some configuration"
secrets.save("key", config_item).close()
assert secrets.get_config_item("key") == config_item
assert secrets.config("key") == config_item


def test_update_credentials(secrets):
Expand All @@ -47,21 +47,21 @@ def test_update_credentials(secrets):
other = Credentials("other", "changed")
secrets.save("key", other)
secrets.close()
assert secrets.get_credentials("key") == other
assert secrets.credentials("key") == other


def test_update_config_item(secrets):
initial = "initial value"
secrets.save("key", initial).close()
other = "other value"
secrets.save("key", other).close()
assert secrets.get_config_item("key") == other
assert secrets.config("key") == other


def test_wrong_password(sample_file):
secrets = Secrets(sample_file, "correct password")
secrets.save("key", Credentials("usr", "pass")).close()
invalid = Secrets(sample_file, "wrong password")
with pytest.raises(sqlcipher.DatabaseError) as ex:
invalid.get_credentials("key")
assert "file is not a database" == str(ex.value)
with pytest.raises(PotentiallyIncorrectMasterPassword) as ex:
invalid.credentials("key")
assert "master password is incorrect" in str(ex.value)

0 comments on commit 77caf8b

Please sign in to comment.