diff --git a/tests/connect/test_collect.py b/tests/connect/test_collect.py new file mode 100644 index 0000000000..516fa621ad --- /dev/null +++ b/tests/connect/test_collect.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import time + +import pytest +from pyspark.sql import SparkSession + + +@pytest.fixture +def spark_session(): + """Fixture to create and clean up a Spark session.""" + from daft.daft import connect_start + + # Start Daft Connect server + server = connect_start("sc://localhost:50051") + + # Initialize Spark Connect session + session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate() + + yield session + + # Cleanup + server.shutdown() + session.stop() + time.sleep(2) # Allow time for session cleanup + + +def test_range_collect(spark_session): + # Create a range using Spark + # For example, creating a range from 0 to 9 + spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9 + + # Collect the data + collected_rows = spark_range.collect() + + # Verify the collected data has expected values + assert len(collected_rows) == 10, "Should have 10 rows" + assert [row["id"] for row in collected_rows] == list(range(10)), "Should contain values 0-9"