Skip to content

Commit

Permalink
Simplify planning of MERGE
Browse files Browse the repository at this point in the history
Currently, the planner inserts a projection of the following shape
to assemble the merged row:

    merge_row := (
        CASE
            WHEN ... THEN
                ROW(..., $not((present IS NULL)), <operation>, 0)
            WHEN ... THEN
                ROW(..., $not((present IS NULL)), <operation>, 1)
            ...
            ELSE
                ROW(<nulls>, $not((present IS NULL)), -1, -1)
        END)

This change replaces the ELSE branch to return a single null instead of a synthetic
value with nulls. By reducing the size of the projection, it allows for wider
tables to be used with MERGE.
  • Loading branch information
martint committed Oct 10, 2024
1 parent 54827dc commit c0d70e6
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,6 @@ public Page transformPage(Page inputPage)
checkArgument(positionCount > 0, "positionCount should be > 0, but is %s", positionCount);

Block mergeRow = inputPage.getBlock(mergeRowChannel).getLoadedBlock();
if (mergeRow.mayHaveNull()) {
for (int position = 0; position < positionCount; position++) {
checkArgument(!mergeRow.isNull(position), "The mergeRow may not have null rows");
}
}

List<Block> fields = getRowFieldsFromBlock(mergeRow);
List<Block> builder = new ArrayList<>(dataColumnChannels.size() + 3);
for (int channel : dataColumnChannels) {
Expand All @@ -86,7 +80,7 @@ public Page transformPage(Page inputPage)

int defaultCaseCount = 0;
for (int position = 0; position < positionCount; position++) {
if (TINYINT.getByte(operationChannelBlock, position) == DEFAULT_CASE_OPERATION_NUMBER) {
if (mergeRow.isNull(position)) {
defaultCaseCount++;
}
}
Expand All @@ -97,7 +91,7 @@ public Page transformPage(Page inputPage)
int usedCases = 0;
int[] positions = new int[positionCount - defaultCaseCount];
for (int position = 0; position < positionCount; position++) {
if (TINYINT.getByte(operationChannelBlock, position) != DEFAULT_CASE_OPERATION_NUMBER) {
if (!mergeRow.isNull(position)) {
positions[usedCases] = position;
usedCases++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,24 @@ public Page transformPage(Page inputPage)
int originalPositionCount = inputPage.getPositionCount();
checkArgument(originalPositionCount > 0, "originalPositionCount should be > 0, but is %s", originalPositionCount);

List<Block> fields = getRowFieldsFromBlock(inputPage.getBlock(mergeRowChannel));
Block mergeRow = inputPage.getBlock(mergeRowChannel);
List<Block> fields = getRowFieldsFromBlock(mergeRow);
Block operationChannelBlock = fields.get(fields.size() - 2);

int updatePositions = 0;
int insertPositions = 0;
int deletePositions = 0;
for (int position = 0; position < originalPositionCount; position++) {
byte operation = TINYINT.getByte(operationChannelBlock, position);
switch (operation) {
case DEFAULT_CASE_OPERATION_NUMBER -> { /* ignored */ }
case INSERT_OPERATION_NUMBER -> insertPositions++;
case DELETE_OPERATION_NUMBER -> deletePositions++;
case UPDATE_OPERATION_NUMBER -> updatePositions++;
// This class will create such rows, they are not expected on input
case UPDATE_INSERT_OPERATION_NUMBER, UPDATE_DELETE_OPERATION_NUMBER -> throw new IllegalArgumentException("Unexpected operator number: " + operation);
default -> throw new IllegalArgumentException("Unknown operator number: " + operation);
if (!mergeRow.isNull(position)) {
byte operation = TINYINT.getByte(operationChannelBlock, position);
switch (operation) {
case INSERT_OPERATION_NUMBER -> insertPositions++;
case DELETE_OPERATION_NUMBER -> deletePositions++;
case UPDATE_OPERATION_NUMBER -> updatePositions++;
// This class will create such rows, they are not expected on input
case UPDATE_INSERT_OPERATION_NUMBER, UPDATE_DELETE_OPERATION_NUMBER -> throw new IllegalArgumentException("Unexpected operator number: " + operation);
default -> throw new IllegalArgumentException("Unknown operator number: " + operation);
}
}
}

Expand All @@ -128,8 +130,8 @@ public Page transformPage(Page inputPage)

PageBuilder pageBuilder = new PageBuilder(totalPositions, pageTypes);
for (int position = 0; position < originalPositionCount; position++) {
byte operation = TINYINT.getByte(operationChannelBlock, position);
if (operation != DEFAULT_CASE_OPERATION_NUMBER) {
if (!mergeRow.isNull(position)) {
byte operation = TINYINT.getByte(operationChannelBlock, position);
// Delete and Update because both create a delete row
if (operation == DELETE_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) {
addDeleteRow(pageBuilder, inputPage, position, operation != DELETE_OPERATION_NUMBER);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

public interface MergeRowChangeProcessor
{
int DEFAULT_CASE_OPERATION_NUMBER = -1;

/**
* Transform a page generated by an SQL MERGE operation into page of data columns and
* operations. The SQL MERGE input page consists of the following:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.trino.spi.connector.SortOrder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Int128;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.sql.NodeUtils;
import io.trino.sql.PlannerContext;
Expand Down Expand Up @@ -861,17 +862,16 @@ public MergeWriterNode plan(Merge merge)
}
}

// Build the "else" clause for the SearchedCaseExpression
ImmutableList.Builder<Expression> rowBuilder = ImmutableList.builder();
dataColumnSchemas.forEach(columnSchema ->
rowBuilder.add(new Constant(columnSchema.getType(), null)));
rowBuilder.add(not(metadata, new IsNull(presentColumn.toSymbolReference())));
// The operation number
rowBuilder.add(new Constant(TINYINT, -1L));
// The case number
rowBuilder.add(new Constant(INTEGER, -1L));

Case caseExpression = new Case(whenClauses.build(), new Row(rowBuilder.build()));
Case caseExpression = new Case(
whenClauses.build(),
new Constant(
RowType.anonymous(ImmutableList.<Type>builder()
.addAll(dataColumnSchemas.stream().map(ColumnSchema::getType).collect(toImmutableList()))
.add(BOOLEAN)
.add(TINYINT)
.add(INTEGER)
.build()),
null));

Symbol mergeRowSymbol = symbolAllocator.newSymbol("merge_row", mergeAnalysis.getMergeRowType());
Symbol caseNumberSymbol = symbolAllocator.newSymbol("case_number", INTEGER);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.util.List;
import java.util.Optional;

import static io.trino.operator.MergeRowChangeProcessor.DEFAULT_CASE_OPERATION_NUMBER;
import static io.trino.spi.connector.ConnectorMergeSink.DELETE_OPERATION_NUMBER;
import static io.trino.spi.connector.ConnectorMergeSink.INSERT_OPERATION_NUMBER;
import static io.trino.spi.connector.ConnectorMergeSink.UPDATE_OPERATION_NUMBER;
Expand All @@ -58,20 +57,21 @@ public void testSimpleDeletedRowMerge()
// THEN DELETE
// expected: ('Dave', 11, 'Darbyshire')
DeleteAndInsertMergeProcessor processor = makeMergeProcessor();
Page inputPage = makePageFromBlocks(
2,
Optional.empty(),
new Block[] {
makeLongArrayBlock(1, 1), // TransactionId
makeLongArrayBlock(1, 0), // rowId
makeIntArrayBlock(536870912, 536870912)}, // bucket
new Block[] {
makeVarcharArrayBlock("", "Dave"), // customer
makeIntArrayBlock(0, 11), // purchases
makeVarcharArrayBlock("", "Devon"), // address
makeByteArrayBlock(1, 1), // "present" boolean
makeByteArrayBlock(DEFAULT_CASE_OPERATION_NUMBER, DELETE_OPERATION_NUMBER),
makeIntArrayBlock(-1, 0)});
Block[] rowIdBlocks = new Block[] {
makeLongArrayBlock(1, 1), // TransactionId
makeLongArrayBlock(1, 0), // rowId
makeIntArrayBlock(536870912, 536870912)}; // bucket
Block[] mergeCaseBlocks = new Block[] {
makeVarcharArrayBlock(null, "Dave"), // customer
new IntArrayBlock(2, Optional.of(new boolean[] {true, false}), new int[] {0, 11}), // purchases
makeVarcharArrayBlock(null, "Devon"), // address
new ByteArrayBlock(2, Optional.of(new boolean[] {true, false}), new byte[] {0, 1}), // "present" boolean
new ByteArrayBlock(2, Optional.of(new boolean[] {true, false}), new byte[] {0, DELETE_OPERATION_NUMBER}), // "present" boolean
new IntArrayBlock(2, Optional.of(new boolean[] {true, false}), new int[] {0, 0})
};
Page inputPage = new Page(
RowBlock.fromNotNullSuppressedFieldBlocks(2, Optional.empty(), rowIdBlocks),
RowBlock.fromNotNullSuppressedFieldBlocks(2, Optional.of(new boolean[] {true, false}), mergeCaseBlocks));

Page outputPage = processor.transformPage(inputPage);
assertThat(outputPage.getPositionCount()).isEqualTo(1);
Expand Down Expand Up @@ -215,7 +215,12 @@ private Block makeVarcharArrayBlock(String... elements)
{
BlockBuilder builder = VARCHAR.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), elements.length);
for (String element : elements) {
VARCHAR.writeSlice(builder, Slices.utf8Slice(element));
if (element == null) {
builder.appendNull();
}
else {
VARCHAR.writeSlice(builder, Slices.utf8Slice(element));
}
}
return builder.build();
}
Expand Down

0 comments on commit c0d70e6

Please sign in to comment.