diff --git a/.gitignore b/.gitignore index b1b513e..7d8bcf8 100644 --- a/.gitignore +++ b/.gitignore @@ -168,4 +168,5 @@ __pycache__ # claude claude.sync config.json -claudesync.log \ No newline at end of file +claudesync.log +chats \ No newline at end of file diff --git a/README.md b/README.md index 06c77bf..b350ddf 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ .o..P' `Y8P' ``` +[![Python package](https://github.com/jahwag/ClaudeSync/actions/workflows/python-package.yml/badge.svg)](https://github.com/jahwag/ClaudeSync/actions/workflows/python-package.yml) ![License](https://img.shields.io/badge/License-MIT-blue.svg) [![PyPI version](https://badge.fury.io/py/claudesync.svg)](https://badge.fury.io/py/claudesync) @@ -27,6 +28,9 @@ ClaudeSync bridges the gap between your local development environment and Claude - Seamless integration with your existing workflow - Optional two-way synchronization support - Configuration management through CLI +- Chat and artifact synchronization and management + +**Important Note**: ClaudeSync requires a Claude.ai Professional plan to function properly. Make sure you have an active Professional subscription before using this tool. ## Important Disclaimers @@ -85,6 +89,12 @@ ClaudeSync bridges the gap between your local development environment and Claude - List remote files: `claudesync ls` - Sync files: `claudesync sync` +### Chat Management +- List chats: `claudesync chat ls` +- Sync chats and artifacts: `claudesync chat sync` +- Delete chats: `claudesync chat rm` +- Delete all chats: `claudesync chat rm -a` + ### Configuration - View current status: `claudesync status` - Set configuration values: `claudesync config set ` @@ -139,23 +149,21 @@ ClaudeSync offers two providers for interacting with the Claude.ai API: ### Troubleshooting #### 403 Forbidden Error -If you encounter a 403 Forbidden error when using ClaudeSync, it might be due to an issue with the session key or API access. As a workaround, you can try using the `claude.ai-curl` provider: +If you encounter a 403 Forbidden error when using ClaudeSync, it might be due to an issue with the session key or API access. Here are some steps to resolve this: -1. Ensure cURL is installed on your system (see note above for Windows users). - -2. Logout from your current session: +1. Ensure you have an active Claude.ai Professional plan subscription. +2. Try logging out and logging in again: ```bash claudesync api logout + claudesync api login claude.ai ``` - -3. Login using the claude.ai-curl provider: +3. If the issue persists, you can try using the claude.ai-curl provider as a workaround: ```bash + claudesync api logout claudesync api login claude.ai-curl ``` -4. Try your operation again. - -If the issue persists, please check your network connection and ensure that you have the necessary permissions to access Claude.ai. +If you continue to experience issues, please check your network connection and ensure that you have the necessary permissions to access Claude.ai. ## Contributing diff --git a/pyproject.toml b/pyproject.toml index 9ca26be..cdb3d19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "claudesync" -version = "0.3.8" +version = "0.3.9" authors = [ {name = "Jahziah Wagner", email = "jahziah.wagner+pypi@gmail.com"}, ] @@ -26,6 +26,7 @@ dependencies = [ "pytest", "pytest-cov", "click_completion", + "tqdm", ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 526b5ee..af10441 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ crontab>=1.0.1 setuptools>=65.5.1 pytest>=8.2.2 pytest-cov>=5.0.0 -click_completion>=0.5.2 \ No newline at end of file +click_completion>=0.5.2 +tqdm>=4.66.4 \ No newline at end of file diff --git a/src/claudesync/chat_sync.py b/src/claudesync/chat_sync.py new file mode 100644 index 0000000..2630c0f --- /dev/null +++ b/src/claudesync/chat_sync.py @@ -0,0 +1,162 @@ +import json +import logging +import os +import re + +from tqdm import tqdm + +from .exceptions import ConfigurationError + +logger = logging.getLogger(__name__) + + +def sync_chats(provider, config, sync_all=False): + """ + Synchronize chats and their artifacts from the remote source. + + This function fetches all chats for the active organization, saves their metadata, + messages, and extracts any artifacts found in the assistant's messages. + + Args: + provider: The API provider instance. + config: The configuration manager instance. + sync_all (bool): If True, sync all chats regardless of project. If False, only sync chats for the active project. + + Raises: + ConfigurationError: If required configuration settings are missing. + """ + # Get the local_path for chats + local_path = config.get("local_path") + if not local_path: + raise ConfigurationError( + "Local path not set. Use 'claudesync project select' or 'claudesync project create' to set it." + ) + + # Create chats directory within local_path + chat_destination = os.path.join(local_path, "chats") + os.makedirs(chat_destination, exist_ok=True) + + # Get the active organization ID + organization_id = config.get("active_organization_id") + if not organization_id: + raise ConfigurationError( + "No active organization set. Please select an organization." + ) + + # Get the active project ID + active_project_id = config.get("active_project_id") + if not active_project_id and not sync_all: + raise ConfigurationError( + "No active project set. Please select a project or use the -a flag to sync all chats." + ) + + # Fetch all chats for the organization + logger.debug(f"Fetching chats for organization {organization_id}") + chats = provider.get_chat_conversations(organization_id) + logger.debug(f"Found {len(chats)} chats") + + # Process each chat + for chat in tqdm(chats, desc="Syncing chats"): + # Check if the chat belongs to the active project or if we're syncing all chats + if sync_all or ( + chat.get("project") and chat["project"].get("uuid") == active_project_id + ): + logger.info(f"Processing chat {chat['uuid']}") + chat_folder = os.path.join(chat_destination, chat["uuid"]) + os.makedirs(chat_folder, exist_ok=True) + + # Save chat metadata + with open(os.path.join(chat_folder, "metadata.json"), "w") as f: + json.dump(chat, f, indent=2) + + # Fetch full chat conversation + logger.debug(f"Fetching full conversation for chat {chat['uuid']}") + full_chat = provider.get_chat_conversation(organization_id, chat["uuid"]) + + # Process each message in the chat + for message in full_chat["chat_messages"]: + # Save the message + message_file = os.path.join(chat_folder, f"{message['uuid']}.json") + with open(message_file, "w") as f: + json.dump(message, f, indent=2) + + # Handle artifacts in assistant messages + if message["sender"] == "assistant": + artifacts = extract_artifacts(message["text"]) + if artifacts: + logger.info( + f"Found {len(artifacts)} artifacts in message {message['uuid']}" + ) + artifact_folder = os.path.join(chat_folder, "artifacts") + os.makedirs(artifact_folder, exist_ok=True) + for artifact in artifacts: + # Save each artifact + artifact_file = os.path.join( + artifact_folder, + f"{artifact['identifier']}.{get_file_extension(artifact['type'])}", + ) + with open(artifact_file, "w") as f: + f.write(artifact["content"]) + else: + logger.debug( + f"Skipping chat {chat['uuid']} as it doesn't belong to the active project" + ) + + logger.debug(f"Chats and artifacts synchronized to {chat_destination}") + + +def get_file_extension(artifact_type): + """ + Get the appropriate file extension for a given artifact type. + + Args: + artifact_type (str): The MIME type of the artifact. + + Returns: + str: The corresponding file extension. + """ + type_to_extension = { + "text/html": "html", + "application/vnd.ant.code": "txt", + "image/svg+xml": "svg", + "application/vnd.ant.mermaid": "mmd", + "application/vnd.ant.react": "jsx", + } + return type_to_extension.get(artifact_type, "txt") + + +def extract_artifacts(text): + """ + Extract artifacts from the given text. + + This function searches for antArtifact tags in the text and extracts + the artifact information, including identifier, type, and content. + + Args: + text (str): The text to search for artifacts. + + Returns: + list: A list of dictionaries containing artifact information. + """ + artifacts = [] + + # Regular expression to match the tags and extract their attributes and content + pattern = re.compile( + r'([\s\S]*?)', + re.MULTILINE, + ) + + # Find all matches in the text + matches = pattern.findall(text) + + for match in matches: + identifier, artifact_type, title, content = match + artifacts.append( + { + "identifier": identifier, + "type": artifact_type, + "content": content.strip(), + } + ) + + return artifacts diff --git a/src/claudesync/cli/chat.py b/src/claudesync/cli/chat.py new file mode 100644 index 0000000..763885b --- /dev/null +++ b/src/claudesync/cli/chat.py @@ -0,0 +1,138 @@ +import click +import logging +from ..exceptions import ProviderError +from ..utils import handle_errors, validate_and_get_provider +from ..chat_sync import sync_chats + +logger = logging.getLogger(__name__) + + +@click.group() +def chat(): + """Manage and synchronize chats.""" + pass + + +@chat.command() +@click.pass_obj +@handle_errors +def sync(config): + """Synchronize chats and their artifacts from the remote source.""" + provider = validate_and_get_provider(config) + sync_chats(provider, config) + + +@chat.command() +@click.pass_obj +@handle_errors +def ls(config): + """List all chats.""" + provider = validate_and_get_provider(config) + organization_id = config.get("active_organization_id") + chats = provider.get_chat_conversations(organization_id) + + for chat in chats: + project = chat.get("project") + project_name = project.get("name") if project else "" + click.echo( + f"UUID: {chat.get('uuid', 'Unknown')}, " + f"Name: {chat.get('name', 'Unnamed')}, " + f"Project: {project_name}, " + f"Updated: {chat.get('updated_at', 'Unknown')}" + ) + + +@chat.command() +@click.option("-a", "--all", "delete_all", is_flag=True, help="Delete all chats") +@click.pass_obj +@handle_errors +def rm(config, delete_all): + """Delete chats. Use -a to delete all chats, or select a chat to delete.""" + provider = validate_and_get_provider(config) + organization_id = config.get("active_organization_id") + + if delete_all: + delete_all_chats(provider, organization_id) + else: + delete_single_chat(provider, organization_id) + + +def delete_chats(provider, organization_id, uuids): + """Delete a list of chats by their UUIDs.""" + try: + result = provider.delete_chat(organization_id, uuids) + return len(result), 0 + except ProviderError as e: + logger.error(f"Error deleting chats: {str(e)}") + click.echo(f"Error occurred while deleting chats: {str(e)}") + return 0, len(uuids) + + +def delete_all_chats(provider, organization_id): + """Delete all chats for the given organization.""" + if click.confirm("Are you sure you want to delete all chats?"): + total_deleted = 0 + with click.progressbar(length=100, label="Deleting chats") as bar: + while True: + chats = provider.get_chat_conversations(organization_id) + if not chats: + break + uuids_to_delete = [chat["uuid"] for chat in chats[:50]] + deleted, _ = delete_chats(provider, organization_id, uuids_to_delete) + total_deleted += deleted + bar.update(len(uuids_to_delete)) + click.echo(f"Chat deletion complete. Total chats deleted: {total_deleted}") + + +def delete_single_chat(provider, organization_id): + """Delete a single chat selected by the user.""" + chats = provider.get_chat_conversations(organization_id) + if not chats: + click.echo("No chats found.") + return + + display_chat_list(chats) + selected_chat = get_chat_selection(chats) + if selected_chat: + confirm_and_delete_chat(provider, organization_id, selected_chat) + + +def display_chat_list(chats): + """Display a list of chats to the user.""" + click.echo("Available chats:") + for idx, chat in enumerate(chats, 1): + project = chat.get("project") + project_name = project.get("name") if project else "" + click.echo( + f"{idx}. Name: {chat.get('name', 'Unnamed')}, " + f"Project: {project_name}, Updated: {chat.get('updated_at', 'Unknown')}" + ) + + +def get_chat_selection(chats): + """Get a valid chat selection from the user.""" + while True: + selection = click.prompt( + "Enter the number of the chat to delete (or 'q' to quit)", type=str + ) + if selection.lower() == "q": + return None + try: + selection = int(selection) + if 1 <= selection <= len(chats): + return chats[selection - 1] + click.echo("Invalid selection. Please try again.") + except ValueError: + click.echo("Invalid input. Please enter a number or 'q' to quit.") + + +def confirm_and_delete_chat(provider, organization_id, chat): + """Confirm deletion with the user and delete the selected chat.""" + if click.confirm( + f"Are you sure you want to delete the chat '{chat.get('name', 'Unnamed')}'?" + ): + deleted, _ = delete_chats(provider, organization_id, [chat["uuid"]]) + if deleted: + click.echo(f"Successfully deleted chat: {chat.get('name', 'Unnamed')}") + else: + click.echo(f"Failed to delete chat: {chat.get('name', 'Unnamed')}") diff --git a/src/claudesync/cli/main.py b/src/claudesync/cli/main.py index 925556f..769e389 100644 --- a/src/claudesync/cli/main.py +++ b/src/claudesync/cli/main.py @@ -2,12 +2,18 @@ import click_completion import click_completion.core +from claudesync.cli.chat import chat from claudesync.config_manager import ConfigManager from .api import api from .organization import organization from .project import project from .sync import ls, sync, schedule from .config import config +import logging + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) click_completion.init() @@ -55,6 +61,7 @@ def status(config): cli.add_command(sync) cli.add_command(schedule) cli.add_command(config) +cli.add_command(chat) if __name__ == "__main__": cli() diff --git a/src/claudesync/cli/sync.py b/src/claudesync/cli/sync.py index 70cfb83..a7e8d3a 100644 --- a/src/claudesync/cli/sync.py +++ b/src/claudesync/cli/sync.py @@ -7,6 +7,7 @@ from claudesync.utils import get_local_files from ..utils import handle_errors, validate_and_get_provider from ..syncmanager import SyncManager +from ..chat_sync import sync_chats @click.command() @@ -34,21 +35,21 @@ def ls(config): @click.pass_obj @handle_errors def sync(config): - """Synchronize local files with the active remote project.""" + """Synchronize both projects and chats.""" provider = validate_and_get_provider(config) - local_path = config.get("local_path") - - validate_local_path(local_path) + # Sync projects sync_manager = SyncManager(provider, config) remote_files = provider.list_files( sync_manager.active_organization_id, sync_manager.active_project_id ) - local_files = get_local_files(local_path) - + local_files = get_local_files(config.get("local_path")) sync_manager.sync(local_files, remote_files) + click.echo("Project sync completed successfully.") - click.echo("Sync completed successfully.") + # Sync chats + sync_chats(provider, config) + click.echo("Chat sync completed successfully.") def validate_local_path(local_path): diff --git a/src/claudesync/config_manager.py b/src/claudesync/config_manager.py index ee4ac2c..50472f3 100644 --- a/src/claudesync/config_manager.py +++ b/src/claudesync/config_manager.py @@ -4,10 +4,11 @@ class ConfigManager: """ - A class to manage configuration settings for the application. + A class to manage configuration settings for the ClaudeSync application. This class handles loading, saving, and accessing configuration settings from a JSON file. - It ensures that default values are set for certain keys if they are not present in the configuration file. + It ensures that default values are set for certain keys if they are not present in the configuration file, + and handles the expansion of user home directory paths. Attributes: config_dir (Path): The directory where the configuration file is stored. @@ -17,51 +18,54 @@ class ConfigManager: def __init__(self): """ - Initializes the ConfigManager instance by setting up the configuration directory and file paths, - and loading the current configuration from the file, applying default values as necessary. + Initializes the ConfigManager instance. + + Sets up the configuration directory and file paths, and loads the current configuration from the file. """ self.config_dir = Path.home() / ".claudesync" self.config_file = self.config_dir / "config.json" self.config = self._load_config() + def _get_default_config(self): + """ + Returns the default configuration dictionary. + + This method centralizes the default configuration settings, making it easier to manage and update defaults. + + Returns: + dict: The default configuration settings. + """ + return { + "log_level": "INFO", + "upload_delay": 0.5, + "max_file_size": 32 * 1024, # Default 32 KB + "two_way_sync": False, # Default to False + } + def _load_config(self): """ Loads the configuration from the JSON file, applying default values for missing keys. - If the configuration file does not exist, it creates the directory (if necessary) and returns a dictionary - with default values. + If the configuration file does not exist, + it creates the directory (if necessary) and returns the default configuration. + For existing configurations, it ensures all default values are present and expands user home directory paths. Returns: - dict: The loaded configuration with default values for missing keys. + dict: The loaded configuration with default values for missing keys and expanded paths. """ if not self.config_file.exists(): self.config_dir.mkdir(parents=True, exist_ok=True) - return { - "log_level": "INFO", - "upload_delay": 0.5, - "max_file_size": 32 * 1024, # Default 32 KB - "headers": { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:128.0) Gecko/20100101 Firefox/128.0", - "Origin": "https://claude.ai", - }, - "two_way_sync": False, # Default to False - } + return self._get_default_config() + with open(self.config_file, "r") as f: config = json.load(f) - # Ensure all default values are present - defaults = { - "log_level": "INFO", - "upload_delay": 0.5, - "max_file_size": 32 * 1024, - "headers": { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:128.0) Gecko/20100101 Firefox/128.0", - "Origin": "https://claude.ai", - }, - "two_way_sync": False, - } + defaults = self._get_default_config() for key, value in defaults.items(): if key not in config: config[key] = value + elif key == "chat_destination": + # Expand user home directory for path-based settings + config[key] = str(Path(config[key]).expanduser()) return config def _save_config(self): @@ -91,47 +95,16 @@ def set(self, key, value): """ Sets a configuration value and saves the configuration. + For path-based settings (chat_destination), this method expands the user's home directory. + Args: key (str): The key for the configuration setting to set. value (any): The value to set for the given key. - This method updates the configuration with the provided key-value pair and then saves the configuration - to the file. + This method updates the configuration with the provided key-value pair and then saves the configuration to the file. """ + if key == "chat_destination": + # Expand user home directory for path-based settings + value = str(Path(value).expanduser()) self.config[key] = value self._save_config() - - def update_headers(self, new_headers): - """ - Updates the headers configuration with new values. - - Args: - new_headers (dict): A dictionary containing the new header key-value pairs to update or add. - - This method updates the existing headers with the new values provided, adds any new headers, - and then saves the updated configuration to the file. - """ - self.config.setdefault("headers", {}).update(new_headers) - self._save_config() - - def get_headers(self): - """ - Retrieves the current headers configuration. - - Returns: - dict: The current headers configuration. - """ - return self.config.get("headers", {}) - - def update_cookies(self, new_cookies): - """ - Updates the cookies configuration with new values. - - Args: - new_cookies (dict): A dictionary containing the new cookie key-value pairs to update or add. - - This method updates the existing cookies with the new values provided, adds any new cookies, - and then saves the updated configuration to the file. - """ - self.config.setdefault("cookies", {}).update(new_cookies) - self._save_config() diff --git a/src/claudesync/providers/base_claude_ai.py b/src/claudesync/providers/base_claude_ai.py new file mode 100644 index 0000000..170b54b --- /dev/null +++ b/src/claudesync/providers/base_claude_ai.py @@ -0,0 +1,122 @@ +import click +from .base_provider import BaseProvider +from ..exceptions import ProviderError + + +class BaseClaudeAIProvider(BaseProvider): + BASE_URL = "https://claude.ai/api" + + def __init__(self, session_key=None): + self.session_key = session_key + + def login(self): + click.echo("To obtain your session key, please follow these steps:") + click.echo("1. Open your web browser and go to https://claude.ai") + click.echo("2. Log in to your Claude account if you haven't already") + click.echo("3. Once logged in, open your browser's developer tools:") + click.echo(" - Chrome/Edge: Press F12 or Ctrl+Shift+I (Cmd+Option+I on Mac)") + click.echo(" - Firefox: Press F12 or Ctrl+Shift+I (Cmd+Option+I on Mac)") + click.echo( + " - Safari: Enable developer tools in Preferences > Advanced, then press Cmd+Option+I" + ) + click.echo( + "4. In the developer tools, go to the 'Application' tab (Chrome/Edge) or 'Storage' tab (Firefox)" + ) + click.echo( + "5. In the left sidebar, expand 'Cookies' and select 'https://claude.ai'" + ) + click.echo("6. Find the cookie named 'sessionKey' and copy its value") + self.session_key = click.prompt("Please enter your sessionKey", type=str) + return self.session_key + + def get_organizations(self): + response = self._make_request("GET", "/organizations") + if not response: + raise ProviderError("Unable to retrieve organization information") + return [{"id": org["uuid"], "name": org["name"]} for org in response] + + def get_projects(self, organization_id, include_archived=False): + response = self._make_request( + "GET", f"/organizations/{organization_id}/projects" + ) + projects = [ + { + "id": project["uuid"], + "name": project["name"], + "archived_at": project.get("archived_at"), + } + for project in response + if include_archived or project.get("archived_at") is None + ] + return projects + + def list_files(self, organization_id, project_id): + response = self._make_request( + "GET", f"/organizations/{organization_id}/projects/{project_id}/docs" + ) + return [ + { + "uuid": file["uuid"], + "file_name": file["file_name"], + "content": file["content"], + "created_at": file["created_at"], + } + for file in response + ] + + def upload_file(self, organization_id, project_id, file_name, content): + data = {"file_name": file_name, "content": content} + return self._make_request( + "POST", f"/organizations/{organization_id}/projects/{project_id}/docs", data + ) + + def delete_file(self, organization_id, project_id, file_uuid): + return self._make_request( + "DELETE", + f"/organizations/{organization_id}/projects/{project_id}/docs/{file_uuid}", + ) + + def archive_project(self, organization_id, project_id): + data = {"is_archived": True} + return self._make_request( + "PUT", f"/organizations/{organization_id}/projects/{project_id}", data + ) + + def create_project(self, organization_id, name, description=""): + data = {"name": name, "description": description, "is_private": True} + return self._make_request( + "POST", f"/organizations/{organization_id}/projects", data + ) + + def get_chat_conversations(self, organization_id): + return self._make_request( + "GET", f"/organizations/{organization_id}/chat_conversations" + ) + + def get_published_artifacts(self, organization_id): + return self._make_request( + "GET", f"/organizations/{organization_id}/published_artifacts" + ) + + def get_chat_conversation(self, organization_id, conversation_id): + return self._make_request( + "GET", + f"/organizations/{organization_id}/chat_conversations/{conversation_id}?rendering_mode=raw", + ) + + def get_artifact_content(self, organization_id, artifact_uuid): + artifacts = self._make_request( + "GET", f"/organizations/{organization_id}/published_artifacts" + ) + for artifact in artifacts: + if artifact["published_artifact_uuid"] == artifact_uuid: + return artifact.get("artifact_content", "") + raise ProviderError(f"Artifact with UUID {artifact_uuid} not found") + + def delete_chat(self, organization_id, conversation_uuids): + endpoint = f"/organizations/{organization_id}/chat_conversations/delete_many" + data = {"conversation_uuids": conversation_uuids} + return self._make_request("POST", endpoint, data) + + def _make_request(self, method, endpoint, data=None): + raise NotImplementedError("This method should be implemented by subclasses") diff --git a/src/claudesync/providers/base_provider.py b/src/claudesync/providers/base_provider.py index a9d3f21..b83e156 100644 --- a/src/claudesync/providers/base_provider.py +++ b/src/claudesync/providers/base_provider.py @@ -43,3 +43,28 @@ def archive_project(self, organization_id, project_id): def create_project(self, organization_id, name, description=""): """Create a new project within a specified organization.""" pass + + @abstractmethod + def get_chat_conversations(self, organization_id): + """Retrieve a list of chat conversations for a specified organization.""" + pass + + @abstractmethod + def get_published_artifacts(self, organization_id): + """Retrieve a list of published artifacts for a specified organization.""" + pass + + @abstractmethod + def get_chat_conversation(self, organization_id, conversation_id): + """Retrieve the full content of a specific chat conversation.""" + pass + + @abstractmethod + def get_artifact_content(self, organization_id, artifact_uuid): + """Retrieve the full content of a specific published artifact.""" + pass + + @abstractmethod + def delete_chat(self, organization_id, conversation_uuids): + """Delete specified chats for a given organization.""" + pass diff --git a/src/claudesync/providers/claude_ai.py b/src/claudesync/providers/claude_ai.py index c76d83c..115ba27 100644 --- a/src/claudesync/providers/claude_ai.py +++ b/src/claudesync/providers/claude_ai.py @@ -1,120 +1,61 @@ -# src/claudesync/providers/claude_ai.py - import json import logging - -import click import requests - -from .base_provider import BaseProvider +from .base_claude_ai import BaseClaudeAIProvider from ..config_manager import ConfigManager from ..exceptions import ProviderError logger = logging.getLogger(__name__) -class ClaudeAIProvider(BaseProvider): - """ - A provider class for interacting with the Claude AI API. - - This class encapsulates methods for performing API operations such as logging in, retrieving organizations, - projects, and files, as well as uploading and deleting files. It uses a session key for authentication, - which can be obtained through the login method. - - Attributes: - BASE_URL (str): The base URL for the Claude AI API. - session_key (str, optional): The session key used for authentication with the API. - config (ConfigManager): An instance of ConfigManager to manage application configuration. - """ - - BASE_URL = "https://claude.ai/api" - +class ClaudeAIProvider(BaseClaudeAIProvider): def __init__(self, session_key=None): - """ - Initializes the ClaudeAIProvider instance. - - Sets up the session key if provided, initializes the configuration manager, and configures logging - based on the configuration. - - Args: - session_key (str, optional): The session key used for authentication. Defaults to None. - """ - self.session_key = session_key + super().__init__(session_key) self.config = ConfigManager() self._configure_logging() def _configure_logging(self): - """ - Configures the logging level for the application based on the configuration. - This method sets the global logging configuration to the level specified in the application's configuration. - If the log level is not specified in the configuration, it defaults to "INFO". - It ensures that all log messages across the application are handled at the configured log level. - """ + log_level = self.config.get("log_level", "INFO") + logging.basicConfig(level=getattr(logging, log_level)) + logger.setLevel(getattr(logging, log_level)) - log_level = self.config.get( - "log_level", "INFO" - ) # Retrieve log level from config, default to "INFO" - logging.basicConfig( - level=getattr(logging, log_level) - ) # Set global logging configuration - logger.setLevel( - getattr(logging, log_level) - ) # Set logger instance to the specified log level - - def _make_request(self, method, endpoint, **kwargs): + def _make_request(self, method, endpoint, data=None): url = f"{self.BASE_URL}{endpoint}" - headers = self.config.get_headers() - cookies = self.config.get("cookies", {}) - - # Update headers - headers.update( - { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:128.0) Gecko/20100101 Firefox/128.0", - "Origin": "https://claude.ai", - "Referer": "https://claude.ai/projects", - "Accept": "*/*", - "Accept-Encoding": "gzip, deflate, br, zstd", - "Accept-Language": "en-US,en;q=0.5", - "anthropic-client-sha": "unknown", - "anthropic-client-version": "unknown", - "Connection": "keep-alive", - "Sec-Fetch-Dest": "empty", - "Sec-Fetch-Mode": "cors", - "Sec-Fetch-Site": "same-origin", - } - ) - - # Merge cookies - cookies.update( - { - "sessionKey": self.session_key, - "CH-prefers-color-scheme": "dark", - "anthropic-consent-preferences": '{"analytics":true,"marketing":true}', - } - ) - - if "headers" in kwargs: - headers.update(kwargs.pop("headers")) + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:128.0) Gecko/20100101 Firefox/128.0", + "Origin": "https://claude.ai", + "Referer": "https://claude.ai/projects", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, zstd", + "Accept-Language": "en-US,en;q=0.5", + "anthropic-client-sha": "unknown", + "anthropic-client-version": "unknown", + "Connection": "keep-alive", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-origin", + } + + cookies = { + "sessionKey": self.session_key, + "CH-prefers-color-scheme": "dark", + "anthropic-consent-preferences": '{"analytics":true,"marketing":true}', + } try: logger.debug(f"Making {method} request to {url}") logger.debug(f"Headers: {headers}") logger.debug(f"Cookies: {cookies}") - if "data" in kwargs: - logger.debug(f"Request data: {kwargs['data']}") + if data: + logger.debug(f"Request data: {data}") response = requests.request( - method, url, headers=headers, cookies=cookies, **kwargs + method, url, headers=headers, cookies=cookies, json=data ) logger.debug(f"Response status code: {response.status_code}") logger.debug(f"Response headers: {response.headers}") - logger.debug( - f"Response content: {response.text[:1000]}..." - ) # Log first 1000 characters of response - - # Update cookies with any new values from the response - self.config.update_cookies(response.cookies.get_dict()) + logger.debug(f"Response content: {response.text[:1000]}...") if response.status_code == 403: error_msg = ( @@ -132,12 +73,7 @@ def _make_request(self, method, endpoint, **kwargs): if not response.content: return None - try: - return response.json() - except json.JSONDecodeError as json_err: - logger.error(f"Failed to parse JSON response: {str(json_err)}") - logger.error(f"Response content: {response.text}") - raise ProviderError(f"Invalid JSON response from API: {str(json_err)}") + return response.json() except requests.RequestException as e: logger.error(f"Request failed: {str(e)}") @@ -146,207 +82,7 @@ def _make_request(self, method, endpoint, **kwargs): logger.error(f"Response headers: {e.response.headers}") logger.error(f"Response content: {e.response.text}") raise ProviderError(f"API request failed: {str(e)}") - - def login(self): - """ - Guides the user through obtaining a session key from the Claude AI website. - - This method provides step-by-step instructions for the user to log in to the Claude AI website, - access the developer tools of their browser, navigate to the cookies section, and retrieve the - 'sessionKey' cookie value. It then prompts the user to enter this session key, which is stored - in the instance for future requests. - - Returns: - str: The session key entered by the user. - """ - click.echo("To obtain your session key, please follow these steps:") - click.echo("1. Open your web browser and go to https://claude.ai") - click.echo("2. Log in to your Claude account if you haven't already") - click.echo("3. Once logged in, open your browser's developer tools:") - click.echo(" - Chrome/Edge: Press F12 or Ctrl+Shift+I (Cmd+Option+I on Mac)") - click.echo(" - Firefox: Press F12 or Ctrl+Shift+I (Cmd+Option+I on Mac)") - click.echo( - " - Safari: Enable developer tools in Preferences > Advanced, then press Cmd+Option+I" - ) - click.echo( - "4. In the developer tools, go to the 'Application' tab (Chrome/Edge) or 'Storage' tab (Firefox)" - ) - click.echo( - "5. In the left sidebar, expand 'Cookies' and select 'https://claude.ai'" - ) - click.echo("6. Find the cookie named 'sessionKey' and copy its value") - self.session_key = click.prompt("Please enter your sessionKey", type=str) - return self.session_key - - def get_organizations(self): - """ - Retrieves a list of organizations the user is a member of. - - This method sends a GET request to the '/bootstrap' endpoint to fetch account information, - including memberships in organizations. It parses the response to extract and return - organization IDs and names. - - Raises: - ProviderError: If the account information does not contain 'account' or 'memberships' keys, - indicating an issue with retrieving organization information. - - Returns: - list of dict: A list of dictionaries, each containing the 'id' and 'name' of an organization. - """ - organizations = self._make_request("GET", "/organizations") - if not organizations: - raise ProviderError("Unable to retrieve organization information") - - return [ - { - "id": org["uuid"], - "name": org["name"], - } - for org in organizations - ] - - def get_projects(self, organization_id, include_archived=False): - """ - Retrieves a list of projects for a specified organization. - - This method sends a GET request to fetch all projects associated with a given organization ID. - It then filters these projects based on the `include_archived` parameter. If `include_archived` - is False (default), only active projects are returned. If True, both active and archived projects - are returned. - - Args: - organization_id (str): The unique identifier for the organization. - include_archived (bool, optional): Flag to include archived projects in the result. Defaults to False. - - Returns: - list of dict: A list of dictionaries, each representing a project with its ID, name, and archival status. - """ - projects = self._make_request( - "GET", f"/organizations/{organization_id}/projects" - ) - filtered_projects = [ - { - "id": project["uuid"], - "name": project["name"], - "archived_at": project.get("archived_at"), - } - for project in projects - if include_archived or project.get("archived_at") is None - ] - return filtered_projects - - def list_files(self, organization_id, project_id): - """ - Lists all files within a specified project and organization. - - This method sends a GET request to the Claude AI API to retrieve all documents associated with a given project - within an organization. It then formats the response into a list of dictionaries, each representing a file with - its unique identifier, file name, content, and creation date. - - Args: - organization_id (str): The unique identifier for the organization. - project_id (str): The unique identifier for the project within the organization. - - Returns: - list of dict: A list of dictionaries, each containing details of a file such as its UUID, file name, - content, and the date it was created. - """ - files = self._make_request( - "GET", f"/organizations/{organization_id}/projects/{project_id}/docs" - ) - return [ - { - "uuid": file["uuid"], - "file_name": file["file_name"], - "content": file["content"], - "created_at": file["created_at"], - } - for file in files - ] - - def upload_file(self, organization_id, project_id, file_name, content): - """ - Uploads a file to a specified project within an organization. - - This method sends a POST request to the Claude AI API to upload a file with the given name and content - to a specified project within an organization. The file's metadata, including its name and content, - is sent as JSON in the request body. - - Args: - organization_id (str): The unique identifier for the organization. - project_id (str): The unique identifier for the project within the organization. - file_name (str): The name of the file to be uploaded. - content (str): The content of the file to be uploaded. - - Returns: - dict: The response from the server after the file upload operation, typically including details - about the uploaded file such as its ID, name, and a confirmation of the upload status. - """ - return self._make_request( - "POST", - f"/organizations/{organization_id}/projects/{project_id}/docs", - json={"file_name": file_name, "content": content}, - ) - - def delete_file(self, organization_id, project_id, file_uuid): - """ - Deletes a file from a specified project within an organization. - - This method sends a DELETE request to the Claude AI API to remove a file, identified by its UUID, - from a specified project within an organization. The organization and project are identified by their - respective unique identifiers. - - Args: - organization_id (str): The unique identifier for the organization. - project_id (str): The unique identifier for the project within the organization. - file_uuid (str): The unique identifier (UUID) of the file to be deleted. - - Returns: - dict: The response from the server after the file deletion operation, typically confirming the deletion. - """ - return self._make_request( - "DELETE", - f"/organizations/{organization_id}/projects/{project_id}/docs/{file_uuid}", - ) - - def archive_project(self, organization_id, project_id): - """ - Archives a specified project within an organization. - - This method sends a PUT request to the Claude AI API to change the archival status of a specified project - to archive. The project and organization are identified by their respective unique identifiers. - - Args: - organization_id (str): The unique identifier for the organization. - project_id (str): The unique identifier for the project within the organization. - - Returns: - dict: The response from the server after the archival operation, typically confirming the archival status. - """ - return self._make_request( - "PUT", - f"/organizations/{organization_id}/projects/{project_id}", - json={"is_archived": True}, - ) - - def create_project(self, organization_id, name, description=""): - """ - Creates a new project within a specified organization. - - This method sends a POST request to the Claude AI API to create a new project with the given name, - description, and sets it as private within the specified organization. The project's name, description, - and privacy status are sent as JSON in the request body. - - Args: - organization_id (str): The unique identifier for the organization. - name (str): The name of the project to be created. - description (str, optional): A description of the project. Defaults to an empty string. - - Returns: - dict: The response from the server after the project creation operation, typically including details - about the created project such as its ID, name, and a confirmation of the creation status. - """ - data = {"name": name, "description": description, "is_private": True} - return self._make_request( - "POST", f"/organizations/{organization_id}/projects", json=data - ) + except json.JSONDecodeError as json_err: + logger.error(f"Failed to parse JSON response: {str(json_err)}") + logger.error(f"Response content: {response.text}") + raise ProviderError(f"Invalid JSON response from API: {str(json_err)}") diff --git a/src/claudesync/providers/claude_ai_curl.py b/src/claudesync/providers/claude_ai_curl.py index b1a3d57..7dc792b 100644 --- a/src/claudesync/providers/claude_ai_curl.py +++ b/src/claudesync/providers/claude_ai_curl.py @@ -1,52 +1,11 @@ import json import subprocess -import click -from .base_provider import BaseProvider +from .base_claude_ai import BaseClaudeAIProvider from ..exceptions import ProviderError -class ClaudeAICurlProvider(BaseProvider): - """ - A provider class for interacting with the Claude AI API using cURL. - - This class encapsulates methods for performing API operations such as logging in, retrieving organizations, - projects, and files, as well as uploading and deleting files. It uses cURL commands for HTTP requests and - a session key for authentication. - - Attributes: - BASE_URL (str): The base URL for the Claude AI API. - session_key (str, optional): The session key used for authentication with the API. - """ - - BASE_URL = "https://claude.ai/api" - - def __init__(self, session_key=None): - """ - Initializes the ClaudeAICurlProvider instance. - - Args: - session_key (str, optional): The session key used for authentication. Defaults to None. - """ - self.session_key = session_key - - def _execute_curl(self, method, endpoint, data=None): - """ - Executes a cURL command to make an HTTP request to the Claude AI API. - - This method constructs and executes a cURL command based on the provided method, endpoint, and data. - It handles the response and potential errors from the cURL execution. - - Args: - method (str): The HTTP method for the request (e.g., "GET", "POST", "PUT", "DELETE"). - endpoint (str): The API endpoint to call, relative to the BASE_URL. - data (dict, optional): The data to send with the request for POST or PUT methods. Defaults to None. - - Returns: - dict: The JSON-decoded response from the API. - - Raises: - ProviderError: If the cURL command fails, returns no output, or the response cannot be parsed as JSON. - """ +class ClaudeAICurlProvider(BaseClaudeAIProvider): + def _make_request(self, method, endpoint, data=None): url = f"{self.BASE_URL}{endpoint}" headers = [ "-H", @@ -63,7 +22,7 @@ def _execute_curl(self, method, endpoint, data=None): "--compressed", "-s", "-S", - ] # Add -s for silent mode, -S to show errors + ] command.extend(headers) if method != "GET": @@ -98,195 +57,3 @@ def _execute_curl(self, method, endpoint, data=None): "This might be due to non-UTF-8 characters in the response." ) raise ProviderError(error_message) - - def login(self): - """ - Guides the user through obtaining a session key from the Claude AI website. - - This method provides step-by-step instructions for the user to log in to the Claude AI website, - access the developer tools of their browser, navigate to the cookies section, and retrieve the - 'sessionKey' cookie value. It then prompts the user to enter this session key, which is stored - in the instance for future requests. - - Returns: - str: The session key entered by the user. - """ - click.echo("To obtain your session key, please follow these steps:") - click.echo("1. Open your web browser and go to https://claude.ai") - click.echo("2. Log in to your Claude account if you haven't already") - click.echo("3. Once logged in, open your browser's developer tools:") - click.echo(" - Chrome/Edge: Press F12 or Ctrl+Shift+I (Cmd+Option+I on Mac)") - click.echo(" - Firefox: Press F12 or Ctrl+Shift+I (Cmd+Option+I on Mac)") - click.echo( - " - Safari: Enable developer tools in Preferences > Advanced, then press Cmd+Option+I" - ) - click.echo( - "4. In the developer tools, go to the 'Application' tab (Chrome/Edge) or 'Storage' tab (Firefox)" - ) - click.echo( - "5. In the left sidebar, expand 'Cookies' and select 'https://claude.ai'" - ) - click.echo("6. Find the cookie named 'sessionKey' and copy its value") - self.session_key = click.prompt("Please enter your sessionKey", type=str) - return self.session_key - - def get_organizations(self): - """ - Retrieves a list of organizations the user is a member of. - - This method sends a GET request to the '/organizations' endpoint to fetch account information, - including memberships in organizations. It parses the response to extract and return - organization IDs and names. - - Returns: - list of dict: A list of dictionaries, each containing the 'id' and 'name' of an organization. - - Raises: - ProviderError: If there's an issue with retrieving organization information. - """ - response = self._execute_curl("GET", "/organizations") - if not response: - raise ProviderError("Unable to retrieve organization information") - return [{"id": org["uuid"], "name": org["name"]} for org in response] - - def get_projects(self, organization_id, include_archived=False): - """ - Retrieves a list of projects for a specified organization. - - This method sends a GET request to fetch all projects associated with a given organization ID. - It then filters these projects based on the `include_archived` parameter. - - Args: - organization_id (str): The unique identifier for the organization. - include_archived (bool, optional): Flag to include archived projects in the result. Defaults to False. - - Returns: - list of dict: A list of dictionaries, each representing a project with its ID, name, and archival status. - """ - response = self._execute_curl( - "GET", f"/organizations/{organization_id}/projects" - ) - projects = [ - { - "id": project["uuid"], - "name": project["name"], - "archived_at": project.get("archived_at"), - } - for project in response - if include_archived or project.get("archived_at") is None - ] - return projects - - def list_files(self, organization_id, project_id): - """ - Lists all files within a specified project and organization. - - This method sends a GET request to the Claude AI API to retrieve all documents associated with a given project - within an organization. It then formats the response into a list of dictionaries, each representing a file with - its unique identifier, file name, content, and creation date. - - Args: - organization_id (str): The unique identifier for the organization. - project_id (str): The unique identifier for the project within the organization. - - Returns: - list of dict: A list of dictionaries, each containing details of a file such as its UUID, file name, - content, and the date it was created. - """ - response = self._execute_curl( - "GET", f"/organizations/{organization_id}/projects/{project_id}/docs" - ) - return [ - { - "uuid": file["uuid"], - "file_name": file["file_name"], - "content": file["content"], - "created_at": file["created_at"], - } - for file in response - ] - - def upload_file(self, organization_id, project_id, file_name, content): - """ - Uploads a file to a specified project within an organization. - - This method sends a POST request to the Claude AI API to upload a file with the given name and content - to a specified project within an organization. The file's metadata, including its name and content, - is sent as JSON in the request body. - - Args: - organization_id (str): The unique identifier for the organization. - project_id (str): The unique identifier for the project within the organization. - file_name (str): The name of the file to be uploaded. - content (str): The content of the file to be uploaded. - - Returns: - dict: The response from the server after the file upload operation, typically including details - about the uploaded file such as its ID, name, and a confirmation of the upload status. - """ - data = {"file_name": file_name, "content": content} - return self._execute_curl( - "POST", f"/organizations/{organization_id}/projects/{project_id}/docs", data - ) - - def delete_file(self, organization_id, project_id, file_uuid): - """ - Deletes a file from a specified project within an organization. - - This method sends a DELETE request to the Claude AI API to remove a file, identified by its UUID, - from a specified project within an organization. The organization and project are identified by their - respective unique identifiers. - - Args: - organization_id (str): The unique identifier for the organization. - project_id (str): The unique identifier for the project within the organization. - file_uuid (str): The unique identifier (UUID) of the file to be deleted. - - Returns: - dict: The response from the server after the file deletion operation, typically confirming the deletion. - """ - return self._execute_curl( - "DELETE", - f"/organizations/{organization_id}/projects/{project_id}/docs/{file_uuid}", - ) - - def archive_project(self, organization_id, project_id): - """ - Archives a specified project within an organization. - - This method sends a PUT request to the Claude AI API to change the archival status of a specified project - to archive. The project and organization are identified by their respective unique identifiers. - - Args: - organization_id (str): The unique identifier for the organization. - project_id (str): The unique identifier for the project within the organization. - - Returns: - dict: The response from the server after the archival operation, typically confirming the archival status. - """ - data = {"is_archived": True} - return self._execute_curl( - "PUT", f"/organizations/{organization_id}/projects/{project_id}", data - ) - - def create_project(self, organization_id, name, description=""): - """ - Creates a new project within a specified organization. - - This method sends a POST request to the Claude AI API to create a new project with the given name, - description, and sets it as private within the specified organization. The project's name, description, - and privacy status are sent as JSON in the request body. - - Args: - organization_id (str): The unique identifier for the organization. - name (str): The name of the project to be created. - description (str, optional): A description of the project. Defaults to an empty string. - - Returns: - dict: The response from the server after the project creation operation, typically including details - about the created project such as its ID, name, and a confirmation of the creation status. - """ - data = {"name": name, "description": description, "is_private": True} - return self._execute_curl( - "POST", f"/organizations/{organization_id}/projects", data - ) diff --git a/src/claudesync/syncmanager.py b/src/claudesync/syncmanager.py index 067488d..e168713 100644 --- a/src/claudesync/syncmanager.py +++ b/src/claudesync/syncmanager.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone import click +from tqdm import tqdm from claudesync.utils import compute_md5_hash @@ -48,48 +49,39 @@ def sync(self, local_files, remote_files): remote_files_to_delete = set(rf["file_name"] for rf in remote_files) synced_files = set() - self.sync_local_to_remote( - local_files, remote_files, remote_files_to_delete, synced_files - ) + with tqdm(total=len(local_files), desc="Syncing local to remote") as pbar: + for local_file, local_checksum in local_files.items(): + remote_file = next( + (rf for rf in remote_files if rf["file_name"] == local_file), None + ) + if remote_file: + self.update_existing_file( + local_file, + local_checksum, + remote_file, + remote_files_to_delete, + synced_files, + ) + else: + self.upload_new_file(local_file, synced_files) + pbar.update(1) + self.update_local_timestamps(remote_files, synced_files) if self.two_way_sync: - self.sync_remote_to_local( - remote_files, remote_files_to_delete, synced_files - ) - - self.delete_remote_files(remote_files_to_delete, remote_files) - - def sync_local_to_remote( - self, local_files, remote_files, remote_files_to_delete, synced_files - ): - """ - Synchronize local files to the remote project. - - This method checks each local file against the remote files. If a file exists on the remote, - it updates the file if there are changes. If the file does not exist on the remote, it uploads - the new file. - - Args: - local_files (dict): Dictionary of local file names and their corresponding checksums. - remote_files (list): List of dictionaries representing remote files. - remote_files_to_delete (set): Set of remote file names to be considered for deletion. - synced_files (set): Set of file names that have been synchronized. - """ - for local_file, local_checksum in local_files.items(): - remote_file = next( - (rf for rf in remote_files if rf["file_name"] == local_file), None - ) - if remote_file: - self.update_existing_file( - local_file, - local_checksum, - remote_file, - remote_files_to_delete, - synced_files, - ) - else: - self.upload_new_file(local_file, synced_files) + with tqdm(total=len(remote_files), desc="Syncing remote to local") as pbar: + for remote_file in remote_files: + self.sync_remote_to_local( + remote_file, remote_files_to_delete, synced_files + ) + pbar.update(1) + + with tqdm( + total=len(remote_files_to_delete), desc="Deleting remote files" + ) as pbar: + for file_to_delete in list(remote_files_to_delete): + self.delete_remote_files(file_to_delete, remote_files) + pbar.update(1) def update_existing_file( self, @@ -115,16 +107,24 @@ def update_existing_file( remote_checksum = compute_md5_hash(remote_file["content"]) if local_checksum != remote_checksum: click.echo(f"Updating {local_file} on remote...") - self.provider.delete_file( - self.active_organization_id, self.active_project_id, remote_file["uuid"] - ) - with open( - os.path.join(self.local_path, local_file), "r", encoding="utf-8" - ) as file: - content = file.read() - self.provider.upload_file( - self.active_organization_id, self.active_project_id, local_file, content - ) + with tqdm(total=2, desc=f"Updating {local_file}", leave=False) as pbar: + self.provider.delete_file( + self.active_organization_id, + self.active_project_id, + remote_file["uuid"], + ) + pbar.update(1) + with open( + os.path.join(self.local_path, local_file), "r", encoding="utf-8" + ) as file: + content = file.read() + self.provider.upload_file( + self.active_organization_id, + self.active_project_id, + local_file, + content, + ) + pbar.update(1) time.sleep(self.upload_delay) synced_files.add(local_file) remote_files_to_delete.remove(local_file) @@ -144,9 +144,11 @@ def upload_new_file(self, local_file, synced_files): os.path.join(self.local_path, local_file), "r", encoding="utf-8" ) as file: content = file.read() - self.provider.upload_file( - self.active_organization_id, self.active_project_id, local_file, content - ) + with tqdm(total=1, desc=f"Uploading {local_file}", leave=False) as pbar: + self.provider.upload_file( + self.active_organization_id, self.active_project_id, local_file, content + ) + pbar.update(1) time.sleep(self.upload_delay) synced_files.add(local_file) @@ -161,41 +163,41 @@ def update_local_timestamps(self, remote_files, synced_files): remote_files (list): List of dictionaries representing remote files. synced_files (set): Set of file names that have been synchronized. """ - for remote_file in remote_files: - if remote_file["file_name"] in synced_files: - local_file_path = os.path.join( - self.local_path, remote_file["file_name"] - ) - if os.path.exists(local_file_path): - remote_timestamp = datetime.fromisoformat( - remote_file["created_at"].replace("Z", "+00:00") - ).timestamp() - os.utime(local_file_path, (remote_timestamp, remote_timestamp)) - click.echo(f"Updated timestamp on local file {local_file_path}") - - def sync_remote_to_local(self, remote_files, remote_files_to_delete, synced_files): + with tqdm(total=len(synced_files), desc="Updating local timestamps") as pbar: + for remote_file in remote_files: + if remote_file["file_name"] in synced_files: + local_file_path = os.path.join( + self.local_path, remote_file["file_name"] + ) + if os.path.exists(local_file_path): + remote_timestamp = datetime.fromisoformat( + remote_file["created_at"].replace("Z", "+00:00") + ).timestamp() + os.utime(local_file_path, (remote_timestamp, remote_timestamp)) + click.echo(f"Updated timestamp on local file {local_file_path}") + pbar.update(1) + + def sync_remote_to_local(self, remote_file, remote_files_to_delete, synced_files): """ - Synchronize remote files to the local project (two-way sync). + Synchronize a remote file to the local project (two-way sync). - This method checks each remote file against the local files. If a file exists locally, - it updates the file if the remote version is newer. If the file does not exist locally, - it creates a new local file from the remote file. + This method checks if the remote file exists locally. If it does, it updates the file + if the remote version is newer. If it doesn't exist locally, it creates a new local file. Args: - remote_files (list): List of dictionaries representing remote files. + remote_file (dict): Dictionary representing the remote file. remote_files_to_delete (set): Set of remote file names to be considered for deletion. synced_files (set): Set of file names that have been synchronized. """ - for remote_file in remote_files: - local_file_path = os.path.join(self.local_path, remote_file["file_name"]) - if os.path.exists(local_file_path): - self.update_existing_local_file( - local_file_path, remote_file, remote_files_to_delete, synced_files - ) - else: - self.create_new_local_file( - local_file_path, remote_file, remote_files_to_delete, synced_files - ) + local_file_path = os.path.join(self.local_path, remote_file["file_name"]) + if os.path.exists(local_file_path): + self.update_existing_local_file( + local_file_path, remote_file, remote_files_to_delete, synced_files + ) + else: + self.create_new_local_file( + local_file_path, remote_file, remote_files_to_delete, synced_files + ) def update_existing_local_file( self, local_file_path, remote_file, remote_files_to_delete, synced_files @@ -220,8 +222,12 @@ def update_existing_local_file( ) if remote_mtime > local_mtime: click.echo(f"Updating local file {remote_file['file_name']} from remote...") - with open(local_file_path, "w", encoding="utf-8") as file: - file.write(remote_file["content"]) + with tqdm( + total=1, desc=f"Updating {remote_file['file_name']}", leave=False + ) as pbar: + with open(local_file_path, "w", encoding="utf-8") as file: + file.write(remote_file["content"]) + pbar.update(1) synced_files.add(remote_file["file_name"]) if remote_file["file_name"] in remote_files_to_delete: remote_files_to_delete.remove(remote_file["file_name"]) @@ -230,41 +236,44 @@ def create_new_local_file( self, local_file_path, remote_file, remote_files_to_delete, synced_files ): """ - Create a new local file from a remote file. + Create a new local file from a remote file. - This method creates a new local file with the content from the remote file. + This method creates a new local file with the content from the remote file. - Args: - local_file_path (str): Path to the new local file - - . - remote_file (dict): Dictionary representing the remote file. - remote_files_to_delete (set): Set of remote file names to be considered for deletion. - synced_files (set): Set of file names that have been synchronized. + Args: + local_file_path (str): Path to the new local file. + remote_file (dict): Dictionary representing the remote file. + remote_files_to_delete (set): Set of remote file names to be considered for deletion. + synced_files (set): Set of file names that have been synchronized. """ click.echo(f"Creating new local file {remote_file['file_name']} from remote...") - with open(local_file_path, "w", encoding="utf-8") as file: - file.write(remote_file["content"]) + with tqdm( + total=1, desc=f"Creating {remote_file['file_name']}", leave=False + ) as pbar: + with open(local_file_path, "w", encoding="utf-8") as file: + file.write(remote_file["content"]) + pbar.update(1) synced_files.add(remote_file["file_name"]) if remote_file["file_name"] in remote_files_to_delete: remote_files_to_delete.remove(remote_file["file_name"]) - def delete_remote_files(self, remote_files_to_delete, remote_files): + def delete_remote_files(self, file_to_delete, remote_files): """ - Delete files from the remote project that no longer exist locally. + Delete a file from the remote project that no longer exists locally. - This method deletes remote files that are not present in the local directory. + This method deletes a remote file that is not present in the local directory. Args: - remote_files_to_delete (set): Set of remote file names to be deleted. + file_to_delete (str): Name of the remote file to be deleted. remote_files (list): List of dictionaries representing remote files. """ - for file_to_delete in remote_files_to_delete: - click.echo(f"Deleting {file_to_delete} from remote...") - remote_file = next( - rf for rf in remote_files if rf["file_name"] == file_to_delete - ) + click.echo(f"Deleting {file_to_delete} from remote...") + remote_file = next( + rf for rf in remote_files if rf["file_name"] == file_to_delete + ) + with tqdm(total=1, desc=f"Deleting {file_to_delete}", leave=False) as pbar: self.provider.delete_file( self.active_organization_id, self.active_project_id, remote_file["uuid"] ) - time.sleep(self.upload_delay) + pbar.update(1) + time.sleep(self.upload_delay) diff --git a/tests/providers/test_base_claude_ai.py b/tests/providers/test_base_claude_ai.py new file mode 100644 index 0000000..8457fc7 --- /dev/null +++ b/tests/providers/test_base_claude_ai.py @@ -0,0 +1,100 @@ +import unittest +from unittest.mock import patch +from claudesync.providers.base_claude_ai import BaseClaudeAIProvider + + +class TestBaseClaudeAIProvider(unittest.TestCase): + + def setUp(self): + self.provider = BaseClaudeAIProvider("test_session_key") + + @patch("click.prompt") + def test_login(self, mock_prompt): + mock_prompt.return_value = "new_session_key" + result = self.provider.login() + self.assertEqual(result, "new_session_key") + self.assertEqual(self.provider.session_key, "new_session_key") + + @patch.object(BaseClaudeAIProvider, "_make_request") + def test_get_organizations(self, mock_make_request): + mock_make_request.return_value = [ + {"uuid": "org1", "name": "Org 1"}, + {"uuid": "org2", "name": "Org 2"}, + ] + result = self.provider.get_organizations() + expected = [{"id": "org1", "name": "Org 1"}, {"id": "org2", "name": "Org 2"}] + self.assertEqual(result, expected) + + @patch.object(BaseClaudeAIProvider, "_make_request") + def test_get_projects(self, mock_make_request): + mock_make_request.return_value = [ + {"uuid": "proj1", "name": "Project 1", "archived_at": None}, + {"uuid": "proj2", "name": "Project 2", "archived_at": "2023-01-01"}, + ] + result = self.provider.get_projects("org1", include_archived=True) + expected = [ + {"id": "proj1", "name": "Project 1", "archived_at": None}, + {"id": "proj2", "name": "Project 2", "archived_at": "2023-01-01"}, + ] + self.assertEqual(result, expected) + + @patch.object(BaseClaudeAIProvider, "_make_request") + def test_list_files(self, mock_make_request): + mock_make_request.return_value = [ + { + "uuid": "file1", + "file_name": "test1.txt", + "content": "content1", + "created_at": "2023-01-01", + }, + { + "uuid": "file2", + "file_name": "test2.txt", + "content": "content2", + "created_at": "2023-01-02", + }, + ] + result = self.provider.list_files("org1", "proj1") + expected = [ + { + "uuid": "file1", + "file_name": "test1.txt", + "content": "content1", + "created_at": "2023-01-01", + }, + { + "uuid": "file2", + "file_name": "test2.txt", + "content": "content2", + "created_at": "2023-01-02", + }, + ] + self.assertEqual(result, expected) + + @patch.object(BaseClaudeAIProvider, "_make_request") + def test_upload_file(self, mock_make_request): + mock_make_request.return_value = {"uuid": "new_file", "file_name": "test.txt"} + result = self.provider.upload_file("org1", "proj1", "test.txt", "content") + self.assertEqual(result, {"uuid": "new_file", "file_name": "test.txt"}) + mock_make_request.assert_called_once_with( + "POST", + "/organizations/org1/projects/proj1/docs", + {"file_name": "test.txt", "content": "content"}, + ) + + @patch.object(BaseClaudeAIProvider, "_make_request") + def test_delete_file(self, mock_make_request): + mock_make_request.return_value = {"status": "deleted"} + result = self.provider.delete_file("org1", "proj1", "file1") + self.assertEqual(result, {"status": "deleted"}) + mock_make_request.assert_called_once_with( + "DELETE", "/organizations/org1/projects/proj1/docs/file1" + ) + + def test_make_request_not_implemented(self): + with self.assertRaises(NotImplementedError): + self.provider._make_request("GET", "/test") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/providers/test_claude_ai.py b/tests/providers/test_claude_ai.py index 597ec65..aa4884a 100644 --- a/tests/providers/test_claude_ai.py +++ b/tests/providers/test_claude_ai.py @@ -3,18 +3,12 @@ import requests from claudesync.providers.claude_ai import ClaudeAIProvider from claudesync.exceptions import ProviderError -from claudesync.config_manager import ConfigManager class TestClaudeAIProvider(unittest.TestCase): def setUp(self): self.provider = ClaudeAIProvider("test_session_key") - # Mock ConfigManager - self.mock_config = MagicMock(spec=ConfigManager) - self.mock_config.get_headers.return_value = {} - self.mock_config.get.return_value = {} - self.provider.config = self.mock_config @patch("claudesync.providers.claude_ai.requests.request") def test_make_request_success(self, mock_request): @@ -35,129 +29,13 @@ def test_make_request_failure(self, mock_request): with self.assertRaises(ProviderError): self.provider._make_request("GET", "/test") - @patch("claudesync.providers.claude_ai.click.prompt") - def test_login(self, mock_prompt): - mock_prompt.return_value = "new_session_key" - - result = self.provider.login() - - self.assertEqual(result, "new_session_key") - self.assertEqual(self.provider.session_key, "new_session_key") - - @patch("claudesync.providers.claude_ai.ClaudeAIProvider._make_request") - def test_get_organizations(self, mock_make_request): - mock_make_request.return_value = [ - {"uuid": "org1", "name": "Org 1"}, - {"uuid": "org2", "name": "Org 2"}, - ] - - result = self.provider.get_organizations() - - expected = [{"id": "org1", "name": "Org 1"}, {"id": "org2", "name": "Org 2"}] - self.assertEqual(result, expected) - - @patch("claudesync.providers.claude_ai.ClaudeAIProvider._make_request") - def test_get_projects(self, mock_make_request): - mock_make_request.return_value = [ - {"uuid": "proj1", "name": "Project 1", "archived_at": None}, - {"uuid": "proj2", "name": "Project 2", "archived_at": "2023-01-01"}, - ] - - result = self.provider.get_projects("org1", include_archived=True) - - expected = [ - {"id": "proj1", "name": "Project 1", "archived_at": None}, - {"id": "proj2", "name": "Project 2", "archived_at": "2023-01-01"}, - ] - self.assertEqual(result, expected) - - @patch("claudesync.providers.claude_ai.ClaudeAIProvider._make_request") - def test_list_files(self, mock_make_request): - mock_make_request.return_value = [ - { - "uuid": "file1", - "file_name": "test1.txt", - "content": "content1", - "created_at": "2023-01-01", - }, - { - "uuid": "file2", - "file_name": "test2.txt", - "content": "content2", - "created_at": "2023-01-02", - }, - ] - - result = self.provider.list_files("org1", "proj1") - - expected = [ - { - "uuid": "file1", - "file_name": "test1.txt", - "content": "content1", - "created_at": "2023-01-01", - }, - { - "uuid": "file2", - "file_name": "test2.txt", - "content": "content2", - "created_at": "2023-01-02", - }, - ] - self.assertEqual(result, expected) - - @patch("claudesync.providers.claude_ai.ClaudeAIProvider._make_request") - def test_upload_file(self, mock_make_request): - mock_make_request.return_value = {"uuid": "new_file", "file_name": "test.txt"} - - result = self.provider.upload_file("org1", "proj1", "test.txt", "content") - - self.assertEqual(result, {"uuid": "new_file", "file_name": "test.txt"}) - mock_make_request.assert_called_once_with( - "POST", - "/organizations/org1/projects/proj1/docs", - json={"file_name": "test.txt", "content": "content"}, - ) - - @patch("claudesync.providers.claude_ai.ClaudeAIProvider._make_request") - def test_delete_file(self, mock_make_request): - mock_make_request.return_value = {"status": "deleted"} - - result = self.provider.delete_file("org1", "proj1", "file1") - - self.assertEqual(result, {"status": "deleted"}) - mock_make_request.assert_called_once_with( - "DELETE", "/organizations/org1/projects/proj1/docs/file1" - ) - - @patch("claudesync.providers.claude_ai.ClaudeAIProvider._make_request") - def test_archive_project(self, mock_make_request): - mock_make_request.return_value = {"uuid": "proj1", "is_archived": True} - - result = self.provider.archive_project("org1", "proj1") - - self.assertEqual(result, {"uuid": "proj1", "is_archived": True}) - mock_make_request.assert_called_once_with( - "PUT", "/organizations/org1/projects/proj1", json={"is_archived": True} - ) - - @patch("claudesync.providers.claude_ai.ClaudeAIProvider._make_request") - def test_create_project(self, mock_make_request): - mock_make_request.return_value = {"uuid": "new_proj", "name": "New Project"} - - result = self.provider.create_project("org1", "New Project", "Description") - - self.assertEqual(result, {"uuid": "new_proj", "name": "New Project"}) - mock_make_request.assert_called_once_with( - "POST", - "/organizations/org1/projects", - json={ - "name": "New Project", - "description": "Description", - "is_private": True, - }, - ) + @patch("claudesync.providers.claude_ai.requests.request") + def test_make_request_403_error(self, mock_request): + mock_response = MagicMock() + mock_response.status_code = 403 + mock_request.return_value = mock_response + with self.assertRaises(ProviderError) as context: + self.provider._make_request("GET", "/test") -if __name__ == "__main__": - unittest.main() + self.assertIn("403 Forbidden error", str(context.exception)) diff --git a/tests/providers/test_claude_ai_curl.py b/tests/providers/test_claude_ai_curl.py index 6ec095d..e28f850 100644 --- a/tests/providers/test_claude_ai_curl.py +++ b/tests/providers/test_claude_ai_curl.py @@ -11,149 +11,32 @@ def setUp(self): self.provider = ClaudeAICurlProvider("test_session_key") @patch("subprocess.run") - def test_execute_curl_success(self, mock_run): + def test_make_request_success(self, mock_run): mock_result = MagicMock() mock_result.stdout = '{"key": "value"}' mock_result.returncode = 0 mock_run.return_value = mock_result - result = self.provider._execute_curl("GET", "/test") + result = self.provider._make_request("GET", "/test") self.assertEqual(result, {"key": "value"}) mock_run.assert_called_once() @patch("subprocess.run") - def test_execute_curl_failure(self, mock_run): + def test_make_request_failure(self, mock_run): mock_run.side_effect = subprocess.CalledProcessError( 1, "curl", stderr="Test error" ) with self.assertRaises(ProviderError): - self.provider._execute_curl("GET", "/test") - - @patch("claudesync.providers.claude_ai_curl.click.prompt") - def test_login(self, mock_prompt): - mock_prompt.return_value = "new_session_key" - - result = self.provider.login() - - self.assertEqual(result, "new_session_key") - self.assertEqual(self.provider.session_key, "new_session_key") - - @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") - def test_get_organizations(self, mock_execute_curl): - mock_execute_curl.return_value = [ - {"uuid": "org1", "name": "Org 1"}, - {"uuid": "org2", "name": "Org 2"}, - ] - - result = self.provider.get_organizations() - - expected = [{"id": "org1", "name": "Org 1"}, {"id": "org2", "name": "Org 2"}] - self.assertEqual(result, expected) - - @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") - def test_get_projects(self, mock_execute_curl): - mock_execute_curl.return_value = [ - {"uuid": "proj1", "name": "Project 1", "archived_at": None}, - {"uuid": "proj2", "name": "Project 2", "archived_at": "2023-01-01"}, - ] - - result = self.provider.get_projects("org1", include_archived=True) - - expected = [ - {"id": "proj1", "name": "Project 1", "archived_at": None}, - {"id": "proj2", "name": "Project 2", "archived_at": "2023-01-01"}, - ] - self.assertEqual(result, expected) - - @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") - def test_list_files(self, mock_execute_curl): - mock_execute_curl.return_value = [ - { - "uuid": "file1", - "file_name": "test1.txt", - "content": "content1", - "created_at": "2023-01-01", - }, - { - "uuid": "file2", - "file_name": "test2.txt", - "content": "content2", - "created_at": "2023-01-02", - }, - ] - - result = self.provider.list_files("org1", "proj1") - - expected = [ - { - "uuid": "file1", - "file_name": "test1.txt", - "content": "content1", - "created_at": "2023-01-01", - }, - { - "uuid": "file2", - "file_name": "test2.txt", - "content": "content2", - "created_at": "2023-01-02", - }, - ] - self.assertEqual(result, expected) - - @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") - def test_upload_file(self, mock_execute_curl): - mock_execute_curl.return_value = {"uuid": "new_file", "file_name": "test.txt"} - - result = self.provider.upload_file("org1", "proj1", "test.txt", "content") - - self.assertEqual(result, {"uuid": "new_file", "file_name": "test.txt"}) - mock_execute_curl.assert_called_once_with( - "POST", - "/organizations/org1/projects/proj1/docs", - {"file_name": "test.txt", "content": "content"}, - ) - - @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") - def test_delete_file(self, mock_execute_curl): - mock_execute_curl.return_value = {"status": "deleted"} - - result = self.provider.delete_file("org1", "proj1", "file1") - - self.assertEqual(result, {"status": "deleted"}) - mock_execute_curl.assert_called_once_with( - "DELETE", "/organizations/org1/projects/proj1/docs/file1" - ) - - @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") - def test_archive_project(self, mock_execute_curl): - mock_execute_curl.return_value = {"uuid": "proj1", "is_archived": True} - - result = self.provider.archive_project("org1", "proj1") - - self.assertEqual(result, {"uuid": "proj1", "is_archived": True}) - mock_execute_curl.assert_called_once_with( - "PUT", "/organizations/org1/projects/proj1", {"is_archived": True} - ) - - @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") - def test_create_project(self, mock_execute_curl): - mock_execute_curl.return_value = {"uuid": "new_proj", "name": "New Project"} - - result = self.provider.create_project("org1", "New Project", "Description") - - self.assertEqual(result, {"uuid": "new_proj", "name": "New Project"}) - mock_execute_curl.assert_called_once_with( - "POST", - "/organizations/org1/projects", - { - "name": "New Project", - "description": "Description", - "is_private": True, - }, - ) + self.provider._make_request("GET", "/test") + @patch("subprocess.run") + def test_make_request_json_decode_error(self, mock_run): + mock_result = MagicMock() + mock_result.stdout = "Invalid JSON" + mock_result.returncode = 0 + mock_run.return_value = mock_result -if __name__ == "__main__": - unittest.main() + with self.assertRaises(ProviderError): + self.provider._make_request("GET", "/test") diff --git a/tests/test_chat_sync.py b/tests/test_chat_sync.py new file mode 100644 index 0000000..0fdea44 --- /dev/null +++ b/tests/test_chat_sync.py @@ -0,0 +1,99 @@ +import textwrap +import unittest +from unittest.mock import MagicMock + +from claudesync.chat_sync import extract_artifacts, get_file_extension, sync_chats +from claudesync.exceptions import ConfigurationError + + +class TestExtractArtifacts(unittest.TestCase): + + def setUp(self): + self.mock_provider = MagicMock() + self.mock_config = MagicMock() + self.mock_config.get.side_effect = lambda key, default=None: { + "local_path": "/test/path", + "active_organization_id": "org123", + "active_project_id": "proj456", + }.get(key, default) + + def test_extract_single_artifact(self): + text = """ + Here is some introductory text. + + + Test + Test Content + + + Some concluding text. + """ + expected_result = [ + { + "identifier": "test-id", + "type": "text/html", + "content": "\nTest\nTest Content\n", + } + ] + self.assertEqual(extract_artifacts(textwrap.dedent(text)), expected_result) + + def test_extract_multiple_artifacts(self): + text = """ + Here is some introductory text. + + First artifact content. + + Some middle text. + + + User + ChatGPT + Reminder + Don't forget to check your email! + + + Some concluding text. + """ + expected_result = [ + { + "identifier": "first-id", + "type": "text/plain", + "content": "First artifact content.", + }, + { + "identifier": "second-id", + "type": "text/xml", + "content": "\nUser\nChatGPT" + "\nReminder\nDon't forget to check your email!\n", + }, + ] + self.assertEqual(extract_artifacts(textwrap.dedent(text)), expected_result) + + def test_no_artifacts(self): + text = """ + Here is some text without any artifacts. + """ + expected_result = [] + self.assertEqual(extract_artifacts(text), expected_result) + + def test_sync_chats_no_local_path(self): + self.mock_config.get.side_effect = lambda key, default=None: ( + None if key == "local_path" else "some_value" + ) + with self.assertRaises(ConfigurationError): + sync_chats(self.mock_provider, self.mock_config) + + def test_sync_chats_no_organization(self): + self.mock_config.get.side_effect = lambda key, default=None: ( + None if key == "active_organization_id" else "some_value" + ) + with self.assertRaises(ConfigurationError): + sync_chats(self.mock_provider, self.mock_config) + + def test_get_file_extension(self): + self.assertEqual(get_file_extension("text/html"), "html") + self.assertEqual(get_file_extension("application/vnd.ant.code"), "txt") + self.assertEqual(get_file_extension("image/svg+xml"), "svg") + self.assertEqual(get_file_extension("application/vnd.ant.mermaid"), "mmd") + self.assertEqual(get_file_extension("application/vnd.ant.react"), "jsx") + self.assertEqual(get_file_extension("unknown/type"), "txt")