Skip to content

Commit

Permalink
[FLINK-36706][table] Refactor TypeInferenceExtractor for PTFs
Browse files Browse the repository at this point in the history
  • Loading branch information
twalthr authored Dec 19, 2024
1 parent d6c3e8c commit ab8286e
Show file tree
Hide file tree
Showing 27 changed files with 2,257 additions and 1,173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ import org.apache.flink.table.annotation.{DataTypeHint, FunctionHint}
import org.apache.flink.table.api.DataTypes
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.table.types.extraction.TypeInferenceExtractorTest.TestSpec
import org.apache.flink.table.types.inference.{ArgumentTypeStrategy, InputTypeStrategies, TypeStrategies}
import org.apache.flink.table.types.inference.{ArgumentTypeStrategy, InputTypeStrategies, StaticArgument, TypeStrategies}

import org.assertj.core.api.AssertionsForClassTypes.assertThat
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource

import java.util
import java.util.{stream, Optional}

import scala.annotation.varargs
Expand All @@ -36,28 +37,24 @@ class TypeInferenceExtractorScalaTest {

@ParameterizedTest
@MethodSource(Array("testData"))
def testArgumentNames(testSpec: TestSpec): Unit = {
if (testSpec.expectedArgumentNames != null) {
assertThat(testSpec.typeInferenceExtraction.get.getNamedArguments)
.isEqualTo(Optional.of(testSpec.expectedArgumentNames))
}
}

@ParameterizedTest
@MethodSource(Array("testData"))
def testArgumentTypes(testSpec: TestSpec): Unit = {
if (testSpec.expectedArgumentTypes != null) {
assertThat(testSpec.typeInferenceExtraction.get.getTypedArguments)
.isEqualTo(Optional.of(testSpec.expectedArgumentTypes))
def testStaticArguments(testSpec: TestSpec): Unit = {
if (testSpec.expectedStaticArguments != null) {
val staticArguments = testSpec.typeInferenceExtraction.get.getStaticArguments
assertThat(staticArguments).isEqualTo(Optional.of(testSpec.expectedStaticArguments))
}
}

@ParameterizedTest
@MethodSource(Array("testData"))
def testOutputTypeStrategy(testSpec: TestSpec): Unit = {
if (!testSpec.expectedOutputStrategies.isEmpty) {
assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy)
.isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies))
if (testSpec.expectedOutputStrategies.size == 1) {
assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy)
.isEqualTo(testSpec.expectedOutputStrategies.values.iterator.next)
} else {
assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy)
.isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies))
}
}
}
}
Expand All @@ -68,22 +65,12 @@ object TypeInferenceExtractorScalaTest {
// Scala function with data type hint
TestSpec
.forScalarFunction(classOf[ScalaScalarFunction])
.expectNamedArguments("i", "s", "d")
.expectTypedArguments(
DataTypes.INT.notNull().bridgedTo(classOf[Int]),
DataTypes.STRING,
DataTypes.DECIMAL(10, 4))
.expectOutputMapping(
InputTypeStrategies.sequence(
Array[String]("i", "s", "d"),
Array[ArgumentTypeStrategy](
InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
InputTypeStrategies.explicit(DataTypes.STRING),
InputTypeStrategies.explicit(DataTypes.DECIMAL(10, 4))
)
),
TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))
),
.expectStaticArgument(
StaticArgument.scalar("i", DataTypes.INT.notNull().bridgedTo(classOf[Int]), false))
.expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING, false))
.expectStaticArgument(StaticArgument.scalar("d", DataTypes.DECIMAL(10, 4), false))
.expectOutput(TypeStrategies.explicit(
DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
TestSpec
.forScalarFunction(classOf[ScalaPrimitiveVarArgScalarFunction])
.expectOutputMapping(
Expand Down Expand Up @@ -128,11 +115,17 @@ object TypeInferenceExtractorScalaTest {
TestSpec
.forScalarFunction(classOf[ScalaGlobalOutputFunctionHint])
.expectOutputMapping(
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.INT)),
TypeStrategies.explicit(DataTypes.INT))
InputTypeStrategies.sequence(
Array[String]("arg0"),
Array[ArgumentTypeStrategy](InputTypeStrategies.explicit(DataTypes.INT))),
TypeStrategies.explicit(DataTypes.INT)
)
.expectOutputMapping(
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.STRING)),
TypeStrategies.explicit(DataTypes.INT))
InputTypeStrategies.sequence(
Array[String]("arg0"),
Array[ArgumentTypeStrategy](InputTypeStrategies.explicit(DataTypes.STRING))),
TypeStrategies.explicit(DataTypes.INT)
)
)

