Skip to content

Commit

Permalink
Prune unused stats columns when reading Delta checkpoint
Browse files Browse the repository at this point in the history
Add support for stats projection in Delta checkpoint iterator
  • Loading branch information
findinpath authored and findepi committed Nov 27, 2023
1 parent 95e8126 commit f6f7646
Show file tree
Hide file tree
Showing 37 changed files with 463 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3546,7 +3546,13 @@ private OptionalLong executeDelete(ConnectorSession session, ConnectorTableHandl
private List<AddFileEntry> getAddFileEntriesMatchingEnforcedPartitionConstraint(ConnectorSession session, DeltaLakeTableHandle tableHandle)
{
TableSnapshot tableSnapshot = getSnapshot(session, tableHandle);
List<AddFileEntry> validDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, tableHandle.getMetadataEntry(), tableHandle.getProtocolEntry(), tableHandle.getEnforcedPartitionConstraint(), session);
List<AddFileEntry> validDataFiles = transactionLogAccess.getActiveFiles(
tableSnapshot,
tableHandle.getMetadataEntry(),
tableHandle.getProtocolEntry(),
tableHandle.getEnforcedPartitionConstraint(),
tableHandle.getProjectedColumns(),
session);
TupleDomain<DeltaLakeColumnHandle> enforcedPartitionConstraint = tableHandle.getEnforcedPartitionConstraint();
if (enforcedPartitionConstraint.isAll()) {
return validDataFiles;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,13 @@ private Stream<DeltaLakeSplit> getSplits(
{
TableSnapshot tableSnapshot = deltaLakeTransactionManager.get(transaction, session.getIdentity())
.getSnapshot(session, tableHandle.getSchemaTableName(), tableHandle.getLocation(), tableHandle.getReadVersion());
List<AddFileEntry> validDataFiles = transactionLogAccess.getActiveFiles(tableSnapshot, tableHandle.getMetadataEntry(), tableHandle.getProtocolEntry(), tableHandle.getEnforcedPartitionConstraint(), session);
List<AddFileEntry> validDataFiles = transactionLogAccess.getActiveFiles(
tableSnapshot,
tableHandle.getMetadataEntry(),
tableHandle.getProtocolEntry(),
tableHandle.getEnforcedPartitionConstraint(),
tableHandle.getProjectedColumns(),
session);
TupleDomain<DeltaLakeColumnHandle> enforcedPartitionConstraint = tableHandle.getEnforcedPartitionConstraint();
TupleDomain<DeltaLakeColumnHandle> nonPartitionConstraint = tableHandle.getNonPartitionConstraint();
Domain pathDomain = getPathDomain(nonPartitionConstraint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ public TableStatistics getTableStatistics(ConnectorSession session, DeltaLakeTab
.filter(column -> predicatedColumnNames.contains(column.getName()))
.collect(toImmutableList());

for (AddFileEntry addEntry : transactionLogAccess.getActiveFiles(tableSnapshot, tableHandle.getMetadataEntry(), tableHandle.getProtocolEntry(), session)) {
for (AddFileEntry addEntry : transactionLogAccess.getActiveFiles(
tableSnapshot,
tableHandle.getMetadataEntry(),
tableHandle.getProtocolEntry(),
tableHandle.getEnforcedPartitionConstraint(),
tableHandle.getProjectedColumns(),
session)) {
Optional<? extends DeltaLakeFileStatistics> fileStatistics = addEntry.getStats();
if (fileStatistics.isEmpty()) {
// Open source Delta Lake does not collect stats
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -181,7 +182,8 @@ public Stream<DeltaLakeTransactionLogEntry> getCheckpointTransactionLogEntries(
TrinoFileSystem fileSystem,
FileFormatDataSourceStats stats,
Optional<MetadataAndProtocolEntry> metadataAndProtocol,
TupleDomain<DeltaLakeColumnHandle> partitionConstraint)
TupleDomain<DeltaLakeColumnHandle> partitionConstraint,
Optional<Predicate<String>> addStatsMinMaxColumnFilter)
throws IOException
{
if (lastCheckpoint.isEmpty()) {
Expand Down Expand Up @@ -210,7 +212,8 @@ public Stream<DeltaLakeTransactionLogEntry> getCheckpointTransactionLogEntries(
stats,
checkpoint,
checkpointFile,
partitionConstraint)));
partitionConstraint,
addStatsMinMaxColumnFilter)));
}
return resultStream;
}
Expand All @@ -230,7 +233,8 @@ private Iterator<DeltaLakeTransactionLogEntry> getCheckpointTransactionLogEntrie
FileFormatDataSourceStats stats,
LastCheckpoint checkpoint,
TrinoInputFile checkpointFile,
TupleDomain<DeltaLakeColumnHandle> partitionConstraint)
TupleDomain<DeltaLakeColumnHandle> partitionConstraint,
Optional<Predicate<String>> addStatsMinMaxColumnFilter)
throws IOException
{
long fileSize;
Expand All @@ -253,7 +257,8 @@ private Iterator<DeltaLakeTransactionLogEntry> getCheckpointTransactionLogEntrie
parquetReaderOptions,
checkpointRowStatisticsWritingEnabled,
domainCompactionThreshold,
partitionConstraint);
partitionConstraint,
addStatsMinMaxColumnFilter);
}

