diff --git a/pyproject.toml b/pyproject.toml index 07e89f5..285a723 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "claudesync" -version = "0.5.5" +version = "0.5.6" authors = [ {name = "Jahziah Wagner", email = "jahziah.wagner+pypi@gmail.com"}, ] diff --git a/src/claudesync/chat_sync.py b/src/claudesync/chat_sync.py index c195b06..b1bdb94 100644 --- a/src/claudesync/chat_sync.py +++ b/src/claudesync/chat_sync.py @@ -29,7 +29,7 @@ def sync_chats(provider, config, sync_all=False): local_path = config.get("local_path") if not local_path: raise ConfigurationError( - "Local path not set. Use 'claudesync project select' or 'claudesync project create' to set it." + "Local path not set. Use 'claudesync project set' or 'claudesync project create' to set it." ) # Create chats directory within local_path @@ -40,14 +40,14 @@ def sync_chats(provider, config, sync_all=False): organization_id = config.get("active_organization_id") if not organization_id: raise ConfigurationError( - "No active organization set. Please select an organization." + "No active organization set. Please set an organization." ) # Get the active project ID active_project_id = config.get("active_project_id") if not active_project_id and not sync_all: raise ConfigurationError( - "No active project set. Please select a project or use the -a flag to sync all chats." + "No active project set. Please set a project or use the -a flag to sync all chats." ) # Fetch all chats for the organization diff --git a/src/claudesync/cli/chat.py b/src/claudesync/cli/chat.py index 0384358..c74f9d7 100644 --- a/src/claudesync/cli/chat.py +++ b/src/claudesync/cli/chat.py @@ -188,32 +188,29 @@ def init(config, name, project): @handle_errors def message(config, message, chat, timezone): """Send a message to a specified chat or create a new chat and send the message.""" - provider = validate_and_get_provider(config) - organization_id = config.get("active_organization_id") + provider = validate_and_get_provider(config, require_project=True) + active_organization_id = config.get("active_organization_id") active_project_id = config.get("active_project_id") active_project_name = config.get("active_project_name") - local_path = config.get("local_path") - - if not organization_id: - click.echo("No active organization set.") - return message = " ".join(message) # Join all message parts into a single string try: chat = create_chat( + config, active_project_id, active_project_name, chat, - local_path, - organization_id, + active_organization_id, provider, ) if chat is None: return # Send message and process the streaming response - for event in provider.send_message(organization_id, chat, message, timezone): + for event in provider.send_message( + active_organization_id, chat, message, timezone + ): if "completion" in event: click.echo(event["completion"], nl=False) elif "content" in event: @@ -232,30 +229,38 @@ def message(config, message, chat, timezone): def create_chat( - active_project_id, active_project_name, chat, local_path, organization_id, provider + config, + active_project_id, + active_project_name, + chat, + active_organization_id, + provider, ): if not chat: - selected_project = select_project( - active_project_id, - active_project_name, - local_path, - organization_id, - provider, - ) - if selected_project is None: - return + if not active_project_name: + active_project_id = select_project( + config, + active_project_id, + active_project_name, + active_organization_id, + provider, + ) + if active_project_id is None: + return None # Create a new chat with the selected project - new_chat = provider.create_chat(organization_id, project_uuid=selected_project) + new_chat = provider.create_chat( + active_organization_id, project_uuid=active_project_id + ) chat = new_chat["uuid"] click.echo(f"New chat created with ID: {chat}") return chat def select_project( - active_project_id, active_project_name, local_path, organization_id, provider + config, active_project_id, active_project_name, active_organization_id, provider ): - all_projects = provider.get_projects(organization_id) + all_projects = provider.get_projects(active_organization_id) if not all_projects: click.echo("No projects found in the active organization.") return None @@ -279,11 +284,7 @@ def select_project( current_dir = os.path.abspath(os.getcwd()) default_project = get_default_project( - active_project_id, - active_project_name, - current_dir, - filtered_projects, - local_path, + config, active_project_id, active_project_name, current_dir, filtered_projects ) click.echo("Available projects:") @@ -314,8 +315,12 @@ def select_project( def get_default_project( - active_project_id, active_project_name, current_dir, filtered_projects, local_path + config, active_project_id, active_project_name, current_dir, filtered_projects ): + local_path = config.get("local_path") + if not local_path: + return None + # Find the project that matches the current directory default_project = None for idx, proj in enumerate(filtered_projects): diff --git a/src/claudesync/cli/main.py b/src/claudesync/cli/main.py index 5985719..ac3a477 100644 --- a/src/claudesync/cli/main.py +++ b/src/claudesync/cli/main.py @@ -130,7 +130,7 @@ def push(config, category, uberproject): if not local_path: click.echo( "No .claudesync directory found in this directory or any parent directories. " - "Please run 'claudesync project create' or 'claudesync project select' first." + "Please run 'claudesync project create' or 'claudesync project set' first." ) return diff --git a/src/claudesync/cli/organization.py b/src/claudesync/cli/organization.py index 41513c1..28045cd 100644 --- a/src/claudesync/cli/organization.py +++ b/src/claudesync/cli/organization.py @@ -27,13 +27,37 @@ def ls(config): @organization.command() @click.option("--org-id", help="ID of the organization to set as active") +@click.option( + "--provider", + type=click.Choice(["claude.ai"]), # Add more providers as they become available + default="claude.ai", + help="Specify the provider for repositories without .claudesync", +) @click.pass_context @handle_errors -def set(ctx, org_id): +def set(ctx, org_id, provider): """Set the active organization.""" config = ctx.obj - provider = validate_and_get_provider(config, require_org=False) - organizations = provider.get_organizations() + + # If provider is not specified, try to get it from the config + if not provider: + provider = config.get("active_provider") + + # If provider is still not available, prompt the user + if not provider: + provider = click.prompt( + "Please specify the provider", + type=click.Choice( + ["claude.ai"] + ), # Add more providers as they become available + ) + + # Update the config with the provider + config.set("active_provider", provider, local=True) + + # Now we can get the provider instance + provider_instance = validate_and_get_provider(config, require_org=False) + organizations = provider_instance.get_organizations() if not organizations: click.echo("No organizations with required capabilities found.") diff --git a/src/claudesync/cli/project.py b/src/claudesync/cli/project.py index efaa18f..03cb3a2 100644 --- a/src/claudesync/cli/project.py +++ b/src/claudesync/cli/project.py @@ -123,15 +123,41 @@ def archive(config): is_flag=True, help="Include submodule projects in the selection", ) +@click.option( + "--provider", + type=click.Choice(["claude.ai"]), # Add more providers as they become available + default="claude.ai", + help="Specify the provider for repositories without .claudesync", +) @click.pass_context @handle_errors -def set(ctx, show_all): +def set(ctx, show_all, provider): """Set the active project for syncing.""" config = ctx.obj - provider = validate_and_get_provider(config) + + # If provider is not specified, try to get it from the config + if not provider: + provider = config.get("active_provider") + + # If provider is still not available, prompt the user + if not provider: + provider = click.prompt( + "Please specify the provider", + type=click.Choice( + ["claude.ai"] + ), # Add more providers as they become available + ) + + # Update the config with the provider + config.set("active_provider", provider, local=True) + + # Now we can get the provider instance + provider_instance = validate_and_get_provider(config) active_organization_id = config.get("active_organization_id") active_project_name = config.get("active_project_name") - projects = provider.get_projects(active_organization_id, include_archived=False) + projects = provider_instance.get_projects( + active_organization_id, include_archived=False + ) if show_all: selectable_projects = projects diff --git a/tests/mock_http_server.py b/tests/mock_http_server.py index de73c72..8f1ad5a 100644 --- a/tests/mock_http_server.py +++ b/tests/mock_http_server.py @@ -95,13 +95,24 @@ def do_GET(self): self.send_error(404, "Not Found") def do_POST(self): + content_length = int(self.headers["Content-Length"]) parsed_path = urlparse(self.path) - if parsed_path.path.endswith("/completion"): + + if parsed_path.path.endswith("/chat_conversations"): + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + response = json.dumps({"uuid": "new_chat", "name": "New Chat"}) + self.wfile.write(response.encode()) + elif parsed_path.path.endswith("/completion"): self.send_response(200) self.send_header("Content-type", "text/event-stream") self.end_headers() self.wfile.write(b'data: {"completion": "Hello"}\n\n') - self.wfile.write(b'data: {"completion": " there"}\n\n') + self.wfile.write(b'data: {"completion": " there. "}\n\n') + self.wfile.write( + b'data: {"completion": "I apologize for the confusion. You\'re right."}\n\n' + ) self.wfile.write(b"event: done\n\n") else: # time.sleep(0.01) # Add a small delay to simulate network latency diff --git a/tests/test_chat_happy_path.py b/tests/test_chat_happy_path.py new file mode 100644 index 0000000..f535f97 --- /dev/null +++ b/tests/test_chat_happy_path.py @@ -0,0 +1,73 @@ +import unittest +import threading +import time +from click.testing import CliRunner +from claudesync.cli.main import cli +from claudesync.configmanager import InMemoryConfigManager +from mock_http_server import run_mock_server + + +class TestChatHappyPath(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Start the mock server in a separate thread + cls.mock_server_thread = threading.Thread(target=run_mock_server) + cls.mock_server_thread.daemon = True + cls.mock_server_thread.start() + time.sleep(1) # Give the server a moment to start + + def setUp(self): + self.runner = CliRunner() + self.config = InMemoryConfigManager() + self.config.set("claude_api_url", "http://localhost:8000/api") + + def test_chat_happy_path(self): + # Step 1: Login + result = self.runner.invoke( + cli, + ["auth", "login", "--provider", "claude.ai"], + input="sk-ant-1234\nThu, 26 Sep 2099 17:07:53 UTC\n", + obj=self.config, + ) + self.assertEqual(result.exit_code, 0) + self.assertIn("Successfully authenticated with claude.ai", result.output) + + # Step 2: Set organization + result = self.runner.invoke( + cli, ["organization", "set"], input="1\n", obj=self.config + ) + self.assertEqual(result.exit_code, 0) + self.assertIn("Selected organization: Test Org 1", result.output) + + # Step 3: Create project + result = self.runner.invoke( + cli, + [ + "project", + "create", + "--name", + "Test Project", + "--description", + "Test Description", + "--local-path", + ".", + ], + obj=self.config, + ) + self.assertEqual(result.exit_code, 0) + self.assertIn( + "Project 'New Project' (uuid: new_proj) has been created successfully.", + result.output, + ) + + # Step 4: Send message + result = self.runner.invoke( + cli, ["chat", "message", "Hello, Claude!"], input="1\n", obj=self.config + ) + self.assertEqual(result.exit_code, 0) + self.assertIn("Hello there.", result.output) + self.assertIn("I apologize for the confusion. You're right.", result.output) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_claude_ai.py b/tests/test_claude_ai.py index 3f1b7e5..14dd2c8 100644 --- a/tests/test_claude_ai.py +++ b/tests/test_claude_ai.py @@ -147,9 +147,9 @@ def test_get_chat_conversation(self): def test_send_message(self): messages = list(self.provider.send_message("org1", "chat1", "Hello")) - self.assertEqual(len(messages), 2) + self.assertEqual(len(messages), 3) self.assertEqual(messages[0]["completion"], "Hello") - self.assertEqual(messages[1]["completion"], " there") + self.assertEqual(messages[1]["completion"], " there. ") def test_handle_http_error_403(self): # This test still needs to use a mock as we can't easily trigger a 403 from our mock server