Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add risingwave engine adapter support for sqlmesh. #3436

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/continue_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ workflows:
- spark
- clickhouse
- clickhouse-cluster
- risingwave
- engine_tests_cloud:
name: cloud_engine_<< matrix.engine >>
context:
Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ spark-test: engine-spark-up
trino-test: engine-trino-up
pytest -n auto -x -m "trino or trino_iceberg or trino_delta" --retries 3 --junitxml=test-results/junit-trino.xml

risingwave-test: engine-risingwave-up
pytest -n auto -x -m "risingwave" --retries 3 --junitxml=test-results/junit-risingwave.xml

#################
# Cloud Engines #
#################
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@
"sse-starlette>=0.2.2",
"pyarrow",
],
"risingwave": [
"psycopg2",
],
},
classifiers=[
"Intended Audience :: Developers",
Expand Down
42 changes: 42 additions & 0 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,48 @@ def get_catalog(self) -> t.Optional[str]:
return self.catalog_name


class RisingwaveConnectionConfig(ConnectionConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. If there is nothing different in the connection properties between RisingWaveConnectionConfig and the PostgresConnectionConfig, why not just extend PostgresConnectionConfig and override the internal properties like type_ and _engine_adapter etc instead of copy+paste?

Unnecessary copy+paste is considered a code smell and doesn't meet the standards of this codebase

host: str
user: str
password: str
port: int
database: str
keepalives_idle: t.Optional[int] = None
connect_timeout: int = 10
role: t.Optional[str] = None
sslmode: t.Optional[str] = None

concurrent_tasks: int = 4
register_comments: bool = True
pre_ping: bool = True

type_: Literal["risingwave"] = Field(alias="type", default="risingwave")

@property
def _connection_kwargs_keys(self) -> t.Set[str]:
return {
"host",
"user",
"password",
"port",
"database",
"keepalives_idle",
"connect_timeout",
"role",
"sslmode",
}

@property
def _engine_adapter(self) -> t.Type[EngineAdapter]:
return engine_adapter.RisingwaveEngineAdapter

@property
def _connection_factory(self) -> t.Callable:
from psycopg2 import connect

return connect


CONNECTION_CONFIG_TO_TYPE = {
# Map all subclasses of ConnectionConfig to the value of their `type_` field.
tpe.all_field_infos()["type_"].default: tpe
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/engine_adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter
from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter
from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter

DIALECT_TO_ENGINE_ADAPTER = {
"hive": SparkEngineAdapter,
Expand All @@ -33,6 +34,7 @@
"mssql": MSSQLEngineAdapter,
"trino": TrinoEngineAdapter,
"athena": AthenaEngineAdapter,
"risingwave": RisingwaveEngineAdapter,
}

DIALECT_ALIASES = {
Expand Down
124 changes: 124 additions & 0 deletions sqlmesh/core/engine_adapter/risingwave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

import logging
import typing as t


from sqlglot import exp, Dialect

from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter
from sqlmesh.core.engine_adapter.mixins import (
GetCurrentCatalogFromFunctionMixin,
PandasNativeFetchDFSupportMixin,
)
from sqlmesh.core.engine_adapter.shared import (
set_catalog,
CatalogSupport,
CommentCreationView,
CommentCreationTable,
)
from sqlmesh.core.schema_diff import SchemaDiffer


if t.TYPE_CHECKING:
from sqlmesh.core._typing import SessionProperties
from sqlmesh.core.engine_adapter._typing import DF

logger = logging.getLogger(__name__)


@set_catalog()
class RisingwaveEngineAdapter(
BasePostgresEngineAdapter,
PandasNativeFetchDFSupportMixin,
GetCurrentCatalogFromFunctionMixin,
):
DIALECT = "risingwave"
SUPPORTS_INDEXES = True
HAS_VIEW_BINDING = True
CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog")
SUPPORTS_REPLACE_TABLE = False
DEFAULT_BATCH_SIZE = 400
CATALOG_SUPPORT = CatalogSupport.SINGLE_CATALOG_ONLY
COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY
COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
SUPPORTS_MATERIALIZED_VIEWS = True

SCHEMA_DIFFER = SchemaDiffer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to be copy+pasted directly from the Postgres adapter. Rather than duplicating it, can you extend it instead?

parameterized_type_defaults={
# DECIMAL without precision is "up to 131072 digits before the decimal point; up to 16383 digits after the decimal point"
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(131072 + 16383, 16383), (0,)],
exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)],
exp.DataType.build("TIME", dialect=DIALECT).this: [(6,)],
exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(6,)],
},
types_with_unlimited_length={
# all can ALTER to `TEXT`
exp.DataType.build("TEXT", dialect=DIALECT).this: {
exp.DataType.build("VARCHAR", dialect=DIALECT).this,
exp.DataType.build("CHAR", dialect=DIALECT).this,
exp.DataType.build("BPCHAR", dialect=DIALECT).this,
},
# all can ALTER to unparameterized `VARCHAR`
exp.DataType.build("VARCHAR", dialect=DIALECT).this: {
exp.DataType.build("VARCHAR", dialect=DIALECT).this,
exp.DataType.build("CHAR", dialect=DIALECT).this,
exp.DataType.build("BPCHAR", dialect=DIALECT).this,
exp.DataType.build("TEXT", dialect=DIALECT).this,
},
# parameterized `BPCHAR(n)` can ALTER to unparameterized `BPCHAR`
exp.DataType.build("BPCHAR", dialect=DIALECT).this: {
exp.DataType.build("BPCHAR", dialect=DIALECT).this
},
},
)

