From fc8f94783b6ddd1d05e94665a2c2e86d3f8d6060 Mon Sep 17 00:00:00 2001 From: jahwag <540380+jahwag@users.noreply.github.com> Date: Sun, 21 Jul 2024 18:37:15 +0200 Subject: [PATCH] Fetch organizations from /api/organizations (#13) * Fetch organizations from /api/organizations * Format code with black --- pyproject.toml | 2 +- src/claudesync/cli/project.py | 2 +- src/claudesync/cli/sync.py | 4 +-- src/claudesync/provider_factory.py | 7 +++--- src/claudesync/providers/claude_ai.py | 35 +++++++++++++++------------ src/claudesync/utils.py | 1 + tests/cli/test_api.py | 31 +++++++++++++++++------- tests/providers/test_claude_ai.py | 33 +++++++++++++++++-------- tests/test_utils.py | 3 ++- 9 files changed, 75 insertions(+), 43 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a5dc4e..98b6f45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "claudesync" -version = "0.3.3" +version = "0.3.4" authors = [ {name = "Jahziah Wagner", email = "jahziah.wagner+pypi@gmail.com"}, ] diff --git a/src/claudesync/cli/project.py b/src/claudesync/cli/project.py index 655b83c..f5f42f7 100644 --- a/src/claudesync/cli/project.py +++ b/src/claudesync/cli/project.py @@ -63,7 +63,7 @@ def archive(config): if 1 <= selection <= len(projects): selected_project = projects[selection - 1] if click.confirm( - f"Are you sure you want to archive '{selected_project['name']}'?" + f"Are you sure you want to archive '{selected_project['name']}'?" ): provider.archive_project(active_organization_id, selected_project["id"]) click.echo(f"Project '{selected_project['name']}' has been archived.") diff --git a/src/claudesync/cli/sync.py b/src/claudesync/cli/sync.py index 72fee0e..d9cb9ce 100644 --- a/src/claudesync/cli/sync.py +++ b/src/claudesync/cli/sync.py @@ -71,7 +71,7 @@ def sync(config): active_organization_id, active_project_id, remote_file["uuid"] ) with open( - os.path.join(local_path, local_file), "r", encoding="utf-8" + os.path.join(local_path, local_file), "r", encoding="utf-8" ) as file: content = file.read() provider.upload_file( @@ -82,7 +82,7 @@ def sync(config): else: click.echo(f"Uploading new file {local_file} to remote...") with open( - os.path.join(local_path, local_file), "r", encoding="utf-8" + os.path.join(local_path, local_file), "r", encoding="utf-8" ) as file: content = file.read() provider.upload_file( diff --git a/src/claudesync/provider_factory.py b/src/claudesync/provider_factory.py index b96e00c..1335aea 100644 --- a/src/claudesync/provider_factory.py +++ b/src/claudesync/provider_factory.py @@ -3,6 +3,7 @@ # Import other providers here as they are added + def get_provider(provider_name=None, session_key=None): """ Retrieve an instance of a provider class based on the provider name and session key. @@ -12,9 +13,9 @@ def get_provider(provider_name=None, session_key=None): name is specified but not found in the registry, it raises a ValueError. If a session key is provided, it is passed to the provider class constructor. - Args: - provider_name (str, optional): The name of the provider to retrieve. If None, returns a list of available provider names. - session_key (str, optional): The session key to be used by the provider for authentication. Defaults to None. + Args: provider_name (str, optional): The name of the provider to retrieve. If None, returns a list of available + provider names. session_key (str, optional): The session key to be used by the provider for authentication. + Defaults to None. Returns: object: An instance of the requested provider class if both provider_name and session_key are provided. diff --git a/src/claudesync/providers/claude_ai.py b/src/claudesync/providers/claude_ai.py index f1ffec9..7581173 100644 --- a/src/claudesync/providers/claude_ai.py +++ b/src/claudesync/providers/claude_ai.py @@ -42,15 +42,21 @@ def __init__(self, session_key=None): def _configure_logging(self): """ - Configures the logging level for the application based on the configuration. - This method sets the global logging configuration to the level specified in the application's configuration. - If the log level is not specified in the configuration, it defaults to "INFO". - It ensures that all log messages across the application are handled at the configured log level. - """ + Configures the logging level for the application based on the configuration. + This method sets the global logging configuration to the level specified in the application's configuration. + If the log level is not specified in the configuration, it defaults to "INFO". + It ensures that all log messages across the application are handled at the configured log level. + """ - log_level = self.config.get("log_level", "INFO") # Retrieve log level from config, default to "INFO" - logging.basicConfig(level=getattr(logging, log_level)) # Set global logging configuration - logger.setLevel(getattr(logging, log_level)) # Set logger instance to the specified log level + log_level = self.config.get( + "log_level", "INFO" + ) # Retrieve log level from config, default to "INFO" + logging.basicConfig( + level=getattr(logging, log_level) + ) # Set global logging configuration + logger.setLevel( + getattr(logging, log_level) + ) # Set logger instance to the specified log level def login(self): """ @@ -98,19 +104,16 @@ def get_organizations(self): Returns: list of dict: A list of dictionaries, each containing the 'id' and 'name' of an organization. """ - account_info = self._make_request("GET", "/bootstrap") - if ( - "account" not in account_info - or "memberships" not in account_info["account"] - ): + organizations = self._make_request("GET", "/organizations") + if not organizations: raise ProviderError("Unable to retrieve organization information") return [ { - "id": membership["organization"]["uuid"], - "name": membership["organization"]["name"], + "id": org["uuid"], + "name": org["name"], } - for membership in account_info["account"]["memberships"] + for org in organizations ] def get_projects(self, organization_id, include_archived=False): diff --git a/src/claudesync/utils.py b/src/claudesync/utils.py index 82c221d..4742b29 100644 --- a/src/claudesync/utils.py +++ b/src/claudesync/utils.py @@ -242,6 +242,7 @@ def validate_and_store_local_path(config): This function uses `click.prompt` to interact with the user, providing a default path (the current working directory) and validating the user's input to ensure it meets the criteria for an absolute path to a directory. """ + def get_default_path(): return os.getcwd() diff --git a/tests/cli/test_api.py b/tests/cli/test_api.py index 6a0ddc2..5e6f2ba 100644 --- a/tests/cli/test_api.py +++ b/tests/cli/test_api.py @@ -1,7 +1,6 @@ import unittest from unittest.mock import patch, MagicMock from claudesync.providers.claude_ai import ClaudeAIProvider -from claudesync.exceptions import ProviderError class TestClaudeAIProvider(unittest.TestCase): @@ -11,14 +10,28 @@ def setUp(self): @patch("claudesync.providers.claude_ai.requests.request") def test_get_organizations(self, mock_request): mock_response = MagicMock() - mock_response.json.return_value = { - "account": { - "memberships": [ - {"organization": {"uuid": "org1", "name": "Organization 1"}}, - {"organization": {"uuid": "org2", "name": "Organization 2"}}, - ] - } - } + mock_response.json.return_value = [ + { + "uuid": "org1", + "name": "Organization 1", + "settings": {}, + "capabilities": [], + "rate_limit_tier": "", + "billing_type": "", + "created_at": "", + "updated_at": "", + }, + { + "uuid": "org2", + "name": "Organization 2", + "settings": {}, + "capabilities": [], + "rate_limit_tier": "", + "billing_type": "", + "created_at": "", + "updated_at": "", + }, + ] mock_request.return_value = mock_response organizations = self.provider.get_organizations() diff --git a/tests/providers/test_claude_ai.py b/tests/providers/test_claude_ai.py index 495f8e4..08e47f7 100644 --- a/tests/providers/test_claude_ai.py +++ b/tests/providers/test_claude_ai.py @@ -1,7 +1,6 @@ import unittest from unittest.mock import patch, MagicMock from claudesync.providers.claude_ai import ClaudeAIProvider -from claudesync.exceptions import ProviderError class TestClaudeAIProvider(unittest.TestCase): @@ -18,14 +17,28 @@ def test_login(self, mock_prompt): @patch("claudesync.providers.claude_ai.requests.request") def test_get_organizations(self, mock_request): mock_response = MagicMock() - mock_response.json.return_value = { - "account": { - "memberships": [ - {"organization": {"uuid": "org1", "name": "Organization 1"}}, - {"organization": {"uuid": "org2", "name": "Organization 2"}}, - ] - } - } + mock_response.json.return_value = [ + { + "uuid": "org1", + "name": "Organization 1", + "settings": {}, + "capabilities": [], + "rate_limit_tier": "", + "billing_type": "", + "created_at": "", + "updated_at": "", + }, + { + "uuid": "org2", + "name": "Organization 2", + "settings": {}, + "capabilities": [], + "rate_limit_tier": "", + "billing_type": "", + "created_at": "", + "updated_at": "", + }, + ] mock_request.return_value = mock_response organizations = self.provider.get_organizations() @@ -96,7 +109,7 @@ def test_delete_file(self, mock_request): mock_response.status_code = 204 mock_request.return_value = mock_response - result = self.provider.delete_file("org1", "proj1", "file1") + self.provider.delete_file("org1", "proj1", "file1") mock_request.assert_called_once_with( "DELETE", f"{self.provider.BASE_URL}/organizations/org1/projects/proj1/docs/file1", diff --git a/tests/test_utils.py b/tests/test_utils.py index cb1ce01..ac438c8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -54,7 +54,8 @@ def test_get_local_files(self): self.assertIn("file1.txt", local_files) self.assertIn("file2.py", local_files) self.assertIn(os.path.join("subdir", "file3.txt"), local_files) - self.assertEqual(len(local_files), 3) # Ensure ignored files not included + # Ensure ignored files not included + self.assertEqual(len(local_files), 3) if __name__ == "__main__":