Skip to content

Commit

Permalink
Add keyed encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
bboonstra committed Oct 23, 2024
1 parent 0b9d300 commit 4b259bc
Show file tree
Hide file tree
Showing 4 changed files with 399 additions and 102 deletions.
245 changes: 179 additions & 66 deletions effortless/effortless.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import shutil
from effortless.configuration import EffortlessConfig
from effortless.search import Query
from cryptography.fernet import Fernet, InvalidToken

logger = logging.getLogger(__name__)


class EffortlessDB:
def __init__(self, db_name: str = "db"):
def __init__(self, db_name: str = "db", encryption_key: Optional[str] = None):
"""
Initialize an EffortlessDB instance.
Expand All @@ -23,13 +24,16 @@ def __init__(self, db_name: str = "db"):
Args:
db_name (str, optional): The name of the database. Defaults to "db".
encryption_key (str, optional): The encryption key for the database.
"""
self._config = EffortlessConfig()
self.set_storage(db_name)
self._autoconfigure()
self._operation_count = 0
self._backup_thread = None
self._encryption_key = None
self.set_encryption_key(encryption_key)

@property
def config(self):
Expand All @@ -45,6 +49,77 @@ def config(self, new_config: EffortlessConfig):
data["headers"] = self._config.to_dict()
self._write_db(data)

@property
def encryption_key(self):
"""
Get the current encryption key.
Returns:
Optional[str]: The current encryption key, or None if encryption is not set.
"""
return self._encryption_key

@encryption_key.setter
def encryption_key(self, new_key: Optional[str]):
"""
Set a new encryption key.
This setter is a convenience method that calls set_encryption_key().
Args:
new_key (Optional[str]): The new encryption key to set.
"""
self.set_encryption_key(new_key)

def set_encryption_key(self, new_key: Optional[str]) -> None:
"""
Set a new encryption key and re-encrypt the database if necessary.
If the database is already encrypted, this method will attempt to decrypt
with both the old and new keys, then re-encrypt with the new key.
Args:
new_key (Optional[str]): The new encryption key to set.
Raises:
TypeError: If the new key is not a string (when not None).
ValueError: If unable to decrypt the database with either the old or new key.
"""
if new_key is None:
self._encryption_key = None
return

if not isinstance(new_key, str):
raise TypeError("Encryption key must be a string")

old_key = self._encryption_key
self._encryption_key = new_key

if self.config.encrypted:
try:
self._reencrypt_db(old_key, new_key)
except ValueError as e:
self._encryption_key = old_key # Revert to old key
raise e
else:
self.config.encrypted = True

def _reencrypt_db(self, old_key: Optional[str], new_key: str) -> None:
"""
Re-encrypt the database with a new key.
This method is called internally when changing the encryption key.
Args:
old_key (Optional[str]): The previous encryption key.
new_key (str): The new encryption key.
Raises:
ValueError: If unable to decrypt the database with either the old or new key.
"""
data = self._read_db(try_keys=[old_key, new_key])
self._write_db(data, force_encrypt=True)

@staticmethod
def default_db():
"""
Expand Down Expand Up @@ -286,7 +361,9 @@ def wipe(self, wipe_readonly: bool = False) -> None:
)
self._update_config()

def _read_db(self) -> Dict[str, Any]:
def _read_db(
self, try_keys: Optional[List[Optional[str]]] = None
) -> Dict[str, Any]:
"""
Read the contents of the database file.
Expand Down Expand Up @@ -317,20 +394,43 @@ def _read_db(self) -> Dict[str, Any]:
headers = data["headers"]
content = data["content"]

if headers.get("compressed"):
content = self._decompress_data(content)

if headers.get("encrypted"):
content = self._decrypt_data(
content if isinstance(content, str) else json.dumps(content)
)
if try_keys is None:
try_keys = [self._encryption_key]

decrypted = False
if isinstance(content, str):
for key in try_keys:
if key is not None:
try:
content = self._decrypt_data(content, key)
decrypted = True
break
except InvalidToken:
continue
else:
logger.warning("Content is not encrypted despite encryption flag")

if not decrypted and headers.get("encrypted"):
raise ValueError("Unable to decrypt database with provided keys")

if headers.get("compressed"):
if isinstance(content, str):
content = self._decompress_data(content)
else:
logger.warning("Content is not compressed despite compression flag")

return {"headers": headers, "content": content}
except (IOError, json.JSONDecodeError) as e:
logger.error(f"Error reading database: {str(e)}")
raise

