Skip to content

Commit

Permalink
Fetch organizations from /api/organizations (#13)
Browse files Browse the repository at this point in the history
* Fetch organizations from /api/organizations

* Format code with black
  • Loading branch information
jahwag authored Jul 21, 2024
1 parent d6e720e commit fc8f947
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "claudesync"
version = "0.3.3"
version = "0.3.4"
authors = [
{name = "Jahziah Wagner", email = "[email protected]"},
]
Expand Down
2 changes: 1 addition & 1 deletion src/claudesync/cli/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def archive(config):
if 1 <= selection <= len(projects):
selected_project = projects[selection - 1]
if click.confirm(
f"Are you sure you want to archive '{selected_project['name']}'?"
f"Are you sure you want to archive '{selected_project['name']}'?"
):
provider.archive_project(active_organization_id, selected_project["id"])
click.echo(f"Project '{selected_project['name']}' has been archived.")
Expand Down
4 changes: 2 additions & 2 deletions src/claudesync/cli/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def sync(config):
active_organization_id, active_project_id, remote_file["uuid"]
)
with open(
os.path.join(local_path, local_file), "r", encoding="utf-8"
os.path.join(local_path, local_file), "r", encoding="utf-8"
) as file:
content = file.read()
provider.upload_file(
Expand All @@ -82,7 +82,7 @@ def sync(config):
else:
click.echo(f"Uploading new file {local_file} to remote...")
with open(
os.path.join(local_path, local_file), "r", encoding="utf-8"
os.path.join(local_path, local_file), "r", encoding="utf-8"
) as file:
content = file.read()
provider.upload_file(
Expand Down
7 changes: 4 additions & 3 deletions src/claudesync/provider_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# Import other providers here as they are added


def get_provider(provider_name=None, session_key=None):
"""
Retrieve an instance of a provider class based on the provider name and session key.
Expand All @@ -12,9 +13,9 @@ def get_provider(provider_name=None, session_key=None):
name is specified but not found in the registry, it raises a ValueError. If a session key is provided, it
is passed to the provider class constructor.
Args:
provider_name (str, optional): The name of the provider to retrieve. If None, returns a list of available provider names.
session_key (str, optional): The session key to be used by the provider for authentication. Defaults to None.
Args: provider_name (str, optional): The name of the provider to retrieve. If None, returns a list of available
provider names. session_key (str, optional): The session key to be used by the provider for authentication.
Defaults to None.
Returns:
object: An instance of the requested provider class if both provider_name and session_key are provided.
Expand Down
35 changes: 19 additions & 16 deletions src/claudesync/providers/claude_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,21 @@ def __init__(self, session_key=None):

def _configure_logging(self):
"""
Configures the logging level for the application based on the configuration.
This method sets the global logging configuration to the level specified in the application's configuration.
If the log level is not specified in the configuration, it defaults to "INFO".
It ensures that all log messages across the application are handled at the configured log level.
"""
Configures the logging level for the application based on the configuration.
This method sets the global logging configuration to the level specified in the application's configuration.
If the log level is not specified in the configuration, it defaults to "INFO".
It ensures that all log messages across the application are handled at the configured log level.
"""

log_level = self.config.get("log_level", "INFO") # Retrieve log level from config, default to "INFO"
logging.basicConfig(level=getattr(logging, log_level)) # Set global logging configuration
logger.setLevel(getattr(logging, log_level)) # Set logger instance to the specified log level
log_level = self.config.get(
"log_level", "INFO"
) # Retrieve log level from config, default to "INFO"
logging.basicConfig(
level=getattr(logging, log_level)
) # Set global logging configuration
logger.setLevel(
getattr(logging, log_level)
) # Set logger instance to the specified log level

def login(self):
"""
Expand Down Expand Up @@ -98,19 +104,16 @@ def get_organizations(self):
Returns:
list of dict: A list of dictionaries, each containing the 'id' and 'name' of an organization.
"""
account_info = self._make_request("GET", "/bootstrap")
if (
"account" not in account_info
or "memberships" not in account_info["account"]
):
organizations = self._make_request("GET", "/organizations")
if not organizations:
raise ProviderError("Unable to retrieve organization information")

return [
{
"id": membership["organization"]["uuid"],
"name": membership["organization"]["name"],
"id": org["uuid"],
"name": org["name"],
}
for membership in account_info["account"]["memberships"]
for org in organizations
]

def get_projects(self, organization_id, include_archived=False):
Expand Down
1 change: 1 addition & 0 deletions src/claudesync/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def validate_and_store_local_path(config):
This function uses `click.prompt` to interact with the user, providing a default path (the current working directory)
and validating the user's input to ensure it meets the criteria for an absolute path to a directory.
"""

def get_default_path():
return os.getcwd()

Expand Down
31 changes: 22 additions & 9 deletions tests/cli/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest
from unittest.mock import patch, MagicMock
from claudesync.providers.claude_ai import ClaudeAIProvider
from claudesync.exceptions import ProviderError


class TestClaudeAIProvider(unittest.TestCase):
Expand All @@ -11,14 +10,28 @@ def setUp(self):
@patch("claudesync.providers.claude_ai.requests.request")
def test_get_organizations(self, mock_request):
mock_response = MagicMock()
mock_response.json.return_value = {
"account": {
"memberships": [
{"organization": {"uuid": "org1", "name": "Organization 1"}},
{"organization": {"uuid": "org2", "name": "Organization 2"}},
]
}
}
mock_response.json.return_value = [
{
"uuid": "org1",
"name": "Organization 1",
"settings": {},
"capabilities": [],
"rate_limit_tier": "",
"billing_type": "",
"created_at": "",
"updated_at": "",
},
{
"uuid": "org2",
"name": "Organization 2",
"settings": {},
"capabilities": [],
"rate_limit_tier": "",
"billing_type": "",
"created_at": "",
"updated_at": "",
},
]
mock_request.return_value = mock_response

organizations = self.provider.get_organizations()
Expand Down
33 changes: 23 additions & 10 deletions tests/providers/test_claude_ai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest
from unittest.mock import patch, MagicMock
from claudesync.providers.claude_ai import ClaudeAIProvider
from claudesync.exceptions import ProviderError


class TestClaudeAIProvider(unittest.TestCase):
Expand All @@ -18,14 +17,28 @@ def test_login(self, mock_prompt):
@patch("claudesync.providers.claude_ai.requests.request")
def test_get_organizations(self, mock_request):
mock_response = MagicMock()
mock_response.json.return_value = {
"account": {
"memberships": [
{"organization": {"uuid": "org1", "name": "Organization 1"}},
{"organization": {"uuid": "org2", "name": "Organization 2"}},
]
}
}
mock_response.json.return_value = [
{
"uuid": "org1",
"name": "Organization 1",
"settings": {},
"capabilities": [],
"rate_limit_tier": "",
"billing_type": "",
"created_at": "",
"updated_at": "",
},
{
"uuid": "org2",
"name": "Organization 2",
"settings": {},
"capabilities": [],
"rate_limit_tier": "",
"billing_type": "",
"created_at": "",
"updated_at": "",
},
]
mock_request.return_value = mock_response

organizations = self.provider.get_organizations()
Expand Down Expand Up @@ -96,7 +109,7 @@ def test_delete_file(self, mock_request):
mock_response.status_code = 204
mock_request.return_value = mock_response

result = self.provider.delete_file("org1", "proj1", "file1")
self.provider.delete_file("org1", "proj1", "file1")
mock_request.assert_called_once_with(
"DELETE",
f"{self.provider.BASE_URL}/organizations/org1/projects/proj1/docs/file1",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def test_get_local_files(self):
self.assertIn("file1.txt", local_files)
self.assertIn("file2.py", local_files)
self.assertIn(os.path.join("subdir", "file3.txt"), local_files)
self.assertEqual(len(local_files), 3) # Ensure ignored files not included
# Ensure ignored files not included
self.assertEqual(len(local_files), 3)


if __name__ == "__main__":
Expand Down

0 comments on commit fc8f947

Please sign in to comment.