diff --git a/python/pyspark/errors/utils.py b/python/pyspark/errors/utils.py index cbe5739204ac1..416a2323b170e 100644 --- a/python/pyspark/errors/utils.py +++ b/python/pyspark/errors/utils.py @@ -31,21 +31,42 @@ Type, Optional, Union, - TYPE_CHECKING, overload, cast, ) import pyspark from pyspark.errors.error_classes import ERROR_CLASSES_MAP -if TYPE_CHECKING: - from pyspark.sql import SparkSession - T = TypeVar("T") FuncT = TypeVar("FuncT", bound=Callable[..., Any]) _current_origin = threading.local() +# Providing DataFrame debugging options to reduce performance slowdown. +# Default is True. +_enable_debugging_cache = None + + +def is_debugging_enabled() -> bool: + global _enable_debugging_cache + + if _enable_debugging_cache is None: + from pyspark.sql import SparkSession + + spark = SparkSession.getActiveSession() + if spark is not None: + _enable_debugging_cache = ( + spark.conf.get( + "spark.python.sql.dataFrameDebugging.enabled", + "true", # type: ignore[union-attr] + ).lower() + == "true" + ) + else: + _enable_debugging_cache = False + + return _enable_debugging_cache + def current_origin() -> threading.local: global _current_origin @@ -164,17 +185,12 @@ def get_message_template(self, errorClass: str) -> str: return message_template -def _capture_call_site(spark_session: "SparkSession", depth: int) -> str: +def _capture_call_site(depth: int) -> str: """ Capture the call site information including file name, line number, and function name. This function updates the thread-local storage from JVM side (PySparkCurrentOrigin) with the current call site information when a PySpark API function is called. - Parameters - ---------- - spark_session : SparkSession - Current active Spark session. - Notes ----- The call site information is used to enhance error messages with the exact location @@ -245,7 +261,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: # Getting the configuration requires RPC call. Uses the default value for now. depth = 1 - set_current_origin(func.__name__, _capture_call_site(spark, depth)) + set_current_origin(func.__name__, _capture_call_site(depth)) try: return func(*args, **kwargs) @@ -262,7 +278,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: ) ) # Update call site when the function is called - jvm_pyspark_origin.set(func.__name__, _capture_call_site(spark, depth)) + jvm_pyspark_origin.set(func.__name__, _capture_call_site(depth)) try: return func(*args, **kwargs) @@ -297,7 +313,10 @@ def with_origin_to_class( return lambda cls: with_origin_to_class(cls, ignores) else: cls = cls_or_ignores - if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": + if ( + os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true" + and is_debugging_enabled() + ): skipping = set( ["__init__", "__new__", "__iter__", "__nonzero__", "__repr__", "__bool__"] + (ignores or []) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index cd17a63e5d433..407baba8280c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -295,4 +295,14 @@ object StaticSQLConf { .version("3.1.0") .stringConf .createWithDefault("") + + val DATA_FRAME_DEBUGGING_ENABLED = + buildStaticConf("spark.python.sql.dataFrameDebugging.enabled") + .internal() + .doc( + "Enable the DataFrame debugging. This feature is enabled by default, but has a " + + "non-trivial performance overhead because of the stack trace collection.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) }