From 973740a3b1dd488bbd547dfbdbb01c0e00f96114 Mon Sep 17 00:00:00 2001 From: jahwag <540380+jahwag@users.noreply.github.com> Date: Sat, 14 Sep 2024 09:21:26 +0200 Subject: [PATCH] feat: use ssh-agent to secure sessionKey at rest --- .claudesync/config.local.json | 7 - .gitignore | 1 + pyproject.toml | 9 +- requirements.txt | 3 +- .../configmanager/file_config_manager.py | 47 +++-- src/claudesync/session_key_manager.py | 182 ++++++++++++++++++ 6 files changed, 225 insertions(+), 24 deletions(-) delete mode 100644 .claudesync/config.local.json create mode 100644 src/claudesync/session_key_manager.py diff --git a/.claudesync/config.local.json b/.claudesync/config.local.json deleted file mode 100644 index 99af1c5..0000000 --- a/.claudesync/config.local.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "active_provider": "claude.ai", - "active_organization_id": "e731f9fc-edb4-420d-aa35-952e2ce77137", - "active_project_id": "726c9cf7-394f-43f7-aced-bfa64ad2e1fb", - "active_project_name": "ClaudeSync", - "default_sync_category": "all_files" -} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 421f4ab..ad98fe7 100644 --- a/.gitignore +++ b/.gitignore @@ -171,5 +171,6 @@ config.json claudesync.log claude_chats some_value +.claudesync ROADMAP.md \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b3bab46..6c6bedc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,13 @@ [project] name = "claudesync" -version = "0.5.8" +version = "0.6.0" authors = [ - {name = "Jahziah Wagner", email = "jahziah.wagner+pypi@gmail.com"}, + {name = "Jahziah Wagner", email = "540380+jahwag@users.noreply.github.com"}, ] description = "A tool to synchronize local files with Claude.ai projects" license = {file = "LICENSE"} readme = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.10" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", @@ -27,7 +27,8 @@ dependencies = [ "crontab>=1.0.1", "python-crontab>=3.2.0", "Brotli>=1.1.0", - "anthropic>=0.34.2" + "anthropic>=0.34.2", + "cryptography>=3.4.7" ] keywords = [ "sync", diff --git a/requirements.txt b/requirements.txt index 23a627a..fc27d9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ claudesync>=0.5.4 crontab>=1.0.1 python-crontab>=3.2.0 Brotli>=1.1.0 -anthropic>=0.34.2 \ No newline at end of file +anthropic>=0.34.2 +cryptography>=3.4.7 \ No newline at end of file diff --git a/src/claudesync/configmanager/file_config_manager.py b/src/claudesync/configmanager/file_config_manager.py index 38d5d8e..3c4b171 100644 --- a/src/claudesync/configmanager/file_config_manager.py +++ b/src/claudesync/configmanager/file_config_manager.py @@ -1,9 +1,11 @@ import json +import logging import os from datetime import datetime from pathlib import Path from claudesync.configmanager.base_config_manager import BaseConfigManager +from claudesync.session_key_manager import SessionKeyManager class FileConfigManager(BaseConfigManager): @@ -29,6 +31,7 @@ def __init__(self): self.local_config = {} self.local_config_dir = None self._load_local_config() + self.session_key_manager = SessionKeyManager() def _load_global_config(self): """ @@ -169,42 +172,62 @@ def set_session_key(self, provider, session_key, expiry): session_key (str): The session key to set. expiry (datetime): The expiry datetime for the session key. """ - self.global_config_dir.mkdir(parents=True, exist_ok=True) - provider_key_file = self.global_config_dir / f"{provider}.key" - with open(provider_key_file, "w") as f: - json.dump( - {"session_key": session_key, "session_key_expiry": expiry.isoformat()}, - f, + try: + encrypted_session_key, encryption_method = ( + self.session_key_manager.encrypt_session_key(provider, session_key) ) - def get_session_key(self, providerName): + self.global_config_dir.mkdir(parents=True, exist_ok=True) + provider_key_file = self.global_config_dir / f"{provider}.key" + with open(provider_key_file, "w") as f: + json.dump( + { + "session_key": encrypted_session_key, + "session_key_encryption_method": encryption_method, + "session_key_expiry": expiry.isoformat(), + }, + f, + ) + except RuntimeError as e: + logging.error(f"Failed to encrypt session key: {str(e)}") + raise + + def get_session_key(self, provider): """ Retrieves the session key for the specified provider if it's still valid. Args: - providerName (str): The name of the provider. + provider (str): The name of the provider. Returns: tuple: A tuple containing the session key and expiry if valid, (None, None) otherwise. """ - provider_key_file = self.global_config_dir / f"{providerName}.key" + provider_key_file = self.global_config_dir / f"{provider}.key" if not provider_key_file.exists(): return None, None with open(provider_key_file, "r") as f: data = json.load(f) - session_key = data.get("session_key") + encrypted_key = data.get("session_key") + encryption_method = data.get("session_key_encryption_method") expiry_str = data.get("session_key_expiry") - if not session_key or not expiry_str: + if not encrypted_key or not expiry_str: return None, None expiry = datetime.fromisoformat(expiry_str) if datetime.now() > expiry: return None, None - return session_key, expiry + try: + session_key = self.session_key_manager.decrypt_session_key( + provider, encryption_method, encrypted_key + ) + return session_key, expiry + except RuntimeError as e: + logging.error(f"Failed to decrypt session key: {str(e)}") + return None, None def add_file_category(self, category_name, description, patterns): """ diff --git a/src/claudesync/session_key_manager.py b/src/claudesync/session_key_manager.py new file mode 100644 index 0000000..2b0f579 --- /dev/null +++ b/src/claudesync/session_key_manager.py @@ -0,0 +1,182 @@ +import os +import subprocess +import tempfile +import base64 +import logging +from pathlib import Path +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + +class SessionKeyManager: + def __init__(self): + self.ssh_key_path = self._find_ssh_key() + self.logger = logging.getLogger(__name__) + + def _find_ssh_key(self): + ssh_dir = Path.home() / ".ssh" + key_names = ["id_ed25519", "id_rsa", "id_ecdsa"] + for key_name in key_names: + key_path = ssh_dir / key_name + if key_path.exists(): + return str(key_path) + return input("Enter the full path to your SSH private key: ") + + def _get_key_type(self): + try: + result = subprocess.run( + ["ssh-keygen", "-l", "-f", self.ssh_key_path], + capture_output=True, + text=True, + check=True, + ) + output = result.stdout.lower() + if "rsa" in output: + return "rsa" + elif "ecdsa" in output: + return "ecdsa" + elif "ed25519" in output: + return "ed25519" + else: + raise ValueError(f"Unsupported key type for {self.ssh_key_path}") + except subprocess.CalledProcessError as e: + self.logger.error(f"Failed to determine key type: {e}") + raise RuntimeError( + "Failed to determine SSH key type. Make sure the key file is valid and accessible." + ) + + def _derive_key_from_ssh_key(self): + with open(self.ssh_key_path, "rb") as key_file: + ssh_key_data = key_file.read() + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=b"claudesync", # Using a fixed salt; consider using a secure random salt in production + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(ssh_key_data)) + return key + + def encrypt_session_key(self, provider, session_key): + key_type = self._get_key_type() + + if key_type == "rsa": + return self._encrypt_rsa(session_key) + else: # For ed25519 and ecdsa + return self._encrypt_symmetric(session_key) + + def _encrypt_rsa(self, session_key): + temp_file_path = None + pub_key_file_path = None + try: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + temp_file.write(session_key) + temp_file_path = temp_file.name + + result = subprocess.run( + ["ssh-keygen", "-f", self.ssh_key_path, "-e", "-m", "PKCS8"], + capture_output=True, + text=True, + check=True, + ) + public_key = result.stdout + + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as pub_key_file: + pub_key_file.write(public_key) + pub_key_file_path = pub_key_file.name + + encrypted_output = subprocess.run( + [ + "openssl", + "pkeyutl", + "-encrypt", + "-pubin", + "-inkey", + pub_key_file_path, + "-in", + temp_file_path, + "-pkeyopt", + "rsa_padding_mode:oaep", + "-pkeyopt", + "rsa_oaep_md:sha256", + ], + capture_output=True, + check=True, + ) + + encrypted_session_key = base64.b64encode(encrypted_output.stdout).decode( + "utf-8" + ) + + return encrypted_session_key, "rsa" + except subprocess.CalledProcessError as e: + self.logger.error(f"Encryption failed: {e}") + raise RuntimeError( + "Failed to encrypt session key. Check if openssl and ssh-keygen are installed and the SSH key is valid." + ) + finally: + if temp_file_path and os.path.exists(temp_file_path): + os.unlink(temp_file_path) + if pub_key_file_path and os.path.exists(pub_key_file_path): + os.unlink(pub_key_file_path) + + def _encrypt_symmetric(self, session_key): + key = self._derive_key_from_ssh_key() + f = Fernet(key) + encrypted_session_key = f.encrypt(session_key.encode()).decode() + return encrypted_session_key, "symmetric" + + def decrypt_session_key(self, provider, encryption_method, encrypted_session_key): + if not encrypted_session_key or not encryption_method: + return None + + if encryption_method == "rsa": + return self._decrypt_rsa(encrypted_session_key) + elif encryption_method == "symmetric": + return self._decrypt_symmetric(encrypted_session_key) + else: + raise ValueError(f"Unknown encryption method: {encryption_method}") + + def _decrypt_rsa(self, encrypted_session_key): + temp_file_path = None + try: + with tempfile.NamedTemporaryFile(mode="wb+", delete=False) as temp_file: + temp_file.write(base64.b64decode(encrypted_session_key)) + temp_file_path = temp_file.name + + decrypted_output = subprocess.run( + [ + "openssl", + "pkeyutl", + "-decrypt", + "-inkey", + self.ssh_key_path, + "-in", + temp_file_path, + "-pkeyopt", + "rsa_padding_mode:oaep", + "-pkeyopt", + "rsa_oaep_md:sha256", + ], + capture_output=True, + text=True, + check=True, + ) + + return decrypted_output.stdout.strip() + + except subprocess.CalledProcessError as e: + self.logger.error(f"Decryption failed: {e}") + raise RuntimeError( + "Failed to decrypt session key. Make sure the SSH key is valid and matches the one used for encryption." + ) + finally: + if temp_file_path and os.path.exists(temp_file_path): + os.unlink(temp_file_path) + + def _decrypt_symmetric(self, encrypted_session_key): + key = self._derive_key_from_ssh_key() + f = Fernet(key) + return f.decrypt(encrypted_session_key.encode()).decode()