From b4789d398b47bf26ec24ed605e9eac13539d20f0 Mon Sep 17 00:00:00 2001 From: jahwag <540380+jahwag@users.noreply.github.com> Date: Sat, 11 Jan 2025 09:23:54 +0100 Subject: [PATCH] feat: support specifying chat model --- pyproject.toml | 2 +- src/claudesync/cli/chat.py | 35 ++++++++++++-------- src/claudesync/providers/base_claude_ai.py | 26 ++++++--------- src/claudesync/providers/base_provider.py | 37 +++++++++++++++++++--- 4 files changed, 66 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6997ef7..4ca323a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "claudesync" -version = "0.6.9" +version = "0.7.0" authors = [ {name = "Jahziah Wagner", email = "540380+jahwag@users.noreply.github.com"}, ] diff --git a/src/claudesync/cli/chat.py b/src/claudesync/cli/chat.py index c74f9d7..5e6eba1 100644 --- a/src/claudesync/cli/chat.py +++ b/src/claudesync/cli/chat.py @@ -74,16 +74,16 @@ def delete_all_chats(provider, organization_id): """Delete all chats for the given organization.""" if click.confirm("Are you sure you want to delete all chats?"): total_deleted = 0 - with click.progressbar(length=100, label="Deleting chats") as bar: - while True: - chats = provider.get_chat_conversations(organization_id) - if not chats: - break - uuids_to_delete = [chat["uuid"] for chat in chats[:50]] - deleted, _ = delete_chats(provider, organization_id, uuids_to_delete) - total_deleted += deleted - bar.update(len(uuids_to_delete)) - click.echo(f"Chat deletion complete. Total chats deleted: {total_deleted}") + with click.progressbar(length=100, label="Deleting chats") as bar: + while True: + chats = provider.get_chat_conversations(organization_id) + if not chats: + break + uuids_to_delete = [chat["uuid"] for chat in chats[:50]] + deleted, _ = delete_chats(provider, organization_id, uuids_to_delete) + total_deleted += deleted + bar.update(len(uuids_to_delete)) + click.echo(f"Chat deletion complete. Total chats deleted: {total_deleted}") def delete_single_chat(provider, organization_id): @@ -184,9 +184,16 @@ def init(config, name, project): @click.argument("message", nargs=-1, required=True) @click.option("--chat", help="UUID of the chat to send the message to") @click.option("--timezone", default="UTC", help="Timezone for the message") +@click.option( + "--model", + help="Model to use for the conversation. Available options:\n" + + "- claude-3-5-haiku-20241022\n" + + "- claude-3-opus-20240229\n" + + "Or any custom model string. If not specified, uses the default model.", +) @click.pass_obj @handle_errors -def message(config, message, chat, timezone): +def message(config, message, chat, timezone, model): """Send a message to a specified chat or create a new chat and send the message.""" provider = validate_and_get_provider(config, require_project=True) active_organization_id = config.get("active_organization_id") @@ -203,13 +210,14 @@ def message(config, message, chat, timezone): chat, active_organization_id, provider, + model, ) if chat is None: return # Send message and process the streaming response for event in provider.send_message( - active_organization_id, chat, message, timezone + active_organization_id, chat, message, timezone, model ): if "completion" in event: click.echo(event["completion"], nl=False) @@ -235,6 +243,7 @@ def create_chat( chat, active_organization_id, provider, + model, ): if not chat: if not active_project_name: @@ -250,7 +259,7 @@ def create_chat( # Create a new chat with the selected project new_chat = provider.create_chat( - active_organization_id, project_uuid=active_project_id + active_organization_id, project_uuid=active_project_id, model=model ) chat = new_chat["uuid"] click.echo(f"New chat created with ID: {chat}") diff --git a/src/claudesync/providers/base_claude_ai.py b/src/claudesync/providers/base_claude_ai.py index 3930423..f0d31be 100644 --- a/src/claudesync/providers/base_claude_ai.py +++ b/src/claudesync/providers/base_claude_ai.py @@ -267,26 +267,15 @@ def delete_chat(self, organization_id, conversation_uuids): def _make_request(self, method, endpoint, data=None): raise NotImplementedError("This method should be implemented by subclasses") - def create_chat(self, organization_id, chat_name="", project_uuid=None): - """ - Create a new chat conversation in the specified organization. - - Args: - organization_id (str): The UUID of the organization. - chat_name (str, optional): The name of the chat. Defaults to an empty string. - project_uuid (str, optional): The UUID of the project to associate the chat with. Defaults to None. - - Returns: - dict: The created chat conversation data. - - Raises: - ProviderError: If the chat creation fails. - """ + def create_chat(self, organization_id, chat_name="", project_uuid=None, model=None): data = { "uuid": self._generate_uuid(), "name": chat_name, "project_uuid": project_uuid, } + if model is not None: + data["model"] = model + return self._make_request( "POST", f"/organizations/{organization_id}/chat_conversations", data ) @@ -302,7 +291,9 @@ def _make_request_stream(self, method, endpoint, data=None): # that can be used with sseclient raise NotImplementedError("This method should be implemented by subclasses") - def send_message(self, organization_id, chat_id, prompt, timezone="UTC"): + def send_message( + self, organization_id, chat_id, prompt, timezone="UTC", model=None + ): endpoint = ( f"/organizations/{organization_id}/chat_conversations/{chat_id}/completion" ) @@ -312,6 +303,9 @@ def send_message(self, organization_id, chat_id, prompt, timezone="UTC"): "attachments": [], "files": [], } + if model is not None: + data["model"] = model + response = self._make_request_stream("POST", endpoint, data) client = sseclient.SSEClient(response) for event in client.events(): diff --git a/src/claudesync/providers/base_provider.py b/src/claudesync/providers/base_provider.py index 2509073..a94f9d7 100644 --- a/src/claudesync/providers/base_provider.py +++ b/src/claudesync/providers/base_provider.py @@ -70,11 +70,40 @@ def delete_chat(self, organization_id, conversation_uuids): pass @abstractmethod - def create_chat(self, organization_id, chat_name="", project_uuid=None): - """Create a new chat conversation in the specified organization.""" + def create_chat(self, organization_id, chat_name="", project_uuid=None, model=None): + """ + Create a new chat conversation in the specified organization. + + Args: + organization_id (str): The UUID of the organization. + chat_name (str, optional): The name of the chat. Defaults to an empty string. + project_uuid (str, optional): The UUID of the project to associate the chat with. Defaults to None. + model (str, optional): The chat model to use. Defaults to None. + + Returns: + dict: The created chat conversation data. + + Raises: + ProviderError: If the chat creation fails. + """ pass @abstractmethod - def send_message(self, organization_id, chat_id, prompt, timezone="UTC"): - """Send a message to a specified chat conversation.""" + def send_message( + self, organization_id, chat_id, prompt, timezone="UTC", model=None + ): + """Send a message to a specified chat conversation. + + Args: + organization_id (str): The organization ID + chat_id (str): The chat conversation ID + prompt (str): The message to send + timezone (str, optional): The timezone. Defaults to "UTC" + model (str, optional): The model to use. If None, uses the default model. + Available models: + - None (default) + - claude-3-5-haiku-20241022 + - claude-3-opus-20240229 + - custom string entry + """ pass