Skip to content

Commit

Permalink
Fixed additional review findings from nicoretti
Browse files Browse the repository at this point in the history
  • Loading branch information
ckunki committed Oct 6, 2023
1 parent 77caf8b commit 5c73ee9
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 37 deletions.
82 changes: 47 additions & 35 deletions secret-store/secret_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from sqlcipher3 import dbapi2 as sqlcipher
import contextlib
import logging
import os
from dataclasses import dataclass
from sqlcipher3 import dbapi2 as sqlcipher
from typing import Optional, Union
from inspect import cleandoc

Expand All @@ -25,21 +26,17 @@ class Credentials:
password: str


class PotentiallyIncorrectMasterPassword(Exception):
class InvalidPassword(Exception):
"""Signal potentially incorrect master password."""


class Secrets:
def __init__(self, db_file: str, master_password: str) -> None:
self.db_file = db_file
self.master_password = master_password
self._master_password = master_password
self._con = None
self._cur = None

def close(self) -> None:
if self._cur is not None:
self._cur.close()
self._cur = None
if self._con is not None:
self._con.close()
self._con = None
Expand All @@ -58,8 +55,8 @@ def _initialize(self, db_file_found: bool) -> None:
def create_table(table: Table) -> None:
_logger.info(f'Creating table "{table.name}".')
columns = " ,".join(table.columns)
self._cursor().execute(f"CREATE TABLE {table.name} (key, {columns})")
self.connection().commit()
with self._cursor() as cur:
cur.execute(f"CREATE TABLE {table.name} (key, {columns})")

if db_file_found:
self._verify_access()
Expand All @@ -72,54 +69,68 @@ def _use_master_password(self) -> None:
If database is unencrypted then this method encrypts it.
If database is already encrypted then this method enables to access the data.
"""
if self.master_password is not None:
sanitized = self.master_password.replace("'", "\\'")
self._cursor().execute(f"PRAGMA key = '{sanitized}'")
if self._master_password is not None:
sanitized = self._master_password.replace("'", "\\'")
with self._cursor() as cur:
cur.execute(f"PRAGMA key = '{sanitized}'")

def _verify_access(self):
try:
self._cursor().execute("SELECT * FROM sqlite_master")
with self._cursor() as cur:
cur.execute("SELECT * FROM sqlite_master")
except sqlcipher.DatabaseError as ex:
print(f'exception {ex}')
if str(ex) == "file is not a database":
raise PotentiallyIncorrectMasterPassword(
raise InvalidPassword(
cleandoc(
f"""
Cannot access
database file {self.db_file}.
This also happens if master password is incorrect.
""")
)
) from ex
else:
raise ex

@contextlib.contextmanager
def _cursor(self) -> sqlcipher.Cursor:
if self._cur is None:
self._cur = self.connection().cursor()
return self._cur
cur = self.connection().cursor()
try:
yield cur
finally:
self.connection().commit()
cur.close()

def _save_data(self, table: Table, key: str, data: list[str]) -> "Secrets":
cur = self._cursor()
res = cur.execute(f"SELECT * FROM {table.name} WHERE key=?", [key])
if res and res.fetchone():
def entry_exists(cur) -> None:
res = cur.execute(
f"SELECT * FROM {table.name} WHERE key=?",
[key])
return res and res.fetchone()

def update(cur) -> None:
columns = ", ".join(f"{c}=?" for c in table.columns)
cur.execute(
f"UPDATE {table.name} SET {columns} WHERE key=?",
data + [key])
else:

def insert(cur) -> None:
columns = ",".join(table.columns)
value_slots = ", ".join("?" for c in table.columns)
cur.execute(
f"INSERT INTO {table.name} (key,{columns}) VALUES (?, {value_slots})",
(
f"INSERT INTO {table.name}"
f" (key,{columns})"
f" VALUES (?, {value_slots})"
),
[key] + data)
self.connection().commit()
return self

# def save_config_item(self, key: str, item: str) -> None:
# self._save_data(CONFIG_ITEMS_TABLE, key, [item])
#
# def save_credentials(self, key: str, user: str, password: str) -> None:
# self._save_data(SECRETS_TABLE, key, [user, password])
with self._cursor() as cur:
if entry_exists(cur):
update(cur)
else:
insert(cur)
return self

def save(self, key: str, data: Union[str, Credentials]) -> "Secrets":
"""key represents a system, service, or application"""
Expand All @@ -131,10 +142,11 @@ def save(self, key: str, data: Union[str, Credentials]) -> "Secrets":

def _data(self, table: Table, key: str) -> Optional[list[str]]:
columns = ", ".join(table.columns)
res = self._cursor().execute(
f"SELECT {columns} FROM {table.name} WHERE key=?",
[key])
return res.fetchone() if res else None
with self._cursor() as cur:
res = cur.execute(
f"SELECT {columns} FROM {table.name} WHERE key=?",
[key])
return res.fetchone() if res else None

def credentials(self, key: str) -> Optional[Credentials]:
row = self._data(SECRETS_TABLE, key)
Expand Down
4 changes: 2 additions & 2 deletions secret-store/test/test_secret_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pytest
from pathlib import Path
from secret_store import Credentials, PotentiallyIncorrectMasterPassword, Secrets
from secret_store import Credentials, InvalidPassword, Secrets
from sqlcipher3 import dbapi2 as sqlcipher


Expand Down Expand Up @@ -62,6 +62,6 @@ def test_wrong_password(sample_file):
secrets = Secrets(sample_file, "correct password")
secrets.save("key", Credentials("usr", "pass")).close()
invalid = Secrets(sample_file, "wrong password")
with pytest.raises(PotentiallyIncorrectMasterPassword) as ex:
with pytest.raises(InvalidPassword) as ex:
invalid.credentials("key")
assert "master password is incorrect" in str(ex.value)

0 comments on commit 5c73ee9

Please sign in to comment.