diff --git a/tests/trace_server/test_clickhouse_trace_server_migrator.py b/tests/trace_server/test_clickhouse_trace_server_migrator.py new file mode 100644 index 00000000000..3a6a92f2479 --- /dev/null +++ b/tests/trace_server/test_clickhouse_trace_server_migrator.py @@ -0,0 +1,230 @@ +import types +from unittest.mock import Mock, call, patch + +import pytest + +from weave.trace_server import clickhouse_trace_server_migrator as trace_server_migrator +from weave.trace_server.clickhouse_trace_server_migrator import MigrationError + + +@pytest.fixture +def mock_costs(): + with patch( + "weave.trace_server.costs.insert_costs.should_insert_costs", return_value=False + ) as mock_should_insert: + with patch( + "weave.trace_server.costs.insert_costs.get_current_costs", return_value=[] + ) as mock_get_costs: + yield + + +@pytest.fixture +def migrator(): + ch_client = Mock() + migrator = trace_server_migrator.ClickHouseTraceServerMigrator(ch_client) + migrator._get_migration_status = Mock() + migrator._get_migrations = Mock() + migrator._determine_migrations_to_apply = Mock() + migrator._update_migration_status = Mock() + ch_client.command.reset_mock() + return migrator + + +def test_apply_migrations_with_target_version(mock_costs, migrator, tmp_path): + # Setup + migrator._get_migration_status.return_value = { + "curr_version": 1, + "partially_applied_version": None, + } + migrator._get_migrations.return_value = { + "1": {"up": "1.up.sql", "down": "1.down.sql"}, + "2": {"up": "2.up.sql", "down": "2.down.sql"}, + } + migrator._determine_migrations_to_apply.return_value = [(2, "2.up.sql")] + + # Create a temporary migration file + migration_dir = tmp_path / "migrations" + migration_dir.mkdir() + migration_file = migration_dir / "2.up.sql" + migration_file.write_text( + "CREATE TABLE test1 (id Int32);\nCREATE TABLE test2 (id Int32);" + ) + + # Mock the migration directory path + with patch("os.path.dirname") as mock_dirname: + mock_dirname.return_value = str(tmp_path) + + # Execute + migrator.apply_migrations("test_db", target_version=2) + + # Verify + migrator._get_migration_status.assert_called_once_with("test_db") + migrator._get_migrations.assert_called_once() + migrator._determine_migrations_to_apply.assert_called_once_with( + 1, migrator._get_migrations.return_value, 2 + ) + + # Verify migration execution + assert migrator._update_migration_status.call_count == 2 + migrator._update_migration_status.assert_has_calls( + [call("test_db", 2, is_start=True), call("test_db", 2, is_start=False)] + ) + + # Verify the actual SQL commands were executed + ch_client = migrator.ch_client + assert ch_client.command.call_count == 2 + ch_client.command.assert_has_calls( + [call("CREATE TABLE test1 (id Int32)"), call("CREATE TABLE test2 (id Int32)")] + ) + + +def test_execute_migration_command(mock_costs, migrator): + # Setup + ch_client = migrator.ch_client + ch_client.database = "original_db" + + # Execute + migrator._execute_migration_command("test_db", "CREATE TABLE test (id Int32)") + + # Verify + assert ch_client.database == "original_db" # Should restore original database + ch_client.command.assert_called_once_with("CREATE TABLE test (id Int32)") + + +def test_migration_replicated(mock_costs, migrator): + ch_client = migrator.ch_client + orig = "CREATE TABLE test (id String, project_id String) ENGINE = MergeTree ORDER BY (project_id, id);" + migrator._execute_migration_command("test_db", orig) + ch_client.command.assert_called_once_with(orig) + + +def test_update_migration_status(mock_costs, migrator): + # Don't mock _update_migration_status for this test + migrator._update_migration_status = types.MethodType( + trace_server_migrator.ClickHouseTraceServerMigrator._update_migration_status, + migrator, + ) + + # Test start of migration + migrator._update_migration_status("test_db", 2, is_start=True) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE partially_applied_version = 2 WHERE db_name = 'test_db'" + ) + + # Test end of migration + migrator._update_migration_status("test_db", 2, is_start=False) + migrator.ch_client.command.assert_called_with( + "ALTER TABLE db_management.migrations UPDATE curr_version = 2, partially_applied_version = NULL WHERE db_name = 'test_db'" + ) + + +def test_is_safe_identifier(mock_costs, migrator): + # Valid identifiers + assert migrator._is_safe_identifier("test_db") + assert migrator._is_safe_identifier("my_db123") + assert migrator._is_safe_identifier("db.table") + + # Invalid identifiers + assert not migrator._is_safe_identifier("test-db") + assert not migrator._is_safe_identifier("db;") + assert not migrator._is_safe_identifier("db'name") + assert not migrator._is_safe_identifier("db/*") + + +def test_create_db_sql_validation(mock_costs, migrator): + # Test invalid database name + with pytest.raises(MigrationError, match="Invalid database name"): + migrator._create_db_sql("test;db") + + # Test replicated mode with invalid values + migrator.replicated = True + migrator.replicated_cluster = "test;cluster" + with pytest.raises(MigrationError, match="Invalid cluster name"): + migrator._create_db_sql("test_db") + + migrator.replicated_cluster = "test_cluster" + migrator.replicated_path = "/clickhouse/bad;path/{db}" + with pytest.raises(MigrationError, match="Invalid replicated path"): + migrator._create_db_sql("test_db") + + +def test_create_db_sql_non_replicated(mock_costs, migrator): + # Test non-replicated mode + migrator.replicated = False + sql = migrator._create_db_sql("test_db") + assert sql.strip() == "CREATE DATABASE IF NOT EXISTS test_db" + + +def test_create_db_sql_replicated(mock_costs, migrator): + # Test replicated mode + migrator.replicated = True + migrator.replicated_path = "/clickhouse/tables/{db}" + migrator.replicated_cluster = "test_cluster" + + sql = migrator._create_db_sql("test_db") + expected = """ + CREATE DATABASE IF NOT EXISTS test_db ON CLUSTER test_cluster ENGINE=Replicated('/clickhouse/tables/test_db', '{shard}', '{replica}') + """.strip() + assert sql.strip() == expected + + +def test_format_replicated_sql_non_replicated(mock_costs, migrator): + # Test that SQL is unchanged when replicated=False + migrator.replicated = False + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql + + +def test_format_replicated_sql_replicated(mock_costs, migrator): + # Test that MergeTree engines are converted to Replicated variants + migrator.replicated = True + + test_cases = [ + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE = SummingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedSummingMergeTree", + ), + ( + "CREATE TABLE test (id Int32) ENGINE=ReplacingMergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedReplacingMergeTree", + ), + # Test with extra whitespace + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree", + ), + # Test with parameters + ( + "CREATE TABLE test (id Int32) ENGINE = MergeTree()", + "CREATE TABLE test (id Int32) ENGINE = ReplicatedMergeTree()", + ), + ] + + for input_sql, expected_sql in test_cases: + assert migrator._format_replicated_sql(input_sql) == expected_sql + + +def test_format_replicated_sql_non_mergetree(mock_costs, migrator): + # Test that non-MergeTree engines are left unchanged + migrator.replicated = True + + test_cases = [ + "CREATE TABLE test (id Int32) ENGINE = Memory", + "CREATE TABLE test (id Int32) ENGINE = Log", + "CREATE TABLE test (id Int32) ENGINE = TinyLog", + # This should not be changed as it's not a complete word match + "CREATE TABLE test (id Int32) ENGINE = MyMergeTreeCustom", + ] + + for sql in test_cases: + assert migrator._format_replicated_sql(sql) == sql diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts index 95c7639dc4a..df5eca57b46 100644 --- a/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts +++ b/weave-js/src/common/util/SdkPointCloudToBabylon.test.ts @@ -4,6 +4,7 @@ import { DEFAULT_POINT_COLOR, getFilteringOptionsForPointCloud, getVertexCompatiblePositionsAndColors, + loadPointCloud, MAX_BOUNDING_BOX_LABELS_FOR_DISPLAY, MaxAlphaValue, } from './SdkPointCloudToBabylon'; @@ -174,3 +175,16 @@ describe('getFilteringOptionsForPointCloud', () => { expect(newClassIdToLabel.get(49)).toEqual('label49'); }); }); +describe('loadPointCloud', () => { + it('appropriate defaults set when loading point cloud from file', () => { + const fileContents = JSON.stringify({ + boxes: [], + points: [[]], + type: 'lidar/beta', + vectors: [], + }); + const babylonPointCloud = loadPointCloud(fileContents); + expect(babylonPointCloud.points).toHaveLength(1); + expect(babylonPointCloud.points[0].position).toEqual([0, 0, 0]); + }); +}); diff --git a/weave-js/src/common/util/SdkPointCloudToBabylon.ts b/weave-js/src/common/util/SdkPointCloudToBabylon.ts index 274e1676be4..d52682743ee 100644 --- a/weave-js/src/common/util/SdkPointCloudToBabylon.ts +++ b/weave-js/src/common/util/SdkPointCloudToBabylon.ts @@ -160,7 +160,7 @@ export const handlePoints = (object3D: Object3DScene): ScenePoint[] => { // Draw Points return truncatedPoints.map(point => { const [x, y, z, r, g, b] = point; - const position: Position = [x, y, z]; + const position: Position = [x ?? 0, y ?? 0, z ?? 0]; const category = r; if (r !== undefined && g !== undefined && b !== undefined) { diff --git a/weave-js/src/common/util/render_babylon.ts b/weave-js/src/common/util/render_babylon.ts index 10aee3f6c51..ebd213c2677 100644 --- a/weave-js/src/common/util/render_babylon.ts +++ b/weave-js/src/common/util/render_babylon.ts @@ -394,6 +394,15 @@ const pointCloudScene = ( // Apply vertexData to custom mesh vertexData.applyToMesh(pcMesh); + // A file without any points defined still includes a single, empty "point". + // In order to play nice with Babylon, we position this empty point at 0,0,0. + // Hence, a pointCloud with a single point at 0,0,0 is likely empty. + const isEmpty = + pointCloud.points.length === 1 && + pointCloud.points[0].position[0] === 0 && + pointCloud.points[0].position[1] === 0 && + pointCloud.points[0].position[2] === 0; + camera.parent = pcMesh; const pcMaterial = new Babylon.StandardMaterial('mat', scene); @@ -472,8 +481,8 @@ const pointCloudScene = ( new Vector3(edges.length * 2, edges.length * 2, edges.length * 2) ); - // If we are iterating over camera, target a box - if (index === meta?.cameraIndex) { + // If we are iterating over camera or the cloud is empty, target a box + if (index === meta?.cameraIndex || (index === 0 && isEmpty)) { camera.position = center.add(new Vector3(0, 0, 1000)); camera.target = center; camera.zoomOn([lines]); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx index ef6bcbd69ff..0b3c9603fef 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/feedback/StructuredFeedback/FeedbackSidebar.tsx @@ -96,7 +96,7 @@ export const FeedbackSidebar = ({
Feedback
-
+
{humanAnnotationSpecs.length > 0 ? ( <>
diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx index 484b038c193..a3315266f65 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/CallPage/CallPage.tsx @@ -1,5 +1,4 @@ import Box from '@mui/material/Box'; -import {useViewerInfo} from '@wandb/weave/common/hooks/useViewerInfo'; import {Loading} from '@wandb/weave/components/Loading'; import {urlPrefixed} from '@wandb/weave/config'; import {useViewTraceEvent} from '@wandb/weave/integrations/analytics/useViewEvents'; @@ -71,8 +70,10 @@ export const CallPage: FC<{ }; export const useShowRunnableUI = () => { - const viewerInfo = useViewerInfo(); - return viewerInfo.loading ? false : viewerInfo.userInfo?.admin; + return false; + // Uncomment to re-enable + // const viewerInfo = useViewerInfo(); + // return viewerInfo.loading ? false : viewerInfo.userInfo?.admin; }; const useCallTabs = (call: CallSchema) => { diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx index c8143f9549e..d1a2c59d5d0 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoiceView.tsx @@ -6,15 +6,21 @@ import {Choice} from './types'; type ChoiceViewProps = { choice: Choice; isStructuredOutput?: boolean; + isNested?: boolean; }; -export const ChoiceView = ({choice, isStructuredOutput}: ChoiceViewProps) => { +export const ChoiceView = ({ + choice, + isStructuredOutput, + isNested, +}: ChoiceViewProps) => { const {message} = choice; return ( ); diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx new file mode 100644 index 00000000000..16d6897c27b --- /dev/null +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesDrawer.tsx @@ -0,0 +1,104 @@ +import {Box, Drawer} from '@mui/material'; +import {MOON_200} from '@wandb/weave/common/css/color.styles'; +import {Tag} from '@wandb/weave/components/Tag'; +import {Tailwind} from '@wandb/weave/components/Tailwind'; +import React from 'react'; + +import {Button} from '../../../../../Button'; +import {ChoiceView} from './ChoiceView'; +import {Choice} from './types'; + +type ChoicesDrawerProps = { + choices: Choice[]; + isStructuredOutput?: boolean; + isDrawerOpen: boolean; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; +}; + +export const ChoicesDrawer = ({ + choices, + isStructuredOutput, + isDrawerOpen, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, +}: ChoicesDrawerProps) => { + return ( + setIsDrawerOpen(false)} + title="Choices" + anchor="right" + sx={{ + '& .MuiDrawer-paper': {mt: '60px', width: '400px'}, + }}> + + + Responses + + + ) : ( + + )} +
+ +
+ ))} +
+ + + ); +}; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx index c22df7c63d7..5ddc7f12202 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesView.tsx @@ -1,9 +1,9 @@ import React, {useState} from 'react'; +import {ChoicesDrawer} from './ChoicesDrawer'; import {ChoicesViewCarousel} from './ChoicesViewCarousel'; -import {ChoicesViewLinear} from './ChoicesViewLinear'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewProps = { choices: Choice[]; @@ -14,7 +14,12 @@ export const ChoicesView = ({ choices, isStructuredOutput, }: ChoicesViewProps) => { - const [mode, setMode] = useState('linear'); + const [isDrawerOpen, setIsDrawerOpen] = useState(false); + const [localSelectedChoiceIndex, setLocalSelectedChoiceIndex] = useState(0); + + const handleSetSelectedChoiceIndex = (choiceIndex: number) => { + setLocalSelectedChoiceIndex(choiceIndex); + }; if (choices.length === 0) { return null; @@ -26,20 +31,20 @@ export const ChoicesView = ({ } return ( <> - {mode === 'linear' && ( - - )} - {mode === 'carousel' && ( - - )} + + ); }; diff --git a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx index f4a52fc6801..a34932dea17 100644 --- a/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx +++ b/weave-js/src/components/PagePanelComponents/Home/Browse3/pages/ChatView/ChoicesViewCarousel.tsx @@ -1,34 +1,37 @@ -import React, {useState} from 'react'; +import React from 'react'; import {Button} from '../../../../../Button'; import {ChoiceView} from './ChoiceView'; -import {Choice, ChoicesMode} from './types'; +import {Choice} from './types'; type ChoicesViewCarouselProps = { choices: Choice[]; isStructuredOutput?: boolean; - setMode: React.Dispatch>; + setIsDrawerOpen: React.Dispatch>; + selectedChoiceIndex: number; + setSelectedChoiceIndex: (choiceIndex: number) => void; }; export const ChoicesViewCarousel = ({ choices, isStructuredOutput, - setMode, + setIsDrawerOpen, + selectedChoiceIndex, + setSelectedChoiceIndex, }: ChoicesViewCarouselProps) => { - const [step, setStep] = useState(0); - const onNext = () => { - setStep((step + 1) % choices.length); + setSelectedChoiceIndex((selectedChoiceIndex + 1) % choices.length); }; const onBack = () => { - const newStep = step === 0 ? choices.length - 1 : step - 1; - setStep(newStep); + const newStep = + selectedChoiceIndex === 0 ? choices.length - 1 : selectedChoiceIndex - 1; + setSelectedChoiceIndex(newStep); }; return ( <>
@@ -37,7 +40,7 @@ export const ChoicesViewCarousel = ({ size="small" variant="quiet" icon="expand-uncollapse" - onClick={() => setMode('linear')} + onClick={() => setIsDrawerOpen(true)} tooltip="Switch to linear view" />
@@ -48,7 +51,7 @@ export const ChoicesViewCarousel = ({ size="small" onClick={onBack} /> - {step + 1} of {choices.length} + {selectedChoiceIndex + 1} of {choices.length} - - ))} + {options.map( + ({value: optionValue, icon, isDisabled: optionIsDisabled}) => ( + + + + ) + )} ); diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 86ba7b8653e..0eca3fcbedb 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -239,7 +239,9 @@ def func_name(self) -> str: @property def feedback(self) -> RefFeedbackQuery: if not self.id: - raise ValueError("Can't get feedback for call without ID") + raise ValueError( + "Can't get feedback for call without ID, was `weave.init` called?" + ) if self._feedback is None: try: @@ -253,7 +255,9 @@ def feedback(self) -> RefFeedbackQuery: @property def ui_url(self) -> str: if not self.id: - raise ValueError("Can't get URL for call without ID") + raise ValueError( + "Can't get URL for call without ID, was `weave.init` called?" + ) try: entity, project = self.project_id.split("/") @@ -265,7 +269,9 @@ def ui_url(self) -> str: def ref(self) -> CallRef: entity, project = self.project_id.split("/") if not self.id: - raise ValueError("Can't get ref for call without ID") + raise ValueError( + "Can't get ref for call without ID, was `weave.init` called?" + ) return CallRef(entity, project, self.id) @@ -273,7 +279,9 @@ def ref(self) -> CallRef: def children(self) -> CallsIter: client = weave_client_context.require_weave_client() if not self.id: - raise ValueError("Can't get children of call without ID") + raise ValueError( + "Can't get children of call without ID, was `weave.init` called?" + ) client = weave_client_context.require_weave_client() return CallsIter( diff --git a/weave/trace_server/clickhouse_trace_server_migrator.py b/weave/trace_server/clickhouse_trace_server_migrator.py index 30dffe89365..4336630bf50 100644 --- a/weave/trace_server/clickhouse_trace_server_migrator.py +++ b/weave/trace_server/clickhouse_trace_server_migrator.py @@ -1,6 +1,7 @@ # Clickhouse Trace Server Manager import logging import os +import re from typing import Optional from clickhouse_connect.driver.client import Client as CHClient @@ -9,6 +10,11 @@ logger = logging.getLogger(__name__) +# These settings are only used when `replicated` mode is enabled for +# self managed clickhouse instances. +DEFAULT_REPLICATED_PATH = "/clickhouse/tables/{db}" +DEFAULT_REPLICATED_CLUSTER = "weave_cluster" + class MigrationError(RuntimeError): """Raised when a migration error occurs.""" @@ -16,15 +22,77 @@ class MigrationError(RuntimeError): class ClickHouseTraceServerMigrator: ch_client: CHClient + replicated: bool + replicated_path: str + replicated_cluster: str def __init__( self, ch_client: CHClient, + replicated: Optional[bool] = None, + replicated_path: Optional[str] = None, + replicated_cluster: Optional[str] = None, ): super().__init__() self.ch_client = ch_client + self.replicated = False if replicated is None else replicated + self.replicated_path = ( + DEFAULT_REPLICATED_PATH if replicated_path is None else replicated_path + ) + self.replicated_cluster = ( + DEFAULT_REPLICATED_CLUSTER + if replicated_cluster is None + else replicated_cluster + ) self._initialize_migration_db() + def _is_safe_identifier(self, value: str) -> bool: + """Check if a string is safe to use as an identifier in SQL.""" + return bool(re.match(r"^[a-zA-Z0-9_\.]+$", value)) + + def _format_replicated_sql(self, sql_query: str) -> str: + """Format SQL query to use replicated engines if replicated mode is enabled.""" + if not self.replicated: + return sql_query + + # Match "ENGINE = MergeTree" followed by word boundary + pattern = r"ENGINE\s*=\s*(\w+)?MergeTree\b" + + def replace_engine(match: re.Match[str]) -> str: + engine_prefix = match.group(1) or "" + return f"ENGINE = Replicated{engine_prefix}MergeTree" + + return re.sub(pattern, replace_engine, sql_query, flags=re.IGNORECASE) + + def _create_db_sql(self, db_name: str) -> str: + """Geneate SQL database create string for normal and replicated databases.""" + if not self._is_safe_identifier(db_name): + raise MigrationError(f"Invalid database name: {db_name}") + + replicated_engine = "" + replicated_cluster = "" + if self.replicated: + if not self._is_safe_identifier(self.replicated_cluster): + raise MigrationError(f"Invalid cluster name: {self.replicated_cluster}") + + replicated_path = self.replicated_path.replace("{db}", db_name) + if not all( + self._is_safe_identifier(part) + for part in replicated_path.split("/") + if part + ): + raise MigrationError(f"Invalid replicated path: {replicated_path}") + + replicated_cluster = f" ON CLUSTER {self.replicated_cluster}" + replicated_engine = ( + f" ENGINE=Replicated('{replicated_path}', '{{shard}}', '{{replica}}')" + ) + + create_db_sql = f""" + CREATE DATABASE IF NOT EXISTS {db_name}{replicated_cluster}{replicated_engine} + """ + return create_db_sql + def apply_migrations( self, target_db: str, target_version: Optional[int] = None ) -> None: @@ -46,20 +114,15 @@ def apply_migrations( return logger.info(f"Migrations to apply: {migrations_to_apply}") if status["curr_version"] == 0: - self.ch_client.command(f"CREATE DATABASE IF NOT EXISTS {target_db}") + self.ch_client.command(self._create_db_sql(target_db)) for target_version, migration_file in migrations_to_apply: self._apply_migration(target_db, target_version, migration_file) if should_insert_costs(status["curr_version"], target_version): insert_costs(self.ch_client, target_db) def _initialize_migration_db(self) -> None: - self.ch_client.command( - """ - CREATE DATABASE IF NOT EXISTS db_management - """ - ) - self.ch_client.command( - """ + self.ch_client.command(self._create_db_sql("db_management")) + create_table_sql = """ CREATE TABLE IF NOT EXISTS db_management.migrations ( db_name String, @@ -69,7 +132,7 @@ def _initialize_migration_db(self) -> None: ENGINE = MergeTree() ORDER BY (db_name) """ - ) + self.ch_client.command(self._format_replicated_sql(create_table_sql)) def _get_migration_status(self, db_name: str) -> dict: column_names = ["db_name", "curr_version", "partially_applied_version"] @@ -184,31 +247,48 @@ def _determine_migrations_to_apply( return [] + def _execute_migration_command(self, target_db: str, command: str) -> None: + """Execute a single migration command in the context of the target database.""" + command = command.strip() + if len(command) == 0: + return + curr_db = self.ch_client.database + self.ch_client.database = target_db + self.ch_client.command(self._format_replicated_sql(command)) + self.ch_client.database = curr_db + + def _update_migration_status( + self, target_db: str, target_version: int, is_start: bool = True + ) -> None: + """Update the migration status in db_management.migrations table.""" + if is_start: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}'" + ) + else: + self.ch_client.command( + f"ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}'" + ) + def _apply_migration( self, target_db: str, target_version: int, migration_file: str ) -> None: logger.info(f"Applying migration {migration_file} to `{target_db}`") migration_dir = os.path.join(os.path.dirname(__file__), "migrations") migration_file_path = os.path.join(migration_dir, migration_file) + with open(migration_file_path) as f: migration_sql = f.read() - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE partially_applied_version = {target_version} WHERE db_name = '{target_db}' - """ - ) + + # Mark migration as partially applied + self._update_migration_status(target_db, target_version, is_start=True) + + # Execute each command in the migration migration_sub_commands = migration_sql.split(";") for command in migration_sub_commands: - command = command.strip() - if len(command) == 0: - continue - curr_db = self.ch_client.database - self.ch_client.database = target_db - self.ch_client.command(command) - self.ch_client.database = curr_db - self.ch_client.command( - f""" - ALTER TABLE db_management.migrations UPDATE curr_version = {target_version}, partially_applied_version = NULL WHERE db_name = '{target_db}' - """ - ) + self._execute_migration_command(target_db, command) + + # Mark migration as fully applied + self._update_migration_status(target_db, target_version, is_start=False) + logger.info(f"Migration {migration_file} applied to `{target_db}`")