Skip to content

Commit

Permalink
feat: Encrypt session keys with SSH key for improved security (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
jahwag authored Sep 14, 2024
1 parent e6b938d commit 30c7717
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 24 deletions.
7 changes: 0 additions & 7 deletions .claudesync/config.local.json

This file was deleted.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,6 @@ config.json
claudesync.log
claude_chats
some_value
.claudesync

ROADMAP.md
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
[project]
name = "claudesync"
version = "0.5.8"
version = "0.6.0"
authors = [
{name = "Jahziah Wagner", email = "jahziah.wagner+pypi@gmail.com"},
{name = "Jahziah Wagner", email = "[email protected].com"},
]
description = "A tool to synchronize local files with Claude.ai projects"
license = {file = "LICENSE"}
readme = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand All @@ -27,7 +27,8 @@ dependencies = [
"crontab>=1.0.1",
"python-crontab>=3.2.0",
"Brotli>=1.1.0",
"anthropic>=0.34.2"
"anthropic>=0.34.2",
"cryptography>=3.4.7"
]
keywords = [
"sync",
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ claudesync>=0.5.4
crontab>=1.0.1
python-crontab>=3.2.0
Brotli>=1.1.0
anthropic>=0.34.2
anthropic>=0.34.2
cryptography>=3.4.7
47 changes: 35 additions & 12 deletions src/claudesync/configmanager/file_config_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import logging
import os
from datetime import datetime
from pathlib import Path

from claudesync.configmanager.base_config_manager import BaseConfigManager
from claudesync.session_key_manager import SessionKeyManager


class FileConfigManager(BaseConfigManager):
Expand All @@ -29,6 +31,7 @@ def __init__(self):
self.local_config = {}
self.local_config_dir = None
self._load_local_config()
self.session_key_manager = SessionKeyManager()

def _load_global_config(self):
"""
Expand Down Expand Up @@ -169,42 +172,62 @@ def set_session_key(self, provider, session_key, expiry):
session_key (str): The session key to set.
expiry (datetime): The expiry datetime for the session key.
"""
self.global_config_dir.mkdir(parents=True, exist_ok=True)
provider_key_file = self.global_config_dir / f"{provider}.key"
with open(provider_key_file, "w") as f:
json.dump(
{"session_key": session_key, "session_key_expiry": expiry.isoformat()},
f,
try:
encrypted_session_key, encryption_method = (
self.session_key_manager.encrypt_session_key(provider, session_key)
)

def get_session_key(self, providerName):
self.global_config_dir.mkdir(parents=True, exist_ok=True)
provider_key_file = self.global_config_dir / f"{provider}.key"
with open(provider_key_file, "w") as f:
json.dump(
{
"session_key": encrypted_session_key,
"session_key_encryption_method": encryption_method,
"session_key_expiry": expiry.isoformat(),
},
f,
)
except RuntimeError as e:
logging.error(f"Failed to encrypt session key: {str(e)}")
raise

def get_session_key(self, provider):
"""
Retrieves the session key for the specified provider if it's still valid.
Args:
providerName (str): The name of the provider.
provider (str): The name of the provider.
Returns:
tuple: A tuple containing the session key and expiry if valid, (None, None) otherwise.
"""
provider_key_file = self.global_config_dir / f"{providerName}.key"
provider_key_file = self.global_config_dir / f"{provider}.key"
if not provider_key_file.exists():
return None, None

with open(provider_key_file, "r") as f:
data = json.load(f)

session_key = data.get("session_key")
encrypted_key = data.get("session_key")
encryption_method = data.get("session_key_encryption_method")
expiry_str = data.get("session_key_expiry")

if not session_key or not expiry_str:
if not encrypted_key or not expiry_str:
return None, None

expiry = datetime.fromisoformat(expiry_str)
if datetime.now() > expiry:
return None, None

return session_key, expiry
try:
session_key = self.session_key_manager.decrypt_session_key(
provider, encryption_method, encrypted_key
)
return session_key, expiry
except RuntimeError as e:
logging.error(f"Failed to decrypt session key: {str(e)}")
return None, None

def add_file_category(self, category_name, description, patterns):
"""
Expand Down
182 changes: 182 additions & 0 deletions src/claudesync/session_key_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import os
import subprocess
import tempfile
import base64
import logging
from pathlib import Path
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC


class SessionKeyManager:
def __init__(self):
self.ssh_key_path = self._find_ssh_key()
self.logger = logging.getLogger(__name__)

def _find_ssh_key(self):
ssh_dir = Path.home() / ".ssh"
key_names = ["id_ed25519", "id_rsa", "id_ecdsa"]
for key_name in key_names:
key_path = ssh_dir / key_name
if key_path.exists():
return str(key_path)
return input("Enter the full path to your SSH private key: ")

def _get_key_type(self):
try:
result = subprocess.run(
["ssh-keygen", "-l", "-f", self.ssh_key_path],
capture_output=True,
text=True,
check=True,
)
output = result.stdout.lower()
if "rsa" in output:
return "rsa"
elif "ecdsa" in output:
return "ecdsa"
elif "ed25519" in output:
return "ed25519"
else:
raise ValueError(f"Unsupported key type for {self.ssh_key_path}")
except subprocess.CalledProcessError as e:
self.logger.error(f"Failed to determine key type: {e}")
raise RuntimeError(
"Failed to determine SSH key type. Make sure the key file is valid and accessible."
)

def _derive_key_from_ssh_key(self):
with open(self.ssh_key_path, "rb") as key_file:
ssh_key_data = key_file.read()

kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=b"claudesync", # Using a fixed salt; consider using a secure random salt in production
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(ssh_key_data))
return key

def encrypt_session_key(self, provider, session_key):
key_type = self._get_key_type()

if key_type == "rsa":
return self._encrypt_rsa(session_key)
else: # For ed25519 and ecdsa
return self._encrypt_symmetric(session_key)

def _encrypt_rsa(self, session_key):
temp_file_path = None
pub_key_file_path = None
try:
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
temp_file.write(session_key)
temp_file_path = temp_file.name

result = subprocess.run(
["ssh-keygen", "-f", self.ssh_key_path, "-e", "-m", "PKCS8"],
capture_output=True,
text=True,
check=True,
)
public_key = result.stdout

with tempfile.NamedTemporaryFile(mode="w+", delete=False) as pub_key_file:
pub_key_file.write(public_key)
pub_key_file_path = pub_key_file.name

encrypted_output = subprocess.run(
[
"openssl",
"pkeyutl",
"-encrypt",
"-pubin",
"-inkey",
pub_key_file_path,
"-in",
temp_file_path,
"-pkeyopt",
"rsa_padding_mode:oaep",
"-pkeyopt",
"rsa_oaep_md:sha256",
],
capture_output=True,
check=True,
)

encrypted_session_key = base64.b64encode(encrypted_output.stdout).decode(
"utf-8"
)

return encrypted_session_key, "rsa"
except subprocess.CalledProcessError as e:
self.logger.error(f"Encryption failed: {e}")
raise RuntimeError(
"Failed to encrypt session key. Check if openssl and ssh-keygen are installed and the SSH key is valid."
)
finally:
if temp_file_path and os.path.exists(temp_file_path):
os.unlink(temp_file_path)
if pub_key_file_path and os.path.exists(pub_key_file_path):
os.unlink(pub_key_file_path)

def _encrypt_symmetric(self, session_key):
key = self._derive_key_from_ssh_key()
f = Fernet(key)
encrypted_session_key = f.encrypt(session_key.encode()).decode()
return encrypted_session_key, "symmetric"

def decrypt_session_key(self, provider, encryption_method, encrypted_session_key):
if not encrypted_session_key or not encryption_method:
return None

if encryption_method == "rsa":
return self._decrypt_rsa(encrypted_session_key)
elif encryption_method == "symmetric":
return self._decrypt_symmetric(encrypted_session_key)
else:
raise ValueError(f"Unknown encryption method: {encryption_method}")

def _decrypt_rsa(self, encrypted_session_key):
temp_file_path = None
try:
with tempfile.NamedTemporaryFile(mode="wb+", delete=False) as temp_file:
temp_file.write(base64.b64decode(encrypted_session_key))
temp_file_path = temp_file.name

decrypted_output = subprocess.run(
[
"openssl",
"pkeyutl",
"-decrypt",
"-inkey",
self.ssh_key_path,
"-in",
temp_file_path,
"-pkeyopt",
"rsa_padding_mode:oaep",
"-pkeyopt",
"rsa_oaep_md:sha256",
],
capture_output=True,
text=True,
check=True,
)

return decrypted_output.stdout.strip()

except subprocess.CalledProcessError as e:
self.logger.error(f"Decryption failed: {e}")
raise RuntimeError(
"Failed to decrypt session key. Make sure the SSH key is valid and matches the one used for encryption."
)
finally:
if temp_file_path and os.path.exists(temp_file_path):
os.unlink(temp_file_path)

def _decrypt_symmetric(self, encrypted_session_key):
key = self._derive_key_from_ssh_key()
f = Fernet(key)
return f.decrypt(encrypted_session_key.encode()).decode()

0 comments on commit 30c7717

Please sign in to comment.