// ----------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public enum ArgumentTrait {
*
* <p>It's the default if no {@link ArgumentHint} is provided.
*/
SCALAR(StaticArgumentTrait.SCALAR),
SCALAR(true, StaticArgumentTrait.SCALAR),

/**
* An argument that accepts a table "as row" (i.e. with row semantics). This trait only applies
Expand All @@ -56,7 +56,7 @@ public enum ArgumentTrait {
* can be processed independently. The framework is free in how to distribute rows across
* virtual processors and each virtual processor has access only to the currently processed row.
*/
TABLE_AS_ROW(StaticArgumentTrait.TABLE_AS_ROW),
TABLE_AS_ROW(true, StaticArgumentTrait.TABLE_AS_ROW),

/**
* An argument that accepts a table "as set" (i.e. with set semantics). This trait only applies
Expand All @@ -77,22 +77,28 @@ public enum ArgumentTrait {
* <p>It is also possible not to provide a key ({@link #OPTIONAL_PARTITION_BY}), in which case
* only one virtual processor handles the entire table, thereby losing scalability benefits.
*/
TABLE_AS_SET(StaticArgumentTrait.TABLE_AS_SET),
TABLE_AS_SET(true, StaticArgumentTrait.TABLE_AS_SET),

/**
* Defines that a PARTITION BY clause is optional for {@link #TABLE_AS_SET}. By default, it is
* mandatory for improving the parallel execution by distributing the table by key.
*/
OPTIONAL_PARTITION_BY(StaticArgumentTrait.OPTIONAL_PARTITION_BY, TABLE_AS_SET);
OPTIONAL_PARTITION_BY(false, StaticArgumentTrait.OPTIONAL_PARTITION_BY, TABLE_AS_SET);

private final boolean isRoot;
private final StaticArgumentTrait staticTrait;
private final Set<ArgumentTrait> requirements;

ArgumentTrait(StaticArgumentTrait staticTrait, ArgumentTrait... requirements) {
ArgumentTrait(boolean isRoot, StaticArgumentTrait staticTrait, ArgumentTrait... requirements) {
this.isRoot = isRoot;
this.staticTrait = staticTrait;
this.requirements = Arrays.stream(requirements).collect(Collectors.toSet());
}

public boolean isRoot() {
return isRoot;
}

public Set<ArgumentTrait> getRequirements() {
return requirements;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.annotation.FunctionHint;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.types.extraction.TypeInferenceExtractor;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.util.Collector;

Expand Down Expand Up @@ -225,8 +226,9 @@ public final FunctionKind getKind() {
}

@Override
@SuppressWarnings({"unchecked", "rawtypes"})
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
throw new UnsupportedOperationException("Type inference is not implemented yet.");
return TypeInferenceExtractor.forProcessTableFunction(typeFactory, (Class) getClass());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.flink.table.functions.python.utils.PythonFunctionUtils;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.extraction.ExtractionUtils;
import org.apache.flink.table.types.extraction.ExtractionUtils.Autoboxing;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.util.InstantiationUtil;

Expand Down Expand Up @@ -92,6 +93,8 @@ public final class UserDefinedFunctionHelper {

public static final String ASYNC_TABLE_EVAL = "eval";

public static final String PROCESS_TABLE_EVAL = "eval";

/**
* Tries to infer the TypeInformation of an AggregateFunction's accumulator type.
*
Expand Down Expand Up @@ -320,9 +323,13 @@ public static void validateClassForRuntime(
methods.stream()
.anyMatch(
method ->
ExtractionUtils.isInvokable(method, argumentClasses)
// Strict autoboxing is disabled for backwards compatibility
ExtractionUtils.isInvokable(
Autoboxing.JVM, method, argumentClasses)
&& ExtractionUtils.isAssignable(
outputClass, method.getReturnType(), true));
outputClass,
method.getReturnType(),
Autoboxing.JVM));
if (!isMatching) {
throw new ValidationException(
String.format(
Expand Down
Loading

0 comments on commit ab8286e

Please sign in to comment.