diff --git a/flink-table/flink-sql-client/src/test/resources/sql/function.q b/flink-table/flink-sql-client/src/test/resources/sql/function.q
index 4cc0c043dc28ca..5e0466b0cd2ab2 100644
--- a/flink-table/flink-sql-client/src/test/resources/sql/function.q
+++ b/flink-table/flink-sql-client/src/test/resources/sql/function.q
@@ -406,7 +406,7 @@ describe function extended temp_upperudf;
| requirements | $VAR_UDF_JAR_PATH_SPACE [] |
| is deterministic | $VAR_UDF_JAR_PATH_SPACE true |
| supports constant folding | $VAR_UDF_JAR_PATH_SPACE true |
-| signature | $VAR_UDF_JAR_PATH_SPACE c1.db.temp_upperudf(arg0 => STRING) |
+| signature | $VAR_UDF_JAR_PATH_SPACE c1.db.temp_upperudf(STRING) |
+---------------------------+---------------------------------------------$VAR_UDF_JAR_PATH_DASH+
10 rows in set
!ok
@@ -437,7 +437,7 @@ desc function extended temp_upperudf;
| requirements | $VAR_UDF_JAR_PATH_SPACE [] |
| is deterministic | $VAR_UDF_JAR_PATH_SPACE true |
| supports constant folding | $VAR_UDF_JAR_PATH_SPACE true |
-| signature | $VAR_UDF_JAR_PATH_SPACE c1.db.temp_upperudf(arg0 => STRING) |
+| signature | $VAR_UDF_JAR_PATH_SPACE c1.db.temp_upperudf(STRING) |
+---------------------------+---------------------------------------------$VAR_UDF_JAR_PATH_DASH+
10 rows in set
!ok
@@ -498,7 +498,7 @@ describe function extended temp_upperudf;
| requirements | [] |
| is deterministic | true |
| supports constant folding | true |
-| signature | temp_upperudf(arg0 => STRING) |
+| signature | temp_upperudf(STRING) |
+---------------------------+-------------------------------+
7 rows in set
!ok
diff --git a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala
index bf79a1eb7d92f8..f2286754f6b4e4 100644
--- a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala
+++ b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala
@@ -21,12 +21,13 @@ import org.apache.flink.table.annotation.{DataTypeHint, FunctionHint}
import org.apache.flink.table.api.DataTypes
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.table.types.extraction.TypeInferenceExtractorTest.TestSpec
-import org.apache.flink.table.types.inference.{ArgumentTypeStrategy, InputTypeStrategies, TypeStrategies}
+import org.apache.flink.table.types.inference.{ArgumentTypeStrategy, InputTypeStrategies, StaticArgument, TypeStrategies}
import org.assertj.core.api.AssertionsForClassTypes.assertThat
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
+import java.util
import java.util.{stream, Optional}
import scala.annotation.varargs
@@ -36,19 +37,10 @@ class TypeInferenceExtractorScalaTest {
@ParameterizedTest
@MethodSource(Array("testData"))
- def testArgumentNames(testSpec: TestSpec): Unit = {
- if (testSpec.expectedArgumentNames != null) {
- assertThat(testSpec.typeInferenceExtraction.get.getNamedArguments)
- .isEqualTo(Optional.of(testSpec.expectedArgumentNames))
- }
- }
-
- @ParameterizedTest
- @MethodSource(Array("testData"))
- def testArgumentTypes(testSpec: TestSpec): Unit = {
- if (testSpec.expectedArgumentTypes != null) {
- assertThat(testSpec.typeInferenceExtraction.get.getTypedArguments)
- .isEqualTo(Optional.of(testSpec.expectedArgumentTypes))
+ def testStaticArguments(testSpec: TestSpec): Unit = {
+ if (testSpec.expectedStaticArguments != null) {
+ val staticArguments = testSpec.typeInferenceExtraction.get.getStaticArguments
+ assertThat(staticArguments).isEqualTo(Optional.of(testSpec.expectedStaticArguments))
}
}
@@ -56,8 +48,13 @@ class TypeInferenceExtractorScalaTest {
@MethodSource(Array("testData"))
def testOutputTypeStrategy(testSpec: TestSpec): Unit = {
if (!testSpec.expectedOutputStrategies.isEmpty) {
- assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy)
- .isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies))
+ if (testSpec.expectedOutputStrategies.size == 1) {
+ assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy)
+ .isEqualTo(testSpec.expectedOutputStrategies.values.iterator.next)
+ } else {
+ assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy)
+ .isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies))
+ }
}
}
}
@@ -68,22 +65,12 @@ object TypeInferenceExtractorScalaTest {
// Scala function with data type hint
TestSpec
.forScalarFunction(classOf[ScalaScalarFunction])
- .expectNamedArguments("i", "s", "d")
- .expectTypedArguments(
- DataTypes.INT.notNull().bridgedTo(classOf[Int]),
- DataTypes.STRING,
- DataTypes.DECIMAL(10, 4))
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- Array[String]("i", "s", "d"),
- Array[ArgumentTypeStrategy](
- InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
- InputTypeStrategies.explicit(DataTypes.STRING),
- InputTypeStrategies.explicit(DataTypes.DECIMAL(10, 4))
- )
- ),
- TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))
- ),
+ .expectStaticArgument(
+ StaticArgument.scalar("i", DataTypes.INT.notNull().bridgedTo(classOf[Int]), false))
+ .expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING, false))
+ .expectStaticArgument(StaticArgument.scalar("d", DataTypes.DECIMAL(10, 4), false))
+ .expectOutput(TypeStrategies.explicit(
+ DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
TestSpec
.forScalarFunction(classOf[ScalaPrimitiveVarArgScalarFunction])
.expectOutputMapping(
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java
index df44b5a64f7be7..fe161faac7cb7d 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java
@@ -40,7 +40,7 @@ public enum ArgumentTrait {
*
*
It's the default if no {@link ArgumentHint} is provided.
*/
- SCALAR(StaticArgumentTrait.SCALAR),
+ SCALAR(true, StaticArgumentTrait.SCALAR),
/**
* An argument that accepts a table "as row" (i.e. with row semantics). This trait only applies
@@ -56,7 +56,7 @@ public enum ArgumentTrait {
* can be processed independently. The framework is free in how to distribute rows across
* virtual processors and each virtual processor has access only to the currently processed row.
*/
- TABLE_AS_ROW(StaticArgumentTrait.TABLE_AS_ROW),
+ TABLE_AS_ROW(true, StaticArgumentTrait.TABLE_AS_ROW),
/**
* An argument that accepts a table "as set" (i.e. with set semantics). This trait only applies
@@ -77,22 +77,28 @@ public enum ArgumentTrait {
*
It is also possible not to provide a key ({@link #OPTIONAL_PARTITION_BY}), in which case
* only one virtual processor handles the entire table, thereby losing scalability benefits.
*/
- TABLE_AS_SET(StaticArgumentTrait.TABLE_AS_SET),
+ TABLE_AS_SET(true, StaticArgumentTrait.TABLE_AS_SET),
/**
* Defines that a PARTITION BY clause is optional for {@link #TABLE_AS_SET}. By default, it is
* mandatory for improving the parallel execution by distributing the table by key.
*/
- OPTIONAL_PARTITION_BY(StaticArgumentTrait.OPTIONAL_PARTITION_BY, TABLE_AS_SET);
+ OPTIONAL_PARTITION_BY(false, StaticArgumentTrait.OPTIONAL_PARTITION_BY, TABLE_AS_SET);
+ private final boolean isRoot;
private final StaticArgumentTrait staticTrait;
private final Set requirements;
- ArgumentTrait(StaticArgumentTrait staticTrait, ArgumentTrait... requirements) {
+ ArgumentTrait(boolean isRoot, StaticArgumentTrait staticTrait, ArgumentTrait... requirements) {
+ this.isRoot = isRoot;
this.staticTrait = staticTrait;
this.requirements = Arrays.stream(requirements).collect(Collectors.toSet());
}
+ public boolean isRoot() {
+ return isRoot;
+ }
+
public Set getRequirements() {
return requirements;
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java
index 1e2e215dca8fcd..4dd05ebb1cbb14 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java
@@ -24,6 +24,7 @@
import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.annotation.FunctionHint;
import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.types.extraction.TypeInferenceExtractor;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.util.Collector;
@@ -225,8 +226,9 @@ public final FunctionKind getKind() {
}
@Override
+ @SuppressWarnings({"unchecked", "rawtypes"})
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
- throw new UnsupportedOperationException("Type inference is not implemented yet.");
+ return TypeInferenceExtractor.forProcessTableFunction(typeFactory, (Class) getClass());
}
/**
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java
index 354146d70c3717..23307852cbdc6c 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java
@@ -92,6 +92,8 @@ public final class UserDefinedFunctionHelper {
public static final String ASYNC_TABLE_EVAL = "eval";
+ public static final String PROCESS_TABLE_EVAL = "eval";
+
/**
* Tries to infer the TypeInformation of an AggregateFunction's accumulator type.
*
@@ -320,9 +322,12 @@ public static void validateClassForRuntime(
methods.stream()
.anyMatch(
method ->
- ExtractionUtils.isInvokable(method, argumentClasses)
+ ExtractionUtils.isInvokable(false, method, argumentClasses)
&& ExtractionUtils.isAssignable(
- outputClass, method.getReturnType(), true));
+ outputClass,
+ method.getReturnType(),
+ true,
+ false));
if (!isMatching) {
throw new ValidationException(
String.format(
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java
index 8e3aab464e9351..2e849eb1d68e23 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java
@@ -21,13 +21,21 @@
import org.apache.flink.table.annotation.ArgumentHint;
import org.apache.flink.table.annotation.ArgumentTrait;
import org.apache.flink.table.annotation.DataTypeHint;
+import org.apache.flink.table.annotation.StateHint;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.data.RowData;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.procedures.Procedure;
import org.apache.flink.table.types.CollectionDataType;
import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionOutputTemplate;
+import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionStateTemplate;
+import org.apache.flink.table.types.inference.StaticArgumentTrait;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
import javax.annotation.Nullable;
@@ -36,6 +44,7 @@
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Arrays;
+import java.util.EnumSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -65,27 +74,110 @@ abstract class BaseMappingExtractor {
protected final DataTypeFactory typeFactory;
- private final String methodName;
+ protected final String methodName;
private final SignatureExtraction signatureExtraction;
protected final ResultExtraction outputExtraction;
- protected final MethodVerification verification;
+ protected final MethodVerification outputVerification;
public BaseMappingExtractor(
DataTypeFactory typeFactory,
String methodName,
SignatureExtraction signatureExtraction,
ResultExtraction outputExtraction,
- MethodVerification verification) {
+ MethodVerification outputVerification) {
this.typeFactory = typeFactory;
this.methodName = methodName;
this.signatureExtraction = signatureExtraction;
this.outputExtraction = outputExtraction;
- this.verification = verification;
+ this.outputVerification = outputVerification;
+ }
+
+ Map extractOutputMapping() {
+ try {
+ return extractResultMappings(
+ outputExtraction, FunctionTemplate::getOutputTemplate, outputVerification);
+ } catch (Throwable t) {
+ throw extractionError(t, "Error in extracting a signature to output mapping.");
+ }
+ }
+
+ // --------------------------------------------------------------------------------------------
+ // Extraction strategies
+ // --------------------------------------------------------------------------------------------
+
+ /**
+ * Extraction that uses the method parameters for producing a {@link FunctionSignatureTemplate}.
+ */
+ static SignatureExtraction createArgumentsFromParametersExtraction(
+ int offset, @Nullable Class> contextClass) {
+ return (extractor, method) -> {
+ final List args =
+ extractArgumentParameters(method, offset, contextClass);
+
+ final EnumSet[] argumentTraits = extractArgumentTraits(args);
+
+ final List argumentTemplates =
+ extractArgumentTemplates(
+ extractor.typeFactory, extractor.getFunctionClass(), args);
+
+ final String[] argumentNames = extractArgumentNames(method, args);
+
+ final boolean[] argumentOptionals = extractArgumentOptionals(args);
+
+ return FunctionSignatureTemplate.of(
+ argumentTemplates,
+ method.isVarArgs(),
+ argumentTraits,
+ argumentNames,
+ argumentOptionals);
+ };
+ }
+
+ /**
+ * Extraction that uses the method parameters for producing a {@link FunctionSignatureTemplate}.
+ */
+ static SignatureExtraction createArgumentsFromParametersExtraction(int offset) {
+ return createArgumentsFromParametersExtraction(offset, null);
+ }
+
+ /** Extraction that uses the method parameters with {@link StateHint} for state entries. */
+ static ResultExtraction createStateFromParametersExtraction() {
+ return (extractor, method) -> {
+ final List stateParameters = extractStateParameters(method);
+ return createStateTemplateFromParameters(extractor, method, stateParameters);
+ };
+ }
+
+ /**
+ * Extraction that uses a generic type variable for producing a {@link FunctionStateTemplate}.
+ * Or method parameters with {@link StateHint} for state entries as a fallback.
+ */
+ static ResultExtraction createStateFromGenericInClassOrParameters(
+ Class extends UserDefinedFunction> baseClass, int genericPos) {
+ return (extractor, method) -> {
+ final List stateParameters = extractStateParameters(method);
+ if (stateParameters.isEmpty()) {
+ final DataType dataType =
+ DataTypeExtractor.extractFromGeneric(
+ extractor.typeFactory,
+ baseClass,
+ genericPos,
+ extractor.getFunctionClass());
+ final LinkedHashMap state = new LinkedHashMap<>();
+ state.put("acc", dataType);
+ return FunctionResultTemplate.ofState(state);
+ }
+ return createStateTemplateFromParameters(extractor, method, stateParameters);
+ };
}
+ // --------------------------------------------------------------------------------------------
+ // Methods for subclasses
+ // --------------------------------------------------------------------------------------------
+
protected abstract Set extractGlobalFunctionTemplates();
protected abstract Set extractLocalFunctionTemplates(Method method);
@@ -96,27 +188,35 @@ public BaseMappingExtractor(
protected abstract String getHintType();
- Map extractOutputMapping() {
- try {
- return extractResultMappings(
- outputExtraction, FunctionTemplate::getOutputTemplate, verification);
- } catch (Throwable t) {
- throw extractionError(t, "Error in extracting a signature to output mapping.");
- }
+ protected static Class>[] assembleParameters(List> state, List> arguments) {
+ return Stream.concat(state.stream(), arguments.stream()).toArray(Class[]::new);
+ }
+
+ protected static ValidationException createMethodNotFoundError(
+ String methodName,
+ Class>[] parameters,
+ @Nullable Class> returnType,
+ String pattern) {
+ return extractionError(
+ "Considering all hints, the method should comply with the signature:\n%s%s",
+ createMethodSignatureString(methodName, parameters, returnType),
+ pattern.isEmpty() ? "" : "\nPattern: " + pattern);
}
/**
- * Extracts mappings from signature to result (either accumulator or output) for the entire
- * function. Verifies if the extracted inference matches with the implementation.
+ * Extracts mappings from signature to result (either state or output) for the entire function.
+ * Verifies if the extracted inference matches with the implementation.
*
* For example, from {@code (INT, BOOLEAN, ANY) -> INT}. It does this by going through all
* implementation methods and collecting all "per-method" mappings. The function mapping is the
* union of all "per-method" mappings.
*/
- protected Map extractResultMappings(
- ResultExtraction resultExtraction,
- Function accessor,
- MethodVerification verification) {
+ @SuppressWarnings("unchecked")
+ protected
+ Map extractResultMappings(
+ ResultExtraction resultExtraction,
+ Function accessor,
+ @Nullable MethodVerification verification) {
final Set global = extractGlobalFunctionTemplates();
final Set globalResultOnly =
findResultOnlyTemplates(global, accessor);
@@ -125,7 +225,7 @@ protected Map extractResultMa
final Map collectedMappings =
new LinkedHashMap<>();
final List methods = collectMethods(methodName);
- if (methods.size() == 0) {
+ if (methods.isEmpty()) {
throw extractionError(
"Could not find a publicly accessible method named '%s'.", methodName);
}
@@ -145,9 +245,6 @@ protected Map extractResultMa
// check if the method can be called
verifyMappingForMethod(correctMethod, collectedMappingsPerMethod, verification);
- // check if we declare optional on a primitive type parameter
- verifyOptionalOnPrimitiveParameter(correctMethod, collectedMappingsPerMethod);
-
// check if method strategies conflict with function strategies
collectedMappingsPerMethod.forEach(
(signature, result) -> putMapping(collectedMappings, signature, result));
@@ -158,48 +255,55 @@ protected Map extractResultMa
method.toString());
}
}
- return collectedMappings;
+ return (Map) collectedMappings;
}
- /**
- * Special case for Scala which generates two methods when using var-args (a {@code Seq < String
- * >} and {@code String...}). This method searches for the Java-like variant.
- */
- static Method correctVarArgMethod(Method method) {
- final int paramCount = method.getParameterCount();
- final Class>[] paramClasses = method.getParameterTypes();
- if (paramCount > 0
- && paramClasses[paramCount - 1].getName().equals("scala.collection.Seq")) {
- final Type[] paramTypes = method.getGenericParameterTypes();
- final ParameterizedType seqType = (ParameterizedType) paramTypes[paramCount - 1];
- final Type varArgType = seqType.getActualTypeArguments()[0];
- return ExtractionUtils.collectMethods(method.getDeclaringClass(), method.getName())
- .stream()
- .filter(Method::isVarArgs)
- .filter(candidate -> candidate.getParameterCount() == paramCount)
- .filter(
- candidate -> {
- final Type[] candidateParamTypes =
- candidate.getGenericParameterTypes();
- for (int i = 0; i < paramCount - 1; i++) {
- if (candidateParamTypes[i] != paramTypes[i]) {
- return false;
- }
- }
- final Class> candidateVarArgType =
- candidate.getParameterTypes()[paramCount - 1];
- return candidateVarArgType.isArray()
- &&
- // check for Object is needed in case of Scala primitives
- // (e.g. Int)
- (varArgType == Object.class
- || candidateVarArgType.getComponentType()
- == varArgType);
- })
- .findAny()
- .orElse(method);
+ protected static void checkNoState(@Nullable List> state) {
+ if (state != null && !state.isEmpty()) {
+ throw extractionError("State is not supported for this kind of function.");
+ }
+ }
+
+ protected static void checkSingleState(@Nullable List> state) {
+ if (state == null || state.size() != 1) {
+ throw extractionError(
+ "Aggregating functions need exactly one state entry for the accumulator.");
}
- return method;
+ }
+
+ // --------------------------------------------------------------------------------------------
+ // Helper methods
+ // --------------------------------------------------------------------------------------------
+
+ private static FunctionStateTemplate createStateTemplateFromParameters(
+ BaseMappingExtractor extractor, Method method, List stateParameters) {
+ final String[] argumentNames = extractStateNames(method, stateParameters);
+ if (argumentNames == null) {
+ throw extractionError("Unable to extract names for all state entries.");
+ }
+
+ final List dataTypes =
+ stateParameters.stream()
+ .map(
+ s ->
+ DataTypeExtractor.extractFromMethodParameter(
+ extractor.typeFactory,
+ extractor.getFunctionClass(),
+ s.method,
+ s.pos))
+ .collect(Collectors.toList());
+
+ final LinkedHashMap state =
+ IntStream.range(0, dataTypes.size())
+ .mapToObj(i -> Map.entry(argumentNames[i], dataTypes.get(i)))
+ .collect(
+ Collectors.toMap(
+ Map.Entry::getKey,
+ Map.Entry::getValue,
+ (o, n) -> o,
+ LinkedHashMap::new));
+
+ return FunctionResultTemplate.ofState(state);
}
/**
@@ -247,9 +351,47 @@ private Map collectMethodMapp
return collectedMappingsPerMethod;
}
- // --------------------------------------------------------------------------------------------
- // Helper methods (ordered by invocation order)
- // --------------------------------------------------------------------------------------------
+ /**
+ * Special case for Scala which generates two methods when using var-args (a {@code Seq < String
+ * >} and {@code String...}). This method searches for the Java-like variant.
+ */
+ private static Method correctVarArgMethod(Method method) {
+ final int paramCount = method.getParameterCount();
+ final Class>[] paramClasses = method.getParameterTypes();
+ if (paramCount > 0
+ && paramClasses[paramCount - 1].getName().equals("scala.collection.Seq")) {
+ final Type[] paramTypes = method.getGenericParameterTypes();
+ final ParameterizedType seqType = (ParameterizedType) paramTypes[paramCount - 1];
+ final Type varArgType = seqType.getActualTypeArguments()[0];
+ return ExtractionUtils.collectMethods(method.getDeclaringClass(), method.getName())
+ .stream()
+ .filter(Method::isVarArgs)
+ .filter(candidate -> candidate.getParameterCount() == paramCount)
+ .filter(
+ candidate -> {
+ final Type[] candidateParamTypes =
+ candidate.getGenericParameterTypes();
+ for (int i = 0; i < paramCount - 1; i++) {
+ if (candidateParamTypes[i] != paramTypes[i]) {
+ return false;
+ }
+ }
+ final Class> candidateVarArgType =
+ candidate.getParameterTypes()[paramCount - 1];
+ return candidateVarArgType.isArray()
+ &&
+ // check for Object is needed in case of Scala primitives
+ // (e.g. Int)
+ (varArgType == Object.class
+ || candidateVarArgType.getComponentType()
+ == varArgType);
+ })
+ .findAny()
+ .orElse(method);
+ }
+ return method;
+ }
+
/** Explicit mappings with complete signature to result declaration. */
private void putExplicitMappings(
Map collectedMappings,
@@ -322,166 +464,259 @@ else if (!existingResult.equals(result)) {
private void verifyMappingForMethod(
Method method,
Map collectedMappingsPerMethod,
- MethodVerification verification) {
+ @Nullable MethodVerification verification) {
+ if (verification == null) {
+ return;
+ }
collectedMappingsPerMethod.forEach(
- (signature, result) ->
- verification.verify(method, signature.toClass(), result.toClass()));
- }
-
- private void verifyOptionalOnPrimitiveParameter(
- Method method,
- Map collectedMappingsPerMethod) {
- collectedMappingsPerMethod
- .keySet()
- .forEach(
- signature -> {
- Boolean[] argumentOptional = signature.argumentOptionals;
- // Here we restrict that the argument must contain optional parameters
- // in order to obtain the FunctionSignatureTemplate of the method for
- // verification. Therefore, the extract method will only be called once.
- // If no function hint is set, this verify will not be executed.
- if (argumentOptional != null
- && Arrays.stream(argumentOptional)
- .anyMatch(Boolean::booleanValue)) {
- FunctionSignatureTemplate functionResultTemplate =
- signatureExtraction.extract(this, method);
- for (int i = 0; i < argumentOptional.length; i++) {
- DataType dataType =
- functionResultTemplate.argumentTemplates.get(i)
- .dataType;
- if (dataType != null
- && argumentOptional[i]
- && dataType.getConversionClass() != null
- && dataType.getConversionClass().isPrimitive()) {
- throw extractionError(
- "Argument at position %d is optional but a primitive type doesn't accept null value.",
- i);
- }
- }
- }
- });
+ (signature, result) -> {
+ if (result instanceof FunctionStateTemplate) {
+ final FunctionStateTemplate stateTemplate = (FunctionStateTemplate) result;
+ verification.verify(
+ method, stateTemplate.toClassList(), signature.toClassList(), null);
+ } else if (result instanceof FunctionOutputTemplate) {
+ final FunctionOutputTemplate outputTemplate =
+ (FunctionOutputTemplate) result;
+ verification.verify(
+ method,
+ List.of(),
+ signature.toClassList(),
+ outputTemplate.toClass());
+ }
+ });
}
// --------------------------------------------------------------------------------------------
- // Context sensitive extraction and verification logic
+ // Parameters extraction (i.e. state and arguments)
// --------------------------------------------------------------------------------------------
- /**
- * Extraction that uses the method parameters for producing a {@link FunctionSignatureTemplate}.
- */
- static SignatureExtraction createParameterSignatureExtraction(int offset) {
- return (extractor, method) -> {
- final List parameterTypes =
- extractArgumentTemplates(
- extractor.typeFactory, extractor.getFunctionClass(), method, offset);
+ /** Method parameter that qualifies as a function argument (i.e. not a context or state). */
+ private static class ArgumentParameter {
+ final Parameter parameter;
+ final Method method;
+ // Pos in the method, not necessarily in the extracted function
+ final int pos;
+
+ private ArgumentParameter(Parameter parameter, Method method, int pos) {
+ this.parameter = parameter;
+ this.method = method;
+ this.pos = pos;
+ }
+ }
- final String[] argumentNames = extractArgumentNames(method, offset);
+ /** Method parameter that qualifies as a function state (i.e. not a context or argument). */
+ private static class StateParameter {
+ final Parameter parameter;
+ final Method method;
+ // Pos in the method, not necessarily in the extracted function
+ final int pos;
+
+ private StateParameter(Parameter parameter, Method method, int pos) {
+ this.parameter = parameter;
+ this.method = method;
+ this.pos = pos;
+ }
+ }
- final Boolean[] argumentOptionals = extractArgumentOptionals(method, offset);
+ private static List extractArgumentParameters(
+ Method method, int offset, @Nullable Class> contextClass) {
+ final Parameter[] parameters = method.getParameters();
+ return IntStream.range(0, parameters.length)
+ .mapToObj(
+ pos -> {
+ final Parameter parameter = parameters[pos];
+ return new ArgumentParameter(parameter, method, pos);
+ })
+ .skip(offset)
+ .filter(arg -> contextClass == null || arg.parameter.getType() != contextClass)
+ .filter(arg -> arg.parameter.getAnnotation(StateHint.class) == null)
+ .collect(Collectors.toList());
+ }
- return FunctionSignatureTemplate.of(
- parameterTypes, method.isVarArgs(), argumentNames, argumentOptionals);
- };
+ private static List extractStateParameters(Method method) {
+ final Parameter[] parameters = method.getParameters();
+ return IntStream.range(0, parameters.length)
+ .mapToObj(
+ pos -> {
+ final Parameter parameter = parameters[pos];
+ return new StateParameter(parameter, method, pos);
+ })
+ .filter(arg -> arg.parameter.getAnnotation(StateHint.class) != null)
+ .collect(Collectors.toList());
}
private static List extractArgumentTemplates(
- DataTypeFactory typeFactory, Class> extractedClass, Method method, int offset) {
- return IntStream.range(offset, method.getParameterCount())
- .mapToObj(
- i ->
+ DataTypeFactory typeFactory, Class> extractedClass, List args) {
+ return args.stream()
+ .map(
+ arg ->
// check for input group before start extracting a data type
- tryExtractInputGroupArgument(method, i)
+ tryExtractInputGroupArgument(arg)
.orElseGet(
() ->
- extractDataTypeArgument(
- typeFactory,
- extractedClass,
- method,
- i)))
+ extractArgumentByKind(
+ typeFactory, extractedClass, arg)))
.collect(Collectors.toList());
}
- static Optional tryExtractInputGroupArgument(
- Method method, int paramPos) {
- final Parameter parameter = method.getParameters()[paramPos];
- final DataTypeHint hint = parameter.getAnnotation(DataTypeHint.class);
- final ArgumentHint argumentHint = parameter.getAnnotation(ArgumentHint.class);
- if (hint != null && argumentHint != null) {
+ private static Optional tryExtractInputGroupArgument(
+ ArgumentParameter arg) {
+ final DataTypeHint dataTypehint = arg.parameter.getAnnotation(DataTypeHint.class);
+ final ArgumentHint argumentHint = arg.parameter.getAnnotation(ArgumentHint.class);
+ if (dataTypehint != null && argumentHint != null) {
throw extractionError(
- "Argument and dataType hints cannot be declared in the same parameter at position %d.",
- paramPos);
+ "Argument and data type hints cannot be declared at the same time at position %d.",
+ arg.pos);
}
if (argumentHint != null) {
final DataTypeTemplate template = DataTypeTemplate.fromAnnotation(argumentHint, null);
if (template.inputGroup != null) {
- return Optional.of(FunctionArgumentTemplate.of(template.inputGroup));
+ return Optional.of(FunctionArgumentTemplate.ofInputGroup(template.inputGroup));
}
- } else if (hint != null) {
- final DataTypeTemplate template = DataTypeTemplate.fromAnnotation(hint, null);
+ } else if (dataTypehint != null) {
+ final DataTypeTemplate template = DataTypeTemplate.fromAnnotation(dataTypehint, null);
if (template.inputGroup != null) {
- return Optional.of(FunctionArgumentTemplate.of(template.inputGroup));
+ return Optional.of(FunctionArgumentTemplate.ofInputGroup(template.inputGroup));
}
}
return Optional.empty();
}
- private static FunctionArgumentTemplate extractDataTypeArgument(
- DataTypeFactory typeFactory, Class> extractedClass, Method method, int paramPos) {
+ private static FunctionArgumentTemplate extractArgumentByKind(
+ DataTypeFactory typeFactory, Class> extractedClass, ArgumentParameter arg) {
+ final Parameter parameter = arg.parameter;
+ final ArgumentHint argumentHint = parameter.getAnnotation(ArgumentHint.class);
+ final int pos = arg.pos;
+ final Set rootTrait =
+ Optional.ofNullable(argumentHint)
+ .map(
+ hint ->
+ Arrays.stream(hint.value())
+ .filter(ArgumentTrait::isRoot)
+ .collect(Collectors.toSet()))
+ .orElse(Set.of(ArgumentTrait.SCALAR));
+ if (rootTrait.size() != 1) {
+ throw extractionError(
+ "Incorrect argument kind at position %d. Argument kind must be one of: %s",
+ pos,
+ Arrays.stream(ArgumentTrait.values())
+ .filter(ArgumentTrait::isRoot)
+ .collect(Collectors.toList()));
+ }
+
+ if (rootTrait.contains(ArgumentTrait.SCALAR)) {
+ return extractScalarArgument(typeFactory, extractedClass, arg);
+ } else if (rootTrait.contains(ArgumentTrait.TABLE_AS_ROW)
+ || rootTrait.contains(ArgumentTrait.TABLE_AS_SET)) {
+ return extractTableArgument(typeFactory, argumentHint, extractedClass, arg);
+ } else {
+ throw extractionError("Unknown argument kind.");
+ }
+ }
+
+ private static FunctionArgumentTemplate extractTableArgument(
+ DataTypeFactory typeFactory,
+ ArgumentHint argumentHint,
+ Class> extractedClass,
+ ArgumentParameter arg) {
+ try {
+ final DataType type =
+ DataTypeExtractor.extractFromMethodParameter(
+ typeFactory, extractedClass, arg.method, arg.pos);
+ return FunctionArgumentTemplate.ofDataType(type);
+ } catch (Throwable t) {
+ final Class> paramClass = arg.parameter.getType();
+ final Class> argClass = argumentHint.type().bridgedTo();
+ if (argClass == Row.class || argClass == RowData.class) {
+ return FunctionArgumentTemplate.ofTable(argClass);
+ }
+ if (paramClass == Row.class || paramClass == RowData.class) {
+ return FunctionArgumentTemplate.ofTable(paramClass);
+ }
+ // Just a regular error for a typed argument
+ throw t;
+ }
+ }
+
+ private static FunctionArgumentTemplate extractScalarArgument(
+ DataTypeFactory typeFactory, Class> extractedClass, ArgumentParameter arg) {
final DataType type =
DataTypeExtractor.extractFromMethodParameter(
- typeFactory, extractedClass, method, paramPos);
+ typeFactory, extractedClass, arg.method, arg.pos);
// unwrap data type in case of varargs
- if (method.isVarArgs() && paramPos == method.getParameterCount() - 1) {
+ if (arg.parameter.isVarArgs()) {
// for ARRAY
if (type instanceof CollectionDataType) {
- return FunctionArgumentTemplate.of(
+ return FunctionArgumentTemplate.ofDataType(
((CollectionDataType) type).getElementDataType());
}
// special case for varargs that have been misinterpreted as BYTES
else if (type.equals(DataTypes.BYTES())) {
- return FunctionArgumentTemplate.of(
+ return FunctionArgumentTemplate.ofDataType(
DataTypes.TINYINT().notNull().bridgedTo(byte.class));
}
}
- return FunctionArgumentTemplate.of(type);
+ return FunctionArgumentTemplate.ofDataType(type);
}
- static @Nullable String[] extractArgumentNames(Method method, int offset) {
+ @SuppressWarnings("unchecked")
+ private static EnumSet[] extractArgumentTraits(
+ List args) {
+ return args.stream()
+ .map(
+ arg -> {
+ final ArgumentHint argumentHint =
+ arg.parameter.getAnnotation(ArgumentHint.class);
+ if (argumentHint == null) {
+ return EnumSet.of(StaticArgumentTrait.SCALAR);
+ }
+ final List traits =
+ Arrays.stream(argumentHint.value())
+ .map(ArgumentTrait::toStaticTrait)
+ .collect(Collectors.toList());
+ return EnumSet.copyOf(traits);
+ })
+ .toArray(EnumSet[]::new);
+ }
+
+ private static @Nullable String[] extractArgumentNames(
+ Method method, List args) {
final List methodParameterNames =
ExtractionUtils.extractMethodParameterNames(method);
if (methodParameterNames != null) {
- return methodParameterNames
- .subList(offset, methodParameterNames.size())
- .toArray(new String[0]);
+ return args.stream()
+ .map(arg -> methodParameterNames.get(arg.pos))
+ .toArray(String[]::new);
} else {
return null;
}
}
- static Boolean[] extractArgumentOptionals(Method method, int offset) {
- return Arrays.stream(method.getParameters())
- .skip(offset)
- .map(parameter -> parameter.getAnnotation(ArgumentHint.class))
- .map(
- h -> {
- if (h == null) {
- return false;
- }
- final ArgumentTrait[] traits = h.value();
- if (traits.length != 1 || traits[0] != ArgumentTrait.SCALAR) {
- throw extractionError(
- "Only scalar arguments are supported so far.");
- }
- return h.isOptional();
- })
- .toArray(Boolean[]::new);
+ private static @Nullable String[] extractStateNames(Method method, List state) {
+ final List methodParameterNames =
+ ExtractionUtils.extractMethodParameterNames(method);
+ if (methodParameterNames != null) {
+ return state.stream()
+ .map(arg -> methodParameterNames.get(arg.pos))
+ .toArray(String[]::new);
+ } else {
+ return null;
+ }
}
- protected static ValidationException createMethodNotFoundError(
- String methodName, Class>[] parameters, @Nullable Class> returnType) {
- return extractionError(
- "Considering all hints, the method should comply with the signature:\n%s",
- createMethodSignatureString(methodName, parameters, returnType));
+ private static boolean[] extractArgumentOptionals(List args) {
+ final Boolean[] argumentOptionals =
+ args.stream()
+ .map(arg -> arg.parameter.getAnnotation(ArgumentHint.class))
+ .map(
+ hint -> {
+ if (hint == null) {
+ return false;
+ }
+ return hint.isOptional();
+ })
+ .toArray(Boolean[]::new);
+ return ArrayUtils.toPrimitive(argumentOptionals);
}
// --------------------------------------------------------------------------------------------
@@ -489,18 +724,22 @@ protected static ValidationException createMethodNotFoundError(
// --------------------------------------------------------------------------------------------
/** Extracts a {@link FunctionSignatureTemplate} from a method. */
- protected interface SignatureExtraction {
+ interface SignatureExtraction {
FunctionSignatureTemplate extract(BaseMappingExtractor extractor, Method method);
}
/** Extracts a {@link FunctionResultTemplate} from a class or method. */
- protected interface ResultExtraction {
+ interface ResultExtraction {
@Nullable
FunctionResultTemplate extract(BaseMappingExtractor extractor, Method method);
}
/** Verifies the signature of a method. */
protected interface MethodVerification {
- void verify(Method method, List> arguments, Class> result);
+ void verify(
+ Method method,
+ List> state,
+ List> arguments,
+ @Nullable Class> result);
}
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java
index 203b717ec117ca..e3db03d8b2c0c4 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java
@@ -22,6 +22,7 @@
import org.apache.flink.api.java.typeutils.AvroUtils;
import org.apache.flink.table.annotation.ArgumentHint;
import org.apache.flink.table.annotation.DataTypeHint;
+import org.apache.flink.table.annotation.StateHint;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.dataview.DataView;
import org.apache.flink.table.api.dataview.ListView;
@@ -144,8 +145,11 @@ public static DataType extractFromMethodParameter(
final Parameter parameter = method.getParameters()[paramPos];
final DataTypeHint hint = parameter.getAnnotation(DataTypeHint.class);
final ArgumentHint argumentHint = parameter.getAnnotation(ArgumentHint.class);
+ final StateHint stateHint = parameter.getAnnotation(StateHint.class);
final DataTypeTemplate template;
- if (argumentHint != null) {
+ if (stateHint != null) {
+ template = DataTypeTemplate.fromAnnotation(typeFactory, stateHint.type());
+ } else if (argumentHint != null) {
template = DataTypeTemplate.fromAnnotation(typeFactory, argumentHint.type());
} else if (hint != null) {
template = DataTypeTemplate.fromAnnotation(typeFactory, hint);
@@ -206,9 +210,9 @@ public static DataType extractFromGenericMethodParameter(
* Extracts a data type from a method return type by considering surrounding classes and method
* annotation.
*/
- public static DataType extractFromMethodOutput(
+ public static DataType extractFromMethodReturnType(
DataTypeFactory typeFactory, Class> baseClass, Method method) {
- return extractFromMethodOutput(
+ return extractFromMethodReturnType(
typeFactory, baseClass, method, method.getGenericReturnType());
}
@@ -216,7 +220,7 @@ public static DataType extractFromMethodOutput(
* Extracts a data type from a method return type with specifying the method's type explicitly
* by considering surrounding classes and method annotation.
*/
- public static DataType extractFromMethodOutput(
+ public static DataType extractFromMethodReturnType(
DataTypeFactory typeFactory, Class> baseClass, Method method, Type methodReturnType) {
final DataTypeHint hint = method.getAnnotation(DataTypeHint.class);
final DataTypeTemplate template;
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java
index 0d01e04a08ffc8..34473f9c8cae91 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java
@@ -22,6 +22,7 @@
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.table.annotation.ArgumentHint;
+import org.apache.flink.table.annotation.StateHint;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.catalog.DataTypeFactory;
@@ -91,7 +92,8 @@ public static List collectMethods(Class> function, String methodName)
* E.g., {@code (int.class, int.class)} matches {@code f(Object...), f(int, int), f(Integer,
* Object)} and so forth.
*/
- public static boolean isInvokable(Executable executable, Class>... classes) {
+ public static boolean isInvokable(
+ boolean strictAutoboxing, Executable executable, Class>... classes) {
final int m = executable.getModifiers();
if (!Modifier.isPublic(m)) {
return false;
@@ -110,21 +112,25 @@ public static boolean isInvokable(Executable executable, Class>... classes) {
if (currentParam == paramCount - 1 && executable.isVarArgs()) {
final Class> paramComponent =
executable.getParameterTypes()[currentParam].getComponentType();
- // we have more than 1 classes left so the vararg needs to consume them all
+ // we have more than one class left so the vararg needs to consume them all
if (classCount - currentClass > 1) {
while (currentClass < classCount
&& ExtractionUtils.isAssignable(
- classes[currentClass], paramComponent, true)) {
+ classes[currentClass],
+ paramComponent,
+ true,
+ strictAutoboxing)) {
currentClass++;
}
} else if (currentClass < classCount
- && (parameterMatches(classes[currentClass], param)
- || parameterMatches(classes[currentClass], paramComponent))) {
+ && (parameterMatches(strictAutoboxing, classes[currentClass], param)
+ || parameterMatches(
+ strictAutoboxing, classes[currentClass], paramComponent))) {
currentClass++;
}
}
// entire parameter matches
- else if (parameterMatches(classes[currentClass], param)) {
+ else if (parameterMatches(strictAutoboxing, classes[currentClass], param)) {
currentClass++;
}
}
@@ -132,8 +138,9 @@ else if (parameterMatches(classes[currentClass], param)) {
return currentClass == classCount;
}
- private static boolean parameterMatches(Class> clz, Class> param) {
- return clz == null || ExtractionUtils.isAssignable(clz, param, true);
+ private static boolean parameterMatches(
+ boolean strictAutoboxing, Class> clz, Class> param) {
+ return clz == null || ExtractionUtils.isAssignable(clz, param, true, strictAutoboxing);
}
/** Creates a method signature string like {@code int eval(Integer, String)}. */
@@ -298,11 +305,11 @@ private static String normalizeAccessorName(String name) {
/**
* Checks for an invokable constructor matching the given arguments.
*
- * @see #isInvokable(Executable, Class[])
+ * @see #isInvokable(boolean, Executable, Class[])
*/
public static boolean hasInvokableConstructor(Class> clazz, Class>... classes) {
for (Constructor> constructor : clazz.getDeclaredConstructors()) {
- if (isInvokable(constructor, classes)) {
+ if (isInvokable(false, constructor, classes)) {
return true;
}
}
@@ -758,16 +765,20 @@ private AssigningConstructor(Constructor> constructor, List parameterN
} else {
offset = 0;
}
- // by default parameter names are "arg0, arg1, arg2, ..." if compiler flag is not set
- // so we need to extract them manually if possible
+ // by default parameter names are "arg0, arg1, arg2, ..." if compiler flag is not set,
+ // we need to extract them manually if possible
List parameterNames =
Stream.of(executable.getParameters())
.map(
parameter -> {
- ArgumentHint argumentHint =
+ final StateHint stateHint =
+ parameter.getAnnotation(StateHint.class);
+ final ArgumentHint argHint =
parameter.getAnnotation(ArgumentHint.class);
- if (argumentHint != null && !argumentHint.name().isEmpty()) {
- return argumentHint.name();
+ if (stateHint != null && !stateHint.name().isEmpty()) {
+ return stateHint.name();
+ } else if (argHint != null && !argHint.name().isEmpty()) {
+ return argHint.name();
} else {
return parameter.getName();
}
@@ -787,7 +798,7 @@ private AssigningConstructor(Constructor> constructor, List parameterN
return null;
}
// remove "this" and additional local variables
- // select less names if class file has not the required information
+ // select fewer names if class file has not the required information
parameterNames =
extractedNames.subList(
offset,
@@ -936,10 +947,11 @@ public void visitLocalVariable(
* @param cls the Class to check, may be null
* @param toClass the Class to try to assign into, returns false if null
* @param autoboxing whether to use implicit autoboxing/unboxing between primitives and wrappers
+ * @param strictAutoboxing checks whether null would end up in a primitive type and forbids it
* @return {@code true} if assignment possible
*/
public static boolean isAssignable(
- Class> cls, final Class> toClass, final boolean autoboxing) {
+ Class> cls, final Class> toClass, boolean autoboxing, boolean strictAutoboxing) {
if (toClass == null) {
return false;
}
@@ -955,10 +967,12 @@ public static boolean isAssignable(
return false;
}
}
- if (toClass.isPrimitive() && !cls.isPrimitive()) {
- cls = wrapperToPrimitive(cls);
- if (cls == null) {
- return false;
+ if (!strictAutoboxing) {
+ if (toClass.isPrimitive() && !cls.isPrimitive()) {
+ cls = wrapperToPrimitive(cls);
+ if (cls == null) {
+ return false;
+ }
}
}
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionArgumentTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionArgumentTemplate.java
index dd289d98a567b1..778efce2301ef1 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionArgumentTemplate.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionArgumentTemplate.java
@@ -37,21 +37,29 @@
@Internal
final class FunctionArgumentTemplate {
- final @Nullable DataType dataType;
+ private final @Nullable DataType dataType;
+ private final @Nullable InputGroup inputGroup;
+ private final @Nullable Class> conversionClass;
- final @Nullable InputGroup inputGroup;
-
- private FunctionArgumentTemplate(@Nullable DataType dataType, @Nullable InputGroup inputGroup) {
+ private FunctionArgumentTemplate(
+ @Nullable DataType dataType,
+ @Nullable InputGroup inputGroup,
+ @Nullable Class> conversionClass) {
this.dataType = dataType;
this.inputGroup = inputGroup;
+ this.conversionClass = conversionClass;
+ }
+
+ static FunctionArgumentTemplate ofDataType(DataType dataType) {
+ return new FunctionArgumentTemplate(dataType, null, null);
}
- static FunctionArgumentTemplate of(DataType dataType) {
- return new FunctionArgumentTemplate(dataType, null);
+ static FunctionArgumentTemplate ofInputGroup(InputGroup inputGroup) {
+ return new FunctionArgumentTemplate(null, inputGroup, null);
}
- static FunctionArgumentTemplate of(InputGroup inputGroup) {
- return new FunctionArgumentTemplate(null, inputGroup);
+ static FunctionArgumentTemplate ofTable(Class> conversionClass) {
+ return new FunctionArgumentTemplate(null, null, conversionClass);
}
ArgumentTypeStrategy toArgumentTypeStrategy() {
@@ -68,10 +76,17 @@ ArgumentTypeStrategy toArgumentTypeStrategy() {
}
}
+ public @Nullable DataType toDataType() {
+ return dataType;
+ }
+
Class> toConversionClass() {
if (dataType != null) {
return dataType.getConversionClass();
}
+ if (conversionClass != null) {
+ return conversionClass;
+ }
assert inputGroup != null;
switch (inputGroup) {
case ANY:
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
index 26909257645a18..44b0491f6a0c8e 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java
@@ -24,6 +24,8 @@
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionOutputTemplate;
+import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionStateTemplate;
import org.apache.flink.util.Preconditions;
import javax.annotation.Nullable;
@@ -31,12 +33,13 @@
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
+import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
-import java.util.stream.Collectors;
+import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import static org.apache.flink.table.types.extraction.ExtractionUtils.collectAnnotationsOfClass;
@@ -59,81 +62,47 @@ final class FunctionMappingExtractor extends BaseMappingExtractor {
private final Class extends UserDefinedFunction> function;
- private final @Nullable ResultExtraction accumulatorExtraction;
+ private final @Nullable ResultExtraction stateExtraction;
+ private final @Nullable MethodVerification stateVerification;
FunctionMappingExtractor(
DataTypeFactory typeFactory,
Class extends UserDefinedFunction> function,
String methodName,
SignatureExtraction signatureExtraction,
- @Nullable ResultExtraction accumulatorExtraction,
+ @Nullable ResultExtraction stateExtraction,
+ @Nullable MethodVerification stateVerification,
ResultExtraction outputExtraction,
- MethodVerification verification) {
- super(typeFactory, methodName, signatureExtraction, outputExtraction, verification);
+ @Nullable MethodVerification outputVerification) {
+ super(typeFactory, methodName, signatureExtraction, outputExtraction, outputVerification);
this.function = function;
- this.accumulatorExtraction = accumulatorExtraction;
+ this.stateExtraction = stateExtraction;
+ this.stateVerification = stateVerification;
}
- Class extends UserDefinedFunction> getFunction() {
- return function;
- }
-
- boolean hasAccumulator() {
- return accumulatorExtraction != null;
- }
-
- @Override
- protected Set extractGlobalFunctionTemplates() {
- return TemplateUtils.extractGlobalFunctionTemplates(typeFactory, function);
- }
-
- @Override
- protected Set extractLocalFunctionTemplates(Method method) {
- return TemplateUtils.extractLocalFunctionTemplates(typeFactory, method);
- }
-
- @Override
- protected List collectMethods(String methodName) {
- return ExtractionUtils.collectMethods(function, methodName);
- }
-
- @Override
- protected Class> getFunctionClass() {
- return function;
- }
-
- @Override
- protected String getHintType() {
- return "Function";
- }
-
- Map extractAccumulatorMapping() {
- Preconditions.checkState(hasAccumulator());
+ Map extractStateMapping() {
+ Preconditions.checkState(supportsState());
try {
return extractResultMappings(
- accumulatorExtraction,
- FunctionTemplate::getAccumulatorTemplate,
- (method, signature, result) -> {
- // put the result into the signature for accumulators
- final List> arguments =
- Stream.concat(Stream.of(result), signature.stream())
- .collect(Collectors.toList());
- verification.verify(method, arguments, null);
- });
+ stateExtraction, FunctionTemplate::getStateTemplate, stateVerification);
} catch (Throwable t) {
- throw extractionError(t, "Error in extracting a signature to accumulator mapping.");
+ throw extractionError(t, "Error in extracting a signature to state mapping.");
}
}
+ // --------------------------------------------------------------------------------------------
+ // Extraction strategies
+ // --------------------------------------------------------------------------------------------
+
/**
- * Extraction that uses the method return type for producing a {@link FunctionResultTemplate}.
+ * Extraction that uses the method return type for producing a {@link FunctionOutputTemplate}.
*/
- static ResultExtraction createReturnTypeResultExtraction() {
+ static ResultExtraction createOutputFromReturnTypeInMethod() {
return (extractor, method) -> {
final DataType dataType =
- DataTypeExtractor.extractFromMethodOutput(
+ DataTypeExtractor.extractFromMethodReturnType(
extractor.typeFactory, extractor.getFunctionClass(), method);
- return FunctionResultTemplate.of(dataType);
+ return FunctionResultTemplate.ofOutput(dataType);
};
}
@@ -142,7 +111,7 @@ static ResultExtraction createReturnTypeResultExtraction() {
*
* If enabled, a {@link DataTypeHint} from method or class has higher priority.
*/
- static ResultExtraction createGenericResultExtraction(
+ static ResultExtraction createOutputFromGenericInClass(
Class extends UserDefinedFunction> baseClass,
int genericPos,
boolean allowDataTypeHint) {
@@ -159,7 +128,7 @@ static ResultExtraction createGenericResultExtraction(
baseClass,
genericPos,
extractor.getFunctionClass());
- return FunctionResultTemplate.of(dataType);
+ return FunctionResultTemplate.ofOutput(dataType);
};
}
@@ -169,7 +138,7 @@ static ResultExtraction createGenericResultExtraction(
*
*
If enabled, a {@link DataTypeHint} from method or class has higher priority.
*/
- static ResultExtraction createGenericResultExtractionFromMethod(
+ static ResultExtraction createOutputFromGenericInMethod(
int paramPos, int genericPos, boolean allowDataTypeHint) {
return (extractor, method) -> {
if (allowDataTypeHint) {
@@ -185,101 +154,166 @@ static ResultExtraction createGenericResultExtractionFromMethod(
method,
paramPos,
genericPos);
- return FunctionResultTemplate.of(dataType);
+ return FunctionResultTemplate.ofOutput(dataType);
};
}
- /** Uses hints to extract functional template. */
- private static Optional extractHints(
- BaseMappingExtractor extractor, Method method) {
- final Set dataTypeHints = new HashSet<>();
- dataTypeHints.addAll(collectAnnotationsOfMethod(DataTypeHint.class, method));
- dataTypeHints.addAll(
- collectAnnotationsOfClass(DataTypeHint.class, extractor.getFunctionClass()));
- if (dataTypeHints.size() > 1) {
- throw extractionError(
- "More than one data type hint found for output of function. "
- + "Please use a function hint instead.");
- }
- if (dataTypeHints.size() == 1) {
- return Optional.ofNullable(
- FunctionTemplate.createResultTemplate(
- extractor.typeFactory, dataTypeHints.iterator().next()));
- }
- // otherwise continue with regular extraction
- return Optional.empty();
- }
+ // --------------------------------------------------------------------------------------------
+ // Verification strategies
+ // --------------------------------------------------------------------------------------------
- /** Verification that checks a method by parameters and return type. */
+ /** Verification that checks a method by parameters (arguments only) and return type. */
static MethodVerification createParameterAndReturnTypeVerification() {
- return (method, signature, result) -> {
- final Class>[] parameters = signature.toArray(new Class[0]);
+ return (method, state, arguments, result) -> {
+ checkNoState(state);
+ final Class>[] parameters = assembleParameters(state, arguments);
final Class> returnType = method.getReturnType();
+ // TODO enable strict autoboxing
final boolean isValid =
- isInvokable(method, parameters) && isAssignable(result, returnType, true);
+ isInvokable(false, method, parameters)
+ && isAssignable(result, returnType, true, false);
if (!isValid) {
- throw createMethodNotFoundError(method.getName(), parameters, result);
+ throw createMethodNotFoundError(method.getName(), parameters, result, "");
}
};
}
- /** Verification that checks a method by parameters including an accumulator. */
- static MethodVerification createParameterWithAccumulatorVerification() {
- return (method, signature, result) -> {
- if (result != null) {
- // ignore the accumulator in the first argument
- createParameterWithArgumentVerification(null).verify(method, signature, result);
+ /** Verification that checks a method by parameters (arguments only or with accumulator). */
+ static MethodVerification createParameterVerification(boolean requireAccumulator) {
+ return (method, state, arguments, result) -> {
+ if (requireAccumulator) {
+ checkSingleState(state);
} else {
- // check the signature only
- createParameterVerification().verify(method, signature, null);
+ checkNoState(state);
}
- };
- }
-
- /** Verification that checks a method by parameters including an additional first parameter. */
- static MethodVerification createParameterWithArgumentVerification(
- @Nullable Class> argumentClass) {
- return (method, signature, result) -> {
- final Class>[] parameters =
- Stream.concat(Stream.of(argumentClass), signature.stream())
- .toArray(Class>[]::new);
- if (!isInvokable(method, parameters)) {
- throw createMethodNotFoundError(method.getName(), parameters, null);
+ final Class>[] parameters = assembleParameters(state, arguments);
+ // TODO enable strict autoboxing
+ if (!isInvokable(false, method, parameters)) {
+ throw createMethodNotFoundError(
+ method.getName(),
+ parameters,
+ null,
+ requireAccumulator ? "( [, ]*)" : "");
}
};
}
- /** Verification that checks a method by parameters including an additional first parameter. */
- static MethodVerification createGenericParameterWithArgumentAndReturnTypeVerification(
- Class> baseClass, Class> argumentClass, int paramPos, int genericPos) {
- return (method, signature, result) -> {
- final Class>[] parameters =
- Stream.concat(Stream.of(argumentClass), signature.stream())
+ /**
+ * Verification that checks a method by parameters (arguments only) with mandatory {@link
+ * CompletableFuture}.
+ */
+ static MethodVerification createParameterAndCompletableFutureVerification(Class> baseClass) {
+ return (method, state, arguments, result) -> {
+ checkNoState(state);
+ final Class>[] parameters = assembleParameters(state, arguments);
+ final Class>[] parametersWithFuture =
+ Stream.concat(Stream.of(CompletableFuture.class), Arrays.stream(parameters))
.toArray(Class>[]::new);
- Type genericType = method.getGenericParameterTypes()[paramPos];
+ Type genericType = method.getGenericParameterTypes()[0];
genericType = resolveVariableWithClassContext(baseClass, genericType);
if (!(genericType instanceof ParameterizedType)) {
throw extractionError(
- "The method '%s' needs generic parameters for the %d arg.",
- method.getName(), paramPos);
+ "The method '%s' needs generic parameters for the CompletableFuture at position %d.",
+ method.getName(), 0);
}
- Type returnType =
- ((ParameterizedType) genericType).getActualTypeArguments()[genericPos];
+ final Type returnType = ((ParameterizedType) genericType).getActualTypeArguments()[0];
Class> returnClazz = getClassFromType(returnType);
- if (!(isInvokable(method, parameters) && isAssignable(result, returnClazz, true))) {
- throw createMethodNotFoundError(method.getName(), parameters, null);
+ // TODO enable strict autoboxing
+ if (!(isInvokable(false, method, parametersWithFuture)
+ && isAssignable(result, returnClazz, true, false))) {
+ throw createMethodNotFoundError(
+ method.getName(),
+ parametersWithFuture,
+ null,
+ "( [, ]*)");
}
};
}
- /** Verification that checks a method by parameters. */
- static MethodVerification createParameterVerification() {
- return (method, signature, result) -> {
- final Class>[] parameters = signature.toArray(new Class[0]);
- if (!isInvokable(method, parameters)) {
- throw createMethodNotFoundError(method.getName(), parameters, null);
+ /**
+ * Verification that checks a method by parameters (state and arguments) with optional context.
+ */
+ static MethodVerification createParameterAndOptionalContextVerification(
+ Class> context, boolean allowState) {
+ return (method, state, arguments, result) -> {
+ if (!allowState) {
+ checkNoState(state);
+ }
+ final Class>[] parameters = assembleParameters(state, arguments);
+ final Class>[] parametersWithContext =
+ Stream.concat(Stream.of(context), Arrays.stream(parameters))
+ .toArray(Class>[]::new);
+ if (!isInvokable(true, method, parameters)
+ && !isInvokable(true, method, parametersWithContext)) {
+ throw createMethodNotFoundError(
+ method.getName(),
+ parameters,
+ null,
+ allowState ? "(? [, ]* [, ]*)" : "");
}
};
}
+
+ // --------------------------------------------------------------------------------------------
+ // Methods from super class
+ // --------------------------------------------------------------------------------------------
+
+ Class extends UserDefinedFunction> getFunction() {
+ return function;
+ }
+
+ boolean supportsState() {
+ return stateExtraction != null;
+ }
+
+ @Override
+ protected Set extractGlobalFunctionTemplates() {
+ return TemplateUtils.extractGlobalFunctionTemplates(typeFactory, function);
+ }
+
+ @Override
+ protected Set extractLocalFunctionTemplates(Method method) {
+ return TemplateUtils.extractLocalFunctionTemplates(typeFactory, method);
+ }
+
+ @Override
+ protected List collectMethods(String methodName) {
+ return ExtractionUtils.collectMethods(function, methodName);
+ }
+
+ @Override
+ protected Class> getFunctionClass() {
+ return function;
+ }
+
+ @Override
+ protected String getHintType() {
+ return "Function";
+ }
+
+ // --------------------------------------------------------------------------------------------
+ // Helper methods
+ // --------------------------------------------------------------------------------------------
+
+ /** Uses hints to extract functional template. */
+ private static Optional extractHints(
+ BaseMappingExtractor extractor, Method method) {
+ final Set dataTypeHints = new HashSet<>();
+ dataTypeHints.addAll(collectAnnotationsOfMethod(DataTypeHint.class, method));
+ dataTypeHints.addAll(
+ collectAnnotationsOfClass(DataTypeHint.class, extractor.getFunctionClass()));
+ if (dataTypeHints.size() > 1) {
+ throw extractionError(
+ "More than one data type hint found for output of function. "
+ + "Please use a function hint instead.");
+ }
+ if (dataTypeHints.size() == 1) {
+ return Optional.ofNullable(
+ FunctionTemplate.createOutputTemplate(
+ extractor.typeFactory, dataTypeHints.iterator().next()));
+ }
+ // otherwise continue with regular extraction
+ return Optional.empty();
+ }
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionResultTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionResultTemplate.java
index e8305f9abe0a31..3d71c6c4da942e 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionResultTemplate.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionResultTemplate.java
@@ -20,47 +20,130 @@
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.StateTypeStrategy;
+import org.apache.flink.table.types.inference.StateTypeStrategyWrapper;
import org.apache.flink.table.types.inference.TypeStrategies;
import org.apache.flink.table.types.inference.TypeStrategy;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
import java.util.Objects;
+import java.util.stream.Collectors;
-/** Template of a function intermediate result (i.e. accumulator) or final result (i.e. output). */
-@Internal
-final class FunctionResultTemplate {
+import static org.apache.flink.table.types.extraction.ExtractionUtils.extractionError;
- final DataType dataType;
+/** Template of a function intermediate result (i.e. state) or final result (i.e. output). */
+@Internal
+interface FunctionResultTemplate {
- private FunctionResultTemplate(DataType dataType) {
- this.dataType = dataType;
+ static FunctionOutputTemplate ofOutput(DataType dataType) {
+ return new FunctionOutputTemplate(dataType);
}
- static FunctionResultTemplate of(DataType dataType) {
- return new FunctionResultTemplate(dataType);
+ static FunctionStateTemplate ofState(LinkedHashMap state) {
+ return new FunctionStateTemplate(state);
}
- TypeStrategy toTypeStrategy() {
- return TypeStrategies.explicit(dataType);
- }
+ @Internal
+ class FunctionOutputTemplate implements FunctionResultTemplate {
- Class> toClass() {
- return dataType.getConversionClass();
- }
+ private final DataType dataType;
+
+ private FunctionOutputTemplate(DataType dataType) {
+ this.dataType = dataType;
+ }
+
+ TypeStrategy toTypeStrategy() {
+ return TypeStrategies.explicit(dataType);
+ }
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
+ Class> toClass() {
+ return dataType.getConversionClass();
}
- if (o == null || getClass() != o.getClass()) {
- return false;
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final FunctionOutputTemplate template = (FunctionOutputTemplate) o;
+ return Objects.equals(dataType, template.dataType);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(dataType);
}
- FunctionResultTemplate that = (FunctionResultTemplate) o;
- return dataType.equals(that.dataType);
}
- @Override
- public int hashCode() {
- return Objects.hash(dataType);
+ @Internal
+ class FunctionStateTemplate implements FunctionResultTemplate {
+
+ private final LinkedHashMap state;
+
+ private FunctionStateTemplate(LinkedHashMap state) {
+ this.state = state;
+ }
+
+ List> toClassList() {
+ return state.values().stream()
+ .map(DataType::getConversionClass)
+ .collect(Collectors.toList());
+ }
+
+ LinkedHashMap toStateTypeStrategies() {
+ return state.entrySet().stream()
+ .collect(
+ Collectors.toMap(
+ Map.Entry::getKey,
+ e -> createStateTypeStrategy(e.getValue()),
+ (o, n) -> o,
+ LinkedHashMap::new));
+ }
+
+ String toAccumulatorStateName() {
+ checkSingleStateEntry();
+ return state.keySet().iterator().next();
+ }
+
+ TypeStrategy toAccumulatorTypeStrategy() {
+ checkSingleStateEntry();
+ return createTypeStrategy(state.values().iterator().next());
+ }
+
+ private void checkSingleStateEntry() {
+ if (state.size() != 1) {
+ throw extractionError("Aggregating functions support only one state entry.");
+ }
+ }
+
+ private static StateTypeStrategy createStateTypeStrategy(DataType dataType) {
+ return StateTypeStrategyWrapper.of(TypeStrategies.explicit(dataType));
+ }
+
+ private static TypeStrategy createTypeStrategy(DataType dataType) {
+ return TypeStrategies.explicit(dataType);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final FunctionStateTemplate that = (FunctionStateTemplate) o;
+ return Objects.equals(state, that.state);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(state);
+ }
}
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java
index 855f11e6fa830d..db641fb39ed270 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java
@@ -23,11 +23,14 @@
import org.apache.flink.table.types.inference.ArgumentTypeStrategy;
import org.apache.flink.table.types.inference.InputTypeStrategies;
import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.StaticArgument;
+import org.apache.flink.table.types.inference.StaticArgumentTrait;
import javax.annotation.Nullable;
import java.lang.reflect.Array;
import java.util.Arrays;
+import java.util.EnumSet;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@@ -43,17 +46,21 @@ final class FunctionSignatureTemplate {
final boolean isVarArgs;
+ final EnumSet[] argumentTraits;
+
final @Nullable String[] argumentNames;
- final Boolean[] argumentOptionals;
+ final boolean[] argumentOptionals;
private FunctionSignatureTemplate(
List argumentTemplates,
boolean isVarArgs,
+ EnumSet[] argumentTraits,
@Nullable String[] argumentNames,
- Boolean[] argumentOptionals) {
+ boolean[] argumentOptionals) {
this.argumentTemplates = argumentTemplates;
this.isVarArgs = isVarArgs;
+ this.argumentTraits = argumentTraits;
this.argumentNames = argumentNames;
this.argumentOptionals = argumentOptionals;
}
@@ -61,8 +68,9 @@ private FunctionSignatureTemplate(
static FunctionSignatureTemplate of(
List argumentTemplates,
boolean isVarArgs,
+ EnumSet[] argumentTraits,
@Nullable String[] argumentNames,
- Boolean[] argumentOptionals) {
+ boolean[] argumentOptionals) {
if (argumentNames != null && argumentNames.length != argumentTemplates.size()) {
throw extractionError(
"Mismatch between number of argument names '%s' and argument types '%s'.",
@@ -80,7 +88,7 @@ static FunctionSignatureTemplate of(
}
if (argumentOptionals != null) {
for (int i = 0; i < argumentTemplates.size(); i++) {
- DataType dataType = argumentTemplates.get(i).dataType;
+ DataType dataType = argumentTemplates.get(i).toDataType();
if (dataType != null
&& !dataType.getLogicalType().isNullable()
&& argumentOptionals[i]) {
@@ -91,7 +99,70 @@ static FunctionSignatureTemplate of(
}
}
return new FunctionSignatureTemplate(
- argumentTemplates, isVarArgs, argumentNames, argumentOptionals);
+ argumentTemplates, isVarArgs, argumentTraits, argumentNames, argumentOptionals);
+ }
+
+ /**
+ * Converts the given signature into a list of static arguments if the signature allows it. E.g.
+ * no var-args and all arguments are named.
+ */
+ @Nullable
+ List toStaticArguments() {
+ if (isVarArgs || argumentNames == null) {
+ return null;
+ }
+ final List arguments =
+ IntStream.range(0, argumentTemplates.size())
+ .mapToObj(
+ pos -> {
+ final String name = argumentNames[pos];
+ final boolean isOptional = argumentOptionals[pos];
+ final FunctionArgumentTemplate template =
+ argumentTemplates.get(pos);
+ final EnumSet traits = argumentTraits[pos];
+ if (traits.contains(StaticArgumentTrait.TABLE_AS_ROW)
+ || traits.contains(StaticArgumentTrait.TABLE_AS_SET)) {
+ return createTableArgument(
+ name,
+ isOptional,
+ traits,
+ template.toDataType(),
+ template.toConversionClass());
+ } else if (traits.contains(StaticArgumentTrait.SCALAR)) {
+ return createScalarArgument(
+ name, isOptional, template.toDataType());
+ } else {
+ return null;
+ }
+ })
+ .collect(Collectors.toList());
+ if (arguments.contains(null)) {
+ return null;
+ }
+ return arguments;
+ }
+
+ private static @Nullable StaticArgument createTableArgument(
+ String name,
+ boolean isOptional,
+ EnumSet traits,
+ @Nullable DataType dataType,
+ @Nullable Class> conversionClass) {
+ if (dataType != null) {
+ return StaticArgument.table(name, dataType, isOptional, traits);
+ }
+ if (conversionClass != null) {
+ return StaticArgument.table(name, conversionClass, isOptional, traits);
+ }
+ return null;
+ }
+
+ private static @Nullable StaticArgument createScalarArgument(
+ String name, boolean isOptional, @Nullable DataType dataType) {
+ if (dataType != null) {
+ return StaticArgument.scalar(name, dataType, isOptional);
+ }
+ return null;
}
InputTypeStrategy toInputTypeStrategy() {
@@ -117,7 +188,7 @@ InputTypeStrategy toInputTypeStrategy() {
return strategy;
}
- List> toClass() {
+ List> toClassList() {
return IntStream.range(0, argumentTemplates.size())
.mapToObj(
i -> {
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java
index c92830393cc148..0b62b0e340cda8 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java
@@ -24,15 +24,27 @@
import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.annotation.FunctionHint;
import org.apache.flink.table.annotation.ProcedureHint;
+import org.apache.flink.table.annotation.StateHint;
import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionOutputTemplate;
+import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionStateTemplate;
+import org.apache.flink.table.types.inference.StaticArgumentTrait;
+import org.apache.flink.types.Row;
import javax.annotation.Nullable;
import java.lang.annotation.Annotation;
import java.util.Arrays;
+import java.util.EnumSet;
+import java.util.LinkedHashMap;
+import java.util.List;
import java.util.Objects;
+import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
+import java.util.stream.IntStream;
import static org.apache.flink.table.types.extraction.ExtractionUtils.extractionError;
@@ -47,16 +59,16 @@ final class FunctionTemplate {
private final @Nullable FunctionSignatureTemplate signatureTemplate;
- private final @Nullable FunctionResultTemplate accumulatorTemplate;
+ private final @Nullable FunctionStateTemplate stateTemplate;
- private final @Nullable FunctionResultTemplate outputTemplate;
+ private final @Nullable FunctionOutputTemplate outputTemplate;
private FunctionTemplate(
@Nullable FunctionSignatureTemplate signatureTemplate,
- @Nullable FunctionResultTemplate accumulatorTemplate,
- @Nullable FunctionResultTemplate outputTemplate) {
+ @Nullable FunctionStateTemplate stateTemplate,
+ @Nullable FunctionOutputTemplate outputTemplate) {
this.signatureTemplate = signatureTemplate;
- this.accumulatorTemplate = accumulatorTemplate;
+ this.stateTemplate = stateTemplate;
this.outputTemplate = outputTemplate;
}
@@ -64,10 +76,8 @@ private FunctionTemplate(
* Creates an instance using the given {@link FunctionHint}. It resolves explicitly defined data
* types.
*/
+ @SuppressWarnings("deprecation")
static FunctionTemplate fromAnnotation(DataTypeFactory typeFactory, FunctionHint hint) {
- if (hint.state().length > 0) {
- throw extractionError("State hints are not supported yet.");
- }
return new FunctionTemplate(
createSignatureTemplate(
typeFactory,
@@ -76,14 +86,18 @@ static FunctionTemplate fromAnnotation(DataTypeFactory typeFactory, FunctionHint
defaultAsNull(hint, FunctionHint::argument),
defaultAsNull(hint, FunctionHint::arguments),
hint.isVarArgs()),
- createResultTemplate(typeFactory, defaultAsNull(hint, FunctionHint::accumulator)),
- createResultTemplate(typeFactory, defaultAsNull(hint, FunctionHint::output)));
+ createStateTemplate(
+ typeFactory,
+ defaultAsNull(hint, FunctionHint::accumulator),
+ defaultAsNull(hint, FunctionHint::state)),
+ createOutputTemplate(typeFactory, defaultAsNull(hint, FunctionHint::output)));
}
/**
* Creates an instance using the given {@link ProcedureHint}. It resolves explicitly defined
* data types.
*/
+ @SuppressWarnings("deprecation")
static FunctionTemplate fromAnnotation(DataTypeFactory typeFactory, ProcedureHint hint) {
return new FunctionTemplate(
createSignatureTemplate(
@@ -93,12 +107,12 @@ static FunctionTemplate fromAnnotation(DataTypeFactory typeFactory, ProcedureHin
defaultAsNull(hint, ProcedureHint::argument),
defaultAsNull(hint, ProcedureHint::arguments),
hint.isVarArgs()),
- null,
- createResultTemplate(typeFactory, defaultAsNull(hint, ProcedureHint::output)));
+ createStateTemplate(typeFactory, null, null),
+ createOutputTemplate(typeFactory, defaultAsNull(hint, ProcedureHint::output)));
}
/** Creates an instance of {@link FunctionResultTemplate} from a {@link DataTypeHint}. */
- static @Nullable FunctionResultTemplate createResultTemplate(
+ static @Nullable FunctionOutputTemplate createOutputTemplate(
DataTypeFactory typeFactory, @Nullable DataTypeHint hint) {
if (hint == null) {
return null;
@@ -110,20 +124,49 @@ static FunctionTemplate fromAnnotation(DataTypeFactory typeFactory, ProcedureHin
throw extractionError(t, "Error in data type hint annotation.");
}
if (template.dataType != null) {
- return FunctionResultTemplate.of(template.dataType);
+ return FunctionResultTemplate.ofOutput(template.dataType);
}
throw extractionError(
"Data type hint does not specify a data type for use as function result.");
}
+ /** Creates a {@link FunctionStateTemplate}s from {@link StateHint}s or accumulator. */
+ static @Nullable FunctionStateTemplate createStateTemplate(
+ DataTypeFactory typeFactory,
+ @Nullable DataTypeHint accumulatorHint,
+ @Nullable StateHint[] stateHints) {
+ if (accumulatorHint == null && stateHints == null) {
+ return null;
+ }
+ if (accumulatorHint != null && stateHints != null) {
+ throw extractionError(
+ "State hints and accumulator cannot be declared in the same function hint. "
+ + "Use either one or the other.");
+ }
+ final LinkedHashMap state = new LinkedHashMap<>();
+ if (accumulatorHint != null) {
+ state.put("acc", createStateDataType(typeFactory, accumulatorHint, "accumulator"));
+ return FunctionResultTemplate.ofState(state);
+ }
+ IntStream.range(0, stateHints.length)
+ .forEach(
+ pos -> {
+ final StateHint hint = stateHints[pos];
+ state.put(
+ hint.name(),
+ createStateDataType(typeFactory, hint.type(), "state entry"));
+ });
+ return FunctionResultTemplate.ofState(state);
+ }
+
@Nullable
FunctionSignatureTemplate getSignatureTemplate() {
return signatureTemplate;
}
@Nullable
- FunctionResultTemplate getAccumulatorTemplate() {
- return accumulatorTemplate;
+ FunctionResultTemplate getStateTemplate() {
+ return stateTemplate;
}
@Nullable
@@ -141,13 +184,13 @@ public boolean equals(Object o) {
}
FunctionTemplate template = (FunctionTemplate) o;
return Objects.equals(signatureTemplate, template.signatureTemplate)
- && Objects.equals(accumulatorTemplate, template.accumulatorTemplate)
+ && Objects.equals(stateTemplate, template.stateTemplate)
&& Objects.equals(outputTemplate, template.outputTemplate);
}
@Override
public int hashCode() {
- return Objects.hash(signatureTemplate, accumulatorTemplate, outputTemplate);
+ return Objects.hash(signatureTemplate, stateTemplate, outputTemplate);
}
// --------------------------------------------------------------------------------------------
@@ -185,6 +228,7 @@ private static T defaultAsNull(
return actualValue;
}
+ @SuppressWarnings("unchecked")
private static @Nullable FunctionSignatureTemplate createSignatureTemplate(
DataTypeFactory typeFactory,
@Nullable DataTypeHint[] inputs,
@@ -204,79 +248,150 @@ private static T defaultAsNull(
argumentHints = pluralArgumentHints;
}
- String[] argumentHintNames;
- DataTypeHint[] argumentHintTypes;
-
// Deal with #arguments() and #input()
if (argumentHints != null && inputs != null) {
throw extractionError(
- "Argument and input hints cannot be declared in the same function hint.");
+ "Argument and input hints cannot be declared in the same function hint. "
+ + "Use either one or the other.");
}
-
- Boolean[] argumentOptionals;
+ final DataTypeHint[] argumentHintTypes;
+ final boolean[] argumentOptionals;
+ final ArgumentTrait[][] argumentTraits;
+ String[] argumentHintNames;
if (argumentHints != null) {
- final boolean allScalar =
- Arrays.stream(argumentHints)
- .allMatch(
- h -> {
- final ArgumentTrait[] traits = h.value();
- return traits.length == 1
- && traits[0] == ArgumentTrait.SCALAR;
- });
- if (!allScalar) {
- throw extractionError("Only scalar arguments are supported so far.");
- }
-
- argumentHintNames = new String[argumentHints.length];
argumentHintTypes = new DataTypeHint[argumentHints.length];
- argumentOptionals = new Boolean[argumentHints.length];
- boolean allArgumentNameNotSet = true;
+ argumentOptionals = new boolean[argumentHints.length];
+ argumentTraits = new ArgumentTrait[argumentHints.length][];
+ argumentHintNames = new String[argumentHints.length];
+ boolean allArgumentNamesNotSet = true;
for (int i = 0; i < argumentHints.length; i++) {
- ArgumentHint argumentHint = argumentHints[i];
+ final ArgumentHint argumentHint = argumentHints[i];
argumentHintNames[i] = defaultAsNull(argumentHint, ArgumentHint::name);
argumentHintTypes[i] = defaultAsNull(argumentHint, ArgumentHint::type);
argumentOptionals[i] = argumentHint.isOptional();
- if (argumentHintTypes[i] == null) {
- throw extractionError("The type of the argument at position %d is not set.", i);
- }
+ argumentTraits[i] = argumentHint.value();
if (argumentHintNames[i] != null) {
- allArgumentNameNotSet = false;
- } else if (!allArgumentNameNotSet) {
+ allArgumentNamesNotSet = false;
+ } else if (!allArgumentNamesNotSet) {
throw extractionError(
- "The argument name in function hint must be either fully set or not set at all.");
+ "Argument names in function hint must be either fully set or not set at all.");
}
}
- if (allArgumentNameNotSet) {
+ if (allArgumentNamesNotSet) {
argumentHintNames = null;
}
- } else {
- if (inputs == null) {
- return null;
- }
+ } else if (inputs != null) {
argumentHintTypes = inputs;
argumentHintNames = argumentNames;
- argumentOptionals = new Boolean[inputs.length];
- Arrays.fill(argumentOptionals, false);
+ argumentOptionals = new boolean[inputs.length];
+ argumentTraits = new ArgumentTrait[inputs.length][];
+ Arrays.fill(argumentTraits, new ArgumentTrait[] {ArgumentTrait.SCALAR});
+ } else {
+ return null;
}
+ final List argumentTemplates =
+ IntStream.range(0, argumentHintTypes.length)
+ .mapToObj(
+ i ->
+ createArgumentTemplate(
+ typeFactory,
+ i,
+ argumentHintTypes[i],
+ argumentTraits[i]))
+ .collect(Collectors.toList());
+
return FunctionSignatureTemplate.of(
- Arrays.stream(argumentHintTypes)
- .map(dataTypeHint -> createArgumentTemplate(typeFactory, dataTypeHint))
- .collect(Collectors.toList()),
+ argumentTemplates,
isVarArg,
+ Arrays.stream(argumentTraits)
+ .map(
+ t -> {
+ final List traits =
+ Arrays.stream(t)
+ .map(ArgumentTrait::toStaticTrait)
+ .collect(Collectors.toList());
+ return EnumSet.copyOf(traits);
+ })
+ .toArray(EnumSet[]::new),
argumentHintNames,
argumentOptionals);
}
private static FunctionArgumentTemplate createArgumentTemplate(
- DataTypeFactory typeFactory, DataTypeHint hint) {
- final DataTypeTemplate template = DataTypeTemplate.fromAnnotation(typeFactory, hint);
+ DataTypeFactory typeFactory,
+ int pos,
+ @Nullable DataTypeHint hint,
+ ArgumentTrait[] argumentTraits) {
+ final Set rootTrait =
+ Arrays.stream(argumentTraits)
+ .filter(ArgumentTrait::isRoot)
+ .collect(Collectors.toSet());
+ if (rootTrait.size() != 1) {
+ throw extractionError(
+ "Incorrect argument kind at position %d. Argument kind must be one of: %s",
+ pos,
+ Arrays.stream(ArgumentTrait.values())
+ .filter(ArgumentTrait::isRoot)
+ .collect(Collectors.toList()));
+ }
+
+ if (rootTrait.contains(ArgumentTrait.SCALAR)) {
+ if (hint != null) {
+ final DataTypeTemplate template;
+ try {
+ template = DataTypeTemplate.fromAnnotation(typeFactory, hint);
+ } catch (Throwable t) {
+ throw extractionError(
+ t,
+ "Error in data type hint annotation for argument at position %s.",
+ pos);
+ }
+ if (template.dataType != null) {
+ return FunctionArgumentTemplate.ofDataType(template.dataType);
+ } else if (template.inputGroup != null) {
+ return FunctionArgumentTemplate.ofInputGroup(template.inputGroup);
+ }
+ }
+ throw extractionError("Data type missing for scalar argument at position %s.", pos);
+ } else if (rootTrait.contains(ArgumentTrait.TABLE_AS_ROW)
+ || rootTrait.contains(ArgumentTrait.TABLE_AS_SET)) {
+ try {
+ final DataTypeTemplate template =
+ DataTypeTemplate.fromAnnotation(typeFactory, hint);
+ if (template.dataType != null) {
+ return FunctionArgumentTemplate.ofDataType(template.dataType);
+ } else if (template.inputGroup != null) {
+ throw extractionError(
+ "Input groups are not supported for table argument at position %s.",
+ pos);
+ }
+ return FunctionArgumentTemplate.ofTable(Row.class);
+ } catch (Throwable t) {
+ final Class> argClass = hint == null ? Row.class : hint.bridgedTo();
+ if (argClass == Row.class || argClass == RowData.class) {
+ return FunctionArgumentTemplate.ofTable(argClass);
+ }
+ // Just a regular error for a typed argument
+ throw t;
+ }
+ } else {
+ throw extractionError("Unknown argument kind.");
+ }
+ }
+
+ private static DataType createStateDataType(
+ DataTypeFactory typeFactory, DataTypeHint dataTypeHint, String description) {
+ final DataTypeTemplate template;
+ try {
+ template = DataTypeTemplate.fromAnnotation(typeFactory, dataTypeHint);
+ } catch (Throwable t) {
+ throw extractionError(t, "Error in data type hint annotation.");
+ }
if (template.dataType != null) {
- return FunctionArgumentTemplate.of(template.dataType);
- } else if (template.inputGroup != null) {
- return FunctionArgumentTemplate.of(template.inputGroup);
+ return template.dataType;
}
throw extractionError(
- "Data type hint does neither specify a data type nor input group for use as function argument.");
+ "Data type hint does not specify a data type for use as %s.", description);
}
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java
index da37485011b437..43765a9b5cbcd9 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java
@@ -22,9 +22,11 @@
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.procedures.Procedure;
import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionOutputTemplate;
import java.lang.reflect.Array;
import java.lang.reflect.Method;
+import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;
@@ -49,12 +51,66 @@ final class ProcedureMappingExtractor extends BaseMappingExtractor {
Class extends Procedure> procedure,
String methodName,
SignatureExtraction signatureExtraction,
- ResultExtraction outputExtraction,
+ ResultExtraction resultExtraction,
MethodVerification verification) {
- super(typeFactory, methodName, signatureExtraction, outputExtraction, verification);
+ super(typeFactory, methodName, signatureExtraction, resultExtraction, verification);
this.procedure = procedure;
}
+ // --------------------------------------------------------------------------------------------
+ // Extraction strategy
+ // --------------------------------------------------------------------------------------------
+
+ /**
+ * Extraction that uses the method return type for producing a {@link FunctionOutputTemplate}.
+ */
+ static ResultExtraction createOutputFromArrayReturnTypeInMethod() {
+ return (extractor, method) -> {
+ final DataType dataType =
+ DataTypeExtractor.extractFromMethodReturnType(
+ extractor.typeFactory,
+ extractor.getFunctionClass(),
+ method,
+ method.getReturnType().getComponentType());
+ return FunctionResultTemplate.ofOutput(dataType);
+ };
+ }
+
+ // --------------------------------------------------------------------------------------------
+ // Verification strategy
+ // --------------------------------------------------------------------------------------------
+
+ /**
+ * Verification that checks a method by parameters (arguments only) with mandatory context and
+ * array return type.
+ */
+ static MethodVerification createParameterWithOptionalContextAndArrayReturnTypeVerification() {
+ return (method, state, arguments, result) -> {
+ checkNoState(state);
+ final Class>[] parameters = assembleParameters(state, arguments);
+ // ignore the ProcedureContext in the first argument
+ final Class>[] parametersWithContext =
+ Stream.concat(Stream.of((Class>) null), Arrays.stream(parameters))
+ .toArray(Class>[]::new);
+ final Class> returnType = method.getReturnType();
+ final boolean isValid =
+ isInvokable(true, method, parametersWithContext)
+ && returnType.isArray()
+ && isAssignable(result, returnType.getComponentType(), true, true);
+ if (!isValid) {
+ throw createMethodNotFoundError(
+ method.getName(),
+ parametersWithContext,
+ Array.newInstance(result, 0).getClass(),
+ "( [, ]*)");
+ }
+ };
+ }
+
+ // --------------------------------------------------------------------------------------------
+ // Methods from super class
+ // --------------------------------------------------------------------------------------------
+
@Override
protected Set extractGlobalFunctionTemplates() {
return TemplateUtils.extractProcedureGlobalFunctionTemplates(typeFactory, procedure);
@@ -79,37 +135,4 @@ protected Class> getFunctionClass() {
protected String getHintType() {
return "Procedure";
}
-
- /**
- * Extraction that uses the method return type for producing a {@link FunctionResultTemplate}.
- */
- static ResultExtraction createReturnTypeResultExtraction() {
- return (extractor, method) -> {
- final DataType dataType =
- DataTypeExtractor.extractFromMethodOutput(
- extractor.typeFactory,
- extractor.getFunctionClass(),
- method,
- method.getReturnType().getComponentType());
- return FunctionResultTemplate.of(dataType);
- };
- }
-
- static MethodVerification createParameterAndReturnTypeVerification() {
- return ((method, signature, result) -> {
- // ignore the ProcedureContext in the first argument
- final Class>[] parameters =
- Stream.concat(Stream.of((Class>) null), signature.stream())
- .toArray(Class>[]::new);
- final Class> returnType = method.getReturnType();
- final boolean isValid =
- isInvokable(method, parameters)
- && returnType.isArray()
- && isAssignable(result, returnType.getComponentType(), true);
- if (!isValid) {
- throw createMethodNotFoundError(
- method.getName(), parameters, Array.newInstance(result, 0).getClass());
- }
- });
- }
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java
index 1b54abd5574769..5ef431057fadf2 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java
@@ -25,6 +25,7 @@
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.functions.AsyncScalarFunction;
import org.apache.flink.table.functions.AsyncTableFunction;
+import org.apache.flink.table.functions.ProcessTableFunction;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.functions.TableAggregateFunction;
import org.apache.flink.table.functions.TableFunction;
@@ -32,33 +33,38 @@
import org.apache.flink.table.functions.UserDefinedFunctionHelper;
import org.apache.flink.table.procedures.Procedure;
import org.apache.flink.table.procedures.ProcedureDefinition;
-import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionOutputTemplate;
+import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionStateTemplate;
import org.apache.flink.table.types.inference.InputTypeStrategies;
import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.StateTypeStrategy;
+import org.apache.flink.table.types.inference.StateTypeStrategyWrapper;
+import org.apache.flink.table.types.inference.StaticArgument;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategies;
import org.apache.flink.table.types.inference.TypeStrategy;
import javax.annotation.Nullable;
-import java.util.Arrays;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
import java.util.Set;
-import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
+import static org.apache.flink.table.types.extraction.BaseMappingExtractor.createArgumentsFromParametersExtraction;
+import static org.apache.flink.table.types.extraction.BaseMappingExtractor.createStateFromParametersExtraction;
import static org.apache.flink.table.types.extraction.ExtractionUtils.extractionError;
-import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createGenericParameterWithArgumentAndReturnTypeVerification;
-import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createGenericResultExtraction;
-import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createGenericResultExtractionFromMethod;
+import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createOutputFromGenericInClass;
+import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createOutputFromGenericInMethod;
+import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createOutputFromReturnTypeInMethod;
+import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterAndCompletableFutureVerification;
+import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterAndOptionalContextVerification;
import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterAndReturnTypeVerification;
-import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterSignatureExtraction;
import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterVerification;
-import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterWithAccumulatorVerification;
-import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterWithArgumentVerification;
-import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createReturnTypeResultExtraction;
+import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createStateFromGenericInClassOrParameters;
+import static org.apache.flink.table.types.extraction.ProcedureMappingExtractor.createOutputFromArrayReturnTypeInMethod;
+import static org.apache.flink.table.types.extraction.ProcedureMappingExtractor.createParameterWithOptionalContextAndArrayReturnTypeVerification;
/**
* Reflection-based utility for extracting a {@link TypeInference} from a supported subclass of
@@ -81,11 +87,12 @@ public static TypeInference forScalarFunction(
typeFactory,
function,
UserDefinedFunctionHelper.SCALAR_EVAL,
- createParameterSignatureExtraction(0),
+ createArgumentsFromParametersExtraction(0),
null,
- createReturnTypeResultExtraction(),
+ null,
+ createOutputFromReturnTypeInMethod(),
createParameterAndReturnTypeVerification());
- return extractTypeInference(mappingExtractor);
+ return extractTypeInference(mappingExtractor, false);
}
/** Extracts a type inference from a {@link AsyncScalarFunction}. */
@@ -96,12 +103,12 @@ public static TypeInference forAsyncScalarFunction(
typeFactory,
function,
UserDefinedFunctionHelper.ASYNC_SCALAR_EVAL,
- createParameterSignatureExtraction(1),
+ createArgumentsFromParametersExtraction(1),
null,
- createGenericResultExtractionFromMethod(0, 0, true),
- createGenericParameterWithArgumentAndReturnTypeVerification(
- function, CompletableFuture.class, 0, 0));
- return extractTypeInference(mappingExtractor);
+ null,
+ createOutputFromGenericInMethod(0, 0, true),
+ createParameterAndCompletableFutureVerification(function));
+ return extractTypeInference(mappingExtractor, false);
}
/** Extracts a type inference from a {@link AggregateFunction}. */
@@ -112,11 +119,12 @@ public static TypeInference forAggregateFunction(
typeFactory,
function,
UserDefinedFunctionHelper.AGGREGATE_ACCUMULATE,
- createParameterSignatureExtraction(1),
- createGenericResultExtraction(AggregateFunction.class, 1, false),
- createGenericResultExtraction(AggregateFunction.class, 0, true),
- createParameterWithAccumulatorVerification());
- return extractTypeInference(mappingExtractor);
+ createArgumentsFromParametersExtraction(1),
+ createStateFromGenericInClassOrParameters(AggregateFunction.class, 1),
+ createParameterVerification(true),
+ createOutputFromGenericInClass(AggregateFunction.class, 0, true),
+ null);
+ return extractTypeInference(mappingExtractor, false);
}
/** Extracts a type inference from a {@link TableFunction}. */
@@ -127,11 +135,12 @@ public static TypeInference forTableFunction(
typeFactory,
function,
UserDefinedFunctionHelper.TABLE_EVAL,
- createParameterSignatureExtraction(0),
+ createArgumentsFromParametersExtraction(0),
null,
- createGenericResultExtraction(TableFunction.class, 0, true),
- createParameterVerification());
- return extractTypeInference(mappingExtractor);
+ null,
+ createOutputFromGenericInClass(TableFunction.class, 0, true),
+ createParameterVerification(false));
+ return extractTypeInference(mappingExtractor, false);
}
/** Extracts a type inference from a {@link TableAggregateFunction}. */
@@ -142,11 +151,12 @@ public static TypeInference forTableAggregateFunction(
typeFactory,
function,
UserDefinedFunctionHelper.TABLE_AGGREGATE_ACCUMULATE,
- createParameterSignatureExtraction(1),
- createGenericResultExtraction(TableAggregateFunction.class, 1, false),
- createGenericResultExtraction(TableAggregateFunction.class, 0, true),
- createParameterWithAccumulatorVerification());
- return extractTypeInference(mappingExtractor);
+ createArgumentsFromParametersExtraction(1),
+ createStateFromGenericInClassOrParameters(TableAggregateFunction.class, 1),
+ createParameterVerification(true),
+ createOutputFromGenericInClass(TableAggregateFunction.class, 0, true),
+ null);
+ return extractTypeInference(mappingExtractor, false);
}
/** Extracts a type inference from a {@link AsyncTableFunction}. */
@@ -157,11 +167,30 @@ public static TypeInference forAsyncTableFunction(
typeFactory,
function,
UserDefinedFunctionHelper.ASYNC_TABLE_EVAL,
- createParameterSignatureExtraction(1),
+ createArgumentsFromParametersExtraction(1),
null,
- createGenericResultExtraction(AsyncTableFunction.class, 0, true),
- createParameterWithArgumentVerification(CompletableFuture.class));
- return extractTypeInference(mappingExtractor);
+ null,
+ createOutputFromGenericInClass(AsyncTableFunction.class, 0, true),
+ createParameterAndCompletableFutureVerification(function));
+ return extractTypeInference(mappingExtractor, false);
+ }
+
+ /** Extracts a type inference from a {@link ProcessTableFunction}. */
+ public static TypeInference forProcessTableFunction(
+ DataTypeFactory typeFactory, Class extends ProcessTableFunction>> function) {
+ final FunctionMappingExtractor mappingExtractor =
+ new FunctionMappingExtractor(
+ typeFactory,
+ function,
+ UserDefinedFunctionHelper.PROCESS_TABLE_EVAL,
+ createArgumentsFromParametersExtraction(
+ 0, ProcessTableFunction.Context.class),
+ createStateFromParametersExtraction(),
+ createParameterAndOptionalContextVerification(
+ ProcessTableFunction.Context.class, true),
+ createOutputFromGenericInClass(ProcessTableFunction.class, 0, true),
+ null);
+ return extractTypeInference(mappingExtractor, true);
}
/** Extracts a type in inference from a {@link Procedure}. */
@@ -172,15 +201,16 @@ public static TypeInference forProcedure(
typeFactory,
procedure,
ProcedureDefinition.PROCEDURE_CALL,
- ProcedureMappingExtractor.createParameterSignatureExtraction(1),
- ProcedureMappingExtractor.createReturnTypeResultExtraction(),
- ProcedureMappingExtractor.createParameterAndReturnTypeVerification());
+ createArgumentsFromParametersExtraction(1),
+ createOutputFromArrayReturnTypeInMethod(),
+ createParameterWithOptionalContextAndArrayReturnTypeVerification());
return extractTypeInference(mappingExtractor);
}
- private static TypeInference extractTypeInference(FunctionMappingExtractor mappingExtractor) {
+ private static TypeInference extractTypeInference(
+ FunctionMappingExtractor mappingExtractor, boolean requiresStaticSignature) {
try {
- return extractTypeInferenceOrError(mappingExtractor);
+ return extractTypeInferenceOrError(mappingExtractor, requiresStaticSignature);
} catch (Throwable t) {
throw extractionError(
t,
@@ -192,7 +222,9 @@ private static TypeInference extractTypeInference(FunctionMappingExtractor mappi
private static TypeInference extractTypeInference(ProcedureMappingExtractor mappingExtractor) {
try {
- return extractTypeInferenceOrError(mappingExtractor);
+ final Map outputMapping =
+ mappingExtractor.extractOutputMapping();
+ return buildInference(null, outputMapping, false);
} catch (Throwable t) {
throw extractionError(
t,
@@ -203,110 +235,118 @@ private static TypeInference extractTypeInference(ProcedureMappingExtractor mapp
}
private static TypeInference extractTypeInferenceOrError(
- FunctionMappingExtractor mappingExtractor) {
- final Map outputMapping =
+ FunctionMappingExtractor mappingExtractor, boolean requiresStaticSignature) {
+ final Map outputMapping =
mappingExtractor.extractOutputMapping();
- if (!mappingExtractor.hasAccumulator()) {
- return buildInference(null, outputMapping);
+ if (!mappingExtractor.supportsState()) {
+ return buildInference(null, outputMapping, requiresStaticSignature);
}
- final Map accumulatorMapping =
- mappingExtractor.extractAccumulatorMapping();
- return buildInference(accumulatorMapping, outputMapping);
- }
+ final Map stateMapping =
+ mappingExtractor.extractStateMapping();
- private static TypeInference extractTypeInferenceOrError(
- ProcedureMappingExtractor mappingExtractor) {
- final Map outputMapping =
- mappingExtractor.extractOutputMapping();
- return buildInference(null, outputMapping);
+ return buildInference(stateMapping, outputMapping, requiresStaticSignature);
}
private static TypeInference buildInference(
- @Nullable Map accumulatorMapping,
- Map outputMapping) {
+ @Nullable Map stateMapping,
+ Map outputMapping,
+ boolean requiresStaticSignature) {
final TypeInference.Builder builder = TypeInference.newBuilder();
- configureNamedArguments(builder, outputMapping);
- configureOptionalArguments(builder, outputMapping);
- configureTypedArguments(builder, outputMapping);
-
- builder.inputTypeStrategy(translateInputTypeStrategy(outputMapping));
+ if (!configureStaticArguments(builder, outputMapping)) {
+ if (requiresStaticSignature) {
+ throw extractionError(
+ "Process table functions require a non-overloaded, non-vararg, and static signature.");
+ }
+ builder.inputTypeStrategy(translateInputTypeStrategy(outputMapping));
+ }
- if (accumulatorMapping != null) {
- // verify that accumulator and output are derived from the same input strategy
- if (!accumulatorMapping.keySet().equals(outputMapping.keySet())) {
+ if (stateMapping != null) {
+ // verify that state and output are derived from the same signatures
+ if (!stateMapping.keySet().equals(outputMapping.keySet())) {
throw extractionError(
- "Mismatch between accumulator signature and output signature. "
+ "Mismatch between state signature and output signature. "
+ "Both intermediate and output results must be derived from the same input strategy.");
}
- builder.accumulatorTypeStrategy(translateResultTypeStrategy(accumulatorMapping));
+ builder.stateTypeStrategies(translateStateTypeStrategies(stateMapping));
}
- builder.outputTypeStrategy(translateResultTypeStrategy(outputMapping));
+ builder.outputTypeStrategy(translateOutputTypeStrategy(outputMapping));
+
return builder.build();
}
- private static void configureNamedArguments(
+ private static boolean configureStaticArguments(
TypeInference.Builder builder,
- Map outputMapping) {
+ Map outputMapping) {
final Set signatures = outputMapping.keySet();
- if (signatures.stream().anyMatch(s -> s.isVarArgs || s.argumentNames == null)) {
- return;
+ if (signatures.size() != 1) {
+ // Function is overloaded
+ return false;
}
- final List> argumentNames =
- signatures.stream()
- .map(
- s -> {
- assert s.argumentNames != null;
- return Arrays.asList(s.argumentNames);
- })
- .collect(Collectors.toList());
- if (argumentNames.size() != 1) {
- return;
+ final List arguments = signatures.iterator().next().toStaticArguments();
+ if (arguments == null) {
+ // Function is var arg or non-static (e.g. uses input groups instead of typed arguments)
+ return false;
}
- builder.namedArguments(argumentNames.iterator().next());
+ builder.staticArguments(arguments);
+ return true;
}
- private static void configureOptionalArguments(
- TypeInference.Builder builder,
- Map outputMapping) {
- final Set signatures = outputMapping.keySet();
- if (signatures.stream().anyMatch(s -> s.isVarArgs || s.argumentNames == null)) {
- return;
- }
- final List> argumentOptional =
- signatures.stream()
- .filter(s -> s.argumentOptionals != null)
- .map(s -> Arrays.asList(s.argumentOptionals))
- .collect(Collectors.toList());
- if (argumentOptional.size() != 1 || argumentOptional.size() != signatures.size()) {
- return;
- }
- builder.optionalArguments(argumentOptional.get(0));
+ private static InputTypeStrategy translateInputTypeStrategy(
+ Map outputMapping) {
+ return outputMapping.keySet().stream()
+ .map(FunctionSignatureTemplate::toInputTypeStrategy)
+ .reduce(InputTypeStrategies::or)
+ .orElse(InputTypeStrategies.sequence());
}
- private static void configureTypedArguments(
- TypeInference.Builder builder,
- Map outputMapping) {
- if (outputMapping.size() != 1) {
- return;
+ private static LinkedHashMap translateStateTypeStrategies(
+ Map stateMapping) {
+ // Simple signatures don't require a mapping, default for process table functions
+ if (stateMapping.size() == 1) {
+ final FunctionStateTemplate template =
+ stateMapping.entrySet().iterator().next().getValue();
+ return template.toStateTypeStrategies();
}
- final FunctionSignatureTemplate signature = outputMapping.keySet().iterator().next();
- final List dataTypes =
- signature.argumentTemplates.stream()
- .map(a -> a.dataType)
- .collect(Collectors.toList());
- if (!signature.isVarArgs && dataTypes.stream().allMatch(Objects::nonNull)) {
- builder.typedArguments(dataTypes);
+ // For overloaded signatures to accumulators in aggregating functions
+ final Map mappings =
+ stateMapping.entrySet().stream()
+ .collect(
+ Collectors.toMap(
+ e -> e.getKey().toInputTypeStrategy(),
+ e -> e.getValue().toAccumulatorTypeStrategy()));
+ final StateTypeStrategy accumulatorStrategy =
+ StateTypeStrategyWrapper.of(TypeStrategies.mapping(mappings));
+ final Set stateNames =
+ stateMapping.values().stream()
+ .map(FunctionStateTemplate::toAccumulatorStateName)
+ .collect(Collectors.toSet());
+ if (stateMapping.size() > 1 && stateNames.size() > 1) {
+ throw extractionError(
+ "Overloaded aggregating functions must use the same name for state entries. "
+ + "Found: %s",
+ stateNames);
}
+ final String stateName = stateNames.iterator().next();
+ final LinkedHashMap stateTypeStrategies = new LinkedHashMap<>();
+ stateTypeStrategies.put(stateName, accumulatorStrategy);
+ return stateTypeStrategies;
}
- private static TypeStrategy translateResultTypeStrategy(
- Map resultMapping) {
+ private static TypeStrategy translateOutputTypeStrategy(
+ Map outputMapping) {
+ // Simple signatures don't require a mapping
+ if (outputMapping.size() == 1) {
+ final FunctionOutputTemplate template =
+ outputMapping.entrySet().iterator().next().getValue();
+ return template.toTypeStrategy();
+ }
+ // For overloaded signatures
final Map mappings =
- resultMapping.entrySet().stream()
+ outputMapping.entrySet().stream()
.collect(
Collectors.toMap(
e -> e.getKey().toInputTypeStrategy(),
@@ -314,12 +354,4 @@ private static TypeStrategy translateResultTypeStrategy(
(t1, t2) -> t2));
return TypeStrategies.mapping(mappings);
}
-
- private static InputTypeStrategy translateInputTypeStrategy(
- Map outputMapping) {
- return outputMapping.keySet().stream()
- .map(FunctionSignatureTemplate::toInputTypeStrategy)
- .reduce(InputTypeStrategies::or)
- .orElse(InputTypeStrategies.sequence());
- }
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java
index ad791585c7806c..1faf3a0deaa55d 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java
@@ -19,15 +19,21 @@
package org.apache.flink.table.types.inference;
import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.NullType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.StructuredType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeChecks;
import org.apache.flink.util.Preconditions;
import javax.annotation.Nullable;
import java.util.EnumSet;
+import java.util.Objects;
import java.util.Optional;
+import java.util.stream.Collectors;
/**
* Describes an argument in a static signature that is not overloaded and does not support varargs.
@@ -43,6 +49,8 @@
@PublicEvolving
public class StaticArgument {
+ private static final RowType DUMMY_ROW_TYPE = RowType.of(new NullType());
+
private final String name;
private final @Nullable DataType dataType;
private final @Nullable Class> conversionClass;
@@ -55,13 +63,15 @@ private StaticArgument(
@Nullable Class> conversionClass,
boolean isOptional,
EnumSet traits) {
- StaticArgumentTrait.checkIntegrity(
- Preconditions.checkNotNull(traits, "Traits must not be null."));
this.name = Preconditions.checkNotNull(name, "Name must not be null.");
this.dataType = dataType;
this.conversionClass = conversionClass;
this.isOptional = isOptional;
- this.traits = traits;
+ this.traits = Preconditions.checkNotNull(traits, "Traits must not be null.");
+ checkName();
+ checkTraits(traits);
+ checkOptionalType();
+ checkTableType();
}
/**
@@ -162,4 +172,121 @@ public boolean isOptional() {
public EnumSet getTraits() {
return traits;
}
+
+ @Override
+ public String toString() {
+ final StringBuilder s = new StringBuilder();
+ // Possible signatures:
+ // (myScalar INT)
+ // (myTypedTable ROW {TABLE BY ROW})
+ // (myUntypedTable {TABLE BY ROW})
+ s.append(name);
+ if (dataType != null) {
+ s.append(" ");
+ s.append(dataType);
+ }
+ if (!traits.equals(EnumSet.of(StaticArgumentTrait.SCALAR))) {
+ s.append(" ");
+ s.append(traits.stream().map(Enum::name).collect(Collectors.joining(", ", "{", "}")));
+ }
+ return s.toString();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final StaticArgument that = (StaticArgument) o;
+ return isOptional == that.isOptional
+ && Objects.equals(name, that.name)
+ && Objects.equals(dataType, that.dataType)
+ && Objects.equals(conversionClass, that.conversionClass)
+ && Objects.equals(traits, that.traits);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(name, dataType, conversionClass, isOptional, traits);
+ }
+
+ private void checkName() {
+ if (!TypeInference.PARAMETER_NAME_FORMAT.test(name)) {
+ throw new ValidationException(
+ String.format(
+ "Invalid argument name '%s'. An argument must follow "
+ + "the pattern [a-zA-Z_$][a-zA-Z_$0-9].",
+ name));
+ }
+ }
+
+ private void checkTraits(EnumSet traits) {
+ if (traits.stream().filter(t -> t.getRequirements().isEmpty()).count() != 1) {
+ throw new ValidationException(
+ String.format(
+ "Invalid argument traits for argument '%s'. "
+ + "An argument must be declared as either scalar, table, or model.",
+ name));
+ }
+ traits.forEach(
+ trait ->
+ trait.getRequirements()
+ .forEach(
+ requirement -> {
+ if (!traits.contains(requirement)) {
+ throw new ValidationException(
+ String.format(
+ "Invalid argument traits for argument '%s'. Trait %s requires %s.",
+ name, trait, requirement));
+ }
+ }));
+ }
+
+ private void checkOptionalType() {
+ if (!isOptional) {
+ return;
+ }
+ // e.g. for untyped table arguments
+ if (dataType == null) {
+ return;
+ }
+
+ final LogicalType type = dataType.getLogicalType();
+ if (!type.isNullable() || !type.supportsInputConversion(dataType.getConversionClass())) {
+ throw new ValidationException(
+ String.format(
+ "Invalid data type for optional argument '%s'. "
+ + "An optional argument has to accept null values.",
+ name));
+ }
+ }
+
+ void checkTableType() {
+ if (!traits.contains(StaticArgumentTrait.TABLE)) {
+ return;
+ }
+ if (dataType == null
+ && conversionClass != null
+ && !DUMMY_ROW_TYPE.supportsInputConversion(conversionClass)) {
+ throw new ValidationException(
+ String.format(
+ "Invalid conversion class '%s' for argument '%s'. "
+ + "Polymorphic, untyped table arguments must use a row class.",
+ conversionClass.getName(), name));
+ }
+ if (dataType != null) {
+ final LogicalType type = dataType.getLogicalType();
+ if (traits.contains(StaticArgumentTrait.TABLE)
+ && !LogicalTypeChecks.isCompositeType(type)) {
+ throw new ValidationException(
+ String.format(
+ "Invalid data type '%s' for table argument '%s'. "
+ + "Typed table arguments must use a composite type (i.e. row or structured type).",
+ type, name));
+ }
+ }
+ }
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java
index 76a4e6e26902a6..0590d21a340cd3 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java
@@ -19,10 +19,8 @@
package org.apache.flink.table.types.inference;
import org.apache.flink.annotation.PublicEvolving;
-import org.apache.flink.table.api.ValidationException;
import java.util.Arrays;
-import java.util.EnumSet;
import java.util.Set;
import java.util.stream.Collectors;
@@ -47,21 +45,7 @@ public enum StaticArgumentTrait {
this.requirements = Arrays.stream(requirements).collect(Collectors.toSet());
}
- public static void checkIntegrity(EnumSet traits) {
- if (traits.stream().filter(t -> t.requirements.isEmpty()).count() != 1) {
- throw new ValidationException(
- "Invalid argument traits. An argument must be declared as either scalar, table, or model.");
- }
- traits.forEach(
- trait ->
- trait.requirements.forEach(
- requirement -> {
- if (!traits.contains(requirement)) {
- throw new ValidationException(
- String.format(
- "Invalid argument traits. Trait %s requires %s.",
- trait, requirement));
- }
- }));
+ public Set getRequirements() {
+ return requirements;
}
}
diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java
index 64d372b36c1e78..1939c34b6f18dc 100644
--- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java
+++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java
@@ -19,6 +19,7 @@
package org.apache.flink.table.types.inference;
import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.types.DataType;
import org.apache.flink.util.Preconditions;
@@ -28,6 +29,8 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Optional;
+import java.util.function.Predicate;
+import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -47,6 +50,10 @@
@PublicEvolving
public final class TypeInference {
+ /** Format for both arguments and state entries. */
+ static final Predicate PARAMETER_NAME_FORMAT =
+ Pattern.compile("^[a-zA-Z_$][a-zA-Z_$0-9]*$").asPredicate();
+
private final @Nullable List staticArguments;
private final InputTypeStrategy inputTypeStrategy;
private final LinkedHashMap stateTypeStrategies;
@@ -61,6 +68,7 @@ private TypeInference(
this.inputTypeStrategy = inputTypeStrategy;
this.stateTypeStrategies = stateTypeStrategies;
this.outputTypeStrategy = outputTypeStrategy;
+ checkStateEntries();
}
/** Builder for configuring and creating instances of {@link TypeInference}. */
@@ -144,6 +152,19 @@ public Optional getAccumulatorTypeStrategy() {
return Optional.of(stateTypeStrategies.values().iterator().next());
}
+ private void checkStateEntries() {
+ // Verify state
+ final List invalidStateEntries =
+ stateTypeStrategies.keySet().stream()
+ .filter(n -> !PARAMETER_NAME_FORMAT.test(n))
+ .collect(Collectors.toList());
+ if (!invalidStateEntries.isEmpty()) {
+ throw new ValidationException(
+ "Invalid state names. A state entry must follow the pattern [a-zA-Z_$][a-zA-Z_$0-9]. But found: "
+ + invalidStateEntries);
+ }
+ }
+
// --------------------------------------------------------------------------------------------
// Builder
// --------------------------------------------------------------------------------------------
diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/DataTypeExtractorTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/DataTypeExtractorTest.java
index 03e1e66d137479..21eed2f0277124 100644
--- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/DataTypeExtractorTest.java
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/DataTypeExtractorTest.java
@@ -579,7 +579,8 @@ static TestSpec forMethodOutput(String description, Class> clazz) {
final Method method = clazz.getMethods()[0];
return new TestSpec(
description,
- (lookup) -> DataTypeExtractor.extractFromMethodOutput(lookup, clazz, method));
+ (lookup) ->
+ DataTypeExtractor.extractFromMethodReturnType(lookup, clazz, method));
}
static TestSpec forMethodOutput(Class> clazz) {
diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java
index 5fbc6dc81d354c..9b7ee509926cff 100644
--- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java
+++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java
@@ -31,6 +31,7 @@
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.functions.AggregateFunction;
import org.apache.flink.table.functions.AsyncScalarFunction;
+import org.apache.flink.table.functions.ProcessTableFunction;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.functions.TableAggregateFunction;
import org.apache.flink.table.functions.TableFunction;
@@ -39,6 +40,10 @@
import org.apache.flink.table.types.inference.ArgumentTypeStrategy;
import org.apache.flink.table.types.inference.InputTypeStrategies;
import org.apache.flink.table.types.inference.InputTypeStrategy;
+import org.apache.flink.table.types.inference.StateTypeStrategy;
+import org.apache.flink.table.types.inference.StateTypeStrategyWrapper;
+import org.apache.flink.table.types.inference.StaticArgument;
+import org.apache.flink.table.types.inference.StaticArgumentTrait;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategies;
import org.apache.flink.table.types.inference.TypeStrategy;
@@ -50,7 +55,10 @@
import javax.annotation.Nullable;
+import java.math.BigDecimal;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.EnumSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -74,17 +82,10 @@ private static Stream functionSpecs() {
return Stream.of(
// function hint defines everything
TestSpec.forScalarFunction(FullFunctionHint.class)
- .expectNamedArguments("i", "s")
- .expectTypedArguments(DataTypes.INT(), DataTypes.STRING())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i", "s"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.INT()),
- InputTypeStrategies.explicit(DataTypes.STRING())
- }),
- TypeStrategies.explicit(DataTypes.BOOLEAN())),
-
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING(), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.BOOLEAN())),
+ // ---
// function hint defines everything with overloading
TestSpec.forScalarFunction(FullFunctionHints.class)
.expectOutputMapping(
@@ -95,7 +96,7 @@ private static Stream functionSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.BIGINT())),
TypeStrategies.explicit(DataTypes.BIGINT())),
-
+ // ---
// global output hint with local input overloading
TestSpec.forScalarFunction(GlobalOutputFunctionHint.class)
.expectOutputMapping(
@@ -106,12 +107,12 @@ private static Stream functionSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.STRING())),
TypeStrategies.explicit(DataTypes.INT())),
-
+ // ---
// unsupported output overloading
TestSpec.forScalarFunction(InvalidSingleOutputFunctionHint.class)
.expectErrorMessage(
"Function hints that lead to ambiguous results are not allowed."),
-
+ // ---
// global and local overloading
TestSpec.forScalarFunction(SplitFullFunctionHints.class)
.expectOutputMapping(
@@ -122,22 +123,21 @@ private static Stream functionSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.BIGINT())),
TypeStrategies.explicit(DataTypes.BIGINT())),
-
+ // ---
// global and local overloading with unsupported output overloading
TestSpec.forScalarFunction(InvalidFullOutputFunctionHint.class)
.expectErrorMessage(
"Function hints with same input definition but different result types are not allowed."),
-
+ // ---
// ignore argument names during overloading
TestSpec.forScalarFunction(InvalidFullOutputFunctionWithArgNamesHint.class)
.expectErrorMessage(
"Function hints with same input definition but different result types are not allowed."),
-
+ // ---
// invalid data type hint
TestSpec.forScalarFunction(IncompleteFunctionHint.class)
- .expectErrorMessage(
- "Data type hint does neither specify a data type nor input group for use as function argument."),
-
+ .expectErrorMessage("Data type missing for scalar argument at position 1."),
+ // ---
// varargs and ANY input group
TestSpec.forScalarFunction(ComplexFunctionHint.class)
.expectOutputMapping(
@@ -149,7 +149,7 @@ private static Stream functionSpecs() {
InputTypeStrategies.ANY
}),
TypeStrategies.explicit(DataTypes.BOOLEAN())),
-
+ // ---
// global input hints and local output hints
TestSpec.forScalarFunction(GlobalInputFunctionHints.class)
.expectOutputMapping(
@@ -160,55 +160,33 @@ private static Stream functionSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.BIGINT())),
TypeStrategies.explicit(DataTypes.INT())),
-
+ // ---
// no arguments
TestSpec.forScalarFunction(ZeroArgFunction.class)
- .expectNamedArguments()
- .expectTypedArguments()
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[0], new ArgumentTypeStrategy[0]),
- TypeStrategies.explicit(DataTypes.INT())),
-
+ .expectEmptyStaticArguments()
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// no arguments async
TestSpec.forAsyncScalarFunction(ZeroArgFunctionAsync.class)
- .expectNamedArguments()
- .expectTypedArguments()
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[0], new ArgumentTypeStrategy[0]),
- TypeStrategies.explicit(DataTypes.INT())),
-
+ .expectEmptyStaticArguments()
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// test primitive arguments extraction
TestSpec.forScalarFunction(MixedArgFunction.class)
- .expectNamedArguments("i", "d")
- .expectTypedArguments(
- DataTypes.INT().notNull().bridgedTo(int.class), DataTypes.DOUBLE())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i", "d"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(
- DataTypes.INT().notNull().bridgedTo(int.class)),
- InputTypeStrategies.explicit(DataTypes.DOUBLE())
- }),
- TypeStrategies.explicit(DataTypes.INT())),
-
+ .expectStaticArgument(
+ StaticArgument.scalar(
+ "i", DataTypes.INT().notNull().bridgedTo(int.class), false))
+ .expectStaticArgument(StaticArgument.scalar("d", DataTypes.DOUBLE(), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// test primitive arguments extraction async
TestSpec.forAsyncScalarFunction(MixedArgFunctionAsync.class)
- .expectNamedArguments("i", "d")
- .expectTypedArguments(
- DataTypes.INT().notNull().bridgedTo(int.class), DataTypes.DOUBLE())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i", "d"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(
- DataTypes.INT().notNull().bridgedTo(int.class)),
- InputTypeStrategies.explicit(DataTypes.DOUBLE())
- }),
- TypeStrategies.explicit(DataTypes.INT())),
-
+ .expectStaticArgument(
+ StaticArgument.scalar(
+ "i", DataTypes.INT().notNull().bridgedTo(int.class), false))
+ .expectStaticArgument(StaticArgument.scalar("d", DataTypes.DOUBLE(), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// test overloaded arguments extraction
TestSpec.forScalarFunction(OverloadedFunction.class)
.expectOutputMapping(
@@ -228,7 +206,7 @@ private static Stream functionSpecs() {
}),
TypeStrategies.explicit(
DataTypes.BIGINT().notNull().bridgedTo(long.class))),
-
+ // ---
// test overloaded arguments extraction async
TestSpec.forAsyncScalarFunction(OverloadedFunctionAsync.class)
.expectOutputMapping(
@@ -247,7 +225,7 @@ private static Stream functionSpecs() {
InputTypeStrategies.explicit(DataTypes.STRING())
}),
TypeStrategies.explicit(DataTypes.BIGINT())),
-
+ // ---
// test varying arguments extraction
TestSpec.forScalarFunction(VarArgFunction.class)
.expectOutputMapping(
@@ -260,7 +238,7 @@ private static Stream functionSpecs() {
DataTypes.INT().notNull().bridgedTo(int.class))
}),
TypeStrategies.explicit(DataTypes.STRING())),
-
+ // ---
// test varying arguments extraction async
TestSpec.forAsyncScalarFunction(VarArgFunctionAsync.class)
.expectOutputMapping(
@@ -273,7 +251,7 @@ private static Stream functionSpecs() {
DataTypes.INT().notNull().bridgedTo(int.class))
}),
TypeStrategies.explicit(DataTypes.STRING())),
-
+ // ---
// test varying arguments extraction with byte
TestSpec.forScalarFunction(VarArgWithByteFunction.class)
.expectOutputMapping(
@@ -286,7 +264,7 @@ private static Stream functionSpecs() {
.bridgedTo(byte.class))
}),
TypeStrategies.explicit(DataTypes.STRING())),
-
+ // ---
// test varying arguments extraction with byte async
TestSpec.forAsyncScalarFunction(VarArgWithByteFunctionAsync.class)
.expectOutputMapping(
@@ -299,57 +277,49 @@ private static Stream functionSpecs() {
.bridgedTo(byte.class))
}),
TypeStrategies.explicit(DataTypes.STRING())),
-
+ // ---
// output hint with input extraction
TestSpec.forScalarFunction(ExtractWithOutputHintFunction.class)
- .expectNamedArguments("i")
- .expectTypedArguments(DataTypes.INT())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
- TypeStrategies.explicit(DataTypes.INT())),
-
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// output hint with input extraction
TestSpec.forAsyncScalarFunction(ExtractWithOutputHintFunctionAsync.class)
- .expectNamedArguments("i")
- .expectTypedArguments(DataTypes.INT())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
- TypeStrategies.explicit(DataTypes.INT())),
-
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// output extraction with input hints
TestSpec.forScalarFunction(ExtractWithInputHintFunction.class)
- .expectNamedArguments("i", "b")
- .expectTypedArguments(DataTypes.INT(), DataTypes.BOOLEAN())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i", "b"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.INT()),
- InputTypeStrategies.explicit(DataTypes.BOOLEAN())
- }),
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectStaticArgument(
+ StaticArgument.scalar("b", DataTypes.BOOLEAN(), false))
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.DOUBLE().notNull().bridgedTo(double.class))),
-
+ // ---
// different accumulator depending on input
TestSpec.forAggregateFunction(InputDependentAccumulatorFunction.class)
- .expectAccumulatorMapping(
- InputTypeStrategies.sequence(
- InputTypeStrategies.explicit(DataTypes.BIGINT())),
- TypeStrategies.explicit(
- DataTypes.ROW(DataTypes.FIELD("f", DataTypes.BIGINT()))))
- .expectAccumulatorMapping(
- InputTypeStrategies.sequence(
- InputTypeStrategies.explicit(DataTypes.STRING())),
- TypeStrategies.explicit(
- DataTypes.ROW(DataTypes.FIELD("f", DataTypes.STRING()))))
+ .expectAccumulator(
+ TypeStrategies.mapping(
+ Map.of(
+ InputTypeStrategies.sequence(
+ InputTypeStrategies.explicit(
+ DataTypes.BIGINT())),
+ TypeStrategies.explicit(
+ DataTypes.ROW(
+ DataTypes.FIELD(
+ "f",
+ DataTypes
+ .BIGINT()))),
+ InputTypeStrategies.sequence(
+ InputTypeStrategies.explicit(
+ DataTypes.STRING())),
+ TypeStrategies.explicit(
+ DataTypes.ROW(
+ DataTypes.FIELD(
+ "f",
+ DataTypes
+ .STRING()))))))
.expectOutputMapping(
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.BIGINT())),
@@ -358,81 +328,72 @@ private static Stream functionSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.STRING())),
TypeStrategies.explicit(DataTypes.STRING())),
-
+ // ---
// input, accumulator, and output are spread across the function
TestSpec.forAggregateFunction(AggregateFunctionWithManyAnnotations.class)
- .expectNamedArguments("r")
- .expectTypedArguments(
- DataTypes.ROW(
- DataTypes.FIELD("i", DataTypes.INT()),
- DataTypes.FIELD("b", DataTypes.BOOLEAN())))
- .expectAccumulatorMapping(
- InputTypeStrategies.sequence(
- new String[] {"r"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(
- DataTypes.ROW(
- DataTypes.FIELD("i", DataTypes.INT()),
- DataTypes.FIELD(
- "b", DataTypes.BOOLEAN())))
- }),
+ .expectStaticArgument(
+ StaticArgument.scalar(
+ "r",
+ DataTypes.ROW(
+ DataTypes.FIELD("i", DataTypes.INT()),
+ DataTypes.FIELD("b", DataTypes.BOOLEAN())),
+ false))
+ .expectAccumulator(
TypeStrategies.explicit(
DataTypes.ROW(DataTypes.FIELD("b", DataTypes.BOOLEAN()))))
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"r"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(
- DataTypes.ROW(
- DataTypes.FIELD("i", DataTypes.INT()),
- DataTypes.FIELD(
- "b", DataTypes.BOOLEAN())))
- }),
- TypeStrategies.explicit(DataTypes.STRING())),
-
+ .expectOutput(TypeStrategies.explicit(DataTypes.STRING())),
+ // ---
+ // accumulator with state hint
+ TestSpec.forAggregateFunction(StateHintAggregateFunction.class)
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectState("myAcc", TypeStrategies.explicit(MyState.TYPE))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
+ // accumulator with state hint in function hint
+ TestSpec.forAggregateFunction(StateHintInFunctionHintAggregateFunction.class)
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectState("myAcc", TypeStrategies.explicit(MyState.TYPE))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// test for table functions
TestSpec.forTableFunction(OutputHintTableFunction.class)
- .expectNamedArguments("i")
- .expectTypedArguments(DataTypes.INT().notNull().bridgedTo(int.class))
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(
- DataTypes.INT().notNull().bridgedTo(int.class))
- }),
+ .expectStaticArgument(
+ StaticArgument.scalar(
+ "i", DataTypes.INT().notNull().bridgedTo(int.class), false))
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.ROW(
DataTypes.FIELD("i", DataTypes.INT()),
DataTypes.FIELD("b", DataTypes.BOOLEAN())))),
-
+ // ---
// mismatch between hints and implementation regarding return type
TestSpec.forScalarFunction(InvalidMethodScalarFunction.class)
.expectErrorMessage(
"Considering all hints, the method should comply with the signature:\n"
+ "java.lang.String eval(int[])"),
-
+ // ---
// mismatch between hints and implementation regarding return type
TestSpec.forAsyncScalarFunction(InvalidMethodScalarFunctionAsync.class)
.expectErrorMessage(
"Considering all hints, the method should comply with the signature:\n"
+ "eval(java.util.concurrent.CompletableFuture, int[])"),
-
+ // ---
// mismatch between hints and implementation regarding accumulator
TestSpec.forAggregateFunction(InvalidMethodAggregateFunction.class)
.expectErrorMessage(
"Considering all hints, the method should comply with the signature:\n"
- + "accumulate(java.lang.Integer, int, boolean)"),
-
+ + "accumulate(java.lang.Integer, int, boolean)\n"
+ + "Pattern: ( [, ]*)"),
+ // ---
// no implementation
TestSpec.forTableFunction(MissingMethodTableFunction.class)
.expectErrorMessage(
"Could not find a publicly accessible method named 'eval'."),
-
+ // ---
// named arguments with overloaded function
// expected no named argument for overloaded function
TestSpec.forScalarFunction(NamedArgumentsScalarFunction.class),
-
+ // ---
// scalar function that takes any input
TestSpec.forScalarFunction(InputGroupScalarFunction.class)
.expectOutputMapping(
@@ -440,7 +401,7 @@ private static Stream functionSpecs() {
new String[] {"o"},
new ArgumentTypeStrategy[] {InputTypeStrategies.ANY}),
TypeStrategies.explicit(DataTypes.STRING())),
-
+ // ---
// scalar function that takes any input as vararg
TestSpec.forScalarFunction(VarArgInputGroupScalarFunction.class)
.expectOutputMapping(
@@ -448,6 +409,7 @@ private static Stream functionSpecs() {
new String[] {"o"},
new ArgumentTypeStrategy[] {InputTypeStrategies.ANY}),
TypeStrategies.explicit(DataTypes.STRING())),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with implicit overloading order",
OrderedScalarFunction.class)
@@ -465,6 +427,7 @@ private static Stream functionSpecs() {
InputTypeStrategies.explicit(DataTypes.BIGINT())
}),
TypeStrategies.explicit(DataTypes.BIGINT())),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with explicit overloading order by class annotations",
OrderedScalarFunction2.class)
@@ -476,6 +439,7 @@ private static Stream functionSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.INT())),
TypeStrategies.explicit(DataTypes.INT())),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with explicit overloading order by method annotations",
OrderedScalarFunction3.class)
@@ -487,138 +451,131 @@ private static Stream functionSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.INT())),
TypeStrategies.explicit(DataTypes.INT())),
+ // ---
TestSpec.forTableFunction(
"A data type hint on the class is used instead of a function output hint",
DataTypeHintOnTableFunctionClass.class)
- .expectNamedArguments()
- .expectTypedArguments()
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {}, new ArgumentTypeStrategy[] {}),
+ .expectEmptyStaticArguments()
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())))),
+ // ---
TestSpec.forTableFunction(
"A data type hint on the method is used instead of a function output hint",
DataTypeHintOnTableFunctionMethod.class)
- .expectNamedArguments("i")
- .expectTypedArguments(DataTypes.INT())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())))),
+ // ---
TestSpec.forTableFunction(
"Invalid data type hint on top of method and class",
InvalidDataTypeHintOnTableFunction.class)
.expectErrorMessage(
"More than one data type hint found for output of function. "
+ "Please use a function hint instead."),
+ // ---
TestSpec.forScalarFunction(
"A data type hint on the method is used for enriching (not a function output hint)",
DataTypeHintOnScalarFunction.class)
- .expectNamedArguments()
- .expectTypedArguments()
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {}, new ArgumentTypeStrategy[] {}),
+ .expectEmptyStaticArguments()
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT()))
.bridgedTo(RowData.class))),
+ // ---
TestSpec.forAsyncScalarFunction(
"A data type hint on the method is used for enriching (not a function output hint)",
DataTypeHintOnScalarFunctionAsync.class)
- .expectNamedArguments()
- .expectTypedArguments()
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {}, new ArgumentTypeStrategy[] {}),
+ .expectEmptyStaticArguments()
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT()))
.bridgedTo(RowData.class))),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with arguments hints",
ArgumentHintScalarFunction.class)
- .expectNamedArguments("f1", "f2")
- .expectTypedArguments(DataTypes.STRING(), DataTypes.INT())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"f1", "f2"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.STRING()),
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
- TypeStrategies.explicit(DataTypes.STRING())),
+ .expectStaticArgument(
+ StaticArgument.scalar("f1", DataTypes.STRING(), false))
+ .expectStaticArgument(StaticArgument.scalar("f2", DataTypes.INT(), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.STRING())),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with arguments hints missing type",
ArgumentHintMissingTypeScalarFunction.class)
- .expectErrorMessage("The type of the argument at position 0 is not set."),
+ .expectErrorMessage("Data type missing for scalar argument at position 0."),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with arguments hints all missing name",
ArgumentHintMissingNameScalarFunction.class)
- .expectNamedArguments("arg0", "arg1")
- .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()),
+ .expectOutputMapping(
+ InputTypeStrategies.sequence(
+ InputTypeStrategies.explicit(DataTypes.STRING()),
+ InputTypeStrategies.explicit(DataTypes.INT())),
+ TypeStrategies.explicit(DataTypes.STRING())),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with arguments hints all missing partial name",
ArgumentHintMissingPartialNameScalarFunction.class)
.expectErrorMessage(
- "The argument name in function hint must be either fully set or not set at all."),
+ "Argument names in function hint must be either fully set or not set at all."),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with arguments hints name conflict",
ArgumentHintNameConflictScalarFunction.class)
.expectErrorMessage(
"Argument name conflict, there are at least two argument names that are the same."),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with arguments hints on method parameter",
ArgumentHintOnParameterScalarFunction.class)
- .expectNamedArguments("in1", "in2")
- .expectTypedArguments(DataTypes.STRING(), DataTypes.INT())
- .expectOptionalArguments(false, false)
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"in1", "in2"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.STRING()),
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
- TypeStrategies.explicit(DataTypes.STRING())),
+ .expectStaticArgument(
+ StaticArgument.scalar("in1", DataTypes.STRING(), false))
+ .expectStaticArgument(StaticArgument.scalar("in2", DataTypes.INT(), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.STRING())),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with arguments hints and inputs hints both defined",
ArgumentsAndInputsScalarFunction.class)
.expectErrorMessage(
"Argument and input hints cannot be declared in the same function hint."),
+ // ---
TestSpec.forScalarFunction(
- "Scalar function with argument hint and dataType hint declared in the same parameter",
+ "Scalar function with argument hint and data type hint declared in the same parameter",
ArgumentsHintAndDataTypeHintScalarFunction.class)
.expectErrorMessage(
- "Argument and dataType hints cannot be declared in the same parameter at position 0."),
+ "Argument and data type hints cannot be declared at the same time at position 0."),
+ // ---
TestSpec.forScalarFunction(
"An invalid scalar function that declare FunctionHint for both class and method in the same class.",
InvalidFunctionHintOnClassAndMethod.class)
.expectErrorMessage(
"Argument and input hints cannot be declared in the same function hint."),
+ // ---
TestSpec.forScalarFunction(
"A valid scalar class that declare FunctionHint for both class and method in the same class.",
ValidFunctionHintOnClassAndMethod.class)
- .expectNamedArguments("f1", "f2")
- .expectTypedArguments(DataTypes.STRING(), DataTypes.INT())
- .expectOptionalArguments(true, true),
+ .expectStaticArgument(StaticArgument.scalar("f1", DataTypes.STRING(), true))
+ .expectStaticArgument(StaticArgument.scalar("f2", DataTypes.INT(), true)),
+ // ---
TestSpec.forScalarFunction(
"The FunctionHint of the function conflicts with the method.",
ScalarFunctionWithFunctionHintConflictMethod.class)
.expectErrorMessage(
"Considering all hints, the method should comply with the signature"),
+ // ---
// For function with overloaded function, argument name will be empty
TestSpec.forScalarFunction(
"Scalar function with overloaded functions and arguments hint declared.",
ArgumentsHintScalarFunctionWithOverloadedFunction.class),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with argument type not null but optional.",
ArgumentHintNotNullTypeWithOptionalsScalarFunction.class)
.expectErrorMessage(
"Argument at position 0 is optional but its type doesn't accept null value."),
+ // ---
TestSpec.forScalarFunction(
"Scalar function with arguments hint and variable length args",
ArgumentHintVariableLengthScalarFunction.class)
@@ -630,29 +587,143 @@ private static Stream functionSpecs() {
InputTypeStrategies.explicit(DataTypes.INT())
}),
TypeStrategies.explicit(DataTypes.STRING())),
- TestSpec.forScalarFunction(FunctionHintTableArgScalarFunction.class)
- .expectErrorMessage("Only scalar arguments are supported so far."),
- TestSpec.forScalarFunction(ArgumentHintTableArgScalarFunction.class)
- .expectErrorMessage("Only scalar arguments are supported so far."),
- TestSpec.forScalarFunction(StateHintScalarFunction.class)
- .expectErrorMessage("State hints are not supported yet."));
+ // ---
+ TestSpec.forProcessTableFunction(StatelessProcessTableFunction.class)
+ .expectStaticArgument(
+ StaticArgument.scalar(
+ "i", DataTypes.INT().notNull().bridgedTo(int.class), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
+ TestSpec.forProcessTableFunction(StateProcessTableFunction.class)
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectState("s", TypeStrategies.explicit(MyState.TYPE))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
+ TestSpec.forProcessTableFunction(NamedStateProcessTableFunction.class)
+ .expectStaticArgument(
+ StaticArgument.scalar("myArg", DataTypes.INT(), false))
+ .expectState("myState", TypeStrategies.explicit(MyState.TYPE))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
+ TestSpec.forProcessTableFunction(MultiStateProcessTableFunction.class)
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectState("s1", TypeStrategies.explicit(MyFirstState.TYPE))
+ .expectState("s2", TypeStrategies.explicit(MySecondState.TYPE))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
+ TestSpec.forProcessTableFunction(UntypedTableArgProcessTableFunction.class)
+ .expectStaticArgument(
+ StaticArgument.table(
+ "t",
+ Row.class,
+ false,
+ EnumSet.of(StaticArgumentTrait.TABLE_AS_ROW)))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
+ TestSpec.forProcessTableFunction(TypedTableArgProcessTableFunction.class)
+ .expectStaticArgument(
+ StaticArgument.table(
+ "t",
+ TypedTableArgProcessTableFunction.Customer.TYPE,
+ false,
+ EnumSet.of(StaticArgumentTrait.TABLE_AS_ROW)))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
+ TestSpec.forProcessTableFunction(ComplexProcessTableFunction.class)
+ .expectStaticArgument(
+ StaticArgument.table(
+ "setTable",
+ RowData.class,
+ false,
+ EnumSet.of(
+ StaticArgumentTrait.TABLE_AS_SET,
+ StaticArgumentTrait.OPTIONAL_PARTITION_BY)))
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectStaticArgument(
+ StaticArgument.table(
+ "rowTable",
+ Row.class,
+ true,
+ EnumSet.of(StaticArgumentTrait.TABLE_AS_ROW)))
+ .expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING(), true))
+ .expectState("s1", TypeStrategies.explicit(MyFirstState.TYPE))
+ .expectState(
+ "other",
+ TypeStrategies.explicit(
+ DataTypes.ROW(DataTypes.FIELD("f", DataTypes.FLOAT()))))
+ .expectOutput(
+ TypeStrategies.explicit(
+ DataTypes.ROW(DataTypes.FIELD("b", DataTypes.BOOLEAN())))),
+ // ---
+ TestSpec.forProcessTableFunction(ComplexProcessTableFunctionWithFunctionHint.class)
+ .expectStaticArgument(
+ StaticArgument.table(
+ "setTable",
+ RowData.class,
+ false,
+ EnumSet.of(
+ StaticArgumentTrait.TABLE_AS_SET,
+ StaticArgumentTrait.OPTIONAL_PARTITION_BY)))
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectStaticArgument(
+ StaticArgument.table(
+ "rowTable",
+ Row.class,
+ true,
+ EnumSet.of(StaticArgumentTrait.TABLE_AS_ROW)))
+ .expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING(), true))
+ .expectState("s1", TypeStrategies.explicit(MyFirstState.TYPE))
+ .expectState(
+ "other",
+ TypeStrategies.explicit(
+ DataTypes.ROW(DataTypes.FIELD("f", DataTypes.FLOAT()))))
+ .expectOutput(
+ TypeStrategies.explicit(
+ DataTypes.ROW(DataTypes.FIELD("b", DataTypes.BOOLEAN())))),
+ // ---
+ TestSpec.forProcessTableFunction(WrongStateOrderProcessTableFunction.class)
+ .expectErrorMessage(
+ "Considering all hints, the method should comply with the signature:\n"
+ + "eval(org.apache.flink.table.types.extraction.TypeInferenceExtractorTest.MyFirstState, int)\n"
+ + "Pattern: (? [, ]* [, ]*)"),
+ // ---
+ TestSpec.forProcessTableFunction(MissingStateTypeProcessTableFunction.class)
+ .expectErrorMessage(
+ "Could not extract a data type from 'class java.lang.Object' in parameter 0 of method 'eval'"),
+ // ---
+ TestSpec.forProcessTableFunction(EnrichedExtractionStateProcessTableFunction.class)
+ .expectState("d", TypeStrategies.explicit(DataTypes.DECIMAL(3, 2)))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
+ TestSpec.forProcessTableFunction(WrongTypedTableProcessTableFunction.class)
+ .expectErrorMessage(
+ "Invalid data type 'INT' for table argument 'i'. "
+ + "Typed table arguments must use a composite type (i.e. row or structured type)."),
+ // ---
+ TestSpec.forProcessTableFunction(WrongArgumentTraitsProcessTableFunction.class)
+ .expectErrorMessage(
+ "Invalid argument traits for argument 'r'. "
+ + "Trait OPTIONAL_PARTITION_BY requires TABLE_AS_SET."),
+ // ---
+ TestSpec.forProcessTableFunction(
+ MixingStaticAndInputGroupProcessTableFunction.class)
+ .expectErrorMessage(
+ "Process table functions require a non-overloaded, non-vararg, and static signature."),
+ // ---
+ TestSpec.forProcessTableFunction(MultiEvalProcessTableFunction.class)
+ .expectErrorMessage(
+ "Process table functions require a non-overloaded, non-vararg, and static signature."));
}
private static Stream procedureSpecs() {
return Stream.of(
// procedure hint defines everything
TestSpec.forProcedure(FullProcedureHint.class)
- .expectNamedArguments("i", "s")
- .expectTypedArguments(DataTypes.INT(), DataTypes.STRING())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i", "s"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.INT()),
- InputTypeStrategies.explicit(DataTypes.STRING())
- }),
- TypeStrategies.explicit(DataTypes.BOOLEAN())),
- // procedure hint defines everything
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING(), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.BOOLEAN())),
+ // ---
+ // procedure hints define everything
TestSpec.forProcedure(FullProcedureHints.class)
.expectOutputMapping(
InputTypeStrategies.sequence(
@@ -662,6 +733,7 @@ private static Stream procedureSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.BIGINT())),
TypeStrategies.explicit(DataTypes.BIGINT())),
+ // ---
// global output hint with local input overloading
TestSpec.forProcedure(GlobalOutputProcedureHint.class)
.expectOutputMapping(
@@ -672,6 +744,7 @@ private static Stream procedureSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.STRING())),
TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// global and local overloading
TestSpec.forProcedure(SplitFullProcedureHints.class)
.expectOutputMapping(
@@ -682,6 +755,7 @@ private static Stream procedureSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.BIGINT())),
TypeStrategies.explicit(DataTypes.BIGINT())),
+ // ---
// varargs and ANY input group
TestSpec.forProcedure(ComplexProcedureHint.class)
.expectOutputMapping(
@@ -693,6 +767,7 @@ private static Stream procedureSpecs() {
InputTypeStrategies.ANY
}),
TypeStrategies.explicit(DataTypes.BOOLEAN())),
+ // ---
// global input hints and local output hints
TestSpec.forProcedure(GlobalInputProcedureHints.class)
.expectOutputMapping(
@@ -703,28 +778,20 @@ private static Stream procedureSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.BIGINT())),
TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// no arguments
TestSpec.forProcedure(ZeroArgProcedure.class)
- .expectNamedArguments()
- .expectTypedArguments()
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[0], new ArgumentTypeStrategy[0]),
- TypeStrategies.explicit(DataTypes.INT())),
+ .expectEmptyStaticArguments()
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// test primitive arguments extraction
TestSpec.forProcedure(MixedArgProcedure.class)
- .expectNamedArguments("i", "d")
- .expectTypedArguments(
- DataTypes.INT().notNull().bridgedTo(int.class), DataTypes.DOUBLE())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i", "d"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(
- DataTypes.INT().notNull().bridgedTo(int.class)),
- InputTypeStrategies.explicit(DataTypes.DOUBLE())
- }),
- TypeStrategies.explicit(DataTypes.INT())),
+ .expectStaticArgument(
+ StaticArgument.scalar(
+ "i", DataTypes.INT().notNull().bridgedTo(int.class), false))
+ .expectStaticArgument(StaticArgument.scalar("d", DataTypes.DOUBLE(), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// test overloaded arguments extraction
TestSpec.forProcedure(OverloadedProcedure.class)
.expectOutputMapping(
@@ -744,6 +811,7 @@ private static Stream procedureSpecs() {
}),
TypeStrategies.explicit(
DataTypes.BIGINT().notNull().bridgedTo(long.class))),
+ // ---
// test varying arguments extraction
TestSpec.forProcedure(VarArgProcedure.class)
.expectOutputMapping(
@@ -756,6 +824,7 @@ private static Stream procedureSpecs() {
DataTypes.INT().notNull().bridgedTo(int.class))
}),
TypeStrategies.explicit(DataTypes.STRING())),
+ // ---
// test varying arguments extraction with byte
TestSpec.forProcedure(VarArgWithByteProcedure.class)
.expectOutputMapping(
@@ -768,33 +837,25 @@ private static Stream procedureSpecs() {
.bridgedTo(byte.class))
}),
TypeStrategies.explicit(DataTypes.STRING())),
+ // ---
// output hint with input extraction
TestSpec.forProcedure(ExtractWithOutputHintProcedure.class)
- .expectNamedArguments("i")
- .expectTypedArguments(DataTypes.INT())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
- TypeStrategies.explicit(DataTypes.INT())),
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectOutput(TypeStrategies.explicit(DataTypes.INT())),
+ // ---
// output extraction with input hints
TestSpec.forProcedure(ExtractWithInputHintProcedure.class)
- .expectNamedArguments("i", "b")
- .expectTypedArguments(DataTypes.INT(), DataTypes.BOOLEAN())
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"i", "b"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.INT()),
- InputTypeStrategies.explicit(DataTypes.BOOLEAN())
- }),
+ .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false))
+ .expectStaticArgument(
+ StaticArgument.scalar("b", DataTypes.BOOLEAN(), false))
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.DOUBLE().notNull().bridgedTo(double.class))),
+ // ---
// named arguments with overloaded function
// expected no named argument for overloaded function
TestSpec.forProcedure(NamedArgumentsProcedure.class),
+ // ---
// procedure function that takes any input
TestSpec.forProcedure(InputGroupProcedure.class)
.expectOutputMapping(
@@ -802,6 +863,7 @@ private static Stream procedureSpecs() {
new String[] {"o"},
new ArgumentTypeStrategy[] {InputTypeStrategies.ANY}),
TypeStrategies.explicit(DataTypes.STRING())),
+ // ---
// procedure function that takes any input as vararg
TestSpec.forProcedure(VarArgInputGroupProcedure.class)
.expectOutputMapping(
@@ -809,6 +871,7 @@ private static Stream procedureSpecs() {
new String[] {"o"},
new ArgumentTypeStrategy[] {InputTypeStrategies.ANY}),
TypeStrategies.explicit(DataTypes.STRING())),
+ // ---
TestSpec.forProcedure(
"Procedure with implicit overloading order", OrderedProcedure.class)
.expectOutputMapping(
@@ -825,6 +888,7 @@ private static Stream procedureSpecs() {
InputTypeStrategies.explicit(DataTypes.BIGINT())
}),
TypeStrategies.explicit(DataTypes.BIGINT())),
+ // ---
TestSpec.forProcedure(
"Procedure with explicit overloading order by class annotations",
OrderedProcedure2.class)
@@ -836,6 +900,7 @@ private static Stream procedureSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.INT())),
TypeStrategies.explicit(DataTypes.INT())),
+ // ---
TestSpec.forProcedure(
"Procedure with explicit overloading order by method annotations",
OrderedProcedure3.class)
@@ -847,181 +912,141 @@ private static Stream procedureSpecs() {
InputTypeStrategies.sequence(
InputTypeStrategies.explicit(DataTypes.INT())),
TypeStrategies.explicit(DataTypes.INT())),
+ // ---
TestSpec.forProcedure(
"A data type hint on the method is used for enriching (not a function output hint)",
DataTypeHintOnProcedure.class)
- .expectNamedArguments()
- .expectTypedArguments()
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {}, new ArgumentTypeStrategy[] {}),
+ .expectEmptyStaticArguments()
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT()))
.bridgedTo(RowData.class))),
+ // ---
// unsupported output overloading
TestSpec.forProcedure(InvalidSingleOutputProcedureHint.class)
.expectErrorMessage(
"Procedure hints that lead to ambiguous results are not allowed."),
+ // ---
// global and local overloading with unsupported output overloading
TestSpec.forProcedure(InvalidFullOutputProcedureHint.class)
.expectErrorMessage(
"Procedure hints with same input definition but different result types are not allowed."),
+ // ---
// ignore argument names during overloading
TestSpec.forProcedure(InvalidFullOutputProcedureWithArgNamesHint.class)
.expectErrorMessage(
"Procedure hints with same input definition but different result types are not allowed."),
+ // ---
// invalid data type hint
TestSpec.forProcedure(IncompleteProcedureHint.class)
- .expectErrorMessage(
- "Data type hint does neither specify a data type nor input group for use as function argument."),
+ .expectErrorMessage("Data type missing for scalar argument at position 1."),
+ // ---
// mismatch between hints and implementation regarding return type
TestSpec.forProcedure(InvalidMethodProcedure.class)
.expectErrorMessage(
"Considering all hints, the method should comply with the signature:\n"
- + "java.lang.String[] call(_, int[])"),
+ + "java.lang.String[] call(_, int[])\n"
+ + "Pattern: ( [, ]*)"),
+ // ---
// no implementation
TestSpec.forProcedure(MissingMethodProcedure.class)
.expectErrorMessage(
"Could not find a publicly accessible method named 'call'."),
+ // ---
TestSpec.forProcedure(
"Named arguments procedure with argument hint on method",
ArgumentHintOnMethodProcedure.class)
- .expectNamedArguments("f1", "f2")
- .expectTypedArguments(DataTypes.STRING(), DataTypes.INT())
- .expectOptionalArguments(true, true)
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"f1", "f2"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.STRING()),
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
+ .expectStaticArgument(StaticArgument.scalar("f1", DataTypes.STRING(), true))
+ .expectStaticArgument(StaticArgument.scalar("f2", DataTypes.INT(), true))
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.INT().notNull().bridgedTo(int.class))),
+ // ---
TestSpec.forProcedure(
"Named arguments procedure with argument hint on class",
ArgumentHintOnClassProcedure.class)
- .expectNamedArguments("f1", "f2")
- .expectTypedArguments(DataTypes.STRING(), DataTypes.INT())
- .expectOptionalArguments(true, true)
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"f1", "f2"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.STRING()),
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
+ .expectStaticArgument(StaticArgument.scalar("f1", DataTypes.STRING(), true))
+ .expectStaticArgument(StaticArgument.scalar("f2", DataTypes.INT(), true))
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.INT().notNull().bridgedTo(int.class))),
+ // ---
TestSpec.forProcedure(
"Named arguments procedure with argument hint on parameter",
ArgumentHintOnParameterProcedure.class)
- .expectNamedArguments("parameter_f1", "parameter_f2")
- .expectTypedArguments(
- DataTypes.STRING(), DataTypes.INT().bridgedTo(int.class))
- .expectOptionalArguments(true, false)
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"parameter_f1", "parameter_f2"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.STRING()),
- InputTypeStrategies.explicit(
- DataTypes.INT().bridgedTo(int.class))
- }),
+ .expectStaticArgument(
+ StaticArgument.scalar("parameter_f1", DataTypes.STRING(), true))
+ .expectStaticArgument(
+ StaticArgument.scalar(
+ "parameter_f2",
+ DataTypes.INT().bridgedTo(int.class),
+ false))
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.INT().notNull().bridgedTo(int.class))),
+ // ---
TestSpec.forProcedure(
"Named arguments procedure with argument hint on method and parameter",
ArgumentHintOnMethodAndParameterProcedure.class)
- .expectNamedArguments("local_f1", "local_f2")
- .expectTypedArguments(DataTypes.STRING(), DataTypes.INT())
- .expectOptionalArguments(true, true)
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"local_f1", "local_f2"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.STRING()),
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
+ .expectStaticArgument(
+ StaticArgument.scalar("local_f1", DataTypes.STRING(), true))
+ .expectStaticArgument(
+ StaticArgument.scalar("local_f2", DataTypes.INT(), true))
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.INT().notNull().bridgedTo(int.class))),
+ // ---
TestSpec.forProcedure(
"Named arguments procedure with argument hint on class and method",
ArgumentHintOnClassAndMethodProcedure.class)
- .expectNamedArguments("global_f1", "global_f2")
- .expectTypedArguments(DataTypes.STRING(), DataTypes.INT())
- .expectOptionalArguments(false, false)
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"global_f1", "global_f2"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.STRING()),
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
+ .expectStaticArgument(
+ StaticArgument.scalar("global_f1", DataTypes.STRING(), false))
+ .expectStaticArgument(
+ StaticArgument.scalar("global_f2", DataTypes.INT(), false))
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.INT().notNull().bridgedTo(int.class))),
+ // ---
TestSpec.forProcedure(
"Named arguments procedure with argument hint on class and method and parameter",
ArgumentHintOnClassAndMethodAndParameterProcedure.class)
- .expectNamedArguments("global_f1", "global_f2")
- .expectTypedArguments(DataTypes.STRING(), DataTypes.INT())
- .expectOptionalArguments(false, false)
- .expectOutputMapping(
- InputTypeStrategies.sequence(
- new String[] {"global_f1", "global_f2"},
- new ArgumentTypeStrategy[] {
- InputTypeStrategies.explicit(DataTypes.STRING()),
- InputTypeStrategies.explicit(DataTypes.INT())
- }),
+ .expectStaticArgument(
+ StaticArgument.scalar("global_f1", DataTypes.STRING(), false))
+ .expectStaticArgument(
+ StaticArgument.scalar("global_f2", DataTypes.INT(), false))
+ .expectOutput(
TypeStrategies.explicit(
DataTypes.INT().notNull().bridgedTo(int.class))),
+ // ---
TestSpec.forProcedure(
"Named arguments procedure with argument hint type not null but optional",
ArgumentHintNotNullWithOptionalProcedure.class)
.expectErrorMessage(
"Argument at position 1 is optional but its type doesn't accept null value."),
+ // ---
TestSpec.forProcedure(
"Named arguments procedure with argument name conflict",
ArgumentHintNameConflictProcedure.class)
.expectErrorMessage(
"Argument name conflict, there are at least two argument names that are the same."),
+ // ---
TestSpec.forProcedure(
"Named arguments procedure with optional type on primitive type",
ArgumentHintOptionalOnPrimitiveParameterConflictProcedure.class)
.expectErrorMessage(
- "Argument at position 1 is optional but a primitive type doesn't accept null value."));
- }
-
- @ParameterizedTest(name = "{index}: {0}")
- @MethodSource("testData")
- void testArgumentNames(TestSpec testSpec) {
- if (testSpec.expectedArgumentNames != null) {
- assertThat(testSpec.typeInferenceExtraction.get().getNamedArguments())
- .isEqualTo(Optional.of(testSpec.expectedArgumentNames));
- } else if (testSpec.expectedErrorMessage == null) {
- assertThat(testSpec.typeInferenceExtraction.get().getNamedArguments())
- .isEqualTo(Optional.empty());
- }
- }
-
- @ParameterizedTest(name = "{index}: {0}")
- @MethodSource("testData")
- void testArgumentOptionals(TestSpec testSpec) {
- if (testSpec.expectedArgumentOptionals != null) {
- assertThat(testSpec.typeInferenceExtraction.get().getOptionalArguments())
- .isEqualTo(Optional.of(testSpec.expectedArgumentOptionals));
- }
+ "Considering all hints, the method should comply with the signature:\n"
+ + "int[] call(_, java.lang.String, java.lang.Integer)"));
}
@ParameterizedTest(name = "{index}: {0}")
@MethodSource("testData")
- void testArgumentTypes(TestSpec testSpec) {
- if (testSpec.expectedArgumentTypes != null) {
- assertThat(testSpec.typeInferenceExtraction.get().getTypedArguments())
- .isEqualTo(Optional.of(testSpec.expectedArgumentTypes));
- } else if (testSpec.expectedErrorMessage == null) {
- assertThat(testSpec.typeInferenceExtraction.get().getTypedArguments())
- .isEqualTo(Optional.empty());
+ void testStaticArguments(TestSpec testSpec) {
+ if (testSpec.expectedStaticArguments != null) {
+ final Optional> staticArguments =
+ testSpec.typeInferenceExtraction.get().getStaticArguments();
+ assertThat(staticArguments).isPresent();
+ assertThat(staticArguments.get())
+ .containsExactlyElementsOf(testSpec.expectedStaticArguments);
}
}
@@ -1039,16 +1064,14 @@ void testInputTypeStrategy(TestSpec testSpec) {
@ParameterizedTest(name = "{index}: {0}")
@MethodSource("testData")
- void testAccumulatorTypeStrategy(TestSpec testSpec) {
- if (!testSpec.expectedAccumulatorStrategies.isEmpty()) {
- assertThat(
- testSpec.typeInferenceExtraction
- .get()
- .getAccumulatorTypeStrategy()
- .isPresent())
- .isEqualTo(true);
- assertThat(testSpec.typeInferenceExtraction.get().getAccumulatorTypeStrategy().get())
- .isEqualTo(TypeStrategies.mapping(testSpec.expectedAccumulatorStrategies));
+ void testStateTypeStrategies(TestSpec testSpec) {
+ if (!testSpec.expectedStateStrategies.isEmpty()) {
+ assertThat(testSpec.typeInferenceExtraction.get().getStateTypeStrategies())
+ .isNotEmpty();
+ assertThat(testSpec.typeInferenceExtraction.get().getStateTypeStrategies())
+ .isEqualTo(testSpec.expectedStateStrategies);
+ } else if (testSpec.expectedErrorMessage == null) {
+ assertThat(testSpec.typeInferenceExtraction.get().getStateTypeStrategies()).isEmpty();
}
}
@@ -1056,8 +1079,13 @@ void testAccumulatorTypeStrategy(TestSpec testSpec) {
@MethodSource("testData")
void testOutputTypeStrategy(TestSpec testSpec) {
if (!testSpec.expectedOutputStrategies.isEmpty()) {
- assertThat(testSpec.typeInferenceExtraction.get().getOutputTypeStrategy())
- .isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies));
+ if (testSpec.expectedOutputStrategies.size() == 1) {
+ assertThat(testSpec.typeInferenceExtraction.get().getOutputTypeStrategy())
+ .isEqualTo(testSpec.expectedOutputStrategies.values().iterator().next());
+ } else {
+ assertThat(testSpec.typeInferenceExtraction.get().getOutputTypeStrategy())
+ .isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies));
+ }
}
}
@@ -1086,13 +1114,9 @@ static class TestSpec {
final Supplier typeInferenceExtraction;
- @Nullable List expectedArgumentNames;
-
- @Nullable List expectedArgumentOptionals;
-
- @Nullable List expectedArgumentTypes;
+ @Nullable List expectedStaticArguments;
- Map expectedAccumulatorStrategies;
+ LinkedHashMap expectedStateStrategies;
Map expectedOutputStrategies;
@@ -1101,7 +1125,7 @@ static class TestSpec {
private TestSpec(String description, Supplier typeInferenceExtraction) {
this.description = description;
this.typeInferenceExtraction = typeInferenceExtraction;
- this.expectedAccumulatorStrategies = new LinkedHashMap<>();
+ this.expectedStateStrategies = new LinkedHashMap<>();
this.expectedOutputStrategies = new LinkedHashMap<>();
}
@@ -1161,6 +1185,14 @@ static TestSpec forTableAggregateFunction(
new DataTypeFactoryMock(), function));
}
+ static TestSpec forProcessTableFunction(Class extends ProcessTableFunction>> function) {
+ return new TestSpec(
+ function.getSimpleName(),
+ () ->
+ TypeInferenceExtractor.forProcessTableFunction(
+ new DataTypeFactoryMock(), function));
+ }
+
static TestSpec forProcedure(Class extends Procedure> procedure) {
return forProcedure(null, procedure);
}
@@ -1174,24 +1206,36 @@ static TestSpec forProcedure(
new DataTypeFactoryMock(), procedure));
}
- TestSpec expectNamedArguments(String... expectedArgumentNames) {
- this.expectedArgumentNames = Arrays.asList(expectedArgumentNames);
+ TestSpec expectEmptyStaticArguments() {
+ this.expectedStaticArguments = new ArrayList<>();
+ return this;
+ }
+
+ TestSpec expectStaticArgument(StaticArgument argument) {
+ if (this.expectedStaticArguments == null) {
+ this.expectedStaticArguments = new ArrayList<>();
+ }
+ this.expectedStaticArguments.add(argument);
+ return this;
+ }
+
+ TestSpec expectAccumulator(TypeStrategy typeStrategy) {
+ this.expectedStateStrategies.put("acc", StateTypeStrategyWrapper.of(typeStrategy));
return this;
}
- TestSpec expectOptionalArguments(Boolean... expectedArgumentOptionals) {
- this.expectedArgumentOptionals = Arrays.asList(expectedArgumentOptionals);
+ TestSpec expectState(String name, StateTypeStrategy stateTypeStrategy) {
+ this.expectedStateStrategies.put(name, stateTypeStrategy);
return this;
}
- TestSpec expectTypedArguments(DataType... expectedArgumentTypes) {
- this.expectedArgumentTypes = Arrays.asList(expectedArgumentTypes);
+ TestSpec expectState(String name, TypeStrategy typeStrategy) {
+ this.expectedStateStrategies.put(name, StateTypeStrategyWrapper.of(typeStrategy));
return this;
}
- TestSpec expectAccumulatorMapping(
- InputTypeStrategy validator, TypeStrategy accumulatorStrategy) {
- this.expectedAccumulatorStrategies.put(validator, accumulatorStrategy);
+ TestSpec expectState(LinkedHashMap stateTypeStrategy) {
+ this.expectedStateStrategies.putAll(stateTypeStrategy);
return this;
}
@@ -1200,6 +1244,11 @@ TestSpec expectOutputMapping(InputTypeStrategy validator, TypeStrategy outputStr
return this;
}
+ TestSpec expectOutput(TypeStrategy outputStrategy) {
+ this.expectedOutputStrategies.put(InputTypeStrategies.WILDCARD, outputStrategy);
+ return this;
+ }
+
TestSpec expectErrorMessage(String expectedErrorMessage) {
this.expectedErrorMessage = expectedErrorMessage;
return this;
@@ -1411,6 +1460,39 @@ public Row createAccumulator() {
}
}
+ private static class StateHintAggregateFunction extends AggregateFunction {
+ public void accumulate(
+ @StateHint(name = "myAcc") MyState acc, @ArgumentHint(name = "i") Integer i) {}
+
+ @Override
+ public Integer getValue(MyState accumulator) {
+ return null;
+ }
+
+ @Override
+ public MyState createAccumulator() {
+ return new MyState();
+ }
+ }
+
+ @FunctionHint(
+ state = {@StateHint(name = "myAcc", type = @DataTypeHint(bridgedTo = MyState.class))},
+ arguments = {@ArgumentHint(name = "i", type = @DataTypeHint("INT"))})
+ private static class StateHintInFunctionHintAggregateFunction
+ extends AggregateFunction {
+ public void accumulate(Object acc, Integer i) {}
+
+ @Override
+ public Integer getValue(Object accumulator) {
+ return null;
+ }
+
+ @Override
+ public Object createAccumulator() {
+ return new Object();
+ }
+ }
+
@FunctionHint(output = @DataTypeHint("ROW"))
private static class OutputHintTableFunction extends TableFunction {
public void eval(int i) {
@@ -2130,32 +2212,167 @@ public String eval(String f1, Integer... f2) {
}
}
+ private static class StatelessProcessTableFunction extends ProcessTableFunction {
+ public void eval(int i) {}
+ }
+
+ public static class MyState {
+ static final DataType TYPE =
+ DataTypes.STRUCTURED(
+ MyState.class,
+ DataTypes.FIELD("d", DataTypes.DOUBLE().notNull().bridgedTo(double.class)));
+ public double d;
+ }
+
+ public static class MyFirstState {
+ static final DataType TYPE =
+ DataTypes.STRUCTURED(MyFirstState.class, DataTypes.FIELD("d", DataTypes.DOUBLE()));
+ public Double d;
+ }
+
+ public static class MySecondState {
+ static final DataType TYPE =
+ DataTypes.STRUCTURED(MySecondState.class, DataTypes.FIELD("i", DataTypes.INT()));
+ public Integer i;
+ }
+
+ private static class StateProcessTableFunction extends ProcessTableFunction {
+ public void eval(@StateHint MyState s, Integer i) {}
+ }
+
+ private static class NamedStateProcessTableFunction extends ProcessTableFunction {
+ public void eval(
+ @StateHint(name = "myState") MyState s, @ArgumentHint(name = "myArg") Integer i) {}
+ }
+
+ private static class MultiStateProcessTableFunction extends ProcessTableFunction {
+ public void eval(@StateHint MyFirstState s1, @StateHint MySecondState s2, Integer i) {}
+ }
+
+ private static class UntypedTableArgProcessTableFunction extends ProcessTableFunction {
+ public void eval(@ArgumentHint(ArgumentTrait.TABLE_AS_ROW) Row t) {}
+ }
+
+ private static class TypedTableArgProcessTableFunction extends ProcessTableFunction {
+ public static class Customer {
+ static final DataType TYPE =
+ DataTypes.STRUCTURED(
+ Customer.class,
+ DataTypes.FIELD("age", DataTypes.INT()),
+ DataTypes.FIELD("name", DataTypes.STRING()));
+ public String name;
+ public Integer age;
+ }
+
+ public void eval(@ArgumentHint(ArgumentTrait.TABLE_AS_ROW) Customer t) {}
+ }
+
+ @DataTypeHint("ROW")
+ private static class ComplexProcessTableFunction extends ProcessTableFunction {
+ public void eval(
+ Context context,
+ @StateHint(name = "s1") MyFirstState s1,
+ @StateHint(name = "other", type = @DataTypeHint("ROW")) Row s2,
+ @ArgumentHint(
+ value = {
+ ArgumentTrait.TABLE_AS_SET,
+ ArgumentTrait.OPTIONAL_PARTITION_BY
+ },
+ name = "setTable")
+ RowData t1,
+ @ArgumentHint(name = "i") Integer i,
+ @ArgumentHint(
+ value = {ArgumentTrait.TABLE_AS_ROW},
+ name = "rowTable",
+ isOptional = true)
+ Row t2,
+ @ArgumentHint(isOptional = true, name = "s") String s) {}
+ }
+
@FunctionHint(
+ state = {
+ @StateHint(name = "s1", type = @DataTypeHint(bridgedTo = MyFirstState.class)),
+ @StateHint(name = "other", type = @DataTypeHint("ROW"))
+ },
arguments = {
@ArgumentHint(
- value = ArgumentTrait.TABLE_AS_ROW,
- type = @DataTypeHint("ROW"))
- })
- private static class FunctionHintTableArgScalarFunction extends ScalarFunction {
- public String eval(Row table) {
- return "";
- }
+ name = "setTable",
+ value = {ArgumentTrait.TABLE_AS_SET, ArgumentTrait.OPTIONAL_PARTITION_BY},
+ type = @DataTypeHint(bridgedTo = RowData.class)),
+ @ArgumentHint(name = "i", type = @DataTypeHint("INT")),
+ @ArgumentHint(
+ name = "rowTable",
+ value = {ArgumentTrait.TABLE_AS_ROW},
+ isOptional = true),
+ @ArgumentHint(name = "s", isOptional = true, type = @DataTypeHint("STRING"))
+ },
+ output = @DataTypeHint("ROW"))
+ private static class ComplexProcessTableFunctionWithFunctionHint
+ extends ProcessTableFunction {
+
+ public void eval(
+ Context context,
+ MyFirstState arg0,
+ Row arg1,
+ RowData arg2,
+ Integer arg3,
+ Row arg4,
+ String arg5) {}
}
- private static class ArgumentHintTableArgScalarFunction extends ScalarFunction {
- public String eval(
+ private static class WrongStateOrderProcessTableFunction extends ProcessTableFunction {
+
+ public void eval(int i, @StateHint MyFirstState state) {}
+ }
+
+ private static class MissingStateTypeProcessTableFunction
+ extends ProcessTableFunction {
+
+ public void eval(@StateHint Object state) {}
+ }
+
+ private static class EnrichedExtractionStateProcessTableFunction
+ extends ProcessTableFunction {
+
+ public void eval(
+ @StateHint(
+ type =
+ @DataTypeHint(
+ defaultDecimalPrecision = 3,
+ defaultDecimalScale = 2))
+ BigDecimal d) {}
+ }
+
+ private static class WrongTypedTableProcessTableFunction extends ProcessTableFunction {
+ public void eval(@ArgumentHint(ArgumentTrait.TABLE_AS_SET) Integer i) {}
+ }
+
+ private static class WrongArgumentTraitsProcessTableFunction
+ extends ProcessTableFunction {
+ public void eval(
+ @ArgumentHint({ArgumentTrait.TABLE_AS_ROW, ArgumentTrait.OPTIONAL_PARTITION_BY})
+ Row r) {}
+ }
+
+ private static class MixingStaticAndInputGroupProcessTableFunction
+ extends ProcessTableFunction {
+ public void eval(
+ @ArgumentHint(ArgumentTrait.TABLE_AS_ROW) Row r,
+ @DataTypeHint(inputGroup = InputGroup.ANY) Object o) {}
+ }
+
+ private static class InvalidInputGroupTableArgProcessTableFunction
+ extends ProcessTableFunction {
+ public void eval(
@ArgumentHint(
value = ArgumentTrait.TABLE_AS_ROW,
- type = @DataTypeHint("ROW"))
- Row table) {
- return "";
- }
+ type = @DataTypeHint(inputGroup = InputGroup.ANY))
+ Row r) {}
}
- @FunctionHint(state = @StateHint(name = "state", type = @DataTypeHint("INT")))
- private static class StateHintScalarFunction extends ScalarFunction {
- public String eval() {
- return "";
- }
+ private static class MultiEvalProcessTableFunction extends ProcessTableFunction {
+ public void eval(int i) {}
+
+ public void eval(String i) {}
}
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java
index 3d3301d3c48d68..cb5c6778996051 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java
@@ -27,6 +27,7 @@
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.ConstantArgumentCount;
+import org.apache.flink.table.types.inference.StaticArgument;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeInferenceUtil;
import org.apache.flink.table.types.logical.LogicalType;
@@ -47,7 +48,6 @@
import org.apache.calcite.sql.validate.SqlValidatorNamespace;
import java.util.List;
-import java.util.Optional;
import java.util.stream.Collectors;
import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType;
@@ -83,8 +83,7 @@ public TypeInferenceOperandChecker(
this.dataTypeFactory = dataTypeFactory;
this.definition = definition;
this.typeInference = typeInference;
- this.countRange =
- new ArgumentCountRange(typeInference.getInputTypeStrategy().getArgumentCount());
+ this.countRange = new ArgumentCountRange(deriveArgumentCount(typeInference));
}
@Override
@@ -105,20 +104,7 @@ public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFail
@Override
public SqlOperandCountRange getOperandCountRange() {
- if (typeInference.getOptionalArguments().isPresent()
- && typeInference.getOptionalArguments().get().stream()
- .anyMatch(Boolean::booleanValue)) {
- int notOptionalCount =
- (int)
- typeInference.getOptionalArguments().get().stream()
- .filter(optional -> !optional)
- .count();
- ArgumentCount argumentCount =
- ConstantArgumentCount.between(notOptionalCount, countRange.getMax());
- return new ArgumentCountRange(argumentCount);
- } else {
- return countRange;
- }
+ return countRange;
}
@Override
@@ -133,22 +119,21 @@ public Consistency getConsistency() {
@Override
public boolean isOptional(int i) {
- Optional> optionalArguments = typeInference.getOptionalArguments();
- if (optionalArguments.isPresent()) {
- return optionalArguments.get().get(i);
- } else {
+ if (typeInference.getStaticArguments().isEmpty()) {
return false;
}
+ final List staticArgs = typeInference.getStaticArguments().get();
+ return staticArgs.get(i).isOptional();
}
@Override
public boolean isFixedParameters() {
- // This method returns true only if optional arguments are declared and at least one
- // optional argument is present.
+ // This method returns true only if at least one optional argument is present.
// Otherwise, it defaults to false, bypassing the parameter check in Calcite.
- return typeInference.getOptionalArguments().isPresent()
- && typeInference.getOptionalArguments().get().stream()
- .anyMatch(Boolean::booleanValue);
+ return typeInference
+ .getStaticArguments()
+ .map(args -> args.stream().anyMatch(StaticArgument::isOptional))
+ .orElse(false);
}
@Override
@@ -239,4 +224,17 @@ private void updateInferredType(SqlValidator validator, SqlNode node, RelDataTyp
namespace.setType(type);
}
}
+
+ private static ArgumentCount deriveArgumentCount(TypeInference typeInference) {
+ final int staticArgs = typeInference.getStaticArguments().map(List::size).orElse(-1);
+ if (staticArgs == -1) {
+ return typeInference.getInputTypeStrategy().getArgumentCount();
+ }
+ final int optionalArgs =
+ typeInference
+ .getStaticArguments()
+ .map(args -> (int) args.stream().filter(StaticArgument::isOptional).count())
+ .orElse(0);
+ return ConstantArgumentCount.between(staticArgs - optionalArgs, staticArgs);
+ }
}
diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/operations/PlannerCallProcedureOperation.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/operations/PlannerCallProcedureOperation.java
index 66d6a6074e3522..a88a5e808d707f 100644
--- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/operations/PlannerCallProcedureOperation.java
+++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/operations/PlannerCallProcedureOperation.java
@@ -171,12 +171,13 @@ private Object callProcedure(Procedure procedure, Class>[] inputClz, Object[]
methods.stream()
.filter(
method ->
- ExtractionUtils.isInvokable(method, inputClz)
+ ExtractionUtils.isInvokable(false, method, inputClz)
&& method.getReturnType().isArray()
&& isAssignable(
outputType.getConversionClass(),
method.getReturnType().getComponentType(),
- true))
+ true,
+ false))
.collect(Collectors.toList());
if (callMethods.isEmpty()) {
throw new ValidationException(
diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java
index 0a0369bc3dce90..f17259c9c69f7d 100644
--- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java
+++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java
@@ -212,7 +212,7 @@ private static class ProcedureWithNamedArguments implements Procedure {
input = {@DataTypeHint("STRING"), @DataTypeHint("STRING")},
output = @DataTypeHint("INT"),
argumentNames = {"c", "d"})
- public int[] call(ProcedureContext context, String arg3, String arg4) {
+ public java.lang.Integer[] call(ProcedureContext context, String arg3, String arg4) {
return null;
}
}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/userDefinedScalarFunctions.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/userDefinedScalarFunctions.scala
index 6a1fd981620dfa..2ee92e7cda0c72 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/userDefinedScalarFunctions.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/userDefinedScalarFunctions.scala
@@ -108,7 +108,7 @@ class RichFunc1 extends ScalarFunction {
@FunctionHint(
input = Array(new DataTypeHint("INT")),
output = new DataTypeHint(value = "INT", bridgedTo = classOf[JInt]))
- def eval(index: Int): Int = {
+ def eval(index: JInt): JInt = {
index + added
}
diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala
index 72b2678f79a5f6..0556e2c529f0f4 100644
--- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala
+++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala
@@ -382,20 +382,15 @@ object UserDefinedFunctionTestUtils {
TestAddWithOpen.aliveCounter.incrementAndGet()
}
- @FunctionHint(
- input = Array(
- new DataTypeHint(value = "BIGINT", bridgedTo = classOf[JLong]),
- new DataTypeHint(value = "BIGINT", bridgedTo = classOf[JLong])),
- output = new DataTypeHint(value = "BIGINT", bridgedTo = classOf[JLong]))
- def eval(a: Long, b: Long): Long = {
+ def eval(a: JLong, b: JLong): JLong = {
if (!isOpened) {
throw new IllegalStateException("Open method is not called.")
}
a + b
}
- def eval(a: Long, b: Int): Long = {
- eval(a, b.asInstanceOf[Long])
+ def eval(a: JLong, b: JInt): JLong = {
+ eval(a, b.toLong)
}
override def close(): Unit = {
@@ -411,13 +406,7 @@ object UserDefinedFunctionTestUtils {
@SerialVersionUID(1L)
object TestMod extends ScalarFunction {
- @FunctionHint(
- input = Array(
- new DataTypeHint(value = "BIGINT", bridgedTo = classOf[JLong]),
- new DataTypeHint(value = "INT", bridgedTo = classOf[JInt])
- ),
- output = new DataTypeHint(value = "BIGINT", bridgedTo = classOf[JLong]))
- def eval(src: Long, mod: Int): Long = {
+ def eval(src: JLong, mod: JInt): JLong = {
src % mod
}
}