Skip to content

Commit

Permalink
Merge branch 'master' into adrian/trace_sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
adrnswanberg authored Dec 11, 2024
2 parents 0cd0fc2 + 0b2d99c commit a720f05
Show file tree
Hide file tree
Showing 9 changed files with 488 additions and 134 deletions.
230 changes: 230 additions & 0 deletions tests/trace_server/test_clickhouse_trace_server_migrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import types
from unittest.mock import Mock, call, patch

import pytest

from weave.trace_server import clickhouse_trace_server_migrator as trace_server_migrator
from weave.trace_server.clickhouse_trace_server_migrator import MigrationError


@pytest.fixture
def mock_costs():
with patch(
"weave.trace_server.costs.insert_costs.should_insert_costs", return_value=False
) as mock_should_insert:
with patch(
"weave.trace_server.costs.insert_costs.get_current_costs", return_value=[]
) as mock_get_costs:
yield


@pytest.fixture
def migrator():
ch_client = Mock()
migrator = trace_server_migrator.ClickHouseTraceServerMigrator(ch_client)
migrator._get_migration_status = Mock()
migrator._get_migrations = Mock()
migrator._determine_migrations_to_apply = Mock()
migrator._update_migration_status = Mock()
ch_client.command.reset_mock()
return migrator


def test_apply_migrations_with_target_version(mock_costs, migrator, tmp_path):
# Setup
migrator._get_migration_status.return_value = {
"curr_version": 1,
"partially_applied_version": None,
}
migrator._get_migrations.return_value = {
"1": {"up": "1.up.sql", "down": "1.down.sql"},
"2": {"up": "2.up.sql", "down": "2.down.sql"},
}
migrator._determine_migrations_to_apply.return_value = [(2, "2.up.sql")]

# Create a temporary migration file
migration_dir = tmp_path / "migrations"
migration_dir.mkdir()
migration_file = migration_dir / "2.up.sql"
migration_file.write_text(
"CREATE TABLE test1 (id Int32);\nCREATE TABLE test2 (id Int32);"
)

# Mock the migration directory path
with patch("os.path.dirname") as mock_dirname:
mock_dirname.return_value = str(tmp_path)

# Execute
migrator.apply_migrations("test_db", target_version=2)

# Verify
migrator._get_migration_status.assert_called_once_with("test_db")
migrator._get_migrations.assert_called_once()
migrator._determine_migrations_to_apply.assert_called_once_with(
1, migrator._get_migrations.return_value, 2
)

# Verify migration execution
assert migrator._update_migration_status.call_count == 2
migrator._update_migration_status.assert_has_calls(
[call("test_db", 2, is_start=True), call("test_db", 2, is_start=False)]
)

# Verify the actual SQL commands were executed
ch_client = migrator.ch_client
assert ch_client.command.call_count == 2
ch_client.command.assert_has_calls(
[call("CREATE TABLE test1 (id Int32)"), call("CREATE TABLE test2 (id Int32)")]
)


def test_execute_migration_command(mock_costs, migrator):
# Setup
ch_client = migrator.ch_client
ch_client.database = "original_db"

# Execute
migrator._execute_migration_command("test_db", "CREATE TABLE test (id Int32)")

# Verify
assert ch_client.database == "original_db" # Should restore original database
ch_client.command.assert_called_once_with("CREATE TABLE test (id Int32)")


def test_migration_replicated(mock_costs, migrator):
ch_client = migrator.ch_client
orig = "CREATE TABLE test (id String, project_id String) ENGINE = MergeTree ORDER BY (project_id, id);"
migrator._execute_migration_command("test_db", orig)
ch_client.command.assert_called_once_with(orig)


def test_update_migration_status(mock_costs, migrator):
# Don't mock _update_migration_status for this test
migrator._update_migration_status = types.MethodType(
trace_server_migrator.ClickHouseTraceServerMigrator._update_migration_status,
migrator,
)

# Test start of migration
migrator._update_migration_status("test_db", 2, is_start=True)
migrator.ch_client.command.assert_called_with(
"ALTER TABLE db_management.migrations UPDATE partially_applied_version = 2 WHERE db_name = 'test_db'"
)

