Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jahwag committed Jul 27, 2024
1 parent aa99239 commit 859db15
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 21 deletions.
158 changes: 158 additions & 0 deletions tests/providers/test_claude_ai_curl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import unittest
from unittest.mock import patch, MagicMock
import json
import subprocess
from claudesync.providers.claude_ai_curl import ClaudeAICurlProvider
from claudesync.exceptions import ProviderError


class TestClaudeAICurlProvider(unittest.TestCase):

def setUp(self):
self.provider = ClaudeAICurlProvider("test_session_key")

@patch("subprocess.run")
def test_execute_curl_success(self, mock_run):
mock_result = MagicMock()
mock_result.stdout = '{"key": "value"}'
mock_result.returncode = 0
mock_run.return_value = mock_result

result = self.provider._execute_curl("GET", "/test")

self.assertEqual(result, {"key": "value"})
mock_run.assert_called_once()

@patch("subprocess.run")
def test_execute_curl_failure(self, mock_run):
mock_run.side_effect = subprocess.CalledProcessError(1, "curl", stderr="Test error")

with self.assertRaises(ProviderError):
self.provider._execute_curl("GET", "/test")

@patch("claudesync.providers.claude_ai_curl.click.prompt")
def test_login(self, mock_prompt):
mock_prompt.return_value = "new_session_key"

result = self.provider.login()

self.assertEqual(result, "new_session_key")
self.assertEqual(self.provider.session_key, "new_session_key")

@patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl")
def test_get_organizations(self, mock_execute_curl):
mock_execute_curl.return_value = [
{"uuid": "org1", "name": "Org 1"},
{"uuid": "org2", "name": "Org 2"},
]

result = self.provider.get_organizations()

expected = [{"id": "org1", "name": "Org 1"}, {"id": "org2", "name": "Org 2"}]
self.assertEqual(result, expected)

@patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl")
def test_get_projects(self, mock_execute_curl):
mock_execute_curl.return_value = [
{"uuid": "proj1", "name": "Project 1", "archived_at": None},
{"uuid": "proj2", "name": "Project 2", "archived_at": "2023-01-01"},
]

result = self.provider.get_projects("org1", include_archived=True)

expected = [
{"id": "proj1", "name": "Project 1", "archived_at": None},
{"id": "proj2", "name": "Project 2", "archived_at": "2023-01-01"},
]
self.assertEqual(result, expected)

@patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl")
def test_list_files(self, mock_execute_curl):
mock_execute_curl.return_value = [
{
"uuid": "file1",
"file_name": "test1.txt",
"content": "content1",
"created_at": "2023-01-01",
},
{
"uuid": "file2",
"file_name": "test2.txt",
"content": "content2",
"created_at": "2023-01-02",
},
]

result = self.provider.list_files("org1", "proj1")

expected = [
{
"uuid": "file1",
"file_name": "test1.txt",
"content": "content1",
"created_at": "2023-01-01",
},
{
"uuid": "file2",
"file_name": "test2.txt",
"content": "content2",
"created_at": "2023-01-02",
},
]
self.assertEqual(result, expected)

@patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl")
def test_upload_file(self, mock_execute_curl):
mock_execute_curl.return_value = {"uuid": "new_file", "file_name": "test.txt"}

result = self.provider.upload_file("org1", "proj1", "test.txt", "content")

self.assertEqual(result, {"uuid": "new_file", "file_name": "test.txt"})
mock_execute_curl.assert_called_once_with(
"POST",
"/organizations/org1/projects/proj1/docs",
{"file_name": "test.txt", "content": "content"},
)

@patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl")
def test_delete_file(self, mock_execute_curl):
mock_execute_curl.return_value = {"status": "deleted"}

result = self.provider.delete_file("org1", "proj1", "file1")

self.assertEqual(result, {"status": "deleted"})
mock_execute_curl.assert_called_once_with(
"DELETE", "/organizations/org1/projects/proj1/docs/file1"
)

@patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl")
def test_archive_project(self, mock_execute_curl):
mock_execute_curl.return_value = {"uuid": "proj1", "is_archived": True}

