Skip to content

Commit

Permalink
[FLINK-37076][table-planner] Support PTFs until ExecNode level
Browse files Browse the repository at this point in the history
  • Loading branch information
twalthr committed Jan 20, 2025
1 parent 4a5238c commit 41a33df
Show file tree
Hide file tree
Showing 22 changed files with 1,396 additions and 274 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.functions.ProcessTableFunction;
import org.apache.flink.table.types.inference.StaticArgumentTrait;
import org.apache.flink.types.RowKind;

/**
* Declares traits for {@link ArgumentHint}. They enable basic validation by the framework.
Expand Down Expand Up @@ -78,6 +79,8 @@ public enum ArgumentTrait {
/**
* Defines that a PARTITION BY clause is optional for {@link #TABLE_AS_SET}. By default, it is
* mandatory for improving the parallel execution by distributing the table by key.
*
* <p>Note: This trait is only valid for {@link #TABLE_AS_SET} arguments.
*/
OPTIONAL_PARTITION_BY(false, StaticArgumentTrait.OPTIONAL_PARTITION_BY),

Expand All @@ -97,8 +100,34 @@ public enum ArgumentTrait {
*
* <p>In case of multiple table arguments, pass-through columns are added according to the
* declaration order in the PTF signature.
*
* <p>Note: This trait is valid for {@link #TABLE_AS_ROW} and {@link #TABLE_AS_SET} arguments.
*/
PASS_COLUMNS_THROUGH(false, StaticArgumentTrait.PASS_COLUMNS_THROUGH),

/**
* Defines that updates are allowed as input to the given table argument. By default, a table
* argument is insert-only and updates will be rejected.
*
* <p>Input tables become updating when sub queries such as aggregations or outer joins force an
* incremental computation. For example, the following query only works if the function is able
* to digest retraction messages:
*
* <pre>
* // Changes +[1] followed by -U[1], +U[2], -U[2], +U[3] will enter the function
* WITH UpdatingTable AS (
* SELECT COUNT(*) FROM (VALUES 1, 2, 3)
* )
* SELECT * FROM f(tableArg => TABLE UpdatingTable)
* </pre>
*
* <p>If updates should be supported, ensure that the data type of the table argument is chosen
* in a way that it can encode changes. In other words: choose a row type that exposes the
* {@link RowKind} change flag.
*
* <p>Note: This trait is valid for {@link #TABLE_AS_ROW} and {@link #TABLE_AS_SET} arguments.
*/
PASS_COLUMNS_THROUGH(false, StaticArgumentTrait.PASS_COLUMNS_THROUGH);
SUPPORT_UPDATES(false, StaticArgumentTrait.SUPPORT_UPDATES);

private final boolean isRoot;
private final StaticArgumentTrait staticTrait;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.NullType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.StructuredType;
Expand Down Expand Up @@ -273,29 +274,46 @@ private void checkOptionalType() {
}
}

void checkTableType() {
private void checkTableType() {
if (!traits.contains(StaticArgumentTrait.TABLE)) {
return;
}
if (dataType == null
&& conversionClass != null
&& !DUMMY_ROW_TYPE.supportsInputConversion(conversionClass)) {
checkPolymorphicTableType();
checkTypedTableType();
}

private void checkPolymorphicTableType() {
if (dataType != null || conversionClass == null) {
return;
}
if (!DUMMY_ROW_TYPE.supportsInputConversion(conversionClass)) {
throw new ValidationException(
String.format(
"Invalid conversion class '%s' for argument '%s'. "
+ "Polymorphic, untyped table arguments must use a row class.",
conversionClass.getName(), name));
}
if (dataType != null) {
final LogicalType type = dataType.getLogicalType();
if (traits.contains(StaticArgumentTrait.TABLE)
&& !LogicalTypeChecks.isCompositeType(type)) {
throw new ValidationException(
String.format(
"Invalid data type '%s' for table argument '%s'. "
+ "Typed table arguments must use a composite type (i.e. row or structured type).",
type, name));
}
}

private void checkTypedTableType() {
if (dataType == null) {
return;
}
final LogicalType type = dataType.getLogicalType();
if (traits.contains(StaticArgumentTrait.TABLE)
&& !LogicalTypeChecks.isCompositeType(type)) {
throw new ValidationException(
String.format(
"Invalid data type '%s' for table argument '%s'. "
+ "Typed table arguments must use a composite type (i.e. row or structured type).",
type, name));
}
if (is(StaticArgumentTrait.SUPPORT_UPDATES) && !type.is(LogicalTypeRoot.ROW)) {
throw new ValidationException(
String.format(
"Invalid data type '%s' for table argument '%s'. "
+ "Table arguments that support updates must use a row type.",
type, name));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ public enum StaticArgumentTrait {
MODEL(),
TABLE_AS_ROW(TABLE),
TABLE_AS_SET(TABLE),
OPTIONAL_PARTITION_BY(TABLE_AS_SET),
PASS_COLUMNS_THROUGH(TABLE);
PASS_COLUMNS_THROUGH(TABLE),
SUPPORT_UPDATES(TABLE),
OPTIONAL_PARTITION_BY(TABLE_AS_SET);

private final Set<StaticArgumentTrait> requirements;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ public static TypeInference of(FunctionKind functionKind, TypeInference origin)
return builder.build();
}

public static boolean isValidUidForProcessTableFunction(String uid) {
return UID_FORMAT.test(uid);
}

// --------------------------------------------------------------------------------------------

private static void checkScalarArgsOnly(List<StaticArgument> defaultArgs) {
Expand Down Expand Up @@ -283,7 +287,7 @@ public Optional<List<DataType>> inferInputTypes(
+ "that is not overloaded and doesn't contain varargs.");
}

checkUidColumn(callContext);
checkUidArg(callContext);
checkMultipleTableArgs(callContext);
checkTableArgTraits(staticArgs, callContext);

Expand All @@ -297,16 +301,16 @@ public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
return origin.getExpectedSignatures(definition);
}

private static void checkUidColumn(CallContext callContext) {
private static void checkUidArg(CallContext callContext) {
final List<DataType> args = callContext.getArgumentDataTypes();

// Verify the uid format if provided
int uidPos = args.size() - 1;
if (!callContext.isArgumentNull(uidPos)) {
final String uid = callContext.getArgumentValue(uidPos, String.class).orElse("");
if (!UID_FORMAT.test(uid)) {
if (!isValidUidForProcessTableFunction(uid)) {
throw new ValidationException(
"Invalid unique identifier for process table function. The 'uid' argument "
"Invalid unique identifier for process table function. The `uid` argument "
+ "must be a string literal that follows the pattern [a-zA-Z_][a-zA-Z-_0-9]*. "
+ "But found: "
+ uid);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -96,4 +97,29 @@ public RexCall clone(RelDataType type, List<RexNode> operands) {
public RexTableArgCall copy(RelDataType type, int[] partitionKeys, int[] orderKeys) {
return new RexTableArgCall(type, inputIndex, partitionKeys, orderKeys);
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
final RexTableArgCall that = (RexTableArgCall) o;
return inputIndex == that.inputIndex
&& Arrays.equals(partitionKeys, that.partitionKeys)
&& Arrays.equals(orderKeys, that.orderKeys);
}

@Override
public int hashCode() {
int result = Objects.hash(super.hashCode(), inputIndex);
result = 31 * result + Arrays.hashCode(partitionKeys);
result = 31 * result + Arrays.hashCode(orderKeys);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@
package org.apache.flink.table.planner.plan.nodes.exec;

import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.planner.plan.nodes.common.CommonIntermediateTableScan;
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecProcessTableFunction;
import org.apache.flink.table.planner.plan.nodes.physical.FlinkPhysicalRel;

import org.apache.calcite.rel.RelNode;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* A generator that generates a {@link ExecNode} graph from a graph of {@link FlinkPhysicalRel}s.
Expand All @@ -43,9 +47,11 @@
public class ExecNodeGraphGenerator {

private final Map<FlinkPhysicalRel, ExecNode<?>> visitedRels;
private final Set<String> visitedProcessTableFunctionUids;

public ExecNodeGraphGenerator() {
this.visitedRels = new IdentityHashMap<>();
this.visitedProcessTableFunctionUids = new HashSet<>();
}

public ExecNodeGraph generate(List<FlinkPhysicalRel> relNodes, boolean isCompiled) {
Expand Down Expand Up @@ -78,8 +84,25 @@ private ExecNode<?> generate(FlinkPhysicalRel rel, boolean isCompiled) {
inputEdges.add(ExecEdge.builder().source(inputNode).target(execNode).build());
}
execNode.setInputEdges(inputEdges);

checkUidForProcessTableFunction(execNode);
visitedRels.put(rel, execNode);
return execNode;
}

private void checkUidForProcessTableFunction(ExecNode<?> execNode) {
if (!(execNode instanceof StreamExecProcessTableFunction)) {
return;
}
final String uid = ((StreamExecProcessTableFunction) execNode).getUid();
if (visitedProcessTableFunctionUids.contains(uid)) {
throw new ValidationException(
String.format(
"Duplicate unique identifier '%s' detected among process table functions. "
+ "Make sure that all PTF calls have an identifier defined that is globally unique. "
+ "Please provide a custom identifier using the implicit `uid` argument. "
+ "For example: myFunction(..., uid => 'my-id')",
uid));
}
visitedProcessTableFunctionUids.add(uid);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.flink.table.functions.FunctionIdentifier;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.functions.UserDefinedFunctionHelper;
import org.apache.flink.table.planner.calcite.RexTableArgCall;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
import org.apache.flink.table.planner.functions.sql.BuiltInSqlOperator;
Expand Down Expand Up @@ -93,6 +94,8 @@
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.FIELD_NAME_NAME;
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.FIELD_NAME_NULL_AS;
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.FIELD_NAME_OPERANDS;
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.FIELD_NAME_ORDER_KEYS;
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.FIELD_NAME_PARTITION_KEYS;
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.FIELD_NAME_RANGES;
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.FIELD_NAME_SARG;
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.FIELD_NAME_SQL_KIND;
Expand All @@ -107,6 +110,7 @@
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.KIND_INPUT_REF;
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.KIND_LITERAL;
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.KIND_PATTERN_INPUT_REF;
import static org.apache.flink.table.planner.plan.nodes.exec.serde.RexNodeJsonSerializer.KIND_TABLE_ARG_CALL;
import static org.apache.flink.table.planner.typeutils.SymbolUtil.serializableToCalcite;

/**
Expand Down Expand Up @@ -144,6 +148,8 @@ private static RexNode deserialize(JsonNode jsonNode, SerdeContext serdeContext)
return deserializeCorrelVariable(jsonNode, serdeContext);
case KIND_PATTERN_INPUT_REF:
return deserializePatternFieldRef(jsonNode, serdeContext);
case KIND_TABLE_ARG_CALL:
return deserializeTableArgCall(jsonNode, serdeContext);
case KIND_CALL:
return deserializeCall(jsonNode, serdeContext);
default:
Expand Down Expand Up @@ -313,6 +319,28 @@ private static RexNode deserializePatternFieldRef(
return serdeContext.getRexBuilder().makePatternFieldRef(alpha, fieldType, inputIndex);
}

private static RexNode deserializeTableArgCall(JsonNode jsonNode, SerdeContext serdeContext) {
final JsonNode logicalTypeNode = jsonNode.required(FIELD_NAME_TYPE);
final RelDataType callType =
RelDataTypeJsonDeserializer.deserialize(logicalTypeNode, serdeContext);

final int inputIndex = jsonNode.required(FIELD_NAME_INPUT_INDEX).intValue();

final JsonNode partitionKeysNode = jsonNode.required(FIELD_NAME_PARTITION_KEYS);
final int[] partitionKeys = new int[partitionKeysNode.size()];
for (int i = 0; i < partitionKeysNode.size(); ++i) {
partitionKeys[i] = partitionKeysNode.get(i).asInt();
}

final JsonNode orderKeysNode = jsonNode.required(FIELD_NAME_ORDER_KEYS);
final int[] orderKeys = new int[orderKeysNode.size()];
for (int i = 0; i < orderKeysNode.size(); ++i) {
orderKeys[i] = orderKeysNode.get(i).asInt();
}

return new RexTableArgCall(callType, inputIndex, partitionKeys, orderKeys);
}

private static RexNode deserializeCall(JsonNode jsonNode, SerdeContext serdeContext)
throws IOException {
final SqlOperator operator = deserializeSqlOperator(jsonNode, serdeContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.flink.table.functions.TableAggregateFunctionDefinition;
import org.apache.flink.table.functions.TableFunctionDefinition;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.planner.calcite.RexTableArgCall;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
import org.apache.flink.table.planner.functions.sql.BuiltInSqlOperator;
Expand Down Expand Up @@ -81,10 +82,10 @@ final class RexNodeJsonSerializer extends StdSerializer<RexNode> {
static final String FIELD_NAME_VALUE = "value";
static final String FIELD_NAME_TYPE = "type";
static final String FIELD_NAME_NAME = "name";
static final String FIELD_NAME_INPUT_INDEX = "inputIndex";

// INPUT_REF
static final String KIND_INPUT_REF = "INPUT_REF";
static final String FIELD_NAME_INPUT_INDEX = "inputIndex";

// LITERAL
static final String KIND_LITERAL = "LITERAL";
Expand Down Expand Up @@ -122,6 +123,11 @@ final class RexNodeJsonSerializer extends StdSerializer<RexNode> {
static final String FIELD_NAME_SQL_KIND = "sqlKind";
static final String FIELD_NAME_CLASS = "class";

// TABLE_ARG_CALL
static final String KIND_TABLE_ARG_CALL = "TABLE_ARG_CALL";
static final String FIELD_NAME_PARTITION_KEYS = "partitionKeys";
static final String FIELD_NAME_ORDER_KEYS = "orderKeys";

RexNodeJsonSerializer() {
super(RexNode.class);
}
Expand Down Expand Up @@ -154,7 +160,10 @@ public void serialize(
(RexPatternFieldRef) rexNode, jsonGenerator, serializerProvider);
break;
default:
if (rexNode instanceof RexCall) {
if (rexNode instanceof RexTableArgCall) {
serializeTableArgCall(
(RexTableArgCall) rexNode, jsonGenerator, serializerProvider);
} else if (rexNode instanceof RexCall) {
serializeCall(
(RexCall) rexNode,
jsonGenerator,
Expand Down Expand Up @@ -323,6 +332,20 @@ private static void serializeCorrelVariable(
gen.writeEndObject();
}

private static void serializeTableArgCall(
RexTableArgCall tableArgCall, JsonGenerator gen, SerializerProvider serializerProvider)
throws IOException {
gen.writeStartObject();
gen.writeStringField(FIELD_NAME_KIND, KIND_TABLE_ARG_CALL);
gen.writeNumberField(FIELD_NAME_INPUT_INDEX, tableArgCall.getInputIndex());
gen.writeFieldName(FIELD_NAME_PARTITION_KEYS);
gen.writeArray(tableArgCall.getPartitionKeys(), 0, tableArgCall.getPartitionKeys().length);
gen.writeFieldName(FIELD_NAME_ORDER_KEYS);
gen.writeArray(tableArgCall.getOrderKeys(), 0, tableArgCall.getOrderKeys().length);
serializerProvider.defaultSerializeField(FIELD_NAME_TYPE, tableArgCall.getType(), gen);
gen.writeEndObject();
}

private static void serializeCall(
RexCall call,
JsonGenerator gen,
Expand Down
Loading

0 comments on commit 41a33df

Please sign in to comment.