# Test end of migration
migrator._update_migration_status("test_db", 2, is_start=False)
migrator.ch_client.command.assert_called_with(
"ALTER TABLE db_management.migrations UPDATE curr_version = 2, partially_applied_version = NULL WHERE db_name = 'test_db'"
)


def test_is_safe_identifier(mock_costs, migrator):
# Valid identifiers
assert migrator._is_safe_identifier("test_db")
assert migrator._is_safe_identifier("my_db123")
assert migrator._is_safe_identifier("db.table")

# Invalid identifiers
assert not migrator._is_safe_identifier("test-db")
assert not migrator._is_safe_identifier("db;")
assert not migrator._is_safe_identifier("db'name")
assert not migrator._is_safe_identifier("db/*")


def test_create_db_sql_validation(mock_costs, migrator):
# Test invalid database name
with pytest.raises(MigrationError, match="Invalid database name"):
migrator._create_db_sql("test;db")

# Test replicated mode with invalid values
migrator.replicated = True
migrator.replicated_cluster = "test;cluster"
with pytest.raises(MigrationError, match="Invalid cluster name"):
migrator._create_db_sql("test_db")

migrator.replicated_cluster = "test_cluster"
migrator.replicated_path = "/clickhouse/bad;path/{db}"
with pytest.raises(MigrationError, match="Invalid replicated path"):
migrator._create_db_sql("test_db")


def test_create_db_sql_non_replicated(mock_costs, migrator):
# Test non-replicated mode
migrator.replicated = False
sql = migrator._create_db_sql("test_db")
assert sql.strip() == "CREATE DATABASE IF NOT EXISTS test_db"


def test_create_db_sql_replicated(mock_costs, migrator):
# Test replicated mode
migrator.replicated = True
migrator.replicated_path = "/clickhouse/tables/{db}"
migrator.replicated_cluster = "test_cluster"

sql = migrator._create_db_sql("test_db")
expected = """
CREATE DATABASE IF NOT EXISTS test_db ON CLUSTER test_cluster ENGINE=Replicated('/clickhouse/tables/test_db', '{shard}', '{replica}')
""".strip()
assert sql.strip() == expected


def test_format_replicated_sql_non_replicated(mock_costs, migrator):
# Test that SQL is unchanged when replicated=False
migrator.replicated = False
test_cases = [
"CREATE TABLE test (id Int32) ENGINE = MergeTree",
"CREATE TABLE test (id Int32) ENGINE = SummingMergeTree",
"CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree",
]

for sql in test_cases:
assert migrator._format_replicated_sql(sql) == sql


def test_format_replicated_sql_replicated(mock_costs, migrator):
# Test that MergeTree engines are converted to Replicated variants
migrator.replicated = True

test_cases = [
(
"CREATE TABLE test (id Int32) ENGINE = MergeTree",
"CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree",
),
(
"CREATE TABLE test (id Int32) ENGINE = SummingMergeTree",
"CREATE TABLE test (id Int32) ENGINE = ReplicatedSummingMergeTree",
),
(
"CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree",
"CREATE TABLE test (id Int32) ENGINE = ReplicatedReplacingMergeTree",
),
# Test with extra whitespace
(
"CREATE TABLE test (id Int32) ENGINE = MergeTree",
"CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree",
),
# Test with parameters
(
"CREATE TABLE test (id Int32) ENGINE = MergeTree()",
"CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree()",
),
]

for input_sql, expected_sql in test_cases:
assert migrator._format_replicated_sql(input_sql) == expected_sql


def test_format_replicated_sql_non_mergetree(mock_costs, migrator):
# Test that non-MergeTree engines are left unchanged
migrator.replicated = True

test_cases = [
"CREATE TABLE test (id Int32) ENGINE = Memory",
"CREATE TABLE test (id Int32) ENGINE = Log",
"CREATE TABLE test (id Int32) ENGINE = TinyLog",
# This should not be changed as it's not a complete word match
"CREATE TABLE test (id Int32) ENGINE = MyMergeTreeCustom",
]

