diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index e3b053900..05c7aa653 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -331,6 +331,7 @@ workflows: - spark - clickhouse - clickhouse-cluster + - risingwave - engine_tests_cloud: name: cloud_engine_<< matrix.engine >> context: diff --git a/Makefile b/Makefile index 4c23f3266..cde4f340b 100644 --- a/Makefile +++ b/Makefile @@ -193,6 +193,9 @@ spark-test: engine-spark-up trino-test: engine-trino-up pytest -n auto -x -m "trino or trino_iceberg or trino_delta" --retries 3 --junitxml=test-results/junit-trino.xml +risingwave-test: engine-risingwave-up + pytest -n auto -x -m "risingwave" --retries 3 --junitxml=test-results/junit-risingwave.xml + ################# # Cloud Engines # ################# diff --git a/setup.py b/setup.py index fed5cd158..704e90a6d 100644 --- a/setup.py +++ b/setup.py @@ -158,6 +158,9 @@ "sse-starlette>=0.2.2", "pyarrow", ], + "risingwave": [ + "psycopg2", + ], }, classifiers=[ "Intended Audience :: Developers", diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index b75073f6c..90579fa9e 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -1672,6 +1672,48 @@ def get_catalog(self) -> t.Optional[str]: return self.catalog_name +class RisingwaveConnectionConfig(ConnectionConfig): + host: str + user: str + password: str + port: int + database: str + keepalives_idle: t.Optional[int] = None + connect_timeout: int = 10 + role: t.Optional[str] = None + sslmode: t.Optional[str] = None + + concurrent_tasks: int = 4 + register_comments: bool = True + pre_ping: bool = True + + type_: Literal["risingwave"] = Field(alias="type", default="risingwave") + + @property + def _connection_kwargs_keys(self) -> t.Set[str]: + return { + "host", + "user", + "password", + "port", + "database", + "keepalives_idle", + "connect_timeout", + "role", + "sslmode", + } + + @property + def _engine_adapter(self) -> t.Type[EngineAdapter]: + return engine_adapter.RisingwaveEngineAdapter + + @property + def _connection_factory(self) -> t.Callable: + from psycopg2 import connect + + return connect + + CONNECTION_CONFIG_TO_TYPE = { # Map all subclasses of ConnectionConfig to the value of their `type_` field. tpe.all_field_infos()["type_"].default: tpe diff --git a/sqlmesh/core/engine_adapter/__init__.py b/sqlmesh/core/engine_adapter/__init__.py index 25c45d2e1..19332dc00 100644 --- a/sqlmesh/core/engine_adapter/__init__.py +++ b/sqlmesh/core/engine_adapter/__init__.py @@ -18,6 +18,7 @@ from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter +from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter DIALECT_TO_ENGINE_ADAPTER = { "hive": SparkEngineAdapter, @@ -33,6 +34,7 @@ "mssql": MSSQLEngineAdapter, "trino": TrinoEngineAdapter, "athena": AthenaEngineAdapter, + "risingwave": RisingwaveEngineAdapter, } DIALECT_ALIASES = { diff --git a/sqlmesh/core/engine_adapter/risingwave.py b/sqlmesh/core/engine_adapter/risingwave.py new file mode 100644 index 000000000..618d7975a --- /dev/null +++ b/sqlmesh/core/engine_adapter/risingwave.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import logging +import typing as t + + +from sqlglot import exp, Dialect + +from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter +from sqlmesh.core.engine_adapter.mixins import ( + GetCurrentCatalogFromFunctionMixin, + PandasNativeFetchDFSupportMixin, +) +from sqlmesh.core.engine_adapter.shared import ( + set_catalog, + CatalogSupport, + CommentCreationView, + CommentCreationTable, +) +from sqlmesh.core.schema_diff import SchemaDiffer + + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import SessionProperties + from sqlmesh.core.engine_adapter._typing import DF + +logger = logging.getLogger(__name__) + + +@set_catalog() +class RisingwaveEngineAdapter( + BasePostgresEngineAdapter, + PandasNativeFetchDFSupportMixin, + GetCurrentCatalogFromFunctionMixin, +): + DIALECT = "risingwave" + SUPPORTS_INDEXES = True + HAS_VIEW_BINDING = True + CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") + SUPPORTS_REPLACE_TABLE = False + DEFAULT_BATCH_SIZE = 400 + CATALOG_SUPPORT = CatalogSupport.SINGLE_CATALOG_ONLY + COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY + COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY + SUPPORTS_MATERIALIZED_VIEWS = True + + SCHEMA_DIFFER = SchemaDiffer( + parameterized_type_defaults={ + # DECIMAL without precision is "up to 131072 digits before the decimal point; up to 16383 digits after the decimal point" + exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(131072 + 16383, 16383), (0,)], + exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], + exp.DataType.build("TIME", dialect=DIALECT).this: [(6,)], + exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(6,)], + }, + types_with_unlimited_length={ + # all can ALTER to `TEXT` + exp.DataType.build("TEXT", dialect=DIALECT).this: { + exp.DataType.build("VARCHAR", dialect=DIALECT).this, + exp.DataType.build("CHAR", dialect=DIALECT).this, + exp.DataType.build("BPCHAR", dialect=DIALECT).this, + }, + # all can ALTER to unparameterized `VARCHAR` + exp.DataType.build("VARCHAR", dialect=DIALECT).this: { + exp.DataType.build("VARCHAR", dialect=DIALECT).this, + exp.DataType.build("CHAR", dialect=DIALECT).this, + exp.DataType.build("BPCHAR", dialect=DIALECT).this, + exp.DataType.build("TEXT", dialect=DIALECT).this, + }, + # parameterized `BPCHAR(n)` can ALTER to unparameterized `BPCHAR` + exp.DataType.build("BPCHAR", dialect=DIALECT).this: { + exp.DataType.build("BPCHAR", dialect=DIALECT).this + }, + }, + ) + + def _set_flush(self) -> None: + sql = "SET RW_IMPLICIT_FLUSH TO true;" + self._execute(sql) + + def __init__( + self, + connection_factory: t.Callable[[], t.Any], + dialect: str = "", + sql_gen_kwargs: t.Optional[t.Dict[str, Dialect | bool | str]] = None, + multithreaded: bool = False, + cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None, + cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, + default_catalog: t.Optional[str] = None, + execute_log_level: int = logging.DEBUG, + register_comments: bool = True, + pre_ping: bool = False, + **kwargs: t.Any, + ): + super().__init__( + connection_factory, + dialect, + sql_gen_kwargs, + multithreaded, + cursor_kwargs, + cursor_init, + default_catalog, + execute_log_level, + register_comments, + pre_ping, + **kwargs, + ) + if hasattr(self, "cursor"): + self._set_flush() + + def _begin_session(self, properties: SessionProperties) -> t.Any: + """Begin a new session.""" + self._set_flush() + + def _fetch_native_df( + self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + ) -> DF: + """ + `read_sql_query` when using psycopg will result on a hanging transaction that must be committed + https://github.com/pandas-dev/pandas/pull/42277 + """ + df = super()._fetch_native_df(query, quote_identifiers) + if not self._connection_pool.is_transaction_active: + self._connection_pool.commit() + return df diff --git a/tests/core/engine_adapter/integration/config.yaml b/tests/core/engine_adapter/integration/config.yaml index 77f63d41f..eda4c9340 100644 --- a/tests/core/engine_adapter/integration/config.yaml +++ b/tests/core/engine_adapter/integration/config.yaml @@ -89,6 +89,14 @@ gateways: cluster: cluster1 state_connection: type: duckdb + inttest_risingwave: + connection: + type: risingwave + user: risingwave + password: risingwave + database: risingwave + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + port: 5432 # Cloud databases diff --git a/tests/core/engine_adapter/integration/docker/compose.risingwave.yaml b/tests/core/engine_adapter/integration/docker/compose.risingwave.yaml new file mode 100644 index 000000000..90d1dce35 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/compose.risingwave.yaml @@ -0,0 +1,7 @@ +services: + risingwave: + image: risingwavelabs/risingwave:v2.0.1 + ports: + - '5432:5432' + environment: + RISINGWAVE_PASSWORD: risingwave \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index ca4a3e9d2..3b8eb75af 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -220,6 +220,14 @@ def test_type(request): pytest.mark.athena, ], ), + pytest.param( + "risingwave", + marks=[ + pytest.mark.docker, + pytest.mark.engine, + pytest.mark.risingwave, + ], + ), ] ) def mark_gateway(request) -> t.Tuple[str, str]: diff --git a/tests/core/engine_adapter/test_risingwave.py b/tests/core/engine_adapter/test_risingwave.py new file mode 100644 index 000000000..9b2553917 --- /dev/null +++ b/tests/core/engine_adapter/test_risingwave.py @@ -0,0 +1,49 @@ +# type: ignore +import typing as t +from unittest.mock import call + +import pytest +from sqlglot import parse_one +from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter + +pytestmark = [pytest.mark.engine, pytest.mark.postgres, pytest.mark.risingwave] + + +def test_create_view(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(RisingwaveEngineAdapter) + + adapter.create_view("db.view", parse_one("SELECT 1"), replace=True) + adapter.create_view("db.view", parse_one("SELECT 1"), replace=False) + + adapter.cursor.execute.assert_has_calls( + [ + # 1st call + call('DROP VIEW IF EXISTS "db"."view" CASCADE'), + call('CREATE VIEW "db"."view" AS SELECT 1'), + # 2nd call + call('CREATE VIEW "db"."view" AS SELECT 1'), + ] + ) + + +def test_drop_view(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(RisingwaveEngineAdapter) + + adapter.SUPPORTS_MATERIALIZED_VIEWS = True + + adapter.drop_view("db.view") + + adapter.drop_view("db.view", materialized=True) + + adapter.drop_view("db.view", cascade=False) + + adapter.cursor.execute.assert_has_calls( + [ + # 1st call + call('DROP VIEW IF EXISTS "db"."view" CASCADE'), + # 2nd call + call('DROP MATERIALIZED VIEW IF EXISTS "db"."view" CASCADE'), + # 3rd call + call('DROP VIEW IF EXISTS "db"."view"'), + ] + )