Skip to content

Commit

Permalink
Implement spark connect connection method
Browse files Browse the repository at this point in the history
rebase with upstream/main

refactor and improve integration test

fix integration tests

refactor and add grpc url parameters

url build fixes

Add import pandas and fix typo in build-dist

rebase on upstream master

Fix typos in tox.ini

remove some tests and remove session UUID generation

revert session URL hardcoring, add comments

fix commiter email

delete todo in circleci
  • Loading branch information
vakarisbk committed Oct 3, 2023
1 parent 992de28 commit 38bfada
Show file tree
Hide file tree
Showing 20 changed files with 237 additions and 53 deletions.
30 changes: 30 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ jobs:
- store_artifacts:
path: ./logs

integration-spark-connect:
environment:
DBT_INVOCATION_ENV: circle
docker:
- image: eclipse-temurin:11.0.20.1_1-jre-jammy
steps:
- run:
name: install python, pip and git, tox
command: apt update && apt install -y python3.10 pip git
- run:
name: install pyspark and tox
command: pip install pyspark==3.5.0 tox
- run:
name: start connect server
command: |
spark-submit --class org.apache.spark.sql.connect.service.SparkConnectServer \
--conf spark.sql.catalogImplementation=hive \
--packages org.apache.spark:spark-connect_2.12:3.5.0
background: true
- checkout
- run:
name: Run integration tests
command: tox -e integration-spark-connect
no_output_timeout: 1h
- store_artifacts:
path: /tmp/logs

integration-spark-thrift:
environment:
DBT_INVOCATION_ENV: circle
Expand Down Expand Up @@ -120,6 +147,9 @@ workflows:
- integration-spark-session:
requires:
- unit
- integration-spark-connect:
requires:
- unit
- integration-spark-thrift:
requires:
- unit
Expand Down
63 changes: 61 additions & 2 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class SparkConnectionMethod(StrEnum):
HTTP = "http"
ODBC = "odbc"
SESSION = "session"
CONNECT = "connect"


@dataclass
Expand All @@ -77,7 +78,6 @@ class SparkCredentials(Credentials):
connect_timeout: int = 10
use_ssl: bool = False
server_side_parameters: Dict[str, str] = field(default_factory=dict)
retry_all: bool = False

@classmethod
def __pre_deserialize__(cls, data: Any) -> Any:
Expand Down Expand Up @@ -149,6 +149,21 @@ def __post_init__(self) -> None:
f"ImportError({e.msg})"
) from e

if self.method == SparkConnectionMethod.CONNECT:
try:
import pyspark # noqa: F401 F811
import grpc # noqa: F401
import pyarrow # noqa: F401
import pandas # noqa: F401
except ImportError as e:
raise dbt.exceptions.DbtRuntimeError(
f"{self.method} connection method requires "
"additional dependencies. \n"
"Install the additional required dependencies with "
"`pip install dbt-spark[connect]`\n\n"
f"ImportError({e.msg})"
) from e

if self.method != SparkConnectionMethod.SESSION:
self.host = self.host.rstrip("/")

Expand Down Expand Up @@ -521,8 +536,52 @@ def open(cls, connection: Connection) -> Connection:
SessionConnectionWrapper,
)

# Pass session type (session or connect) into SessionConnectionWrapper
handle = SessionConnectionWrapper(
Connection(server_side_parameters=creds.server_side_parameters)
Connection(
conn_method=creds.method,
conn_url="localhost",
server_side_parameters=creds.server_side_parameters,
)
)
elif SparkConnectionMethod.CONNECT:
# Create the url

host = creds.host
port = creds.port
token = creds.token
use_ssl = creds.use_ssl
user = creds.user

# URL Format: sc://localhost:15002/;user_id=str;token=str;use_ssl=bool
if not host.startswith("sc://"):
base_url = f"sc://{host}"
base_url += f":{str(port)}"

url_extensions = []
if user:
url_extensions.append(f"user_id={user}")
if use_ssl:
url_extensions.append(f"use_ssl={use_ssl}")
if token:
url_extensions.append(f"token={token}")

conn_url = base_url + ";".join(url_extensions)

logger.debug("connection url: {}".format(conn_url))

from .session import ( # noqa: F401
Connection,
SessionConnectionWrapper,
)