for sql in test_cases:
assert migrator._format_replicated_sql(sql) == sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import Box from '@mui/material/Box';
import {useViewerInfo} from '@wandb/weave/common/hooks/useViewerInfo';
import {Loading} from '@wandb/weave/components/Loading';
import {urlPrefixed} from '@wandb/weave/config';
import {useViewTraceEvent} from '@wandb/weave/integrations/analytics/useViewEvents';
Expand Down Expand Up @@ -71,8 +70,10 @@ export const CallPage: FC<{
};

export const useShowRunnableUI = () => {
const viewerInfo = useViewerInfo();
return viewerInfo.loading ? false : viewerInfo.userInfo?.admin;
return false;
// Uncomment to re-enable
// const viewerInfo = useViewerInfo();
// return viewerInfo.loading ? false : viewerInfo.userInfo?.admin;
};

const useCallTabs = (call: CallSchema) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@ import {Choice} from './types';
type ChoiceViewProps = {
choice: Choice;
isStructuredOutput?: boolean;
isNested?: boolean;
};

export const ChoiceView = ({choice, isStructuredOutput}: ChoiceViewProps) => {
export const ChoiceView = ({
choice,
isStructuredOutput,
isNested,
}: ChoiceViewProps) => {
const {message} = choice;
return (
<MessagePanel
index={choice.index}
message={message}
isStructuredOutput={isStructuredOutput}
isNested={isNested}
isChoice
/>
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import {Box, Drawer} from '@mui/material';
import {MOON_200} from '@wandb/weave/common/css/color.styles';
import {Tag} from '@wandb/weave/components/Tag';
import {Tailwind} from '@wandb/weave/components/Tailwind';
import React from 'react';

import {Button} from '../../../../../Button';
import {ChoiceView} from './ChoiceView';
import {Choice} from './types';

type ChoicesDrawerProps = {
choices: Choice[];
isStructuredOutput?: boolean;
isDrawerOpen: boolean;
setIsDrawerOpen: React.Dispatch<React.SetStateAction<boolean>>;
selectedChoiceIndex: number;
setSelectedChoiceIndex: (choiceIndex: number) => void;
};

export const ChoicesDrawer = ({
choices,
isStructuredOutput,
isDrawerOpen,
setIsDrawerOpen,
selectedChoiceIndex,
setSelectedChoiceIndex,
}: ChoicesDrawerProps) => {
return (
<Drawer
open={isDrawerOpen}
onClose={() => setIsDrawerOpen(false)}
title="Choices"
anchor="right"
sx={{
'& .MuiDrawer-paper': {mt: '60px', width: '400px'},
}}>
<Box
sx={{
position: 'sticky',
top: 0,
zIndex: 1,
px: 2,
height: 44,
width: '100%',
borderBottom: `1px solid ${MOON_200}`,
display: 'flex',
flexDirection: 'row',
alignItems: 'center',
justifyContent: 'space-between',
}}>
<Box
sx={{
height: 44,
display: 'flex',
alignItems: 'center',
fontWeight: 600,
fontSize: '1.25rem',
}}>
Responses
</Box>
<Button
size="medium"
variant="ghost"
icon="close"
onClick={() => setIsDrawerOpen(false)}
tooltip="Close"
/>
</Box>
<Tailwind>
<div className="flex flex-col p-12">
{choices.map((c, index) => (
<div key={c.index}>
<div className="flex items-center gap-4 font-semibold">
<Tag color="moon" label={`Response ${index + 1}`} />
{index === selectedChoiceIndex ? (
<Button
className="text-green-500"
size="small"
variant="ghost"
icon="checkmark">
<span className="text-moon-500">Response selected</span>
</Button>
) : (
<Button
size="small"
variant="ghost"
icon="boolean"
onClick={() => setSelectedChoiceIndex(index)}>
<span className="text-moon-500">Select response</span>
</Button>
)}
</div>
<ChoiceView
choice={c}
isStructuredOutput={isStructuredOutput}
isNested
/>
</div>
))}
</div>
</Tailwind>
</Drawer>
);
};
Loading

0 comments on commit a720f05

Please sign in to comment.