From 9a37821da88dd4e3b14450b02a03d936c67eb873 Mon Sep 17 00:00:00 2001 From: jahwag <540380+jahwag@users.noreply.github.com> Date: Sat, 27 Jul 2024 21:56:31 +0200 Subject: [PATCH] curl fallback (#21) --- README.md | 44 +++- pyproject.toml | 92 +++---- src/claudesync/provider_factory.py | 84 +++--- src/claudesync/providers/claude_ai.py | 11 + src/claudesync/providers/claude_ai_curl.py | 292 +++++++++++++++++++++ tests/providers/test_claude_ai_curl.py | 159 +++++++++++ tests/test_provider_factory.py | 61 +++-- 7 files changed, 636 insertions(+), 107 deletions(-) create mode 100644 src/claudesync/providers/claude_ai_curl.py create mode 100644 tests/providers/test_claude_ai_curl.py diff --git a/README.md b/README.md index 815ebed..d5b7096 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,47 @@ Set up automatic syncing at regular intervals: claudesync schedule ``` +### Providers + +ClaudeSync offers two providers for interacting with the Claude.ai API: + +1. **claude.ai (Default)**: + - Uses built-in Python libraries to make API requests. + - No additional dependencies required. + - Recommended for most users. + +2. **claude.ai-curl**: + - Uses cURL to make API requests. + - Requires cURL to be installed on your system. + - Can be used as a workaround for certain issues, such as 403 Forbidden errors. + + **Note for Windows Users**: To use the claude.ai-curl provider on Windows, you need to have cURL installed. This can be done by: + - Installing [Git for Windows](https://git-scm.com/download/win) (which includes cURL) + - Installing [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) + + Make sure cURL is accessible from your command line before using this provider. + +### Troubleshooting + +#### 403 Forbidden Error +If you encounter a 403 Forbidden error when using ClaudeSync, it might be due to an issue with the session key or API access. As a workaround, you can try using the `claude.ai-curl` provider: + +1. Ensure cURL is installed on your system (see note above for Windows users). + +2. Logout from your current session: + ```bash + claudesync api logout + ``` + +3. Login using the claude.ai-curl provider: + ```bash + claudesync api login claude.ai-curl + ``` + +4. Try your operation again. + +If the issue persists, please check your network connection and ensure that you have the necessary permissions to access Claude.ai. + ## Contributing We welcome contributions! Please see our [Contributing Guidelines](CONTRIBUTING.md) for more information. @@ -134,5 +175,4 @@ ClaudeSync is licensed under the MIT License. See the [LICENSE](LICENSE) file fo --- -Made with ❤️ by the ClaudeSync team -``` +Made with ❤️ by the ClaudeSync team \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 27da9cd..9ca26be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,46 +1,46 @@ -[build-system] -requires = ["setuptools>=42", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "claudesync" -version = "0.3.7" -authors = [ - {name = "Jahziah Wagner", email = "jahziah.wagner+pypi@gmail.com"}, -] -description = "A tool to synchronize local files with Claude.ai projects" -license = {file = "LICENSE"} -readme = "README.md" -requires-python = ">=3.7" -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", -] -dependencies = [ - "Click", - "requests", - "pathspec", - "crontab", - "setuptools", - "pytest", - "pytest-cov", - "click_completion", -] - -[project.urls] -"Homepage" = "https://github.com/jahwag/claudesync" -"Bug Tracker" = "https://github.com/jahwag/claudesync/issues" - -[project.scripts] -claudesync = "claudesync.cli.main:cli" - -[tool.setuptools.packages.find] -where = ["src"] -include = ["claudesync*"] - - -[tool.pytest.ini_options] -testpaths = ["tests"] -python_files = "test_*.py" -addopts = "-v --cov=claudesync --cov-report=term-missing" +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "claudesync" +version = "0.3.8" +authors = [ + {name = "Jahziah Wagner", email = "jahziah.wagner+pypi@gmail.com"}, +] +description = "A tool to synchronize local files with Claude.ai projects" +license = {file = "LICENSE"} +readme = "README.md" +requires-python = ">=3.7" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +dependencies = [ + "Click", + "requests", + "pathspec", + "crontab", + "setuptools", + "pytest", + "pytest-cov", + "click_completion", +] + +[project.urls] +"Homepage" = "https://github.com/jahwag/claudesync" +"Bug Tracker" = "https://github.com/jahwag/claudesync/issues" + +[project.scripts] +claudesync = "claudesync.cli.main:cli" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["claudesync*"] + + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = "test_*.py" +addopts = "-v --cov=claudesync --cov-report=term-missing" diff --git a/src/claudesync/provider_factory.py b/src/claudesync/provider_factory.py index b69beb0..c36e34a 100644 --- a/src/claudesync/provider_factory.py +++ b/src/claudesync/provider_factory.py @@ -1,41 +1,43 @@ -# src/claudesync/provider_factory.py - -from .providers.base_provider import BaseProvider -from .providers.claude_ai import ClaudeAIProvider - - -def get_provider(provider_name=None, session_key=None) -> BaseProvider: - """ - Retrieve an instance of a provider class based on the provider name and session key. - - This function serves as a factory to instantiate provider classes. It maintains a registry of available - providers. If a provider name is not specified, it returns a list of available provider names. If a provider - name is specified but not found in the registry, it raises a ValueError. If a session key is provided, it - is passed to the provider class constructor. - - Args: - provider_name (str, optional): The name of the provider to retrieve. If None, returns a list of available - provider names. - session_key (str, optional): The session key to be used by the provider for authentication. - Defaults to None. - - Returns: - BaseProvider: An instance of the requested provider class if both provider_name and session_key are provided. - list: A list of available provider names if provider_name is None. - - Raises: - ValueError: If the specified provider_name is not found in the registry of providers. - """ - providers = { - "claude.ai": ClaudeAIProvider, - # Add other providers here as they are implemented - } - - if provider_name is None: - return list(providers.keys()) - - provider_class = providers.get(provider_name) - if provider_class is None: - raise ValueError(f"Unsupported provider: {provider_name}") - - return provider_class(session_key) if session_key else provider_class() +# src/claudesync/provider_factory.py + +from .providers.base_provider import BaseProvider +from .providers.claude_ai import ClaudeAIProvider +from .providers.claude_ai_curl import ClaudeAICurlProvider + + +def get_provider(provider_name=None, session_key=None) -> BaseProvider: + """ + Retrieve an instance of a provider class based on the provider name and session key. + + This function serves as a factory to instantiate provider classes. It maintains a registry of available + providers. If a provider name is not specified, it returns a list of available provider names. If a provider + name is specified but not found in the registry, it raises a ValueError. If a session key is provided, it + is passed to the provider class constructor. + + Args: + provider_name (str, optional): The name of the provider to retrieve. If None, returns a list of available + provider names. + session_key (str, optional): The session key to be used by the provider for authentication. + Defaults to None. + + Returns: + BaseProvider: An instance of the requested provider class if both provider_name and session_key are provided. + list: A list of available provider names if provider_name is None. + + Raises: + ValueError: If the specified provider_name is not found in the registry of providers. + """ + providers = { + "claude.ai": ClaudeAIProvider, + "claude.ai-curl": ClaudeAICurlProvider, + # Add other providers here as they are implemented + } + + if provider_name is None: + return list(providers.keys()) + + provider_class = providers.get(provider_name) + if provider_class is None: + raise ValueError(f"Unsupported provider: {provider_name}") + + return provider_class(session_key) if session_key else provider_class() diff --git a/src/claudesync/providers/claude_ai.py b/src/claudesync/providers/claude_ai.py index 8103ec4..c76d83c 100644 --- a/src/claudesync/providers/claude_ai.py +++ b/src/claudesync/providers/claude_ai.py @@ -116,6 +116,17 @@ def _make_request(self, method, endpoint, **kwargs): # Update cookies with any new values from the response self.config.update_cookies(response.cookies.get_dict()) + if response.status_code == 403: + error_msg = ( + "Received a 403 Forbidden error. Your session key might be invalid. " + "Please try logging out and logging in again. If the issue persists, " + "you can try using the claude.ai-curl provider as a workaround:\n" + "claudesync api logout\n" + "claudesync api login claude.ai-curl" + ) + logger.error(error_msg) + raise ProviderError(error_msg) + response.raise_for_status() if not response.content: diff --git a/src/claudesync/providers/claude_ai_curl.py b/src/claudesync/providers/claude_ai_curl.py new file mode 100644 index 0000000..b1a3d57 --- /dev/null +++ b/src/claudesync/providers/claude_ai_curl.py @@ -0,0 +1,292 @@ +import json +import subprocess +import click +from .base_provider import BaseProvider +from ..exceptions import ProviderError + + +class ClaudeAICurlProvider(BaseProvider): + """ + A provider class for interacting with the Claude AI API using cURL. + + This class encapsulates methods for performing API operations such as logging in, retrieving organizations, + projects, and files, as well as uploading and deleting files. It uses cURL commands for HTTP requests and + a session key for authentication. + + Attributes: + BASE_URL (str): The base URL for the Claude AI API. + session_key (str, optional): The session key used for authentication with the API. + """ + + BASE_URL = "https://claude.ai/api" + + def __init__(self, session_key=None): + """ + Initializes the ClaudeAICurlProvider instance. + + Args: + session_key (str, optional): The session key used for authentication. Defaults to None. + """ + self.session_key = session_key + + def _execute_curl(self, method, endpoint, data=None): + """ + Executes a cURL command to make an HTTP request to the Claude AI API. + + This method constructs and executes a cURL command based on the provided method, endpoint, and data. + It handles the response and potential errors from the cURL execution. + + Args: + method (str): The HTTP method for the request (e.g., "GET", "POST", "PUT", "DELETE"). + endpoint (str): The API endpoint to call, relative to the BASE_URL. + data (dict, optional): The data to send with the request for POST or PUT methods. Defaults to None. + + Returns: + dict: The JSON-decoded response from the API. + + Raises: + ProviderError: If the cURL command fails, returns no output, or the response cannot be parsed as JSON. + """ + url = f"{self.BASE_URL}{endpoint}" + headers = [ + "-H", + "User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:128.0) Gecko/20100101 Firefox/128.0", + "-H", + f"Cookie: sessionKey={self.session_key};", + "-H", + "Content-Type: application/json", + ] + + command = [ + "curl", + url, + "--compressed", + "-s", + "-S", + ] # Add -s for silent mode, -S to show errors + command.extend(headers) + + if method != "GET": + command.extend(["-X", method]) + + if data: + json_data = json.dumps(data) + command.extend(["-d", json_data]) + + try: + result = subprocess.run( + command, capture_output=True, text=True, check=True, encoding="utf-8" + ) + + if not result.stdout: + return None + + try: + return json.loads(result.stdout) + except json.JSONDecodeError as e: + raise ProviderError( + f"Failed to parse JSON response: {e}. Response content: {result.stdout}" + ) + + except subprocess.CalledProcessError as e: + error_message = f"cURL command failed with return code {e.returncode}. " + error_message += f"stdout: {e.stdout}, stderr: {e.stderr}" + raise ProviderError(error_message) + except UnicodeDecodeError as e: + error_message = f"Failed to decode cURL output: {e}. " + error_message += ( + "This might be due to non-UTF-8 characters in the response." + ) + raise ProviderError(error_message) + + def login(self): + """ + Guides the user through obtaining a session key from the Claude AI website. + + This method provides step-by-step instructions for the user to log in to the Claude AI website, + access the developer tools of their browser, navigate to the cookies section, and retrieve the + 'sessionKey' cookie value. It then prompts the user to enter this session key, which is stored + in the instance for future requests. + + Returns: + str: The session key entered by the user. + """ + click.echo("To obtain your session key, please follow these steps:") + click.echo("1. Open your web browser and go to https://claude.ai") + click.echo("2. Log in to your Claude account if you haven't already") + click.echo("3. Once logged in, open your browser's developer tools:") + click.echo(" - Chrome/Edge: Press F12 or Ctrl+Shift+I (Cmd+Option+I on Mac)") + click.echo(" - Firefox: Press F12 or Ctrl+Shift+I (Cmd+Option+I on Mac)") + click.echo( + " - Safari: Enable developer tools in Preferences > Advanced, then press Cmd+Option+I" + ) + click.echo( + "4. In the developer tools, go to the 'Application' tab (Chrome/Edge) or 'Storage' tab (Firefox)" + ) + click.echo( + "5. In the left sidebar, expand 'Cookies' and select 'https://claude.ai'" + ) + click.echo("6. Find the cookie named 'sessionKey' and copy its value") + self.session_key = click.prompt("Please enter your sessionKey", type=str) + return self.session_key + + def get_organizations(self): + """ + Retrieves a list of organizations the user is a member of. + + This method sends a GET request to the '/organizations' endpoint to fetch account information, + including memberships in organizations. It parses the response to extract and return + organization IDs and names. + + Returns: + list of dict: A list of dictionaries, each containing the 'id' and 'name' of an organization. + + Raises: + ProviderError: If there's an issue with retrieving organization information. + """ + response = self._execute_curl("GET", "/organizations") + if not response: + raise ProviderError("Unable to retrieve organization information") + return [{"id": org["uuid"], "name": org["name"]} for org in response] + + def get_projects(self, organization_id, include_archived=False): + """ + Retrieves a list of projects for a specified organization. + + This method sends a GET request to fetch all projects associated with a given organization ID. + It then filters these projects based on the `include_archived` parameter. + + Args: + organization_id (str): The unique identifier for the organization. + include_archived (bool, optional): Flag to include archived projects in the result. Defaults to False. + + Returns: + list of dict: A list of dictionaries, each representing a project with its ID, name, and archival status. + """ + response = self._execute_curl( + "GET", f"/organizations/{organization_id}/projects" + ) + projects = [ + { + "id": project["uuid"], + "name": project["name"], + "archived_at": project.get("archived_at"), + } + for project in response + if include_archived or project.get("archived_at") is None + ] + return projects + + def list_files(self, organization_id, project_id): + """ + Lists all files within a specified project and organization. + + This method sends a GET request to the Claude AI API to retrieve all documents associated with a given project + within an organization. It then formats the response into a list of dictionaries, each representing a file with + its unique identifier, file name, content, and creation date. + + Args: + organization_id (str): The unique identifier for the organization. + project_id (str): The unique identifier for the project within the organization. + + Returns: + list of dict: A list of dictionaries, each containing details of a file such as its UUID, file name, + content, and the date it was created. + """ + response = self._execute_curl( + "GET", f"/organizations/{organization_id}/projects/{project_id}/docs" + ) + return [ + { + "uuid": file["uuid"], + "file_name": file["file_name"], + "content": file["content"], + "created_at": file["created_at"], + } + for file in response + ] + + def upload_file(self, organization_id, project_id, file_name, content): + """ + Uploads a file to a specified project within an organization. + + This method sends a POST request to the Claude AI API to upload a file with the given name and content + to a specified project within an organization. The file's metadata, including its name and content, + is sent as JSON in the request body. + + Args: + organization_id (str): The unique identifier for the organization. + project_id (str): The unique identifier for the project within the organization. + file_name (str): The name of the file to be uploaded. + content (str): The content of the file to be uploaded. + + Returns: + dict: The response from the server after the file upload operation, typically including details + about the uploaded file such as its ID, name, and a confirmation of the upload status. + """ + data = {"file_name": file_name, "content": content} + return self._execute_curl( + "POST", f"/organizations/{organization_id}/projects/{project_id}/docs", data + ) + + def delete_file(self, organization_id, project_id, file_uuid): + """ + Deletes a file from a specified project within an organization. + + This method sends a DELETE request to the Claude AI API to remove a file, identified by its UUID, + from a specified project within an organization. The organization and project are identified by their + respective unique identifiers. + + Args: + organization_id (str): The unique identifier for the organization. + project_id (str): The unique identifier for the project within the organization. + file_uuid (str): The unique identifier (UUID) of the file to be deleted. + + Returns: + dict: The response from the server after the file deletion operation, typically confirming the deletion. + """ + return self._execute_curl( + "DELETE", + f"/organizations/{organization_id}/projects/{project_id}/docs/{file_uuid}", + ) + + def archive_project(self, organization_id, project_id): + """ + Archives a specified project within an organization. + + This method sends a PUT request to the Claude AI API to change the archival status of a specified project + to archive. The project and organization are identified by their respective unique identifiers. + + Args: + organization_id (str): The unique identifier for the organization. + project_id (str): The unique identifier for the project within the organization. + + Returns: + dict: The response from the server after the archival operation, typically confirming the archival status. + """ + data = {"is_archived": True} + return self._execute_curl( + "PUT", f"/organizations/{organization_id}/projects/{project_id}", data + ) + + def create_project(self, organization_id, name, description=""): + """ + Creates a new project within a specified organization. + + This method sends a POST request to the Claude AI API to create a new project with the given name, + description, and sets it as private within the specified organization. The project's name, description, + and privacy status are sent as JSON in the request body. + + Args: + organization_id (str): The unique identifier for the organization. + name (str): The name of the project to be created. + description (str, optional): A description of the project. Defaults to an empty string. + + Returns: + dict: The response from the server after the project creation operation, typically including details + about the created project such as its ID, name, and a confirmation of the creation status. + """ + data = {"name": name, "description": description, "is_private": True} + return self._execute_curl( + "POST", f"/organizations/{organization_id}/projects", data + ) diff --git a/tests/providers/test_claude_ai_curl.py b/tests/providers/test_claude_ai_curl.py new file mode 100644 index 0000000..6ec095d --- /dev/null +++ b/tests/providers/test_claude_ai_curl.py @@ -0,0 +1,159 @@ +import unittest +from unittest.mock import patch, MagicMock +import subprocess +from claudesync.providers.claude_ai_curl import ClaudeAICurlProvider +from claudesync.exceptions import ProviderError + + +class TestClaudeAICurlProvider(unittest.TestCase): + + def setUp(self): + self.provider = ClaudeAICurlProvider("test_session_key") + + @patch("subprocess.run") + def test_execute_curl_success(self, mock_run): + mock_result = MagicMock() + mock_result.stdout = '{"key": "value"}' + mock_result.returncode = 0 + mock_run.return_value = mock_result + + result = self.provider._execute_curl("GET", "/test") + + self.assertEqual(result, {"key": "value"}) + mock_run.assert_called_once() + + @patch("subprocess.run") + def test_execute_curl_failure(self, mock_run): + mock_run.side_effect = subprocess.CalledProcessError( + 1, "curl", stderr="Test error" + ) + + with self.assertRaises(ProviderError): + self.provider._execute_curl("GET", "/test") + + @patch("claudesync.providers.claude_ai_curl.click.prompt") + def test_login(self, mock_prompt): + mock_prompt.return_value = "new_session_key" + + result = self.provider.login() + + self.assertEqual(result, "new_session_key") + self.assertEqual(self.provider.session_key, "new_session_key") + + @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") + def test_get_organizations(self, mock_execute_curl): + mock_execute_curl.return_value = [ + {"uuid": "org1", "name": "Org 1"}, + {"uuid": "org2", "name": "Org 2"}, + ] + + result = self.provider.get_organizations() + + expected = [{"id": "org1", "name": "Org 1"}, {"id": "org2", "name": "Org 2"}] + self.assertEqual(result, expected) + + @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") + def test_get_projects(self, mock_execute_curl): + mock_execute_curl.return_value = [ + {"uuid": "proj1", "name": "Project 1", "archived_at": None}, + {"uuid": "proj2", "name": "Project 2", "archived_at": "2023-01-01"}, + ] + + result = self.provider.get_projects("org1", include_archived=True) + + expected = [ + {"id": "proj1", "name": "Project 1", "archived_at": None}, + {"id": "proj2", "name": "Project 2", "archived_at": "2023-01-01"}, + ] + self.assertEqual(result, expected) + + @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") + def test_list_files(self, mock_execute_curl): + mock_execute_curl.return_value = [ + { + "uuid": "file1", + "file_name": "test1.txt", + "content": "content1", + "created_at": "2023-01-01", + }, + { + "uuid": "file2", + "file_name": "test2.txt", + "content": "content2", + "created_at": "2023-01-02", + }, + ] + + result = self.provider.list_files("org1", "proj1") + + expected = [ + { + "uuid": "file1", + "file_name": "test1.txt", + "content": "content1", + "created_at": "2023-01-01", + }, + { + "uuid": "file2", + "file_name": "test2.txt", + "content": "content2", + "created_at": "2023-01-02", + }, + ] + self.assertEqual(result, expected) + + @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") + def test_upload_file(self, mock_execute_curl): + mock_execute_curl.return_value = {"uuid": "new_file", "file_name": "test.txt"} + + result = self.provider.upload_file("org1", "proj1", "test.txt", "content") + + self.assertEqual(result, {"uuid": "new_file", "file_name": "test.txt"}) + mock_execute_curl.assert_called_once_with( + "POST", + "/organizations/org1/projects/proj1/docs", + {"file_name": "test.txt", "content": "content"}, + ) + + @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") + def test_delete_file(self, mock_execute_curl): + mock_execute_curl.return_value = {"status": "deleted"} + + result = self.provider.delete_file("org1", "proj1", "file1") + + self.assertEqual(result, {"status": "deleted"}) + mock_execute_curl.assert_called_once_with( + "DELETE", "/organizations/org1/projects/proj1/docs/file1" + ) + + @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") + def test_archive_project(self, mock_execute_curl): + mock_execute_curl.return_value = {"uuid": "proj1", "is_archived": True} + + result = self.provider.archive_project("org1", "proj1") + + self.assertEqual(result, {"uuid": "proj1", "is_archived": True}) + mock_execute_curl.assert_called_once_with( + "PUT", "/organizations/org1/projects/proj1", {"is_archived": True} + ) + + @patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl") + def test_create_project(self, mock_execute_curl): + mock_execute_curl.return_value = {"uuid": "new_proj", "name": "New Project"} + + result = self.provider.create_project("org1", "New Project", "Description") + + self.assertEqual(result, {"uuid": "new_proj", "name": "New Project"}) + mock_execute_curl.assert_called_once_with( + "POST", + "/organizations/org1/projects", + { + "name": "New Project", + "description": "Description", + "is_private": True, + }, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_provider_factory.py b/tests/test_provider_factory.py index d9ddcd8..24993a7 100644 --- a/tests/test_provider_factory.py +++ b/tests/test_provider_factory.py @@ -1,41 +1,66 @@ -import unittest from unittest.mock import patch +import pytest from claudesync.provider_factory import get_provider from claudesync.providers.claude_ai import ClaudeAIProvider +from claudesync.providers.claude_ai_curl import ClaudeAICurlProvider -class TestProviderFactory(unittest.TestCase): +class TestProviderFactory: - def test_get_provider_list(self): + @pytest.mark.parametrize("provider_name", ["claude.ai", "claude.ai-curl"]) + def test_get_provider_list(self, provider_name): # Test that get_provider returns a list of available providers when called without arguments providers = get_provider() - self.assertIsInstance(providers, list) - self.assertIn("claude.ai", providers) + assert isinstance(providers, list) + assert provider_name in providers - def test_get_provider_claude_ai(self): - # Test that get_provider returns a ClaudeAIProvider instance for "claude.ai" - provider = get_provider("claude.ai") - self.assertIsInstance(provider, ClaudeAIProvider) + @pytest.mark.parametrize( + "provider_name, expected_class", + [("claude.ai", ClaudeAIProvider), ("claude.ai-curl", ClaudeAICurlProvider)], + ) + def test_get_provider_instance(self, provider_name, expected_class): + # Test that get_provider returns the correct provider instance + provider = get_provider(provider_name) + assert isinstance(provider, expected_class) - def test_get_provider_with_session_key(self): + @pytest.mark.parametrize( + "provider_name, expected_class", + [("claude.ai", ClaudeAIProvider), ("claude.ai-curl", ClaudeAICurlProvider)], + ) + def test_get_provider_with_session_key(self, provider_name, expected_class): # Test that get_provider returns a provider instance with a session key session_key = "test_session_key" - provider = get_provider("claude.ai", session_key) - self.assertIsInstance(provider, ClaudeAIProvider) - self.assertEqual(provider.session_key, session_key) + provider = get_provider(provider_name, session_key) + assert isinstance(provider, expected_class) + assert provider.session_key == session_key def test_get_provider_unknown(self): # Test that get_provider raises a ValueError for an unknown provider - with self.assertRaises(ValueError): + with pytest.raises(ValueError): get_provider("unknown_provider") + @pytest.mark.parametrize( + "provider_name, expected_class", + [("claude.ai", ClaudeAIProvider), ("claude.ai-curl", ClaudeAICurlProvider)], + ) @patch("claudesync.provider_factory.ClaudeAIProvider") - def test_get_provider_calls_constructor(self, mock_claude_ai_provider): + @patch("claudesync.provider_factory.ClaudeAICurlProvider") + def test_get_provider_calls_constructor( + self, + mock_claude_ai_curl_provider, + mock_claude_ai_provider, + provider_name, + expected_class, + ): # Test that get_provider calls the provider's constructor session_key = "test_session_key" - get_provider("claude.ai", session_key) - mock_claude_ai_provider.assert_called_once_with(session_key) + get_provider(provider_name, session_key) + + if provider_name == "claude.ai": + mock_claude_ai_provider.assert_called_once_with(session_key) + else: + mock_claude_ai_curl_provider.assert_called_once_with(session_key) if __name__ == "__main__": - unittest.main() + pytest.main()