-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: add risingwave engine adapter support for sqlmesh. #3436
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import typing as t | ||
|
||
|
||
from sqlglot import exp, Dialect | ||
|
||
from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter | ||
from sqlmesh.core.engine_adapter.mixins import ( | ||
GetCurrentCatalogFromFunctionMixin, | ||
PandasNativeFetchDFSupportMixin, | ||
) | ||
from sqlmesh.core.engine_adapter.shared import ( | ||
set_catalog, | ||
CatalogSupport, | ||
CommentCreationView, | ||
CommentCreationTable, | ||
) | ||
from sqlmesh.core.schema_diff import SchemaDiffer | ||
|
||
|
||
if t.TYPE_CHECKING: | ||
from sqlmesh.core._typing import SessionProperties | ||
from sqlmesh.core.engine_adapter._typing import DF | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@set_catalog() | ||
class RisingwaveEngineAdapter( | ||
BasePostgresEngineAdapter, | ||
PandasNativeFetchDFSupportMixin, | ||
GetCurrentCatalogFromFunctionMixin, | ||
): | ||
DIALECT = "risingwave" | ||
SUPPORTS_INDEXES = True | ||
HAS_VIEW_BINDING = True | ||
CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") | ||
SUPPORTS_REPLACE_TABLE = False | ||
DEFAULT_BATCH_SIZE = 400 | ||
CATALOG_SUPPORT = CatalogSupport.SINGLE_CATALOG_ONLY | ||
COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY | ||
COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY | ||
SUPPORTS_MATERIALIZED_VIEWS = True | ||
|
||
SCHEMA_DIFFER = SchemaDiffer( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This appears to be copy+pasted directly from the Postgres adapter. Rather than duplicating it, can you extend it instead? |
||
parameterized_type_defaults={ | ||
# DECIMAL without precision is "up to 131072 digits before the decimal point; up to 16383 digits after the decimal point" | ||
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(131072 + 16383, 16383), (0,)], | ||
exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], | ||
exp.DataType.build("TIME", dialect=DIALECT).this: [(6,)], | ||
exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(6,)], | ||
}, | ||
types_with_unlimited_length={ | ||
# all can ALTER to `TEXT` | ||
exp.DataType.build("TEXT", dialect=DIALECT).this: { | ||
exp.DataType.build("VARCHAR", dialect=DIALECT).this, | ||
exp.DataType.build("CHAR", dialect=DIALECT).this, | ||
exp.DataType.build("BPCHAR", dialect=DIALECT).this, | ||
}, | ||
# all can ALTER to unparameterized `VARCHAR` | ||
exp.DataType.build("VARCHAR", dialect=DIALECT).this: { | ||
exp.DataType.build("VARCHAR", dialect=DIALECT).this, | ||
exp.DataType.build("CHAR", dialect=DIALECT).this, | ||
exp.DataType.build("BPCHAR", dialect=DIALECT).this, | ||
exp.DataType.build("TEXT", dialect=DIALECT).this, | ||
}, | ||
# parameterized `BPCHAR(n)` can ALTER to unparameterized `BPCHAR` | ||
exp.DataType.build("BPCHAR", dialect=DIALECT).this: { | ||
exp.DataType.build("BPCHAR", dialect=DIALECT).this | ||
}, | ||
}, | ||
) | ||
|
||
def _set_flush(self) -> None: | ||
sql = "SET RW_IMPLICIT_FLUSH TO true;" | ||
self._execute(sql) | ||
|
||
def __init__( | ||
self, | ||
connection_factory: t.Callable[[], t.Any], | ||
dialect: str = "", | ||
sql_gen_kwargs: t.Optional[t.Dict[str, Dialect | bool | str]] = None, | ||
multithreaded: bool = False, | ||
cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None, | ||
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, | ||
default_catalog: t.Optional[str] = None, | ||
execute_log_level: int = logging.DEBUG, | ||
register_comments: bool = True, | ||
pre_ping: bool = False, | ||
**kwargs: t.Any, | ||
): | ||
super().__init__( | ||
connection_factory, | ||
dialect, | ||
sql_gen_kwargs, | ||
multithreaded, | ||
cursor_kwargs, | ||
cursor_init, | ||
default_catalog, | ||
execute_log_level, | ||
register_comments, | ||
pre_ping, | ||
**kwargs, | ||
) | ||
if hasattr(self, "cursor"): | ||
self._set_flush() | ||
|
||
def _begin_session(self, properties: SessionProperties) -> t.Any: | ||
"""Begin a new session.""" | ||
self._set_flush() | ||
|
||
def _fetch_native_df( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also looks like it was copy+pasted directly from the Postgres adapter... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because risingwave is postgres-compatible, which means postgres engine adapter works for risingwave with few changes. Do you have an idea on how to support it without manually copying? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So instead of copy+pasting the
Is there a reason you cant extend it instead?
|
||
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False | ||
) -> DF: | ||
""" | ||
`read_sql_query` when using psycopg will result on a hanging transaction that must be committed | ||
https://github.com/pandas-dev/pandas/pull/42277 | ||
""" | ||
df = super()._fetch_native_df(query, quote_identifiers) | ||
if not self._connection_pool.is_transaction_active: | ||
self._connection_pool.commit() | ||
return df |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,6 +89,14 @@ gateways: | |
cluster: cluster1 | ||
state_connection: | ||
type: duckdb | ||
inttest_risingwave: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the integration tests to run on CircleCI without the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi Erin, thanks for your help, I have updated and now there are some issues with the integration tests and I am trying to fix them. |
||
connection: | ||
type: risingwave | ||
user: risingwave | ||
password: risingwave | ||
database: risingwave | ||
host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} | ||
port: 5432 | ||
|
||
|
||
# Cloud databases | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
services: | ||
risingwave: | ||
image: risingwavelabs/risingwave:v2.0.1 | ||
ports: | ||
- '5432:5432' | ||
environment: | ||
RISINGWAVE_PASSWORD: risingwave |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# type: ignore | ||
import typing as t | ||
from unittest.mock import call | ||
|
||
import pytest | ||
from sqlglot import parse_one | ||
from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter | ||
|
||
pytestmark = [pytest.mark.engine, pytest.mark.postgres, pytest.mark.risingwave] | ||
|
||
|
||
def test_create_view(make_mocked_engine_adapter: t.Callable): | ||
adapter = make_mocked_engine_adapter(RisingwaveEngineAdapter) | ||
|
||
adapter.create_view("db.view", parse_one("SELECT 1"), replace=True) | ||
adapter.create_view("db.view", parse_one("SELECT 1"), replace=False) | ||
|
||
adapter.cursor.execute.assert_has_calls( | ||
[ | ||
# 1st call | ||
call('DROP VIEW IF EXISTS "db"."view" CASCADE'), | ||
call('CREATE VIEW "db"."view" AS SELECT 1'), | ||
# 2nd call | ||
call('CREATE VIEW "db"."view" AS SELECT 1'), | ||
] | ||
) | ||
|
||
|
||
def test_drop_view(make_mocked_engine_adapter: t.Callable): | ||
adapter = make_mocked_engine_adapter(RisingwaveEngineAdapter) | ||
|
||
adapter.SUPPORTS_MATERIALIZED_VIEWS = True | ||
|
||
adapter.drop_view("db.view") | ||
|
||
adapter.drop_view("db.view", materialized=True) | ||
|
||
adapter.drop_view("db.view", cascade=False) | ||
|
||
adapter.cursor.execute.assert_has_calls( | ||
[ | ||
# 1st call | ||
call('DROP VIEW IF EXISTS "db"."view" CASCADE'), | ||
# 2nd call | ||
call('DROP MATERIALIZED VIEW IF EXISTS "db"."view" CASCADE'), | ||
# 3rd call | ||
call('DROP VIEW IF EXISTS "db"."view"'), | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. If there is nothing different in the connection properties between
RisingWaveConnectionConfig
and thePostgresConnectionConfig
, why not just extendPostgresConnectionConfig
and override the internal properties liketype_
and_engine_adapter
etc instead of copy+paste?Unnecessary copy+paste is considered a code smell and doesn't meet the standards of this codebase