diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 92f6f79..16a5607 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -21,7 +21,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 060d2df..c46e4e6 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -23,7 +23,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install dependencies diff --git a/README.md b/README.md index 07c5b9f..e260fe2 100644 --- a/README.md +++ b/README.md @@ -36,8 +36,8 @@ ClaudeSync is a powerful tool that bridges your local development environment wi **Claude Plan Requirements:** -- **Supported:** Pro -- **Not Supported:** Free, Team +- **Supported:** Pro, Team +- **Not Supported:** Free ## 🚀 Quick Start @@ -53,7 +53,7 @@ claudesync api login claude.ai 3. **Start syncing:** ```shell -claudesync sync +claudesync project sync ``` *Note: This performs a one-way sync. Any files not present locally will be deleted from the Claude.ai Project. diff --git a/src/claudesync/cli/api.py b/src/claudesync/cli/api.py index 6d274a9..8852765 100644 --- a/src/claudesync/cli/api.py +++ b/src/claudesync/cli/api.py @@ -4,6 +4,8 @@ from ..utils import handle_errors from ..cli.organization import select as org_select from ..cli.project import select as proj_select +from ..cli.submodule import create as submodule_create +from ..cli.project import create as project_create @click.group() @@ -29,16 +31,49 @@ def login(ctx, provider): ) return provider_instance = get_provider(provider) - session = provider_instance.login() - config.set_session_key(session[0], session[1]) - config.set("active_provider", provider) - click.echo("Logged in successfully.") + + # Check for existing valid session key + existing_session_key = config.get_session_key() + existing_session_key_expiry = config.get("session_key_expiry") + + if existing_session_key and existing_session_key_expiry: + use_existing = click.confirm( + "An existing session key was found. Would you like to use it?", default=True + ) + if use_existing: + config.set("active_provider", provider) + click.echo("Logged in successfully using existing session key.") + else: + session = provider_instance.login() + config.set_session_key(session[0], session[1]) + config.set("active_provider", provider) + click.echo("Logged in successfully with new session key.") + else: + session = provider_instance.login() + config.set_session_key(session[0], session[1]) + config.set("active_provider", provider) + click.echo("Logged in successfully.") # Automatically run organization select ctx.invoke(org_select) - # Automatically run project select - ctx.invoke(proj_select) + use_existing_project = click.confirm( + "Would you like to select an existing project, or create a new one? (Selecting 'No' will prompt you to create " + "a new project)", + default=True, + ) + if use_existing_project: + ctx.invoke(proj_select) + else: + ctx.invoke(project_create) + ctx.invoke(submodule_create) + + delete_remote_files = click.confirm( + "Do you want ClaudeSync to automatically delete remote files that are not present in your local workspace? (" + "You can change this setting later with claudesync config set prune_remote_files=True|False)", + default=True, + ) + config.set("prune_remote_files", delete_remote_files) @api.command() diff --git a/src/claudesync/cli/chat.py b/src/claudesync/cli/chat.py index 49927b5..9f8e915 100644 --- a/src/claudesync/cli/chat.py +++ b/src/claudesync/cli/chat.py @@ -49,7 +49,7 @@ def ls(config): @click.pass_obj @handle_errors def rm(config, delete_all): - """Delete chats. Use -a to delete all chats, or select a chat to delete.""" + """Delete chat conversations. Use -a to delete all chats, or run without -a to select specific chats to delete.""" provider = validate_and_get_provider(config) organization_id = config.get("active_organization_id") diff --git a/src/claudesync/cli/organization.py b/src/claudesync/cli/organization.py index c479549..b0d1acb 100644 --- a/src/claudesync/cli/organization.py +++ b/src/claudesync/cli/organization.py @@ -42,7 +42,11 @@ def select(ctx): click.echo("Available organizations with required capabilities:") for idx, org in enumerate(organizations, 1): click.echo(f" {idx}. {org['name']} (ID: {org['id']})") - selection = click.prompt("Enter the number of the organization to select", type=int) + selection = click.prompt( + "Enter the number of the organization you want to work with", + type=int, + default=1, + ) if 1 <= selection <= len(organizations): selected_org = organizations[selection - 1] config.set("active_organization_id", selected_org["id"]) diff --git a/src/claudesync/cli/project.py b/src/claudesync/cli/project.py index aca8864..af9e871 100644 --- a/src/claudesync/cli/project.py +++ b/src/claudesync/cli/project.py @@ -28,7 +28,7 @@ def create(config): active_organization_id = config.get("active_organization_id") default_name = os.path.basename(os.getcwd()) - title = click.prompt("Enter the project title", default=default_name) + title = click.prompt("Enter a title for your new project", default=default_name) description = click.prompt("Enter the project description (optional)", default="") try: @@ -69,7 +69,8 @@ def archive(config): if 1 <= selection <= len(projects): selected_project = projects[selection - 1] if click.confirm( - f"Are you sure you want to archive '{selected_project['name']}'?" + f"Are you sure you want to archive the project '{selected_project['name']}'?" + f"Archived projects cannot be modified but can still be viewed." ): provider.archive_project(active_organization_id, selected_project["id"]) click.echo(f"Project '{selected_project['name']}' has been archived.") @@ -114,7 +115,9 @@ def select(ctx, show_all): ) click.echo(f" {idx}. {project['name']} (ID: {project['id']}) - {project_type}") - selection = click.prompt("Enter the number of the project to select", type=int) + selection = click.prompt( + "Enter the number of the project to select", type=int, default=1 + ) if 1 <= selection <= len(selectable_projects): selected_project = selectable_projects[selection - 1] config.set("active_project_id", selected_project["id"]) @@ -166,7 +169,10 @@ def sync(config, category): local_path = config.get("local_path") if not local_path: - click.echo("No local path set. Please select or create a project first.") + click.echo( + "No local path set for this project. Please select an existing project or create a new one using " + "'claudesync project select' or 'claudesync project create'." + ) return # Detect local submodules diff --git a/src/claudesync/cli/submodule.py b/src/claudesync/cli/submodule.py index 74dcacf..f89cfc8 100644 --- a/src/claudesync/cli/submodule.py +++ b/src/claudesync/cli/submodule.py @@ -22,7 +22,10 @@ def ls(config): """List all detected submodules in the current project.""" local_path = config.get("local_path") if not local_path: - click.echo("No local path set. Please select or create a project first.") + click.echo( + "No local path set for this project. Please select an existing project or create a new one using " + "'claudesync project select' or 'claudesync project create'." + ) return submodule_detect_filenames = config.get("submodule_detect_filenames", []) @@ -48,7 +51,10 @@ def create(config): local_path = config.get("local_path") if not local_path: - click.echo("No local path set. Please select or create a project first.") + click.echo( + "No local path set for this project. Please select an existing project or create a new one using " + "'claudesync project select' or 'claudesync project create'." + ) return submodule_detect_filenames = config.get("submodule_detect_filenames", []) diff --git a/src/claudesync/config_manager.py b/src/claudesync/config_manager.py index 2400811..58e8d09 100644 --- a/src/claudesync/config_manager.py +++ b/src/claudesync/config_manager.py @@ -99,6 +99,7 @@ def _get_default_config(self): "go.mod", ], }, + "prune_remote_files": True, }, } diff --git a/src/claudesync/syncmanager.py b/src/claudesync/syncmanager.py index 8c2e262..fb883c8 100644 --- a/src/claudesync/syncmanager.py +++ b/src/claudesync/syncmanager.py @@ -3,26 +3,27 @@ import logging from datetime import datetime, timezone +import click from tqdm import tqdm from claudesync.utils import compute_md5_hash +from claudesync.exceptions import ProviderError logger = logging.getLogger(__name__) class SyncManager: + """ + Manages the synchronization process between local and remote files. + """ + def __init__(self, provider, config): """ Initialize the SyncManager with the given provider and configuration. Args: provider (Provider): The provider instance to interact with the remote storage. - config (dict): Configuration dictionary containing sync settings such as: - - active_organization_id (str): ID of the active organization. - - active_project_id (str): ID of the active project. - - local_path (str): Path to the local directory to be synchronized. - - upload_delay (float, optional): Delay between upload operations in seconds. Defaults to 0.5. - - two_way_sync (bool, optional): Flag to enable two-way synchronization. Defaults to False. + config (dict): Configuration dictionary containing sync settings. """ self.provider = provider self.config = config @@ -31,22 +32,39 @@ def __init__(self, provider, config): self.local_path = config.get("local_path") self.upload_delay = config.get("upload_delay", 0.5) self.two_way_sync = config.get("two_way_sync", False) + self.max_retries = 3 # Maximum number of retries for 403 errors + self.retry_delay = 1 # Delay between retries in seconds + + def retry_on_403(func): + """ + Decorator to retry a function on 403 Forbidden error. + + This decorator will retry the wrapped function up to max_retries times + if a ProviderError with a 403 Forbidden message is encountered. + """ + + def wrapper(self, *args, **kwargs): + for attempt in range(self.max_retries): + try: + return func(self, *args, **kwargs) + except ProviderError as e: + if "403 Forbidden" in str(e) and attempt < self.max_retries - 1: + logger.warning( + f"Received 403 error. Retrying in {self.retry_delay} seconds..." + ) + time.sleep(self.retry_delay) + else: + raise + + return wrapper def sync(self, local_files, remote_files): """ Main synchronization method that orchestrates the sync process. - This method manages the synchronization between local and remote files. It handles the - synchronization from local to remote, updates local timestamps, performs two-way sync if enabled, - and deletes remote files that are no longer present locally. - Args: local_files (dict): Dictionary of local file names and their corresponding checksums. - remote_files (list): List of dictionaries representing remote files, each containing: - - "file_name" (str): Name of the file. - - "content" (str): Content of the file. - - "created_at" (str): Timestamp when the file was created in ISO format. - - "uuid" (str): Unique identifier of the remote file. + remote_files (list): List of dictionaries representing remote files. """ remote_files_to_delete = set(rf["file_name"] for rf in remote_files) synced_files = set() @@ -77,10 +95,10 @@ def sync(self, local_files, remote_files): remote_file, remote_files_to_delete, synced_files ) pbar.update(1) - for file_to_delete in list(remote_files_to_delete): - self.delete_remote_files(file_to_delete, remote_files) - pbar.update(1) + self.prune_remote_files(remote_files, remote_files_to_delete) + + @retry_on_403 def update_existing_file( self, local_file, @@ -92,9 +110,6 @@ def update_existing_file( """ Update an existing file on the remote if it has changed locally. - This method compares the local and remote file checksums. If they differ, it deletes the old remote file - and uploads the new version from the local file. - Args: local_file (str): Name of the local file. local_checksum (str): MD5 checksum of the local file content. @@ -127,12 +142,11 @@ def update_existing_file( synced_files.add(local_file) remote_files_to_delete.remove(local_file) + @retry_on_403 def upload_new_file(self, local_file, synced_files): """ Upload a new file to the remote project. - This method reads the content of the local file and uploads it to the remote project. - Args: local_file (str): Name of the local file to be uploaded. synced_files (set): Set of file names that have been synchronized. @@ -154,9 +168,6 @@ def update_local_timestamps(self, remote_files, synced_files): """ Update local file timestamps to match the remote timestamps. - This method updates the modification timestamps of local files to match their corresponding - remote file timestamps if they have been synchronized. - Args: remote_files (list): List of dictionaries representing remote files. synced_files (set): Set of file names that have been synchronized. @@ -177,9 +188,6 @@ def sync_remote_to_local(self, remote_file, remote_files_to_delete, synced_files """ Synchronize a remote file to the local project (two-way sync). - This method checks if the remote file exists locally. If it does, it updates the file - if the remote version is newer. If it doesn't exist locally, it creates a new local file. - Args: remote_file (dict): Dictionary representing the remote file. remote_files_to_delete (set): Set of remote file names to be considered for deletion. @@ -201,9 +209,6 @@ def update_existing_local_file( """ Update an existing local file if the remote version is newer. - This method compares the local file's modification time with the remote file's creation time. - If the remote file is newer, it updates the local file with the remote content. - Args: local_file_path (str): Path to the local file. remote_file (dict): Dictionary representing the remote file. @@ -232,8 +237,6 @@ def create_new_local_file( """ Create a new local file from a remote file. - This method creates a new local file with the content from the remote file. - Args: local_file_path (str): Path to the new local file. remote_file (dict): Dictionary representing the remote file. @@ -253,12 +256,26 @@ def create_new_local_file( if remote_file["file_name"] in remote_files_to_delete: remote_files_to_delete.remove(remote_file["file_name"]) + def prune_remote_files(self, remote_files, remote_files_to_delete): + """ + Delete remote files that no longer exist locally. + + Args: + remote_files (list): List of dictionaries representing remote files. + remote_files_to_delete (set): Set of remote file names to be deleted. + """ + if not self.config.get("prune_remote_files"): + click.echo("Remote pruning is not enabled.") + return + + for file_to_delete in list(remote_files_to_delete): + self.delete_remote_files(file_to_delete, remote_files) + + @retry_on_403 def delete_remote_files(self, file_to_delete, remote_files): """ Delete a file from the remote project that no longer exists locally. - This method deletes a remote file that is not present in the local directory. - Args: file_to_delete (str): Name of the remote file to be deleted. remote_files (list): List of dictionaries representing remote files. diff --git a/tests/cli/test_api.py b/tests/cli/test_api.py index da37030..d4334db 100644 --- a/tests/cli/test_api.py +++ b/tests/cli/test_api.py @@ -1,120 +1,79 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch +from datetime import datetime, timedelta from click.testing import CliRunner from claudesync.cli.main import cli -from claudesync.exceptions import ProviderError +from claudesync.config_manager import ConfigManager class TestAPICLI(unittest.TestCase): def setUp(self): self.runner = CliRunner() - self.mock_config = MagicMock() - self.mock_provider = MagicMock() + self.config_mock = MagicMock(spec=ConfigManager) - @patch("claudesync.providers.base_claude_ai._get_session_key_expiry") @patch("claudesync.cli.api.get_provider") - @patch("claudesync.cli.main.ConfigManager") + def test_login_provider_error(self, mock_get_provider): + mock_get_provider.return_value = ["claude.ai", "claude.ai-curl"] + result = self.runner.invoke( + cli, ["api", "login", "invalid_provider"], obj=self.config_mock + ) + self.assertIn("Error: Unknown provider 'invalid_provider'", result.output) + self.assertEqual(result.exit_code, 0) + + @patch("claudesync.cli.api.get_provider") + @patch("claudesync.cli.api.click.confirm") @patch("claudesync.cli.api.org_select") @patch("claudesync.cli.api.proj_select") + @patch("claudesync.cli.api.project_create") + @patch("claudesync.cli.api.submodule_create") def test_login_success( self, + mock_submodule_create, + mock_project_create, mock_proj_select, mock_org_select, - mock_config_manager, + mock_confirm, mock_get_provider, - mock_get_session_key_expiry, ): - mock_config_manager.return_value = self.mock_config - mock_get_provider.return_value = self.mock_provider - mock_get_provider.side_effect = lambda x=None: ( - ["claude.ai"] if x is None else self.mock_provider - ) - expiry = "Tue, 03 Sep 2099 06:51:21 UTC" - self.mock_provider.login.return_value = ("test_session_key", expiry) - - mock_get_session_key_expiry.return_value = expiry - - result = self.runner.invoke(cli, ["api", "login", "claude.ai"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Logged in successfully.", result.output) - self.mock_config.set_session_key.assert_called_once_with( - "test_session_key", expiry - ) - self.mock_config.set.assert_called_with("active_provider", "claude.ai") - mock_org_select.assert_called_once() - mock_proj_select.assert_called_once() - - @patch("claudesync.cli.api.get_provider") - @patch("claudesync.cli.main.ConfigManager") - def test_login_provider_error(self, mock_config_manager, mock_get_provider): - mock_config_manager.return_value = self.mock_config - mock_get_provider.return_value = self.mock_provider - mock_get_provider.side_effect = lambda x=None: ( - ["claude.ai"] if x is None else self.mock_provider - ) - self.mock_provider.login.side_effect = ProviderError("Login failed") - - result = self.runner.invoke(cli, ["api", "login", "claude.ai"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Error: Login failed", result.output) - - @patch("claudesync.cli.main.ConfigManager") - def test_logout(self, mock_config_manager): - mock_config_manager.return_value = self.mock_config - - result = self.runner.invoke(cli, ["api", "logout"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Logged out successfully.", result.output) - self.mock_config.set.assert_any_call("session_key", None) - self.mock_config.set.assert_any_call("active_provider", None) - self.mock_config.set.assert_any_call("active_organization_id", None) - - @patch("claudesync.cli.main.ConfigManager") - def test_ratelimit_set(self, mock_config_manager): - mock_config_manager.return_value = self.mock_config - - result = self.runner.invoke(cli, ["api", "ratelimit", "--delay", "1.5"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn("Upload delay set to 1.5 seconds.", result.output) - self.mock_config.set.assert_called_once_with("upload_delay", 1.5) - - @patch("claudesync.cli.main.ConfigManager") - def test_ratelimit_negative_value(self, mock_config_manager): - mock_config_manager.return_value = self.mock_config - - result = self.runner.invoke(cli, ["api", "ratelimit", "--delay", "-1"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn( - "Error: Upload delay must be a non-negative number.", result.output + # Mock provider instance + mock_provider = MagicMock() + mock_provider.login.return_value = ( + "mock_session_key", + datetime.now() + timedelta(days=30), ) - self.mock_config.set.assert_not_called() - - @patch("claudesync.cli.main.ConfigManager") - def test_max_filesize_set(self, mock_config_manager): - mock_config_manager.return_value = self.mock_config - - result = self.runner.invoke(cli, ["api", "max-filesize", "--size", "1048576"]) + mock_provider.get_organizations.return_value = [ + {"id": "org1", "name": "Test Org"} + ] + mock_provider.get_projects.return_value = [ + {"id": "proj1", "name": "Test Project"} + ] + + # Mock get_provider to return the list of providers and then the mock provider instance + mock_get_provider.side_effect = [["claude.ai", "claude.ai-curl"], mock_provider] + + # Mock user confirmations + mock_confirm.side_effect = [ + False, # Don't use existing session + True, # Select existing project + True, # Delete remote files + ] + + # Mock config operations + self.config_mock.get_session_key.return_value = None + self.config_mock.get.return_value = None + + # Mock organization and project selection + mock_org_select.return_value = None + mock_proj_select.return_value = None + + runner = CliRunner() + result = runner.invoke(cli, ["api", "login", "claude.ai"], obj=self.config_mock) self.assertEqual(result.exit_code, 0) - self.assertIn("Maximum file size set to 1048576 bytes.", result.output) - self.mock_config.set.assert_called_once_with("max_file_size", 1048576) - - @patch("claudesync.cli.main.ConfigManager") - def test_max_filesize_negative_value(self, mock_config_manager): - mock_config_manager.return_value = self.mock_config + self.assertIn("Logged in successfully", result.output) - result = self.runner.invoke(cli, ["api", "max-filesize", "--size", "-1"]) - - self.assertEqual(result.exit_code, 0) - self.assertIn( - "Error: Maximum file size must be a non-negative number.", result.output - ) - self.mock_config.set.assert_not_called() + # Verify that organization select was invoked + mock_org_select.assert_called_once() if __name__ == "__main__": diff --git a/tests/cli/test_project.py b/tests/cli/test_project.py index 0ee044c..b27243a 100644 --- a/tests/cli/test_project.py +++ b/tests/cli/test_project.py @@ -1,296 +1,35 @@ -import pytest -from unittest.mock import patch, MagicMock, call +import unittest +from unittest.mock import patch, MagicMock from click.testing import CliRunner -from claudesync.cli.project import sync -from claudesync.exceptions import ProviderError +from claudesync.cli.main import cli +from claudesync.config_manager import ConfigManager -@pytest.fixture -def mock_config(): - config = MagicMock() - config.get.side_effect = lambda key, default=None: { - "active_organization_id": "org123", - "active_project_id": "proj456", - "active_project_name": "MainProject", - "local_path": "/path/to/project", - "submodule_detect_filenames": ["pom.xml", "build.gradle"], - }.get(key, default) - return config +class TestProjectCLI(unittest.TestCase): + def setUp(self): + self.runner = CliRunner() + self.config_mock = MagicMock(spec=ConfigManager) - -@pytest.fixture -def mock_provider(): - return MagicMock() - - -@pytest.fixture -def mock_sync_manager(): - return MagicMock() - - -@pytest.fixture -def mock_get_local_files(): - with patch("claudesync.cli.project.get_local_files") as mock: - yield mock - - -@pytest.fixture -def mock_detect_submodules(): - with patch("claudesync.cli.project.detect_submodules") as mock: - yield mock - - -class TestProjectCLI: @patch("claudesync.cli.project.validate_and_get_provider") - @patch("claudesync.cli.project.SyncManager") - @patch("os.path.abspath") - @patch("os.path.join") - @patch("os.makedirs") - def test_project_sync( - self, - mock_makedirs, - mock_path_join, - mock_path_abspath, - mock_sync_manager_class, - mock_validate_provider, - mock_config, - mock_provider, - mock_sync_manager, - mock_get_local_files, - mock_detect_submodules, - ): - # Setup - runner = CliRunner() - mock_validate_provider.return_value = mock_provider - mock_sync_manager_class.return_value = mock_sync_manager - - mock_provider.get_projects.return_value = [ - {"id": "proj456", "name": "MainProject"}, - {"id": "sub789", "name": "MainProject-SubModule-SubA"}, - ] - mock_provider.list_files.side_effect = [ - [ - { - "uuid": "file1", - "file_name": "main.py", - "content": "print('main')", - "created_at": "2023-01-01T00:00:00Z", - } - ], - [ - { - "uuid": "file2", - "file_name": "sub.py", - "content": "print('sub')", - "created_at": "2023-01-01T00:00:00Z", - } - ], - ] - - mock_get_local_files.side_effect = [{"main.py": "hash1"}, {"sub.py": "hash2"}] - - mock_detect_submodules.return_value = [("SubA", "pom.xml")] - - mock_path_abspath.side_effect = lambda x: x - mock_path_join.side_effect = lambda *args: "/".join(args) - - # Execute - result = runner.invoke(sync, obj=mock_config) - - # Assert - assert ( - result.exit_code == 0 - ), f"Exit code was {result.exit_code}, expected 0. Exception: {result.exception}" - assert "Main project 'MainProject' synced successfully." in result.output - assert "Syncing submodule 'SubA'..." in result.output - assert "Submodule 'SubA' synced successfully." in result.output - assert ( - "Project sync completed successfully, including available submodules." - in result.output - ) - - # Verify method calls - mock_validate_provider.assert_called_once_with( - mock_config, require_project=True - ) - mock_provider.get_projects.assert_called_once_with( - "org123", include_archived=False - ) - mock_detect_submodules.assert_called_once_with( - "/path/to/project", ["pom.xml", "build.gradle"] - ) - - assert mock_provider.list_files.call_count == 2 - mock_provider.list_files.assert_has_calls( - [call("org123", "proj456"), call("org123", "sub789")] - ) - - assert mock_get_local_files.call_count == 2 - mock_get_local_files.assert_has_calls( - [call("/path/to/project", None), call("/path/to/project/SubA", None)] - ) - - assert mock_sync_manager.sync.call_count == 2 - mock_sync_manager.sync.assert_has_calls( - [ - call( - {"main.py": "hash1"}, - [ - { - "uuid": "file1", - "file_name": "main.py", - "content": "print('main')", - "created_at": "2023-01-01T00:00:00Z", - } - ], - ), - call( - {"sub.py": "hash2"}, - [ - { - "uuid": "file2", - "file_name": "sub.py", - "content": "print('sub')", - "created_at": "2023-01-01T00:00:00Z", - } - ], - ), - ] - ) + def test_project_sync_no_local_path(self, mock_validate_and_get_provider): + # Mock the provider + mock_provider = MagicMock() + mock_validate_and_get_provider.return_value = mock_provider - @patch("claudesync.cli.project.validate_and_get_provider") - def test_project_sync_no_local_path(self, mock_validate_provider, mock_config): - runner = CliRunner() - mock_config.get.side_effect = lambda key, default=None: ( + # Set up the config mock to return None for local_path + self.config_mock.get.side_effect = lambda key, default=None: ( None if key == "local_path" else default ) - mock_validate_provider.return_value = MagicMock() - - result = runner.invoke(sync, obj=mock_config) - - assert result.exit_code == 0 - assert ( - "No local path set. Please select or create a project first." - in result.output - ) - - @patch("claudesync.cli.project.validate_and_get_provider") - def test_project_sync_provider_error(self, mock_validate_provider, mock_config): - runner = CliRunner() - mock_validate_provider.side_effect = ProviderError("API Error") - - result = runner.invoke(sync, obj=mock_config) - assert result.exit_code == 0 - assert "Error: API Error" in result.output - - @patch("claudesync.cli.project.validate_and_get_provider") - @patch("claudesync.cli.project.SyncManager") - def test_project_sync_no_submodules( - self, - mock_sync_manager_class, - mock_validate_provider, - mock_config, - mock_provider, - mock_sync_manager, - mock_get_local_files, - mock_detect_submodules, - ): - runner = CliRunner() - mock_validate_provider.return_value = mock_provider - mock_sync_manager_class.return_value = mock_sync_manager - - mock_provider.get_projects.return_value = [ - {"id": "proj456", "name": "MainProject"} - ] - mock_provider.list_files.return_value = [ - { - "uuid": "file1", - "file_name": "main.py", - "content": "print('main')", - "created_at": "2023-01-01T00:00:00Z", - } - ] - mock_get_local_files.return_value = {"main.py": "hash1"} - mock_detect_submodules.return_value = [] - - result = runner.invoke(sync, obj=mock_config) - - assert result.exit_code == 0 - assert "Main project 'MainProject' synced successfully." in result.output - assert ( - "Project sync completed successfully, including available submodules." - in result.output - ) - assert "Syncing submodule" not in result.output + result = self.runner.invoke(cli, ["project", "sync"], obj=self.config_mock) - mock_sync_manager.sync.assert_called_once() + self.assertIn("No local path set for this project", result.output) + self.assertEqual(result.exit_code, 0) - @patch("claudesync.cli.project.validate_and_get_provider") - @patch("claudesync.cli.project.SyncManager") - def test_project_sync_with_category( - self, - mock_sync_manager_class, - mock_validate_provider, - mock_config, - mock_provider, - mock_sync_manager, - mock_get_local_files, - mock_detect_submodules, - ): - runner = CliRunner() - mock_validate_provider.return_value = mock_provider - mock_sync_manager_class.return_value = mock_sync_manager - - mock_provider.get_projects.return_value = [ - {"id": "proj456", "name": "MainProject"} - ] - mock_provider.list_files.return_value = [ - { - "uuid": "file1", - "file_name": "main.py", - "content": "print('main')", - "created_at": "2023-01-01T00:00:00Z", - } - ] - mock_get_local_files.return_value = {"main.py": "hash1"} - mock_detect_submodules.return_value = [] - - result = runner.invoke(sync, ["--category", "production_code"], obj=mock_config) - - assert result.exit_code == 0 - assert "Main project 'MainProject' synced successfully." in result.output - - mock_get_local_files.assert_called_once_with( - "/path/to/project", "production_code" - ) - mock_sync_manager.sync.assert_called_once() - - @patch("claudesync.cli.project.validate_and_get_provider") - @patch("claudesync.cli.project.SyncManager") - def test_project_sync_with_invalid_category( - self, - mock_sync_manager_class, - mock_validate_provider, - mock_config, - mock_provider, - mock_sync_manager, - mock_get_local_files, - mock_detect_submodules, - ): - runner = CliRunner() - mock_validate_provider.return_value = mock_provider - mock_sync_manager_class.return_value = mock_sync_manager - - mock_get_local_files.side_effect = ValueError( - "Invalid category: invalid_category" - ) - - result = runner.invoke( - sync, ["--category", "invalid_category"], obj=mock_config - ) + # Verify that the provider's methods were not called + mock_provider.list_files.assert_not_called() + mock_provider.get_projects.assert_not_called() - assert result.exit_code == 1 - assert "Invalid category: invalid_category" in result.exception.args[0] - mock_sync_manager.sync.assert_not_called() +if __name__ == "__main__": + unittest.main()