def _set_flush(self) -> None:
sql = "SET RW_IMPLICIT_FLUSH TO true;"
self._execute(sql)

def __init__(
self,
connection_factory: t.Callable[[], t.Any],
dialect: str = "",
sql_gen_kwargs: t.Optional[t.Dict[str, Dialect | bool | str]] = None,
multithreaded: bool = False,
cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
default_catalog: t.Optional[str] = None,
execute_log_level: int = logging.DEBUG,
register_comments: bool = True,
pre_ping: bool = False,
**kwargs: t.Any,
):
super().__init__(
connection_factory,
dialect,
sql_gen_kwargs,
multithreaded,
cursor_kwargs,
cursor_init,
default_catalog,
execute_log_level,
register_comments,
pre_ping,
**kwargs,
)
if hasattr(self, "cursor"):
self._set_flush()

def _begin_session(self, properties: SessionProperties) -> t.Any:
"""Begin a new session."""
self._set_flush()

def _fetch_native_df(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also looks like it was copy+pasted directly from the Postgres adapter...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because risingwave is postgres-compatible, which means postgres engine adapter works for risingwave with few changes. Do you have an idea on how to support it without manually copying?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So instead of copy+pasting the PostgresEngineAdapter class definition:

class RisingwaveEngineAdapter(
    BasePostgresEngineAdapter,
    PandasNativeFetchDFSupportMixin,
    GetCurrentCatalogFromFunctionMixin,
):

Is there a reason you cant extend it instead?

class RisingwaveEngineAdapter(PostgresEngineAdapter):

self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
) -> DF:
"""
`read_sql_query` when using psycopg will result on a hanging transaction that must be committed
https://github.com/pandas-dev/pandas/pull/42277
"""
df = super()._fetch_native_df(query, quote_identifiers)
if not self._connection_pool.is_transaction_active:
self._connection_pool.commit()
return df
8 changes: 8 additions & 0 deletions tests/core/engine_adapter/integration/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ gateways:
cluster: cluster1
state_connection:
type: duckdb
inttest_risingwave:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the integration tests to run on CircleCI without the port 5432 failed: server closed the connection unexpectedly error, you'll also need to add a branch for risingwave to .circleci/wait-for-db.sh to block until the DB server is available

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Erin, thanks for your help, I have updated and now there are some issues with the integration tests and I am trying to fix them.

connection:
type: risingwave
user: risingwave
password: risingwave
database: risingwave
host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }}
port: 5432


# Cloud databases
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
services:
risingwave:
image: risingwavelabs/risingwave:v2.0.1
ports:
- '5432:5432'
environment:
RISINGWAVE_PASSWORD: risingwave
8 changes: 8 additions & 0 deletions tests/core/engine_adapter/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ def test_type(request):
pytest.mark.athena,
],
),
pytest.param(
"risingwave",
marks=[
pytest.mark.docker,
pytest.mark.engine,
pytest.mark.risingwave,
],
),
]
)
def mark_gateway(request) -> t.Tuple[str, str]:
Expand Down
49 changes: 49 additions & 0 deletions tests/core/engine_adapter/test_risingwave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# type: ignore
import typing as t
from unittest.mock import call

import pytest
from sqlglot import parse_one
from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter

pytestmark = [pytest.mark.engine, pytest.mark.postgres, pytest.mark.risingwave]


def test_create_view(make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(RisingwaveEngineAdapter)

adapter.create_view("db.view", parse_one("SELECT 1"), replace=True)
adapter.create_view("db.view", parse_one("SELECT 1"), replace=False)

adapter.cursor.execute.assert_has_calls(
[
# 1st call
call('DROP VIEW IF EXISTS "db"."view" CASCADE'),
call('CREATE VIEW "db"."view" AS SELECT 1'),
# 2nd call
call('CREATE VIEW "db"."view" AS SELECT 1'),
]
)


def test_drop_view(make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(RisingwaveEngineAdapter)

adapter.SUPPORTS_MATERIALIZED_VIEWS = True

adapter.drop_view("db.view")

adapter.drop_view("db.view", materialized=True)

adapter.drop_view("db.view", cascade=False)

adapter.cursor.execute.assert_has_calls(
[
# 1st call
call('DROP VIEW IF EXISTS "db"."view" CASCADE'),
# 2nd call
call('DROP MATERIALIZED VIEW IF EXISTS "db"."view" CASCADE'),
# 3rd call
call('DROP VIEW IF EXISTS "db"."view"'),
]
)