Skip to content

Commit

Permalink
feat: support specifying chat model (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
jahwag authored Jan 11, 2025
1 parent 1845865 commit a723729
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "claudesync"
version = "0.6.9"
version = "0.7.0"
authors = [
{name = "Jahziah Wagner", email = "[email protected]"},
]
Expand Down
35 changes: 22 additions & 13 deletions src/claudesync/cli/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ def delete_all_chats(provider, organization_id):
"""Delete all chats for the given organization."""
if click.confirm("Are you sure you want to delete all chats?"):
total_deleted = 0
with click.progressbar(length=100, label="Deleting chats") as bar:
while True:
chats = provider.get_chat_conversations(organization_id)
if not chats:
break
uuids_to_delete = [chat["uuid"] for chat in chats[:50]]
deleted, _ = delete_chats(provider, organization_id, uuids_to_delete)
total_deleted += deleted
bar.update(len(uuids_to_delete))
click.echo(f"Chat deletion complete. Total chats deleted: {total_deleted}")
with click.progressbar(length=100, label="Deleting chats") as bar:
while True:
chats = provider.get_chat_conversations(organization_id)
if not chats:
break
uuids_to_delete = [chat["uuid"] for chat in chats[:50]]
deleted, _ = delete_chats(provider, organization_id, uuids_to_delete)
total_deleted += deleted
bar.update(len(uuids_to_delete))
click.echo(f"Chat deletion complete. Total chats deleted: {total_deleted}")


def delete_single_chat(provider, organization_id):
Expand Down Expand Up @@ -184,9 +184,16 @@ def init(config, name, project):
@click.argument("message", nargs=-1, required=True)
@click.option("--chat", help="UUID of the chat to send the message to")
@click.option("--timezone", default="UTC", help="Timezone for the message")
@click.option(
"--model",
help="Model to use for the conversation. Available options:\n"
+ "- claude-3-5-haiku-20241022\n"
+ "- claude-3-opus-20240229\n"
+ "Or any custom model string. If not specified, uses the default model.",
)
@click.pass_obj
@handle_errors
def message(config, message, chat, timezone):
def message(config, message, chat, timezone, model):
"""Send a message to a specified chat or create a new chat and send the message."""
provider = validate_and_get_provider(config, require_project=True)
active_organization_id = config.get("active_organization_id")
Expand All @@ -203,13 +210,14 @@ def message(config, message, chat, timezone):
chat,
active_organization_id,
provider,
model,
)
if chat is None:
return

# Send message and process the streaming response
for event in provider.send_message(
active_organization_id, chat, message, timezone
active_organization_id, chat, message, timezone, model
):
if "completion" in event:
click.echo(event["completion"], nl=False)
Expand All @@ -235,6 +243,7 @@ def create_chat(
chat,
active_organization_id,
provider,
model,
):
if not chat:
if not active_project_name:
Expand All @@ -250,7 +259,7 @@ def create_chat(

# Create a new chat with the selected project
new_chat = provider.create_chat(
active_organization_id, project_uuid=active_project_id
active_organization_id, project_uuid=active_project_id, model=model
)
chat = new_chat["uuid"]
click.echo(f"New chat created with ID: {chat}")
Expand Down
26 changes: 10 additions & 16 deletions src/claudesync/providers/base_claude_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,26 +267,15 @@ def delete_chat(self, organization_id, conversation_uuids):
def _make_request(self, method, endpoint, data=None):
raise NotImplementedError("This method should be implemented by subclasses")

def create_chat(self, organization_id, chat_name="", project_uuid=None):
"""
Create a new chat conversation in the specified organization.
Args:
organization_id (str): The UUID of the organization.
chat_name (str, optional): The name of the chat. Defaults to an empty string.
project_uuid (str, optional): The UUID of the project to associate the chat with. Defaults to None.
Returns:
dict: The created chat conversation data.
Raises:
ProviderError: If the chat creation fails.
"""
def create_chat(self, organization_id, chat_name="", project_uuid=None, model=None):
data = {
"uuid": self._generate_uuid(),
"name": chat_name,
"project_uuid": project_uuid,
}
if model is not None:
data["model"] = model

return self._make_request(
"POST", f"/organizations/{organization_id}/chat_conversations", data
)
Expand All @@ -302,7 +291,9 @@ def _make_request_stream(self, method, endpoint, data=None):
# that can be used with sseclient
raise NotImplementedError("This method should be implemented by subclasses")

def send_message(self, organization_id, chat_id, prompt, timezone="UTC"):
def send_message(
self, organization_id, chat_id, prompt, timezone="UTC", model=None
):
endpoint = (
f"/organizations/{organization_id}/chat_conversations/{chat_id}/completion"
)
Expand All @@ -312,6 +303,9 @@ def send_message(self, organization_id, chat_id, prompt, timezone="UTC"):
"attachments": [],
"files": [],
}
if model is not None:
data["model"] = model

response = self._make_request_stream("POST", endpoint, data)
client = sseclient.SSEClient(response)
for event in client.events():
Expand Down
37 changes: 33 additions & 4 deletions src/claudesync/providers/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,40 @@ def delete_chat(self, organization_id, conversation_uuids):
pass

@abstractmethod
def create_chat(self, organization_id, chat_name="", project_uuid=None):
"""Create a new chat conversation in the specified organization."""
def create_chat(self, organization_id, chat_name="", project_uuid=None, model=None):
"""
Create a new chat conversation in the specified organization.
Args:
organization_id (str): The UUID of the organization.
chat_name (str, optional): The name of the chat. Defaults to an empty string.
project_uuid (str, optional): The UUID of the project to associate the chat with. Defaults to None.
model (str, optional): The chat model to use. Defaults to None.
Returns:
dict: The created chat conversation data.
Raises:
ProviderError: If the chat creation fails.
"""
pass

@abstractmethod
def send_message(self, organization_id, chat_id, prompt, timezone="UTC"):
"""Send a message to a specified chat conversation."""
def send_message(
self, organization_id, chat_id, prompt, timezone="UTC", model=None
):
"""Send a message to a specified chat conversation.
Args:
organization_id (str): The organization ID
chat_id (str): The chat conversation ID
prompt (str): The message to send
timezone (str, optional): The timezone. Defaults to "UTC"
model (str, optional): The model to use. If None, uses the default model.
Available models:
- None (default)
- claude-3-5-haiku-20241022
- claude-3-opus-20240229
- custom string entry
"""
pass

0 comments on commit a723729

Please sign in to comment.