From bd703654ce0ea977821971f90ca1f764192fa1fc Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 01:48:00 -0800 Subject: [PATCH] [FEAT]: connect: `df.where` --- tests/connect/test_where.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 tests/connect/test_where.py diff --git a/tests/connect/test_where.py b/tests/connect/test_where.py new file mode 100644 index 0000000000..b3b2855a4e --- /dev/null +++ b/tests/connect/test_where.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_where(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Filter the DataFrame where 'id' is greater than 5 + df_filtered = df.where(col("id") > 5) + + # Verify the filter was applied correctly by checking the expected data + df_filtered_pandas = df_filtered.toPandas() + expected_data = [6, 7, 8, 9] + assert df_filtered_pandas["id"].tolist() == expected_data, "Filtered data does not match expected data"