def _write_db(self, data: Dict[str, Any], write_in_readonly: bool = False) -> None:
def _write_db(
self,
data: Dict[str, Any],
write_in_readonly: bool = False,
force_encrypt: bool = False,
) -> None:
"""
Write data to the database file.
Expand All @@ -356,11 +456,16 @@ def _write_db(self, data: Dict[str, Any], write_in_readonly: bool = False) -> No
headers = data["headers"]
content = data["content"]

if headers.get("encrypted"):
content = self._encrypt_data(content)

if headers.get("compressed"):
content = self._compress_data(json.dumps(content))
content = self._compress_data(content)

if headers.get("encrypted") or force_encrypt:
if self._encryption_key is None:
raise ValueError(
"Encryption key is required to write encrypted data"
)
content = self._encrypt_data(content, self._encryption_key)
headers["encrypted"] = True

final_data = json.dumps(
{"headers": headers, "content": content}, indent=2
Expand Down Expand Up @@ -451,100 +556,79 @@ def _backup(self) -> bool:

return False

def _compress_data(self, data: Union[str, Dict[str, Any]]) -> str:
def _compress_data(self, data: Union[str, List[Dict[str, Any]]]) -> str:
"""
Compress the given data and return as a base64-encoded string.
Args:
data (Union[str, Dict[str, Any]]): The data to be compressed.
data (Union[str, List[Dict[str, Any]]]): The data to be compressed.
Returns:
str: A base64-encoded string of the compressed data.
"""
if isinstance(data, dict):
if isinstance(data, list):
data = json.dumps(data)
compressed = zlib.compress(data.encode())
return base64.b64encode(compressed).decode()

def _decompress_data(self, data: str) -> Dict[str, Any]:
def _decompress_data(self, data: str) -> List[Dict[str, Any]]:
"""
Decompress the given base64-encoded string data.
Decompress the given string data.
This method decodes the base64 string, decompresses it using zlib,
and then parses it as JSON.
This method decompresses the data using zlib, and then parses it as JSON.
Args:
data (str): The base64-encoded compressed data.
data (str): The compressed data as a string.
Returns:
Dict[str, Any]: The decompressed and parsed data.
List[Dict[str, Any]]: The decompressed and parsed data.
"""
compressed = base64.b64decode(data.encode())
return json.loads(zlib.decompress(compressed).decode())
decompressed = zlib.decompress(base64.b64decode(data))
return json.loads(decompressed.decode())

def _encrypt_data(self, data: Dict[str, Any]) -> str:
def _encrypt_data(self, data: Union[str, Dict[str, Any]], key: str) -> str:
"""
Encrypt the given data and return as a base64-encoded string.
Note: TODO This is a placeholder method and does not actually perform encryption.
It should be implemented with proper encryption algorithms in a production environment.
Encrypt the given data using the provided key.
Args:
data (Dict[str, Any]): The data to be encrypted.
data (Union[str, Dict[str, Any]]): The data to encrypt.
key (str): The encryption key.
Returns:
str: A base64-encoded string of the "encrypted" data.
str: The encrypted data as a string.
"""
# TODO: Implement actual encryption
return base64.b64encode(json.dumps(data).encode()).decode()
fernet = Fernet(self._get_fernet_key(key))
return fernet.encrypt(json.dumps(data).encode()).decode()

def _decrypt_data(self, data: str) -> Dict[str, Any]:
def _decrypt_data(self, data: str, key: str) -> Dict[str, Any]:
"""
Decrypt the given base64-encoded string data.
Note: TODO This is a placeholder method and does not actually perform decryption.
It should be implemented with proper decryption algorithms in a production environment.
Decrypt the given data using the provided key.
Args:
data (str): The base64-encoded "encrypted" data.
data (str): The encrypted data.
key (str): The decryption key.
Returns:
Dict[str, Any]: The "decrypted" and parsed data.
"""
# TODO: Implement actual decryption
return json.loads(base64.b64decode(data.encode()).decode())
Dict[str, Any]: The decrypted data.
def _encrypt_value(self, value: Any) -> Any:
"""
Encrypt a single value.
Note: TODO This is a placeholder method and does not actually perform encryption.
It should be implemented with proper encryption algorithms in a production environment.
Args:
value (Any): The value to be encrypted.
Returns:
Any: The "encrypted" value (currently unchanged).
Raises:
InvalidToken: If the decryption fails due to an invalid key.
"""
# TODO: Implement encryption
return value
fernet = Fernet(self._get_fernet_key(key))
return json.loads(fernet.decrypt(data.encode()).decode())

def _decrypt_value(self, value: Any) -> Any:
@staticmethod
def _get_fernet_key(key: str) -> bytes:
"""
Decrypt a single value.
Note: TODO This is a placeholder method and does not actually perform decryption.
It should be implemented with proper decryption algorithms in a production environment.
Generate a Fernet-compatible key from the given string key.
Args:
value (Any): The value to be decrypted.
key (str): The original key string.
Returns:
Any: The "decrypted" value (currently unchanged).
bytes: A URL-safe base64-encoded 32-byte key for Fernet.
"""
# TODO: Implement decryption
return value
return base64.urlsafe_b64encode(key.encode().ljust(32)[:32])

def update(self, update_data: Dict[str, Any], condition: Query) -> bool:
"""
Expand Down Expand Up @@ -657,5 +741,34 @@ def erase(self, condition: Query) -> int:

return removed_count

def search(self, query: Query) -> Optional[Dict[str, Any]]:
"""
Search for a single entry in the database that matches the given query.
This method is similar to filter(), but returns only one matching entry
instead of a list of all matching entries.
Args:
query (Query): A Query object defining the search criteria.
Returns:
Optional[Dict[str, Any]]: The first entry that matches the query,
or None if no matching entry is found.
Raises:
ValueError: If more than one entry matches the query.
"""
matching_entries = self.filter(query)

if len(matching_entries) > 1:
raise ValueError(
"More than one entry matches the given query. Use filter() if you want to retrieve multiple entries."
)
elif len(matching_entries) == 0:
return None

return matching_entries[0]


db = EffortlessDB()

4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
"Topic :: Software Development :: Libraries :: Application Frameworks",
],
python_requires=">=3.9",
install_requires=[],
install_requires=[
"cryptography>=41.0.0",
],
keywords="database, effortless, simple storage, beginner, easy, db",
project_urls={
"Bug Tracker": "https://github.com/bboonstra/Effortless/issues",
Expand Down
Loading

0 comments on commit 4b259bc

Please sign in to comment.