Skip to content

Commit

Permalink
[FEAT] consolidate Spark session fixture into conftest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 20, 2024
1 parent b89ee3d commit 98a1b86
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 62 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,7 @@ def sql_expr(sql: str) -> PyExpr: ...
def list_sql_functions() -> list[SQLFunctionStub]: ...
def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ...
def to_struct(inputs: list[PyExpr]) -> PyExpr: ...
def connect_start(addr: str) -> ConnectionHandle: ...
def connect_start(addr: str) -> tuple[ConnectionHandle, int]: ...

class ConnectionHandle:
def shutdown(self) -> None: ...
Expand Down
1 change: 1 addition & 0 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[dependencies]
arrow2 = {workspace = true}
async-stream = "0.3.6"
common-daft-config = {workspace = true}
daft-local-execution = {workspace = true}
daft-local-plan = {workspace = true}
Expand Down
49 changes: 33 additions & 16 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ impl ConnectionHandle {
}
}

pub fn start(addr: &str) -> eyre::Result<ConnectionHandle> {
pub fn start(addr: &str) -> eyre::Result<(ConnectionHandle, u16)> {
info!("Daft-Connect server listening on {addr}");
let addr = util::parse_spark_connect_address(addr)?;

let listener = std::net::TcpListener::bind(addr)?;
let port = listener.local_addr()?.port();

let service = DaftSparkConnectService::default();

info!("Daft-Connect server listening on {addr}");
Expand All @@ -65,21 +68,35 @@ pub fn start(addr: &str) -> eyre::Result<ConnectionHandle> {

std::thread::spawn(move || {
let runtime = tokio::runtime::Runtime::new().unwrap();
let result = runtime
.block_on(async {
tokio::select! {
result = Server::builder()
.add_service(SparkConnectServiceServer::new(service))
.serve(addr) => {
result
}
_ = shutdown_receiver => {
info!("Received shutdown signal");
Ok(())
let result = runtime.block_on(async {
let incoming = {
let listener = tokio::net::TcpListener::from_std(listener)
.wrap_err("Failed to create TcpListener from std::net::TcpListener")?;

async_stream::stream! {
loop {
match listener.accept().await {
Ok((stream, _)) => yield Ok(stream),
Err(e) => yield Err(e),
}
}
}
})
.wrap_err_with(|| format!("Failed to start server on {addr}"));
};

let result = tokio::select! {
result = Server::builder()
.add_service(SparkConnectServiceServer::new(service))
.serve_with_incoming(incoming)=> {
result

Check warning on line 90 in src/daft-connect/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/lib.rs#L90

Added line #L90 was not covered by tests
}
_ = shutdown_receiver => {
info!("Received shutdown signal");
Ok(())

Check warning on line 94 in src/daft-connect/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/lib.rs#L93-L94

Added lines #L93 - L94 were not covered by tests
}
};

result.wrap_err_with(|| format!("Failed to start server on {addr}"))

Check warning on line 98 in src/daft-connect/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/lib.rs#L98

Added line #L98 was not covered by tests
});

if let Err(e) = result {
eprintln!("Daft-Connect server error: {e:?}");
Expand All @@ -88,7 +105,7 @@ pub fn start(addr: &str) -> eyre::Result<ConnectionHandle> {
eyre::Result::<_>::Ok(())
});

Ok(handle)
Ok((handle, port))
}

#[derive(Default)]
Expand Down Expand Up @@ -364,7 +381,7 @@ impl SparkConnectService for DaftSparkConnectService {
#[cfg(feature = "python")]
#[pyo3::pyfunction]
#[pyo3(name = "connect_start")]
pub fn py_connect_start(addr: &str) -> pyo3::PyResult<ConnectionHandle> {
pub fn py_connect_start(addr: &str) -> pyo3::PyResult<(ConnectionHandle, u16)> {
start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}")))
}

Expand Down
Empty file removed tests/connect/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions tests/connect/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

import pytest
from pyspark.sql import SparkSession


@pytest.fixture(scope="session")
def spark_session():
"""
Fixture to create and clean up a Spark session.
This fixture is available to all test files and creates a single
Spark session for the entire test suite run.
"""
from daft.daft import connect_start

# Start Daft Connect server
(server, port) = connect_start("sc://localhost:0")

url = f"sc://localhost:{port}"

# Initialize Spark Connect session
session = SparkSession.builder.appName("DaftConfigTest").remote(url).getOrCreate()

yield session

# Cleanup
server.shutdown()
session.stop()
24 changes: 0 additions & 24 deletions tests/connect/test_config_simple.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,5 @@
from __future__ import annotations

import time

import pytest
from pyspark.sql import SparkSession


@pytest.fixture
def spark_session():
"""Fixture to create and clean up a Spark session."""
from daft.daft import connect_start

# Start Daft Connect server
server = connect_start("sc://localhost:50051")

# Initialize Spark Connect session
session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate()

yield session

# Cleanup
server.shutdown()
session.stop()
time.sleep(2) # Allow time for session cleanup


def test_set_operation(spark_session):
"""Test the Set operation with various data types and edge cases."""
Expand Down
21 changes: 0 additions & 21 deletions tests/connect/test_range_simple.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,5 @@
from __future__ import annotations

import pytest
from pyspark.sql import SparkSession


@pytest.fixture
def spark_session():
"""Fixture to create and clean up a Spark session."""
from daft.daft import connect_start

# Start Daft Connect server
server = connect_start("sc://localhost:50051")

# Initialize Spark Connect session
session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate()

yield session

# Cleanup
server.shutdown()
session.stop()


def test_range_operation(spark_session):
# Create a range using Spark
Expand Down

0 comments on commit 98a1b86

Please sign in to comment.