Skip to content

Commit

Permalink
Spark 3.5: Support Reporting Column Stats (#10659)
Browse files Browse the repository at this point in the history
Co-authored-by: Karuppayya Rajendran <[email protected]>
  • Loading branch information
huaxingao and karuppayya authored Jul 31, 2024
1 parent 76dba8f commit 506fee4
Show file tree
Hide file tree
Showing 7 changed files with 346 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -347,4 +347,12 @@ private boolean executorCacheLocalityEnabledInternal() {
.defaultValue(SparkSQLProperties.EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT)
.parse();
}

public boolean reportColumnStats() {
return confParser
.booleanConf()
.sessionConf(SparkSQLProperties.REPORT_COLUMN_STATS)
.defaultValue(SparkSQLProperties.REPORT_COLUMN_STATS_DEFAULT)
.parse();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,8 @@ private SparkSQLProperties() {}
public static final String EXECUTOR_CACHE_LOCALITY_ENABLED =
"spark.sql.iceberg.executor-cache.locality.enabled";
public static final boolean EXECUTOR_CACHE_LOCALITY_ENABLED_DEFAULT = false;

// Controls whether to report available column statistics to Spark for query optimization.
public static final String REPORT_COLUMN_STATS = "spark.sql.iceberg.report-column-stats";
public static final boolean REPORT_COLUMN_STATS_DEFAULT = true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class SparkChangelogScan implements Scan, SupportsReportStatistics {
public Statistics estimateStatistics() {
long rowsCount = taskGroups().stream().mapToLong(ScanTaskGroup::estimatedRowsCount).sum();
long sizeInBytes = SparkSchemaUtil.estimateSize(readSchema(), rowsCount);
return new Stats(sizeInBytes, rowsCount);
return new Stats(sizeInBytes, rowsCount, Collections.emptyMap());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark.source;

import java.util.Optional;
import java.util.OptionalLong;
import org.apache.spark.sql.connector.read.colstats.ColumnStatistics;
import org.apache.spark.sql.connector.read.colstats.Histogram;

class SparkColumnStatistics implements ColumnStatistics {

private final OptionalLong distinctCount;
private final Optional<Object> min;
private final Optional<Object> max;
private final OptionalLong nullCount;
private final OptionalLong avgLen;
private final OptionalLong maxLen;
private final Optional<Histogram> histogram;

SparkColumnStatistics(
Long distinctCount,
Object min,
Object max,
Long nullCount,
Long avgLen,
Long maxLen,
Histogram histogram) {
this.distinctCount =
(distinctCount == null) ? OptionalLong.empty() : OptionalLong.of(distinctCount);
this.min = Optional.ofNullable(min);
this.max = Optional.ofNullable(max);
this.nullCount = (nullCount == null) ? OptionalLong.empty() : OptionalLong.of(nullCount);
this.avgLen = (avgLen == null) ? OptionalLong.empty() : OptionalLong.of(avgLen);
this.maxLen = (maxLen == null) ? OptionalLong.empty() : OptionalLong.of(maxLen);
this.histogram = Optional.ofNullable(histogram);
}

@Override
public OptionalLong distinctCount() {
return distinctCount;
}

@Override
public Optional<Object> min() {
return min;
}

@Override
public Optional<Object> max() {
return max;
}

@Override
public OptionalLong nullCount() {
return nullCount;
}

@Override
public OptionalLong avgLen() {
return avgLen;
}

@Override
public OptionalLong maxLen() {
return maxLen;
}

@Override
public Optional<Histogram> histogram() {
return histogram;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.iceberg.BlobMetadata;
import org.apache.iceberg.ScanTask;
import org.apache.iceberg.ScanTaskGroup;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.StatisticsFile;
import org.apache.iceberg.Table;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.metrics.ScanReport;
import org.apache.iceberg.relocated.com.google.common.base.Strings;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.Spark3Util;
import org.apache.iceberg.spark.SparkReadConf;
import org.apache.iceberg.spark.SparkSchemaUtil;
Expand Down Expand Up @@ -75,22 +79,28 @@
import org.apache.iceberg.util.TableScanUtil;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.FieldReference;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.metric.CustomMetric;
import org.apache.spark.sql.connector.metric.CustomTaskMetric;
import org.apache.spark.sql.connector.read.Batch;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.connector.read.SupportsReportStatistics;
import org.apache.spark.sql.connector.read.colstats.ColumnStatistics;
import org.apache.spark.sql.connector.read.streaming.MicroBatchStream;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

abstract class SparkScan implements Scan, SupportsReportStatistics {
private static final Logger LOG = LoggerFactory.getLogger(SparkScan.class);
private static final String NDV_KEY = "ndv";

private final JavaSparkContext sparkContext;
private final Table table;
private final SparkSession spark;
private final SparkReadConf readConf;
private final boolean caseSensitive;
private final Schema expectedSchema;
Expand All @@ -111,6 +121,7 @@ abstract class SparkScan implements Scan, SupportsReportStatistics {
Schema snapshotSchema = SnapshotUtil.schemaFor(table, readConf.branch());
SparkSchemaUtil.validateMetadataColumnReferences(snapshotSchema, expectedSchema);

this.spark = spark;
this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext());
this.table = table;
this.readConf = readConf;
Expand Down Expand Up @@ -175,7 +186,43 @@ public Statistics estimateStatistics() {
protected Statistics estimateStatistics(Snapshot snapshot) {
// its a fresh table, no data
if (snapshot == null) {
return new Stats(0L, 0L);
return new Stats(0L, 0L, Collections.emptyMap());
}

boolean cboEnabled =
Boolean.parseBoolean(spark.conf().get(SQLConf.CBO_ENABLED().key(), "false"));
Map<NamedReference, ColumnStatistics> colStatsMap = Collections.emptyMap();
if (readConf.reportColumnStats() && cboEnabled) {
colStatsMap = Maps.newHashMap();
List<StatisticsFile> files = table.statisticsFiles();
if (!files.isEmpty()) {
List<BlobMetadata> metadataList = (files.get(0)).blobMetadata();

for (BlobMetadata blobMetadata : metadataList) {
int id = blobMetadata.fields().get(0);
String colName = table.schema().findColumnName(id);
NamedReference ref = FieldReference.column(colName);

Long ndv = null;
if (blobMetadata
.type()
.equals(org.apache.iceberg.puffin.StandardBlobTypes.APACHE_DATASKETCHES_THETA_V1)) {
String ndvStr = blobMetadata.properties().get(NDV_KEY);
if (!Strings.isNullOrEmpty(ndvStr)) {
ndv = Long.parseLong(ndvStr);
} else {
LOG.debug("ndv is not set in BlobMetadata for column {}", colName);
}
} else {
LOG.debug("DataSketch blob is not available for column {}", colName);
}

ColumnStatistics colStats =
new SparkColumnStatistics(ndv, null, null, null, null, null, null);

colStatsMap.put(ref, colStats);
}
}
}

// estimate stats using snapshot summary only for partitioned tables
Expand All @@ -186,12 +233,13 @@ protected Statistics estimateStatistics(Snapshot snapshot) {
snapshot.snapshotId(),
table.name());
long totalRecords = totalRecords(snapshot);
return new Stats(SparkSchemaUtil.estimateSize(readSchema(), totalRecords), totalRecords);
return new Stats(
SparkSchemaUtil.estimateSize(readSchema(), totalRecords), totalRecords, colStatsMap);
}

long rowsCount = taskGroups().stream().mapToLong(ScanTaskGroup::estimatedRowsCount).sum();
long sizeInBytes = SparkSchemaUtil.estimateSize(readSchema(), rowsCount);
return new Stats(sizeInBytes, rowsCount);
return new Stats(sizeInBytes, rowsCount, colStatsMap);
}

private long totalRecords(Snapshot snapshot) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@
*/
package org.apache.iceberg.spark.source;

import java.util.Map;
import java.util.OptionalLong;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.connector.read.colstats.ColumnStatistics;

class Stats implements Statistics {
private final OptionalLong sizeInBytes;
private final OptionalLong numRows;
private final Map<NamedReference, ColumnStatistics> colstats;

Stats(long sizeInBytes, long numRows) {
Stats(long sizeInBytes, long numRows, Map<NamedReference, ColumnStatistics> colstats) {
this.sizeInBytes = OptionalLong.of(sizeInBytes);
this.numRows = OptionalLong.of(numRows);
this.colstats = colstats;
}

@Override
Expand All @@ -39,4 +44,9 @@ public OptionalLong sizeInBytes() {
public OptionalLong numRows() {
return numRows;
}

@Override
public Map<NamedReference, ColumnStatistics> columnStats() {
return colstats;
}
}
Loading

0 comments on commit 506fee4

Please sign in to comment.