From 0b23119f832eae056c472f1ed258893e92ce8179 Mon Sep 17 00:00:00 2001 From: Jahziah Wagner Date: Thu, 1 Aug 2024 22:57:03 +0200 Subject: [PATCH] fixup! Sync chats and artifacts --- .gitignore | 3 +- .../metadata.json | 20 +++ src/claudesync/chat_sync.py | 155 +++++++++--------- src/claudesync/cli/chat.py | 2 - src/claudesync/config_manager.py | 8 +- src/claudesync/providers/base_claude_ai.py | 2 +- tests/test_chat_sync.py | 71 ++++++++ 7 files changed, 177 insertions(+), 84 deletions(-) create mode 100644 chats/57506f44-16c7-4a0d-aac8-43f4f770c6e7/metadata.json create mode 100644 tests/test_chat_sync.py diff --git a/.gitignore b/.gitignore index b1b513e..7d8bcf8 100644 --- a/.gitignore +++ b/.gitignore @@ -168,4 +168,5 @@ __pycache__ # claude claude.sync config.json -claudesync.log \ No newline at end of file +claudesync.log +chats \ No newline at end of file diff --git a/chats/57506f44-16c7-4a0d-aac8-43f4f770c6e7/metadata.json b/chats/57506f44-16c7-4a0d-aac8-43f4f770c6e7/metadata.json new file mode 100644 index 0000000..5c6d832 --- /dev/null +++ b/chats/57506f44-16c7-4a0d-aac8-43f4f770c6e7/metadata.json @@ -0,0 +1,20 @@ +{ + "uuid": "57506f44-16c7-4a0d-aac8-43f4f770c6e7", + "name": "", + "summary": "", + "model": null, + "created_at": "2024-08-01T12:18:07.059322+00:00", + "updated_at": "2024-08-01T12:18:07.059322+00:00", + "settings": { + "preview_feature_uses_artifacts": true, + "preview_feature_uses_latex": null, + "preview_feature_uses_citations": null + }, + "is_starred": false, + "project_uuid": "5bdc9859-67b1-4024-ae16-67b7375097f5", + "current_leaf_message_uuid": null, + "project": { + "uuid": "5bdc9859-67b1-4024-ae16-67b7375097f5", + "name": "claudesync" + } +} \ No newline at end of file diff --git a/src/claudesync/chat_sync.py b/src/claudesync/chat_sync.py index f753137..1556e3c 100644 --- a/src/claudesync/chat_sync.py +++ b/src/claudesync/chat_sync.py @@ -1,6 +1,8 @@ import os import json import logging +import re + from tqdm import tqdm from .config_manager import ConfigManager from .exceptions import ConfigurationError @@ -8,7 +10,7 @@ logger = logging.getLogger(__name__) -def sync_chats(provider, config): +def sync_chats(provider, config, sync_all=False): """ Synchronize chats and their artifacts from the remote source. @@ -18,19 +20,22 @@ def sync_chats(provider, config): Args: provider: The API provider instance. config: The configuration manager instance. + sync_all (bool): If True, sync all chats regardless of project. If False, only sync chats for the active project. Raises: ConfigurationError: If required configuration settings are missing. """ - # Get the configured destinations for chats and artifacts - chat_destination = config.get("chat_destination") - artifact_destination = config.get("artifact_destination") - if not chat_destination or not artifact_destination: + # Get the local_path for chats + local_path = config.get("local_path") + if not local_path: raise ConfigurationError( - "Chat or artifact destination not set. Use 'claudesync config set chat_destination ' and " - "'claudesync config set artifact_destination ' to set them." + "Local path not set. Use 'claudesync project select' or 'claudesync project create' to set it." ) + # Create chats directory within local_path + chat_destination = os.path.join(local_path, "chats") + os.makedirs(chat_destination, exist_ok=True) + # Get the active organization ID organization_id = config.get("active_organization_id") if not organization_id: @@ -38,51 +43,66 @@ def sync_chats(provider, config): "No active organization set. Please select an organization." ) + # Get the active project ID + active_project_id = config.get("active_project_id") + if not active_project_id and not sync_all: + raise ConfigurationError( + "No active project set. Please select a project or use the -a flag to sync all chats." + ) + # Fetch all chats for the organization - logger.info(f"Fetching chats for organization {organization_id}") + logger.debug(f"Fetching chats for organization {organization_id}") chats = provider.get_chat_conversations(organization_id) - logger.info(f"Found {len(chats)} chats") + logger.debug(f"Found {len(chats)} chats") # Process each chat for chat in tqdm(chats, desc="Syncing chats"): - logger.info(f"Processing chat {chat['uuid']}") - chat_folder = os.path.join(chat_destination, chat["uuid"]) - os.makedirs(chat_folder, exist_ok=True) - - # Save chat metadata - with open(os.path.join(chat_folder, "metadata.json"), "w") as f: - json.dump(chat, f, indent=2) - - # Fetch full chat conversation - logger.info(f"Fetching full conversation for chat {chat['uuid']}") - full_chat = provider.get_chat_conversation(organization_id, chat["uuid"]) - - # Process each message in the chat - for message in full_chat["chat_messages"]: - # Save the message - message_file = os.path.join(chat_folder, f"{message['uuid']}.json") - with open(message_file, "w") as f: - json.dump(message, f, indent=2) - - # Handle artifacts in assistant messages - if message["sender"] == "assistant": - artifacts = extract_artifacts(message["text"]) - if artifacts: - logger.info( - f"Found {len(artifacts)} artifacts in message {message['uuid']}" - ) - for artifact in artifacts: - # Save each artifact - artifact_file = os.path.join( - artifact_destination, - f"{artifact['identifier']}.{get_file_extension(artifact['type'])}", + # Check if the chat belongs to the active project or if we're syncing all chats + if sync_all or ( + chat.get("project") and chat["project"].get("uuid") == active_project_id + ): + logger.info(f"Processing chat {chat['uuid']}") + chat_folder = os.path.join(chat_destination, chat["uuid"]) + os.makedirs(chat_folder, exist_ok=True) + + # Save chat metadata + with open(os.path.join(chat_folder, "metadata.json"), "w") as f: + json.dump(chat, f, indent=2) + + # Fetch full chat conversation + logger.debug(f"Fetching full conversation for chat {chat['uuid']}") + full_chat = provider.get_chat_conversation(organization_id, chat["uuid"]) + + # Process each message in the chat + for message in full_chat["chat_messages"]: + # Save the message + message_file = os.path.join(chat_folder, f"{message['uuid']}.json") + with open(message_file, "w") as f: + json.dump(message, f, indent=2) + + # Handle artifacts in assistant messages + if message["sender"] == "assistant": + artifacts = extract_artifacts(message["text"]) + if artifacts: + logger.info( + f"Found {len(artifacts)} artifacts in message {message['uuid']}" ) - os.makedirs(os.path.dirname(artifact_file), exist_ok=True) - with open(artifact_file, "w") as f: - f.write(artifact["content"]) - - logger.info(f"Chats synchronized to {chat_destination}") - logger.info(f"Artifacts synchronized to {artifact_destination}") + artifact_folder = os.path.join(chat_folder, "artifacts") + os.makedirs(artifact_folder, exist_ok=True) + for artifact in artifacts: + # Save each artifact + artifact_file = os.path.join( + artifact_folder, + f"{artifact['identifier']}.{get_file_extension(artifact['type'])}", + ) + with open(artifact_file, "w") as f: + f.write(artifact["content"]) + else: + logger.debug( + f"Skipping chat {chat['uuid']} as it doesn't belong to the active project" + ) + + logger.debug(f"Chats and artifacts synchronized to {chat_destination}") def get_file_extension(artifact_type): @@ -119,39 +139,24 @@ def extract_artifacts(text): list: A list of dictionaries containing artifact information. """ artifacts = [] - start_tag = '' - while start_tag in text: - start = text.index(start_tag) - end = text.index(end_tag, start) + len(end_tag) + # Regular expression to match the tags and extract their attributes and content + pattern = re.compile( + r'([\s\S]*?)', + re.MULTILINE, + ) - artifact_text = text[start:end] - identifier = extract_attribute(artifact_text, "identifier") - artifact_type = extract_attribute(artifact_text, "type") - content = artifact_text[ - artifact_text.index(">") + 1 : artifact_text.rindex("<") - ] + # Find all matches in the text + matches = pattern.findall(text) + for match in matches: + identifier, artifact_type, title, content = match artifacts.append( - {"identifier": identifier, "type": artifact_type, "content": content} + { + "identifier": identifier, + "type": artifact_type, + "content": content.strip(), + } ) - text = text[end:] - return artifacts - - -def extract_attribute(text, attribute): - """ - Extract the value of a specific attribute from an XML-like tag. - - Args: - text (str): The XML-like tag text. - attribute (str): The name of the attribute to extract. - - Returns: - str: The value of the specified attribute. - """ - start = text.index(f'{attribute}="') + len(f'{attribute}="') - end = text.index('"', start) - return text[start:end] diff --git a/src/claudesync/cli/chat.py b/src/claudesync/cli/chat.py index 737c585..37c7dd2 100644 --- a/src/claudesync/cli/chat.py +++ b/src/claudesync/cli/chat.py @@ -1,7 +1,5 @@ import click import logging -import json -from tqdm import tqdm from ..exceptions import ProviderError from ..utils import handle_errors, validate_and_get_provider from ..chat_sync import sync_chats diff --git a/src/claudesync/config_manager.py b/src/claudesync/config_manager.py index dd1c474..469e80c 100644 --- a/src/claudesync/config_manager.py +++ b/src/claudesync/config_manager.py @@ -40,8 +40,6 @@ def _get_default_config(self): "upload_delay": 0.5, "max_file_size": 32 * 1024, # Default 32 KB "two_way_sync": False, # Default to False - "chat_destination": str(Path.home() / ".claudesync" / "chats"), - "artifact_destination": str(Path.home() / ".claudesync" / "artifacts"), } def _load_config(self): @@ -64,7 +62,7 @@ def _load_config(self): for key, value in defaults.items(): if key not in config: config[key] = value - elif key in ["chat_destination", "artifact_destination"]: + elif key == "chat_destination": # Expand user home directory for path-based settings config[key] = str(Path(config[key]).expanduser()) return config @@ -96,7 +94,7 @@ def set(self, key, value): """ Sets a configuration value and saves the configuration. - For path-based settings (chat_destination and artifact_destination), this method expands the user's home directory. + For path-based settings (chat_destination), this method expands the user's home directory. Args: key (str): The key for the configuration setting to set. @@ -104,7 +102,7 @@ def set(self, key, value): This method updates the configuration with the provided key-value pair and then saves the configuration to the file. """ - if key in ["chat_destination", "artifact_destination"]: + if key == "chat_destination": # Expand user home directory for path-based settings value = str(Path(value).expanduser()) self.config[key] = value diff --git a/src/claudesync/providers/base_claude_ai.py b/src/claudesync/providers/base_claude_ai.py index f08a3e4..170b54b 100644 --- a/src/claudesync/providers/base_claude_ai.py +++ b/src/claudesync/providers/base_claude_ai.py @@ -101,7 +101,7 @@ def get_published_artifacts(self, organization_id): def get_chat_conversation(self, organization_id, conversation_id): return self._make_request( "GET", - f"/organizations/{organization_id}/chat_conversations/{conversation_id}", + f"/organizations/{organization_id}/chat_conversations/{conversation_id}?rendering_mode=raw", ) def get_artifact_content(self, organization_id, artifact_uuid): diff --git a/tests/test_chat_sync.py b/tests/test_chat_sync.py new file mode 100644 index 0000000..f3ae1e7 --- /dev/null +++ b/tests/test_chat_sync.py @@ -0,0 +1,71 @@ +import textwrap +import unittest + +from claudesync.chat_sync import extract_artifacts + + +class TestExtractArtifacts(unittest.TestCase): + + def test_extract_single_artifact(self): + text = """ + Here is some introductory text. + + + + Test + Test Content + + + + Some concluding text. + """ + expected_result = [ + { + "identifier": "test-id", + "type": "text/html", + "content": "\nTest\nTest Content\n", + } + ] + self.assertEqual(extract_artifacts(textwrap.dedent(text)), expected_result) + + def test_extract_multiple_artifacts(self): + text = """ + Here is some introductory text. + + + First artifact content. + + + Some middle text. + + + + User + ChatGPT + Reminder + Don't forget to check your email! + + + + Some concluding text. + """ + expected_result = [ + { + "identifier": "first-id", + "type": "text/plain", + "content": "First artifact content.", + }, + { + "identifier": "second-id", + "type": "text/xml", + "content": "\nUser\nChatGPT\nReminder\nDon't forget to check your email!\n", + }, + ] + self.assertEqual(extract_artifacts(textwrap.dedent(text)), expected_result) + + def test_no_artifacts(self): + text = """ + Here is some text without any artifacts. + """ + expected_result = [] + self.assertEqual(extract_artifacts(text), expected_result)