public record MetadataAndProtocolEntry(MetadataEntry metadataEntry, ProtocolEntry protocolEntry)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,15 @@
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;

import static com.google.common.base.Predicates.alwaysFalse;
import static com.google.common.base.Predicates.alwaysTrue;
import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.airlift.slice.SizeOf.estimatedSizeOf;
import static io.airlift.slice.SizeOf.instanceSize;
import static io.trino.cache.CacheUtils.invalidateAllIf;
Expand Down Expand Up @@ -223,17 +227,48 @@ public MetadataEntry getMetadataEntry(TableSnapshot tableSnapshot, ConnectorSess
.orElseThrow(() -> new TrinoException(DELTA_LAKE_INVALID_SCHEMA, "Metadata not found in transaction log for " + tableSnapshot.getTable()));
}

// Deprecated in favor of the namesake method which allows checkpoint filtering
// to be able to perform partition pruning and stats projection on the `add` entries
// from the checkpoint.
/**
* @see #getActiveFiles(TableSnapshot, MetadataEntry, ProtocolEntry, TupleDomain, Optional, ConnectorSession)
*/
@Deprecated
public List<AddFileEntry> getActiveFiles(TableSnapshot tableSnapshot, MetadataEntry metadataEntry, ProtocolEntry protocolEntry, ConnectorSession session)
{
return getActiveFiles(tableSnapshot, metadataEntry, protocolEntry, TupleDomain.all(), session);
return retrieveActiveFiles(tableSnapshot, metadataEntry, protocolEntry, TupleDomain.all(), Optional.empty(), session);
}

public List<AddFileEntry> getActiveFiles(
TableSnapshot tableSnapshot,
MetadataEntry metadataEntry,
ProtocolEntry protocolEntry,
TupleDomain<DeltaLakeColumnHandle> partitionConstraint,
Optional<Set<DeltaLakeColumnHandle>> projectedColumns,
ConnectorSession session)
{
Optional<Predicate<String>> addStatsMinMaxColumnFilter = Optional.of(alwaysFalse());
if (projectedColumns.isPresent()) {
Set<String> baseColumnNames = projectedColumns.get().stream()
.filter(DeltaLakeColumnHandle::isBaseColumn) // Only base column stats are supported
.map(DeltaLakeColumnHandle::getColumnName)
.collect(toImmutableSet());
addStatsMinMaxColumnFilter = Optional.of(baseColumnNames::contains);
}
return retrieveActiveFiles(tableSnapshot, metadataEntry, protocolEntry, partitionConstraint, addStatsMinMaxColumnFilter, session);
}

