From 281b19c884aea64479f287d0732b344d62dc4ab1 Mon Sep 17 00:00:00 2001 From: Jahziah Wagner <540380+jahwag@users.noreply.github.com> Date: Sat, 31 Aug 2024 15:33:38 +0200 Subject: [PATCH] Set command should take provider as option. Updated chat to work with config.local.json. (#60) --- pyproject.toml | 2 +- src/claudesync/chat_sync.py | 6 +-- src/claudesync/cli/chat.py | 63 ++++++++++++++++-------------- src/claudesync/cli/main.py | 2 +- src/claudesync/cli/organization.py | 30 ++++++++++++-- src/claudesync/cli/project.py | 32 +++++++++++++-- tests/mock_http_server.py | 15 ++++++- tests/test_claude_ai.py | 4 +- 8 files changed, 110 insertions(+), 44 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 07e89f5..285a723 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "claudesync" -version = "0.5.5" +version = "0.5.6" authors = [ {name = "Jahziah Wagner", email = "jahziah.wagner+pypi@gmail.com"}, ] diff --git a/src/claudesync/chat_sync.py b/src/claudesync/chat_sync.py index c195b06..b1bdb94 100644 --- a/src/claudesync/chat_sync.py +++ b/src/claudesync/chat_sync.py @@ -29,7 +29,7 @@ def sync_chats(provider, config, sync_all=False): local_path = config.get("local_path") if not local_path: raise ConfigurationError( - "Local path not set. Use 'claudesync project select' or 'claudesync project create' to set it." + "Local path not set. Use 'claudesync project set' or 'claudesync project create' to set it." ) # Create chats directory within local_path @@ -40,14 +40,14 @@ def sync_chats(provider, config, sync_all=False): organization_id = config.get("active_organization_id") if not organization_id: raise ConfigurationError( - "No active organization set. Please select an organization." + "No active organization set. Please set an organization." ) # Get the active project ID active_project_id = config.get("active_project_id") if not active_project_id and not sync_all: raise ConfigurationError( - "No active project set. Please select a project or use the -a flag to sync all chats." + "No active project set. Please set a project or use the -a flag to sync all chats." ) # Fetch all chats for the organization diff --git a/src/claudesync/cli/chat.py b/src/claudesync/cli/chat.py index 0384358..c74f9d7 100644 --- a/src/claudesync/cli/chat.py +++ b/src/claudesync/cli/chat.py @@ -188,32 +188,29 @@ def init(config, name, project): @handle_errors def message(config, message, chat, timezone): """Send a message to a specified chat or create a new chat and send the message.""" - provider = validate_and_get_provider(config) - organization_id = config.get("active_organization_id") + provider = validate_and_get_provider(config, require_project=True) + active_organization_id = config.get("active_organization_id") active_project_id = config.get("active_project_id") active_project_name = config.get("active_project_name") - local_path = config.get("local_path") - - if not organization_id: - click.echo("No active organization set.") - return message = " ".join(message) # Join all message parts into a single string try: chat = create_chat( + config, active_project_id, active_project_name, chat, - local_path, - organization_id, + active_organization_id, provider, ) if chat is None: return # Send message and process the streaming response - for event in provider.send_message(organization_id, chat, message, timezone): + for event in provider.send_message( + active_organization_id, chat, message, timezone + ): if "completion" in event: click.echo(event["completion"], nl=False) elif "content" in event: @@ -232,30 +229,38 @@ def message(config, message, chat, timezone): def create_chat( - active_project_id, active_project_name, chat, local_path, organization_id, provider + config, + active_project_id, + active_project_name, + chat, + active_organization_id, + provider, ): if not chat: - selected_project = select_project( - active_project_id, - active_project_name, - local_path, - organization_id, - provider, - ) - if selected_project is None: - return + if not active_project_name: + active_project_id = select_project( + config, + active_project_id, + active_project_name, + active_organization_id, + provider, + ) + if active_project_id is None: + return None # Create a new chat with the selected project - new_chat = provider.create_chat(organization_id, project_uuid=selected_project) + new_chat = provider.create_chat( + active_organization_id, project_uuid=active_project_id + ) chat = new_chat["uuid"] click.echo(f"New chat created with ID: {chat}") return chat def select_project( - active_project_id, active_project_name, local_path, organization_id, provider + config, active_project_id, active_project_name, active_organization_id, provider ): - all_projects = provider.get_projects(organization_id) + all_projects = provider.get_projects(active_organization_id) if not all_projects: click.echo("No projects found in the active organization.") return None @@ -279,11 +284,7 @@ def select_project( current_dir = os.path.abspath(os.getcwd()) default_project = get_default_project( - active_project_id, - active_project_name, - current_dir, - filtered_projects, - local_path, + config, active_project_id, active_project_name, current_dir, filtered_projects ) click.echo("Available projects:") @@ -314,8 +315,12 @@ def select_project( def get_default_project( - active_project_id, active_project_name, current_dir, filtered_projects, local_path + config, active_project_id, active_project_name, current_dir, filtered_projects ): + local_path = config.get("local_path") + if not local_path: + return None + # Find the project that matches the current directory default_project = None for idx, proj in enumerate(filtered_projects): diff --git a/src/claudesync/cli/main.py b/src/claudesync/cli/main.py index 5985719..ac3a477 100644 --- a/src/claudesync/cli/main.py +++ b/src/claudesync/cli/main.py @@ -130,7 +130,7 @@ def push(config, category, uberproject): if not local_path: click.echo( "No .claudesync directory found in this directory or any parent directories. " - "Please run 'claudesync project create' or 'claudesync project select' first." + "Please run 'claudesync project create' or 'claudesync project set' first." ) return diff --git a/src/claudesync/cli/organization.py b/src/claudesync/cli/organization.py index 41513c1..28045cd 100644 --- a/src/claudesync/cli/organization.py +++ b/src/claudesync/cli/organization.py @@ -27,13 +27,37 @@ def ls(config): @organization.command() @click.option("--org-id", help="ID of the organization to set as active") +@click.option( + "--provider", + type=click.Choice(["claude.ai"]), # Add more providers as they become available + default="claude.ai", + help="Specify the provider for repositories without .claudesync", +) @click.pass_context @handle_errors -def set(ctx, org_id): +def set(ctx, org_id, provider): """Set the active organization.""" config = ctx.obj - provider = validate_and_get_provider(config, require_org=False) - organizations = provider.get_organizations() + + # If provider is not specified, try to get it from the config + if not provider: + provider = config.get("active_provider") + + # If provider is still not available, prompt the user + if not provider: + provider = click.prompt( + "Please specify the provider", + type=click.Choice( + ["claude.ai"] + ), # Add more providers as they become available + ) + + # Update the config with the provider + config.set("active_provider", provider, local=True) + + # Now we can get the provider instance + provider_instance = validate_and_get_provider(config, require_org=False) + organizations = provider_instance.get_organizations() if not organizations: click.echo("No organizations with required capabilities found.") diff --git a/src/claudesync/cli/project.py b/src/claudesync/cli/project.py index efaa18f..03cb3a2 100644 --- a/src/claudesync/cli/project.py +++ b/src/claudesync/cli/project.py @@ -123,15 +123,41 @@ def archive(config): is_flag=True, help="Include submodule projects in the selection", ) +@click.option( + "--provider", + type=click.Choice(["claude.ai"]), # Add more providers as they become available + default="claude.ai", + help="Specify the provider for repositories without .claudesync", +) @click.pass_context @handle_errors -def set(ctx, show_all): +def set(ctx, show_all, provider): """Set the active project for syncing.""" config = ctx.obj - provider = validate_and_get_provider(config) + + # If provider is not specified, try to get it from the config + if not provider: + provider = config.get("active_provider") + + # If provider is still not available, prompt the user + if not provider: + provider = click.prompt( + "Please specify the provider", + type=click.Choice( + ["claude.ai"] + ), # Add more providers as they become available + ) + + # Update the config with the provider + config.set("active_provider", provider, local=True) + + # Now we can get the provider instance + provider_instance = validate_and_get_provider(config) active_organization_id = config.get("active_organization_id") active_project_name = config.get("active_project_name") - projects = provider.get_projects(active_organization_id, include_archived=False) + projects = provider_instance.get_projects( + active_organization_id, include_archived=False + ) if show_all: selectable_projects = projects diff --git a/tests/mock_http_server.py b/tests/mock_http_server.py index de73c72..8f1ad5a 100644 --- a/tests/mock_http_server.py +++ b/tests/mock_http_server.py @@ -95,13 +95,24 @@ def do_GET(self): self.send_error(404, "Not Found") def do_POST(self): + content_length = int(self.headers["Content-Length"]) parsed_path = urlparse(self.path) - if parsed_path.path.endswith("/completion"): + + if parsed_path.path.endswith("/chat_conversations"): + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + response = json.dumps({"uuid": "new_chat", "name": "New Chat"}) + self.wfile.write(response.encode()) + elif parsed_path.path.endswith("/completion"): self.send_response(200) self.send_header("Content-type", "text/event-stream") self.end_headers() self.wfile.write(b'data: {"completion": "Hello"}\n\n') - self.wfile.write(b'data: {"completion": " there"}\n\n') + self.wfile.write(b'data: {"completion": " there. "}\n\n') + self.wfile.write( + b'data: {"completion": "I apologize for the confusion. You\'re right."}\n\n' + ) self.wfile.write(b"event: done\n\n") else: # time.sleep(0.01) # Add a small delay to simulate network latency diff --git a/tests/test_claude_ai.py b/tests/test_claude_ai.py index 3f1b7e5..14dd2c8 100644 --- a/tests/test_claude_ai.py +++ b/tests/test_claude_ai.py @@ -147,9 +147,9 @@ def test_get_chat_conversation(self): def test_send_message(self): messages = list(self.provider.send_message("org1", "chat1", "Hello")) - self.assertEqual(len(messages), 2) + self.assertEqual(len(messages), 3) self.assertEqual(messages[0]["completion"], "Hello") - self.assertEqual(messages[1]["completion"], " there") + self.assertEqual(messages[1]["completion"], " there. ") def test_handle_http_error_403(self): # This test still needs to use a mock as we can't easily trigger a 403 from our mock server