diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py
new file mode 100644
index 00000000000..3a6a92f2479
--- /dev/null
+++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py
@@ -0,0 +1,230 @@
+import types
+from unittest.mock import Mock, call, patch
+
+import pytest
+
+from weave.trace_server import clickhouse_trace_server_migrator as trace_server_migrator
+from weave.trace_server.clickhouse_trace_server_migrator import MigrationError
+
+
+@pytest.fixture
+def mock_costs():
+ with patch(
+ "weave.trace_server.costs.insert_costs.should_insert_costs", return_value=False
+ ) as mock_should_insert:
+ with patch(
+ "weave.trace_server.costs.insert_costs.get_current_costs", return_value=[]
+ ) as mock_get_costs:
+ yield
+
+
+@pytest.fixture
+def migrator():
+ ch_client = Mock()
+ migrator = trace_server_migrator.ClickHouseTraceServerMigrator(ch_client)
+ migrator._get_migration_status = Mock()
+ migrator._get_migrations = Mock()
+ migrator._determine_migrations_to_apply = Mock()
+ migrator._update_migration_status = Mock()
+ ch_client.command.reset_mock()
+ return migrator
+
+
+def test_apply_migrations_with_target_version(mock_costs, migrator, tmp_path):
+ # Setup
+ migrator._get_migration_status.return_value = {
+ "curr_version": 1,
+ "partially_applied_version": None,
+ }
+ migrator._get_migrations.return_value = {
+ "1": {"up": "1.up.sql", "down": "1.down.sql"},
+ "2": {"up": "2.up.sql", "down": "2.down.sql"},
+ }
+ migrator._determine_migrations_to_apply.return_value = [(2, "2.up.sql")]
+
+ # Create a temporary migration file
+ migration_dir = tmp_path / "migrations"
+ migration_dir.mkdir()
+ migration_file = migration_dir / "2.up.sql"
+ migration_file.write_text(
+ "CREATE TABLE test1 (id Int32);\nCREATE TABLE test2 (id Int32);"
+ )
+
+ # Mock the migration directory path
+ with patch("os.path.dirname") as mock_dirname:
+ mock_dirname.return_value = str(tmp_path)
+
+ # Execute
+ migrator.apply_migrations("test_db", target_version=2)
+
+ # Verify
+ migrator._get_migration_status.assert_called_once_with("test_db")
+ migrator._get_migrations.assert_called_once()
+ migrator._determine_migrations_to_apply.assert_called_once_with(
+ 1, migrator._get_migrations.return_value, 2
+ )
+
+ # Verify migration execution
+ assert migrator._update_migration_status.call_count == 2
+ migrator._update_migration_status.assert_has_calls(
+ [call("test_db", 2, is_start=True), call("test_db", 2, is_start=False)]
+ )
+
+ # Verify the actual SQL commands were executed
+ ch_client = migrator.ch_client
+ assert ch_client.command.call_count == 2
+ ch_client.command.assert_has_calls(
+ [call("CREATE TABLE test1 (id Int32)"), call("CREATE TABLE test2 (id Int32)")]
+ )
+
+
+def test_execute_migration_command(mock_costs, migrator):
+ # Setup
+ ch_client = migrator.ch_client
+ ch_client.database = "original_db"
+
+ # Execute
+ migrator._execute_migration_command("test_db", "CREATE TABLE test (id Int32)")
+
+ # Verify
+ assert ch_client.database == "original_db" # Should restore original database
+ ch_client.command.assert_called_once_with("CREATE TABLE test (id Int32)")
+
+
+def test_migration_replicated(mock_costs, migrator):
+ ch_client = migrator.ch_client
+ orig = "CREATE TABLE test (id String, project_id String) ENGINE = MergeTree ORDER BY (project_id, id);"
+ migrator._execute_migration_command("test_db", orig)
+ ch_client.command.assert_called_once_with(orig)
+
+
+def test_update_migration_status(mock_costs, migrator):
+ # Don't mock _update_migration_status for this test
+ migrator._update_migration_status = types.MethodType(
+ trace_server_migrator.ClickHouseTraceServerMigrator._update_migration_status,
+ migrator,
+ )
+
+ # Test start of migration
+ migrator._update_migration_status("test_db", 2, is_start=True)
+ migrator.ch_client.command.assert_called_with(
+ "ALTER TABLE db_management.migrations UPDATE partially_applied_version = 2 WHERE db_name = 'test_db'"
+ )
+
+ # Test end of migration
+ migrator._update_migration_status("test_db", 2, is_start=False)
+ migrator.ch_client.command.assert_called_with(
+ "ALTER TABLE db_management.migrations UPDATE curr_version = 2, partially_applied_version = NULL WHERE db_name = 'test_db'"
+ )
+
+
+def test_is_safe_identifier(mock_costs, migrator):
+ # Valid identifiers
+ assert migrator._is_safe_identifier("test_db")
+ assert migrator._is_safe_identifier("my_db123")
+ assert migrator._is_safe_identifier("db.table")
+
+ # Invalid identifiers
+ assert not migrator._is_safe_identifier("test-db")
+ assert not migrator._is_safe_identifier("db;")
+ assert not migrator._is_safe_identifier("db'name")
+ assert not migrator._is_safe_identifier("db/*")
+
+
+def test_create_db_sql_validation(mock_costs, migrator):
+ # Test invalid database name
+ with pytest.raises(MigrationError, match="Invalid database name"):
+ migrator._create_db_sql("test;db")
+
+ # Test replicated mode with invalid values
+ migrator.replicated = True
+ migrator.replicated_cluster = "test;cluster"
+ with pytest.raises(MigrationError, match="Invalid cluster name"):
+ migrator._create_db_sql("test_db")
+
+ migrator.replicated_cluster = "test_cluster"
+ migrator.replicated_path = "/clickhouse/bad;path/{db}"
+ with pytest.raises(MigrationError, match="Invalid replicated path"):
+ migrator._create_db_sql("test_db")
+
+
+def test_create_db_sql_non_replicated(mock_costs, migrator):
+ # Test non-replicated mode
+ migrator.replicated = False
+ sql = migrator._create_db_sql("test_db")
+ assert sql.strip() == "CREATE DATABASE IF NOT EXISTS test_db"
+
+
+def test_create_db_sql_replicated(mock_costs, migrator):
+ # Test replicated mode
+ migrator.replicated = True
+ migrator.replicated_path = "/clickhouse/tables/{db}"
+ migrator.replicated_cluster = "test_cluster"
+
+ sql = migrator._create_db_sql("test_db")
+ expected = """
+ CREATE DATABASE IF NOT EXISTS test_db ON CLUSTER test_cluster ENGINE=Replicated('/clickhouse/tables/test_db', '{shard}', '{replica}')
+ """.strip()
+ assert sql.strip() == expected
+
+
+def test_format_replicated_sql_non_replicated(mock_costs, migrator):
+ # Test that SQL is unchanged when replicated=False
+ migrator.replicated = False
+ test_cases = [
+ "CREATE TABLE test (id Int32) ENGINE = MergeTree",
+ "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree",
+ "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree",
+ ]
+
+ for sql in test_cases:
+ assert migrator._format_replicated_sql(sql) == sql
+
+
+def test_format_replicated_sql_replicated(mock_costs, migrator):
+ # Test that MergeTree engines are converted to Replicated variants
+ migrator.replicated = True
+
+ test_cases = [
+ (
+ "CREATE TABLE test (id Int32) ENGINE = MergeTree",
+ "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree",
+ ),
+ (
+ "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree",
+ "CREATE TABLE test (id Int32) ENGINE = ReplicatedSummingMergeTree",
+ ),
+ (
+ "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree",
+ "CREATE TABLE test (id Int32) ENGINE = ReplicatedReplacingMergeTree",
+ ),
+ # Test with extra whitespace
+ (
+ "CREATE TABLE test (id Int32) ENGINE = MergeTree",
+ "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree",
+ ),
+ # Test with parameters
+ (
+ "CREATE TABLE test (id Int32) ENGINE = MergeTree()",
+ "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree()",
+ ),
+ ]
+
+ for input_sql, expected_sql in test_cases:
+ assert migrator._format_replicated_sql(input_sql) == expected_sql
+
+
+def test_format_replicated_sql_non_mergetree(mock_costs, migrator):
+ # Test that non-MergeTree engines are left unchanged
+ migrator.replicated = True
+
+ test_cases = [
+ "CREATE TABLE test (id Int32) ENGINE = Memory",
+ "CREATE TABLE test (id Int32) ENGINE = Log",
+ "CREATE TABLE test (id Int32) ENGINE = TinyLog",
+ # This should not be changed as it's not a complete word match
+ "CREATE TABLE test (id Int32) ENGINE = MyMergeTreeCustom",
+ ]
+
+ for sql in test_cases:
+ assert migrator._format_replicated_sql(sql) == sql
diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts
index 95c7639dc4a..df5eca57b46 100644
--- a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts
+++ b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts
@@ -4,6 +4,7 @@ import {
DEFAULT_POINT_COLOR,
getFilteringOptionsForPointCloud,
getVertexCompatiblePositionsAndColors,
+ loadPointCloud,
MAX_BOUNDING_BOX_LABELS_FOR_DISPLAY,
MaxAlphaValue,
} from './SdkPointCloudToBabylon';
@@ -174,3 +175,16 @@ describe('getFilteringOptionsForPointCloud', () => {
expect(newClassIdToLabel.get(49)).toEqual('label49');
});
});
+describe('loadPointCloud', () => {
+ it('appropriate defaults set when loading point cloud from file', () => {
+ const fileContents = JSON.stringify({
+ boxes: [],
+ points: [[]],
+ type: 'lidar/beta',
+ vectors: [],
+ });
+ const babylonPointCloud = loadPointCloud(fileContents);
+ expect(babylonPointCloud.points).toHaveLength(1);
+ expect(babylonPointCloud.points[0].position).toEqual([0, 0, 0]);
+ });
+});
diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.ts
index 274e1676be4..d52682743ee 100644
--- a/weave-js/src/common/util/SdkPointCloudToBabylon.ts
+++ b/weave-js/src/common/util/SdkPointCloudToBabylon.ts
@@ -160,7 +160,7 @@ export const handlePoints = (object3D: Object3DScene): ScenePoint[] => {
// Draw Points
return truncatedPoints.map(point => {
const [x, y, z, r, g, b] = point;
- const position: Position = [x, y, z];
+ const position: Position = [x ?? 0, y ?? 0, z ?? 0];
const category = r;
if (r !== undefined && g !== undefined && b !== undefined) {
diff --git a/weave-js/src/common/util/render_babylon.ts b/weave-js/src/common/util/render_babylon.ts
index 10aee3f6c51..ebd213c2677 100644
--- a/weave-js/src/common/util/render_babylon.ts
+++ b/weave-js/src/common/util/render_babylon.ts
@@ -394,6 +394,15 @@ const pointCloudScene = (
// Apply vertexData to custom mesh
vertexData.applyToMesh(pcMesh);
+ // A file without any points defined still includes a single, empty "point".
+ // In order to play nice with Babylon, we position this empty point at 0,0,0.
+ // Hence, a pointCloud with a single point at 0,0,0 is likely empty.
+ const isEmpty =
+ pointCloud.points.length === 1 &&
+ pointCloud.points[0].position[0] === 0 &&
+ pointCloud.points[0].position[1] === 0 &&
+ pointCloud.points[0].position[2] === 0;
+
camera.parent = pcMesh;
const pcMaterial = new Babylon.StandardMaterial('mat', scene);
@@ -472,8 +481,8 @@ const pointCloudScene = (
new Vector3(edges.length * 2, edges.length * 2, edges.length * 2)
);
- // If we are iterating over camera, target a box
- if (index === meta?.cameraIndex) {
+ // If we are iterating over camera or the cloud is empty, target a box
+ if (index === meta?.cameraIndex || (index === 0 && isEmpty)) {
camera.position = center.add(new Vector3(0, 0, 1000));
camera.target = center;
camera.zoomOn([lines]);
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx
index ef6bcbd69ff..0b3c9603fef 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx
@@ -96,7 +96,7 @@ export const FeedbackSidebar = ({
Feedback
-
+
{humanAnnotationSpecs.length > 0 ? (
<>
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx
index 484b038c193..a3315266f65 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx
@@ -1,5 +1,4 @@
import Box from '@mui/material/Box';
-import {useViewerInfo} from '@wandb/weave/common/hooks/useViewerInfo';
import {Loading} from '@wandb/weave/components/Loading';
import {urlPrefixed} from '@wandb/weave/config';
import {useViewTraceEvent} from '@wandb/weave/integrations/analytics/useViewEvents';
@@ -71,8 +70,10 @@ export const CallPage: FC<{
};
export const useShowRunnableUI = () => {
- const viewerInfo = useViewerInfo();
- return viewerInfo.loading ? false : viewerInfo.userInfo?.admin;
+ return false;
+ // Uncomment to re-enable
+ // const viewerInfo = useViewerInfo();
+ // return viewerInfo.loading ? false : viewerInfo.userInfo?.admin;
};
const useCallTabs = (call: CallSchema) => {
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx
index c8143f9549e..d1a2c59d5d0 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx
@@ -6,15 +6,21 @@ import {Choice} from './types';
type ChoiceViewProps = {
choice: Choice;
isStructuredOutput?: boolean;
+ isNested?: boolean;
};
-export const ChoiceView = ({choice, isStructuredOutput}: ChoiceViewProps) => {
+export const ChoiceView = ({
+ choice,
+ isStructuredOutput,
+ isNested,
+}: ChoiceViewProps) => {
const {message} = choice;
return (
);
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx
new file mode 100644
index 00000000000..16d6897c27b
--- /dev/null
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx
@@ -0,0 +1,104 @@
+import {Box, Drawer} from '@mui/material';
+import {MOON_200} from '@wandb/weave/common/css/color.styles';
+import {Tag} from '@wandb/weave/components/Tag';
+import {Tailwind} from '@wandb/weave/components/Tailwind';
+import React from 'react';
+
+import {Button} from '../../../../../Button';
+import {ChoiceView} from './ChoiceView';
+import {Choice} from './types';
+
+type ChoicesDrawerProps = {
+ choices: Choice[];
+ isStructuredOutput?: boolean;
+ isDrawerOpen: boolean;
+ setIsDrawerOpen: React.Dispatch
>;
+ selectedChoiceIndex: number;
+ setSelectedChoiceIndex: (choiceIndex: number) => void;
+};
+
+export const ChoicesDrawer = ({
+ choices,
+ isStructuredOutput,
+ isDrawerOpen,
+ setIsDrawerOpen,
+ selectedChoiceIndex,
+ setSelectedChoiceIndex,
+}: ChoicesDrawerProps) => {
+ return (
+ setIsDrawerOpen(false)}
+ title="Choices"
+ anchor="right"
+ sx={{
+ '& .MuiDrawer-paper': {mt: '60px', width: '400px'},
+ }}>
+
+
+ Responses
+
+
+
+
+ {choices.map((c, index) => (
+
+
+
+ {index === selectedChoiceIndex ? (
+
+ ) : (
+
+ )}
+
+
+
+ ))}
+
+
+
+ );
+};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx
index c22df7c63d7..5ddc7f12202 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx
@@ -1,9 +1,9 @@
import React, {useState} from 'react';
+import {ChoicesDrawer} from './ChoicesDrawer';
import {ChoicesViewCarousel} from './ChoicesViewCarousel';
-import {ChoicesViewLinear} from './ChoicesViewLinear';
import {ChoiceView} from './ChoiceView';
-import {Choice, ChoicesMode} from './types';
+import {Choice} from './types';
type ChoicesViewProps = {
choices: Choice[];
@@ -14,7 +14,12 @@ export const ChoicesView = ({
choices,
isStructuredOutput,
}: ChoicesViewProps) => {
- const [mode, setMode] = useState('linear');
+ const [isDrawerOpen, setIsDrawerOpen] = useState(false);
+ const [localSelectedChoiceIndex, setLocalSelectedChoiceIndex] = useState(0);
+
+ const handleSetSelectedChoiceIndex = (choiceIndex: number) => {
+ setLocalSelectedChoiceIndex(choiceIndex);
+ };
if (choices.length === 0) {
return null;
@@ -26,20 +31,20 @@ export const ChoicesView = ({
}
return (
<>
- {mode === 'linear' && (
-
- )}
- {mode === 'carousel' && (
-
- )}
+
+
>
);
};
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx
index f4a52fc6801..a34932dea17 100644
--- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx
+++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx
@@ -1,34 +1,37 @@
-import React, {useState} from 'react';
+import React from 'react';
import {Button} from '../../../../../Button';
import {ChoiceView} from './ChoiceView';
-import {Choice, ChoicesMode} from './types';
+import {Choice} from './types';
type ChoicesViewCarouselProps = {
choices: Choice[];
isStructuredOutput?: boolean;
- setMode: React.Dispatch>;
+ setIsDrawerOpen: React.Dispatch>;
+ selectedChoiceIndex: number;
+ setSelectedChoiceIndex: (choiceIndex: number) => void;
};
export const ChoicesViewCarousel = ({
choices,
isStructuredOutput,
- setMode,
+ setIsDrawerOpen,
+ selectedChoiceIndex,
+ setSelectedChoiceIndex,
}: ChoicesViewCarouselProps) => {
- const [step, setStep] = useState(0);
-
const onNext = () => {
- setStep((step + 1) % choices.length);
+ setSelectedChoiceIndex((selectedChoiceIndex + 1) % choices.length);
};
const onBack = () => {
- const newStep = step === 0 ? choices.length - 1 : step - 1;
- setStep(newStep);
+ const newStep =
+ selectedChoiceIndex === 0 ? choices.length - 1 : selectedChoiceIndex - 1;
+ setSelectedChoiceIndex(newStep);
};
return (
<>
@@ -37,7 +40,7 @@ export const ChoicesViewCarousel = ({
size="small"
variant="quiet"
icon="expand-uncollapse"
- onClick={() => setMode('linear')}
+ onClick={() => setIsDrawerOpen(true)}
tooltip="Switch to linear view"
/>
@@ -48,7 +51,7 @@ export const ChoicesViewCarousel = ({
size="small"
onClick={onBack}
/>
- {step + 1} of {choices.length}
+ {selectedChoiceIndex + 1} of {choices.length}