public List<AddFileEntry> getActiveFiles(TableSnapshot tableSnapshot, MetadataEntry metadataEntry, ProtocolEntry protocolEntry, TupleDomain<DeltaLakeColumnHandle> partitionConstraint, ConnectorSession session)
private List<AddFileEntry> retrieveActiveFiles(
TableSnapshot tableSnapshot,
MetadataEntry metadataEntry,
ProtocolEntry protocolEntry,
TupleDomain<DeltaLakeColumnHandle> partitionConstraint,
Optional<Predicate<String>> addStatsMinMaxColumnFilter,
ConnectorSession session)
{
try {
if (isCheckpointFilteringEnabled(session)) {
return loadActiveFiles(tableSnapshot, metadataEntry, protocolEntry, partitionConstraint, session).stream()
return loadActiveFiles(tableSnapshot, metadataEntry, protocolEntry, partitionConstraint, addStatsMinMaxColumnFilter, session).stream()
.collect(toImmutableList());
}

Expand Down Expand Up @@ -264,7 +299,7 @@ public List<AddFileEntry> getActiveFiles(TableSnapshot tableSnapshot, MetadataEn
}
}

List<AddFileEntry> activeFiles = loadActiveFiles(tableSnapshot, metadataEntry, protocolEntry, TupleDomain.all(), session);
List<AddFileEntry> activeFiles = loadActiveFiles(tableSnapshot, metadataEntry, protocolEntry, TupleDomain.all(), Optional.of(alwaysTrue()), session);
return new DeltaLakeDataFileCacheEntry(tableSnapshot.getVersion(), activeFiles);
});
return cacheEntry.getActiveFiles();
Expand All @@ -279,6 +314,7 @@ private List<AddFileEntry> loadActiveFiles(
MetadataEntry metadataEntry,
ProtocolEntry protocolEntry,
TupleDomain<DeltaLakeColumnHandle> partitionConstraint,
Optional<Predicate<String>> addStatsMinMaxColumnFilter,
ConnectorSession session)
{
List<Transaction> transactions = tableSnapshot.getTransactions();
Expand All @@ -290,7 +326,8 @@ private List<AddFileEntry> loadActiveFiles(
fileSystemFactory.create(session),
fileFormatDataSourceStats,
Optional.of(new MetadataAndProtocolEntry(metadataEntry, protocolEntry)),
partitionConstraint)) {
partitionConstraint,
addStatsMinMaxColumnFilter)) {
return activeAddEntries(checkpointEntries, transactions)
.filter(partitionConstraint.isAll()
? addAction -> true
Expand Down Expand Up @@ -433,7 +470,7 @@ private <T> Stream<T> getEntries(
List<Transaction> transactions = tableSnapshot.getTransactions();
// Passing TupleDomain.all() because this method is used for getting all entries
Stream<DeltaLakeTransactionLogEntry> checkpointEntries = tableSnapshot.getCheckpointTransactionLogEntries(
session, entryTypes, checkpointSchemaManager, typeManager, fileSystem, stats, Optional.empty(), TupleDomain.all());
session, entryTypes, checkpointSchemaManager, typeManager, fileSystem, stats, Optional.empty(), TupleDomain.all(), Optional.of(alwaysTrue()));

return entryMapper.apply(
checkpointEntries,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import java.util.OptionalLong;
import java.util.Queue;
import java.util.Set;
import java.util.function.Predicate;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -159,14 +160,16 @@ public CheckpointEntryIterator(
ParquetReaderOptions parquetReaderOptions,
boolean checkpointRowStatisticsWritingEnabled,
int domainCompactionThreshold,
TupleDomain<DeltaLakeColumnHandle> partitionConstraint)
TupleDomain<DeltaLakeColumnHandle> partitionConstraint,
Optional<Predicate<String>> addStatsMinMaxColumnFilter)
{
this.checkpointPath = checkpoint.location().toString();
this.session = requireNonNull(session, "session is null");
this.stringList = (ArrayType) typeManager.getType(TypeSignature.arrayType(VARCHAR.getTypeSignature()));
this.stringMap = (MapType) typeManager.getType(TypeSignature.mapType(VARCHAR.getTypeSignature(), VARCHAR.getTypeSignature()));
this.checkpointRowStatisticsWritingEnabled = checkpointRowStatisticsWritingEnabled;
this.partitionConstraint = requireNonNull(partitionConstraint, "partitionConstraint is null");
requireNonNull(addStatsMinMaxColumnFilter, "addStatsMinMaxColumnFilter is null");
checkArgument(!fields.isEmpty(), "fields is empty");
Map<EntryType, CheckPointFieldExtractor> extractors = ImmutableMap.<EntryType, CheckPointFieldExtractor>builder()
.put(TRANSACTION, this::buildTxnEntry)
Expand All @@ -182,14 +185,19 @@ public CheckpointEntryIterator(
this.metadataEntry = metadataEntry.get();
checkArgument(protocolEntry.isPresent(), "Protocol entry must be provided when reading ADD entries from Checkpoint files");
this.protocolEntry = protocolEntry.get();
checkArgument(addStatsMinMaxColumnFilter.isPresent(), "addStatsMinMaxColumnFilter must be provided when reading ADD entries from Checkpoint files");
this.schema = extractSchema(this.metadataEntry, this.protocolEntry, typeManager);
this.columnsWithMinMaxStats = columnsWithStats(schema, this.metadataEntry.getOriginalPartitionColumns());
Predicate<String> columnStatsFilterFunction = addStatsMinMaxColumnFilter.orElseThrow();
this.columnsWithMinMaxStats = columnsWithMinMaxStats.stream()
.filter(column -> columnStatsFilterFunction.test(column.getName()))
.collect(toImmutableList());
}

ImmutableList.Builder<HiveColumnHandle> columnsBuilder = ImmutableList.builderWithExpectedSize(fields.size());
ImmutableList.Builder<TupleDomain<HiveColumnHandle>> disjunctDomainsBuilder = ImmutableList.builderWithExpectedSize(fields.size());
for (EntryType field : fields) {
HiveColumnHandle column = buildColumnHandle(field, checkpointSchemaManager, this.metadataEntry, this.protocolEntry).toHiveColumnHandle();
HiveColumnHandle column = buildColumnHandle(field, checkpointSchemaManager, this.metadataEntry, this.protocolEntry, addStatsMinMaxColumnFilter).toHiveColumnHandle();
columnsBuilder.add(column);
disjunctDomainsBuilder.add(buildTupleDomainColumnHandle(field, column));
}
Expand Down Expand Up @@ -220,11 +228,16 @@ public CheckpointEntryIterator(
.collect(toImmutableList());
}

private DeltaLakeColumnHandle buildColumnHandle(EntryType entryType, CheckpointSchemaManager schemaManager, MetadataEntry metadataEntry, ProtocolEntry protocolEntry)
private DeltaLakeColumnHandle buildColumnHandle(
EntryType entryType,
CheckpointSchemaManager schemaManager,
MetadataEntry metadataEntry,
ProtocolEntry protocolEntry,
Optional<Predicate<String>> addStatsMinMaxColumnFilter)
{
Type type = switch (entryType) {
case TRANSACTION -> schemaManager.getTxnEntryType();
case ADD -> schemaManager.getAddEntryType(metadataEntry, protocolEntry, true, true, true);
case ADD -> schemaManager.getAddEntryType(metadataEntry, protocolEntry, addStatsMinMaxColumnFilter.orElseThrow(), true, true, true);
case REMOVE -> schemaManager.getRemoveEntryType();
case METADATA -> schemaManager.getMetadataEntryType();
case PROTOCOL -> schemaManager.getProtocolEntryType(true, true);
Expand Down Expand Up @@ -696,6 +709,12 @@ OptionalLong getCompletedPositions()
return pageSource.getCompletedPositions();
}

@VisibleForTesting
long getCompletedBytes()
{
return pageSource.getCompletedBytes();
}

@FunctionalInterface
public interface CheckPointFieldExtractor
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.extractPartitionColumns;
Expand Down Expand Up @@ -114,10 +115,19 @@ public RowType getMetadataEntryType()
return metadataEntryType;
}

public RowType getAddEntryType(MetadataEntry metadataEntry, ProtocolEntry protocolEntry, boolean requireWriteStatsAsJson, boolean requireWriteStatsAsStruct, boolean usePartitionValuesParsed)
public RowType getAddEntryType(
MetadataEntry metadataEntry,
ProtocolEntry protocolEntry,
Predicate<String> addStatsMinMaxColumnFilter,
boolean requireWriteStatsAsJson,
boolean requireWriteStatsAsStruct,
boolean usePartitionValuesParsed)
{
List<DeltaLakeColumnMetadata> allColumns = extractSchema(metadataEntry, protocolEntry, typeManager);
List<DeltaLakeColumnMetadata> minMaxColumns = columnsWithStats(metadataEntry, protocolEntry, typeManager);
minMaxColumns = minMaxColumns.stream()
.filter(column -> addStatsMinMaxColumnFilter.test(column.getName()))
.collect(toImmutableList());
boolean deletionVectorEnabled = isDeletionVectorEnabled(metadataEntry, protocolEntry);

ImmutableList.Builder<RowType.Field> minMaxFields = ImmutableList.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.alwaysTrue;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.airlift.slice.Slices.utf8Slice;
Expand Down Expand Up @@ -112,7 +113,13 @@ public void write(CheckpointEntries entries, TrinoOutputFile outputFile)
RowType protocolEntryType = checkpointSchemaManager.getProtocolEntryType(protocolEntry.getReaderFeatures().isPresent(), protocolEntry.getWriterFeatures().isPresent());
RowType txnEntryType = checkpointSchemaManager.getTxnEntryType();
// TODO https://github.com/trinodb/trino/issues/19586 Add support for writing 'partitionValues_parsed' field
RowType addEntryType = checkpointSchemaManager.getAddEntryType(entries.getMetadataEntry(), entries.getProtocolEntry(), writeStatsAsJson, writeStatsAsStruct, false);
RowType addEntryType = checkpointSchemaManager.getAddEntryType(
entries.getMetadataEntry(),
entries.getProtocolEntry(),
alwaysTrue(),
writeStatsAsJson,
writeStatsAsStruct,
false);
RowType removeEntryType = checkpointSchemaManager.getRemoveEntryType();

List<String> columnNames = ImmutableList.of(
Expand Down
Loading

0 comments on commit f6f7646

Please sign in to comment.