# Pass session type (session or connect) into SessionConnectionWrapper
handle = SessionConnectionWrapper(
Connection(
conn_method=creds.method,
conn_url=conn_url,
server_side_parameters=creds.server_side_parameters,
)
)
else:
raise dbt.exceptions.DbtProfileError(
Expand Down
35 changes: 29 additions & 6 deletions dbt/adapters/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence

from dbt.adapters.spark.connections import SparkConnectionWrapper
from dbt.adapters.spark.connections import SparkConnectionMethod, SparkConnectionWrapper
from dbt.events import AdapterLogger
from dbt.utils import DECIMALS
from dbt.exceptions import DbtRuntimeError
Expand All @@ -27,9 +27,17 @@ class Cursor:
https://github.com/mkleehammer/pyodbc/wiki/Cursor
"""

def __init__(self, *, server_side_parameters: Optional[Dict[str, Any]] = None) -> None:
def __init__(
self,
*,
conn_method: SparkConnectionMethod,
conn_url: str,
server_side_parameters: Optional[Dict[str, Any]] = None,
) -> None:
self._df: Optional[DataFrame] = None
self._rows: Optional[List[Row]] = None
self.conn_method: SparkConnectionMethod = conn_method
self.conn_url: str = conn_url
self.server_side_parameters = server_side_parameters or {}

def __enter__(self) -> Cursor:
Expand Down Expand Up @@ -113,12 +121,15 @@ def execute(self, sql: str, *parameters: Any) -> None:
if len(parameters) > 0:
sql = sql % parameters

builder = SparkSession.builder.enableHiveSupport()
builder = SparkSession.builder

for parameter, value in self.server_side_parameters.items():
builder = builder.config(parameter, value)

spark_session = builder.getOrCreate()
if self.conn_method == SparkConnectionMethod.CONNECT:
spark_session = builder.remote(self.conn_url).getOrCreate()
elif self.conn_method == SparkConnectionMethod.SESSION:
spark_session = builder.enableHiveSupport().getOrCreate()

try:
self._df = spark_session.sql(sql)
Expand Down Expand Up @@ -175,7 +186,15 @@ class Connection:
https://github.com/mkleehammer/pyodbc/wiki/Connection
"""

def __init__(self, *, server_side_parameters: Optional[Dict[Any, str]] = None) -> None:
def __init__(
self,
*,
conn_method: SparkConnectionMethod,
conn_url: str,
server_side_parameters: Optional[Dict[Any, str]] = None,
) -> None:
self.conn_method = conn_method
self.conn_url = conn_url
self.server_side_parameters = server_side_parameters or {}

def cursor(self) -> Cursor:
Expand All @@ -187,7 +206,11 @@ def cursor(self) -> Cursor:
out : Cursor
The cursor.
"""
return Cursor(server_side_parameters=self.server_side_parameters)
return Cursor(
conn_method=self.conn_method,
conn_url=self.conn_url,
server_side_parameters=self.server_side_parameters,
)


class SessionConnectionWrapper(SparkConnectionWrapper):
Expand Down
9 changes: 9 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,14 @@ sqlparams>=3.0.0
thrift>=0.13.0
sqlparse>=0.4.2 # not directly required, pinned by Snyk to avoid a vulnerability

# spark-connect
pyspark==3.5.0
pandas>=1.05
pyarrow>=4.0.0
numpy>=1.15
grpcio>=1.46<1.57
grpcio-status>=1.46<1.57
googleapis-common-protos==1.56.4

types-PyYAML
types-python-dateutil
12 changes: 11 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,16 @@ def _get_dbt_core_version():
"thrift>=0.11.0,<0.17.0",
]
session_extras = ["pyspark>=3.0.0,<4.0.0"]
all_extras = odbc_extras + pyhive_extras + session_extras
connect_extras = [
"pyspark==3.5.0",
"pandas>=1.05",
"pyarrow>=4.0.0",
"numpy>=1.15",
"grpcio>=1.46,<1.57",
"grpcio-status>=1.46,<1.57",
"googleapis-common-protos==1.56.4",
]
all_extras = odbc_extras + pyhive_extras + session_extras + connect_extras

setup(
name=package_name,
Expand All @@ -80,6 +89,7 @@ def _get_dbt_core_version():
"ODBC": odbc_extras,
"PyHive": pyhive_extras,
"session": session_extras,
"connect": connect_extras,
"all": all_extras,
},
zip_safe=False,
Expand Down
12 changes: 7 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def dbt_profile_target(request):
target = databricks_http_cluster_target()
elif profile_type == "spark_session":
target = spark_session_target()
elif profile_type == "spark_connect":
target = spark_connect_target()
else:
raise ValueError(f"Invalid profile type '{profile_type}'")
return target
Expand Down Expand Up @@ -97,11 +99,11 @@ def databricks_http_cluster_target():


def spark_session_target():
return {
"type": "spark",
"host": "localhost",
"method": "session",
}
return {"type": "spark", "host": "localhost", "method": "session"}


def spark_connect_target():
return {"type": "spark", "host": "localhost", "port": 15002, "method": "connect"}


@pytest.fixture(autouse=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/dbt_clone/test_dbt_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestSparkBigqueryClonePossible(BaseClonePossible):
@pytest.fixture(scope="class")
def models(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
)


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestMergeExcludeColumns(BaseMergeExcludeColumns):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def project_config_update(self):
}


@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session")
@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_connect")
class TestInsertOverwriteOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand All @@ -45,7 +45,7 @@ def project_config_update(self):
}


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestDeltaOnSchemaChange(BaseIncrementalOnSchemaChangeSetup):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"""


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestIncrementalPredicatesMergeSpark(BaseIncrementalPredicates):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand All @@ -46,7 +46,7 @@ def models(self):
}


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestPredicatesMergeSpark(BaseIncrementalPredicates):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dbt.tests.adapter.incremental.test_incremental_unique_id import BaseIncrementalUniqueKey


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestUniqueKeySpark(BaseIncrementalUniqueKey):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def run_and_test(self, project):
check_relations_equal(project.adapter, ["merge_update_columns", "expected_partial_upsert"])

@pytest.mark.skip_profile(
"apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session"
"apache_spark",
"databricks_http_cluster",
"databricks_sql_endpoint",
"spark_session",
"spark_connect",
)
def test_delta_strategies(self, project):
self.run_and_test(project)
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/adapter/persist_docs/test_persist_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestPersistDocsDeltaTable:
@pytest.fixture(scope="class")
def models(self):
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_delta_comments(self, project):
assert result[2].startswith("Some stuff here and then a call to")


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestPersistDocsDeltaView:
@pytest.fixture(scope="class")
def models(self):
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_delta_comments(self, project):
assert result[2] is None


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestPersistDocsMissingColumn:
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Loading

0 comments on commit 38bfada

Please sign in to comment.