From f9690fc7ec12b7f216d99ad2ed778a0c541978f7 Mon Sep 17 00:00:00 2001 From: James Petty Date: Wed, 15 Nov 2023 14:19:48 -0500 Subject: [PATCH] Refactor RunLengthEncodedBlock hash combination logic --- .../operator/FlatHashStrategyCompiler.java | 27 +++++++++---------- .../operator/scalar/CombineHashFunction.java | 13 +++++++++ 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java index 499ef7125004..0101154a2038 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java @@ -443,16 +443,6 @@ private static MethodDefinition generateHashBlockVectorized(ClassDefinition defi block, position); - BytecodeExpression setHashExpression; - if (field.index() == 0) { - // hashes[index] = hash; - setHashExpression = hashes.setElement(index, hash); - } - else { - // hashes[index] = CombineHashFunction.getHash(hashes[index], hash); - setHashExpression = hashes.setElement(index, invokeStatic(CombineHashFunction.class, "getHash", long.class, hashes.getElement(index), hash)); - } - BytecodeBlock rleHandling = new BytecodeBlock() .append(new IfStatement("hash = block.isNull(position) ? NULL_HASH_CODE : hash(block, position)") .condition(block.invoke("isNull", boolean.class, position)) @@ -463,11 +453,18 @@ private static MethodDefinition generateHashBlockVectorized(ClassDefinition defi rleHandling.append(invokeStatic(Arrays.class, "fill", void.class, hashes, constantInt(0), length, hash)); } else { - rleHandling.append(new ForLoop("for (int index = 0; index < length; index++) { hashes[index] = CombineHashFunction.getHash(hashes[index], hash); }") - .initialize(index.set(constantInt(0))) - .condition(lessThan(index, length)) - .update(index.increment()) - .body(setHashExpression)); + // CombineHashFunction.combineAllHashesWithConstant(hashes, 0, length, hash) + rleHandling.append(invokeStatic(CombineHashFunction.class, "combineAllHashesWithConstant", void.class, hashes, constantInt(0), length, hash)); + } + + BytecodeExpression setHashExpression; + if (field.index() == 0) { + // hashes[index] = hash; + setHashExpression = hashes.setElement(index, hash); + } + else { + // hashes[index] = CombineHashFunction.getHash(hashes[index], hash); + setHashExpression = hashes.setElement(index, invokeStatic(CombineHashFunction.class, "getHash", long.class, hashes.getElement(index), hash)); } BytecodeBlock computeHashLoop = new BytecodeBlock() diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/CombineHashFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/CombineHashFunction.java index ed08ebec9bd5..42e4ef800a35 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/CombineHashFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/CombineHashFunction.java @@ -13,10 +13,13 @@ */ package io.trino.operator.scalar; +import io.trino.annotation.UsedByGeneratedCode; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlType; import io.trino.spi.type.StandardTypes; +import static java.util.Objects.checkFromToIndex; + public final class CombineHashFunction { private CombineHashFunction() {} @@ -27,4 +30,14 @@ public static long getHash(@SqlType(StandardTypes.BIGINT) long previousHashValue { return (31 * previousHashValue + value); } + + @UsedByGeneratedCode + public static void combineAllHashesWithConstant(long[] hashes, int fromIndex, int toIndex, long value) + { + checkFromToIndex(fromIndex, toIndex, hashes.length); + + for (int i = 0; i < toIndex; i++) { + hashes[i] = (31 * hashes[i]) + value; + } + } }