Skip to content

Commit

Permalink
Set command should take provider as option. Updated chat to work with…
Browse files Browse the repository at this point in the history
… config.local.json. (#60)
  • Loading branch information
jahwag authored Aug 31, 2024
1 parent 1026d5c commit e3d9389
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 44 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "claudesync"
version = "0.5.5"
version = "0.5.6"
authors = [
{name = "Jahziah Wagner", email = "[email protected]"},
]
Expand Down
6 changes: 3 additions & 3 deletions src/claudesync/chat_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def sync_chats(provider, config, sync_all=False):
local_path = config.get("local_path")
if not local_path:
raise ConfigurationError(
"Local path not set. Use 'claudesync project select' or 'claudesync project create' to set it."
"Local path not set. Use 'claudesync project set' or 'claudesync project create' to set it."
)

# Create chats directory within local_path
Expand All @@ -40,14 +40,14 @@ def sync_chats(provider, config, sync_all=False):
organization_id = config.get("active_organization_id")
if not organization_id:
raise ConfigurationError(
"No active organization set. Please select an organization."
"No active organization set. Please set an organization."
)

# Get the active project ID
active_project_id = config.get("active_project_id")
if not active_project_id and not sync_all:
raise ConfigurationError(
"No active project set. Please select a project or use the -a flag to sync all chats."
"No active project set. Please set a project or use the -a flag to sync all chats."
)

# Fetch all chats for the organization
Expand Down
63 changes: 34 additions & 29 deletions src/claudesync/cli/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,32 +188,29 @@ def init(config, name, project):
@handle_errors
def message(config, message, chat, timezone):
"""Send a message to a specified chat or create a new chat and send the message."""
provider = validate_and_get_provider(config)
organization_id = config.get("active_organization_id")
provider = validate_and_get_provider(config, require_project=True)
active_organization_id = config.get("active_organization_id")
active_project_id = config.get("active_project_id")
active_project_name = config.get("active_project_name")
local_path = config.get("local_path")

if not organization_id:
click.echo("No active organization set.")
return

message = " ".join(message) # Join all message parts into a single string

try:
chat = create_chat(
config,
active_project_id,
active_project_name,
chat,
local_path,
organization_id,
active_organization_id,
provider,
)
if chat is None:
return

# Send message and process the streaming response
for event in provider.send_message(organization_id, chat, message, timezone):
for event in provider.send_message(
active_organization_id, chat, message, timezone
):
if "completion" in event:
click.echo(event["completion"], nl=False)
elif "content" in event:
Expand All @@ -232,30 +229,38 @@ def message(config, message, chat, timezone):


def create_chat(
active_project_id, active_project_name, chat, local_path, organization_id, provider
config,
active_project_id,
active_project_name,
chat,
active_organization_id,
provider,
):
if not chat:
selected_project = select_project(
active_project_id,
active_project_name,
local_path,
organization_id,
provider,
)
if selected_project is None:
return
if not active_project_name:
active_project_id = select_project(
config,
active_project_id,
active_project_name,
active_organization_id,
provider,
)
if active_project_id is None:
return None

# Create a new chat with the selected project
new_chat = provider.create_chat(organization_id, project_uuid=selected_project)
new_chat = provider.create_chat(
active_organization_id, project_uuid=active_project_id
)
chat = new_chat["uuid"]
click.echo(f"New chat created with ID: {chat}")
return chat


def select_project(
active_project_id, active_project_name, local_path, organization_id, provider
config, active_project_id, active_project_name, active_organization_id, provider
):
all_projects = provider.get_projects(organization_id)
all_projects = provider.get_projects(active_organization_id)
if not all_projects:
click.echo("No projects found in the active organization.")
return None
Expand All @@ -279,11 +284,7 @@ def select_project(
current_dir = os.path.abspath(os.getcwd())

default_project = get_default_project(
active_project_id,
active_project_name,
current_dir,
filtered_projects,
local_path,
config, active_project_id, active_project_name, current_dir, filtered_projects
)

click.echo("Available projects:")
Expand Down Expand Up @@ -314,8 +315,12 @@ def select_project(


def get_default_project(
active_project_id, active_project_name, current_dir, filtered_projects, local_path
config, active_project_id, active_project_name, current_dir, filtered_projects
):
local_path = config.get("local_path")
if not local_path:
return None

# Find the project that matches the current directory
default_project = None
for idx, proj in enumerate(filtered_projects):
Expand Down
2 changes: 1 addition & 1 deletion src/claudesync/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def push(config, category, uberproject):
if not local_path:
click.echo(
"No .claudesync directory found in this directory or any parent directories. "
"Please run 'claudesync project create' or 'claudesync project select' first."
"Please run 'claudesync project create' or 'claudesync project set' first."
)
return

Expand Down
30 changes: 27 additions & 3 deletions src/claudesync/cli/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,37 @@ def ls(config):

@organization.command()
@click.option("--org-id", help="ID of the organization to set as active")
@click.option(
"--provider",
type=click.Choice(["claude.ai"]), # Add more providers as they become available
default="claude.ai",
help="Specify the provider for repositories without .claudesync",
)
@click.pass_context
@handle_errors
def set(ctx, org_id):
def set(ctx, org_id, provider):
"""Set the active organization."""
config = ctx.obj
provider = validate_and_get_provider(config, require_org=False)
organizations = provider.get_organizations()

# If provider is not specified, try to get it from the config
if not provider:
provider = config.get("active_provider")

# If provider is still not available, prompt the user
if not provider:
provider = click.prompt(
"Please specify the provider",
type=click.Choice(
["claude.ai"]
), # Add more providers as they become available
)

# Update the config with the provider
config.set("active_provider", provider, local=True)

# Now we can get the provider instance
provider_instance = validate_and_get_provider(config, require_org=False)
organizations = provider_instance.get_organizations()

if not organizations:
click.echo("No organizations with required capabilities found.")
Expand Down
32 changes: 29 additions & 3 deletions src/claudesync/cli/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,41 @@ def archive(config):
is_flag=True,
help="Include submodule projects in the selection",
)
@click.option(
"--provider",
type=click.Choice(["claude.ai"]), # Add more providers as they become available
default="claude.ai",
help="Specify the provider for repositories without .claudesync",
)
@click.pass_context
@handle_errors
def set(ctx, show_all):
def set(ctx, show_all, provider):
"""Set the active project for syncing."""
config = ctx.obj
provider = validate_and_get_provider(config)

# If provider is not specified, try to get it from the config
if not provider:
provider = config.get("active_provider")

# If provider is still not available, prompt the user
if not provider:
provider = click.prompt(
"Please specify the provider",
type=click.Choice(
["claude.ai"]
), # Add more providers as they become available
)

# Update the config with the provider
config.set("active_provider", provider, local=True)

# Now we can get the provider instance
provider_instance = validate_and_get_provider(config)
active_organization_id = config.get("active_organization_id")
active_project_name = config.get("active_project_name")
projects = provider.get_projects(active_organization_id, include_archived=False)
projects = provider_instance.get_projects(
active_organization_id, include_archived=False
)

if show_all:
selectable_projects = projects
Expand Down
15 changes: 13 additions & 2 deletions tests/mock_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,24 @@ def do_GET(self):
self.send_error(404, "Not Found")

def do_POST(self):
content_length = int(self.headers["Content-Length"])
parsed_path = urlparse(self.path)
if parsed_path.path.endswith("/completion"):

if parsed_path.path.endswith("/chat_conversations"):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
response = json.dumps({"uuid": "new_chat", "name": "New Chat"})
self.wfile.write(response.encode())
elif parsed_path.path.endswith("/completion"):
self.send_response(200)
self.send_header("Content-type", "text/event-stream")
self.end_headers()
self.wfile.write(b'data: {"completion": "Hello"}\n\n')
self.wfile.write(b'data: {"completion": " there"}\n\n')
self.wfile.write(b'data: {"completion": " there. "}\n\n')
self.wfile.write(
b'data: {"completion": "I apologize for the confusion. You\'re right."}\n\n'
)
self.wfile.write(b"event: done\n\n")
else:
# time.sleep(0.01) # Add a small delay to simulate network latency
Expand Down
73 changes: 73 additions & 0 deletions tests/test_chat_happy_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import unittest
import threading
import time
from click.testing import CliRunner
from claudesync.cli.main import cli
from claudesync.configmanager import InMemoryConfigManager
from mock_http_server import run_mock_server


class TestChatHappyPath(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Start the mock server in a separate thread
cls.mock_server_thread = threading.Thread(target=run_mock_server)
cls.mock_server_thread.daemon = True
cls.mock_server_thread.start()
time.sleep(1) # Give the server a moment to start

def setUp(self):
self.runner = CliRunner()
self.config = InMemoryConfigManager()
self.config.set("claude_api_url", "http://localhost:8000/api")

def test_chat_happy_path(self):
# Step 1: Login
result = self.runner.invoke(
cli,
["auth", "login", "--provider", "claude.ai"],
input="sk-ant-1234\nThu, 26 Sep 2099 17:07:53 UTC\n",
obj=self.config,
)
self.assertEqual(result.exit_code, 0)
self.assertIn("Successfully authenticated with claude.ai", result.output)

# Step 2: Set organization
result = self.runner.invoke(
cli, ["organization", "set"], input="1\n", obj=self.config
)
self.assertEqual(result.exit_code, 0)
self.assertIn("Selected organization: Test Org 1", result.output)

# Step 3: Create project
result = self.runner.invoke(
cli,
[
"project",
"create",
"--name",
"Test Project",
"--description",
"Test Description",
"--local-path",
".",
],
obj=self.config,
)
self.assertEqual(result.exit_code, 0)
self.assertIn(
"Project 'New Project' (uuid: new_proj) has been created successfully.",
result.output,
)

# Step 4: Send message
result = self.runner.invoke(
cli, ["chat", "message", "Hello, Claude!"], input="1\n", obj=self.config
)
self.assertEqual(result.exit_code, 0)
self.assertIn("Hello there.", result.output)
self.assertIn("I apologize for the confusion. You're right.", result.output)


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions tests/test_claude_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def test_get_chat_conversation(self):

def test_send_message(self):
messages = list(self.provider.send_message("org1", "chat1", "Hello"))
self.assertEqual(len(messages), 2)
self.assertEqual(len(messages), 3)
self.assertEqual(messages[0]["completion"], "Hello")
self.assertEqual(messages[1]["completion"], " there")
self.assertEqual(messages[1]["completion"], " there. ")

def test_handle_http_error_403(self):
# This test still needs to use a mock as we can't easily trigger a 403 from our mock server
Expand Down

0 comments on commit e3d9389

Please sign in to comment.