result = self.provider.archive_project("org1", "proj1")

self.assertEqual(result, {"uuid": "proj1", "is_archived": True})
mock_execute_curl.assert_called_once_with(
"PUT", "/organizations/org1/projects/proj1", {"is_archived": True}
)

@patch("claudesync.providers.claude_ai_curl.ClaudeAICurlProvider._execute_curl")
def test_create_project(self, mock_execute_curl):
mock_execute_curl.return_value = {"uuid": "new_proj", "name": "New Project"}

result = self.provider.create_project("org1", "New Project", "Description")

self.assertEqual(result, {"uuid": "new_proj", "name": "New Project"})
mock_execute_curl.assert_called_once_with(
"POST",
"/organizations/org1/projects",
{
"name": "New Project",
"description": "Description",
"is_private": True,
},
)


if __name__ == "__main__":
unittest.main()
60 changes: 39 additions & 21 deletions tests/test_provider_factory.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,59 @@
import unittest
from unittest.mock import patch
from unittest.mock import patch, MagicMock
import pytest
from claudesync.provider_factory import get_provider
from claudesync.providers.claude_ai import ClaudeAIProvider
from claudesync.providers.claude_ai_curl import ClaudeAICurlProvider

class TestProviderFactory:

class TestProviderFactory(unittest.TestCase):

def test_get_provider_list(self):
@pytest.mark.parametrize("provider_name", ["claude.ai", "claude.ai-curl"])
def test_get_provider_list(self, provider_name):
# Test that get_provider returns a list of available providers when called without arguments
providers = get_provider()
self.assertIsInstance(providers, list)
self.assertIn("claude.ai", providers)

def test_get_provider_claude_ai(self):
# Test that get_provider returns a ClaudeAIProvider instance for "claude.ai"
provider = get_provider("claude.ai")
self.assertIsInstance(provider, ClaudeAIProvider)

def test_get_provider_with_session_key(self):
assert isinstance(providers, list)
assert provider_name in providers

@pytest.mark.parametrize("provider_name, expected_class", [
("claude.ai", ClaudeAIProvider),
("claude.ai-curl", ClaudeAICurlProvider)
])
def test_get_provider_instance(self, provider_name, expected_class):
# Test that get_provider returns the correct provider instance
provider = get_provider(provider_name)
assert isinstance(provider, expected_class)

@pytest.mark.parametrize("provider_name, expected_class", [
("claude.ai", ClaudeAIProvider),
("claude.ai-curl", ClaudeAICurlProvider)
])
def test_get_provider_with_session_key(self, provider_name, expected_class):
# Test that get_provider returns a provider instance with a session key
session_key = "test_session_key"
provider = get_provider("claude.ai", session_key)
self.assertIsInstance(provider, ClaudeAIProvider)
self.assertEqual(provider.session_key, session_key)
provider = get_provider(provider_name, session_key)
assert isinstance(provider, expected_class)
assert provider.session_key == session_key

def test_get_provider_unknown(self):
# Test that get_provider raises a ValueError for an unknown provider
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
get_provider("unknown_provider")

@pytest.mark.parametrize("provider_name, expected_class", [
("claude.ai", ClaudeAIProvider),
("claude.ai-curl", ClaudeAICurlProvider)
])
@patch("claudesync.provider_factory.ClaudeAIProvider")
def test_get_provider_calls_constructor(self, mock_claude_ai_provider):
@patch("claudesync.provider_factory.ClaudeAICurlProvider")
def test_get_provider_calls_constructor(self, mock_claude_ai_curl_provider, mock_claude_ai_provider, provider_name, expected_class):
# Test that get_provider calls the provider's constructor
session_key = "test_session_key"
get_provider("claude.ai", session_key)
mock_claude_ai_provider.assert_called_once_with(session_key)
get_provider(provider_name, session_key)

if provider_name == "claude.ai":
mock_claude_ai_provider.assert_called_once_with(session_key)
else:
mock_claude_ai_curl_provider.assert_called_once_with(session_key)

if __name__ == "__main__":
unittest.main()
pytest.main()

0 comments on commit 859db15

Please sign in to comment.