Skip to content

Commit 1b1dc73

Browse files
authored
fix: prevents exception on construction of FunctionConverter with duplicate functions (#564)
* test: introduce test to capture bug #562 This commit introduces a test to capture the bug where name-colliding extension functions with different URNs cause an exception to be thrown on `FunctionConverter` construction. * refactor: rename alm to nameToFnMap in FunctionConverter This is a small change to make the meaning of a variable clearer in the `FunctionConverter.java` implementation. * fix: allow multiple same-named functions in FunctionConverter * test: that order added determines precedence in FunctionConverter * test: simplify tests and add special test for overlapping default ext Added a test for ltrim to ensure there is no issue specifically with the default extension collection functions (which have some special handling in isthmus) and an introduced ltrim function. * fix: address @nielspardon comments on PR #564 This commit: - improves the documentation around the `FunctionConverter.java` class - fixes some minor stylistic comments around the tests * docs: wrap mentions of SqlOperator in @link
1 parent e6aa766 commit 1b1dc73

File tree

4 files changed

+309
-9
lines changed

4 files changed

+309
-9
lines changed

isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import com.google.common.collect.ArrayListMultimap;
44
import com.google.common.collect.ImmutableList;
5-
import com.google.common.collect.ImmutableMap;
5+
import com.google.common.collect.ImmutableListMultimap;
6+
import com.google.common.collect.ListMultimap;
67
import com.google.common.collect.Multimap;
78
import com.google.common.collect.Multimaps;
89
import com.google.common.collect.Streams;
@@ -45,6 +46,27 @@
4546
import org.slf4j.Logger;
4647
import org.slf4j.LoggerFactory;
4748

49+
/**
50+
* Abstract base class for converting between Calcite {@link SqlOperator}s and Substrait function
51+
* invocations.
52+
*
53+
* <p>This class handles bidirectional conversion:
54+
*
55+
* <ul>
56+
* <li><b>Calcite → Substrait:</b> Subclasses implement {@code convert()} methods to convert
57+
* Calcite calls to Substrait function invocations
58+
* <li><b>Substrait → Calcite:</b> {@link #getSqlOperatorFromSubstraitFunc} converts Substrait
59+
* function keys to Calcite {@link SqlOperator}s
60+
* </ul>
61+
*
62+
* <p>When multiple functions with the same name and signature are passed into the constructor, a
63+
* <b>last-wins precedence strategy</b> is used for resolution. The last function in the input list
64+
* takes precedence during Calcite to Substrait conversion.
65+
*
66+
* @param <F> the function type (ScalarFunctionVariant, AggregateFunctionVariant, etc.)
67+
* @param <T> the return type for Calcite→Substrait conversion
68+
* @param <C> the call type being converted
69+
*/
4870
public abstract class FunctionConverter<
4971
F extends SimpleExtension.Function, T, C extends FunctionConverter.GenericCall> {
5072

@@ -57,10 +79,32 @@ public abstract class FunctionConverter<
5779

5880
protected final Multimap<String, SqlOperator> substraitFuncKeyToSqlOperatorMap;
5981

82+
/**
83+
* Creates a FunctionConverter with the given functions.
84+
*
85+
* <p>If there are multiple functions provided with the same name and signature (e.g., from
86+
* different extension URNs), the last one in the list will be given precedence during Calcite to
87+
* Substrait conversion.
88+
*
89+
* @param functions the list of function variants to register
90+
* @param typeFactory the Calcite type factory
91+
*/
6092
public FunctionConverter(List<F> functions, RelDataTypeFactory typeFactory) {
6193
this(functions, Collections.EMPTY_LIST, typeFactory, TypeConverter.DEFAULT);
6294
}
6395

96+
/**
97+
* Creates a FunctionConverter with the given functions and additional signatures.
98+
*
99+
* <p>If there are multiple functions provided with the same name and signature (e.g., from
100+
* different extension URNs), the last one in the list will be given precedence during Calcite to
101+
* Substrait conversion.
102+
*
103+
* @param functions the list of function variants to register
104+
* @param additionalSignatures additional Calcite operator signatures to map
105+
* @param typeFactory the Calcite type factory
106+
* @param typeConverter the type converter to use
107+
*/
64108
public FunctionConverter(
65109
List<F> functions,
66110
List<FunctionMappings.Sig> additionalSignatures,
@@ -75,9 +119,9 @@ public FunctionConverter(
75119
this.typeFactory = typeFactory;
76120
this.substraitFuncKeyToSqlOperatorMap = ArrayListMultimap.create();
77121

78-
ArrayListMultimap<String, F> alm = ArrayListMultimap.<String, F>create();
122+
ArrayListMultimap<String, F> nameToFn = ArrayListMultimap.<String, F>create();
79123
for (F f : functions) {
80-
alm.put(f.name().toLowerCase(Locale.ROOT), f);
124+
nameToFn.put(f.name().toLowerCase(Locale.ROOT), f);
81125
}
82126

83127
Multimap<String, FunctionMappings.Sig> calciteOperators =
@@ -87,21 +131,21 @@ public FunctionConverter(
87131
FunctionMappings.Sig::name, Function.identity(), ArrayListMultimap::create));
88132
IdentityHashMap<SqlOperator, FunctionFinder> matcherMap =
89133
new IdentityHashMap<SqlOperator, FunctionFinder>();
90-
for (String key : alm.keySet()) {
134+
for (String key : nameToFn.keySet()) {
91135
Collection<Sig> sigs = calciteOperators.get(key);
92136
if (sigs.isEmpty()) {
93137
LOGGER.atDebug().log("No binding for function: {}", key);
94138
}
95139

96140
for (Sig sig : sigs) {
97-
List<F> implList = alm.get(key);
141+
List<F> implList = nameToFn.get(key);
98142
if (!implList.isEmpty()) {
99143
matcherMap.put(sig.operator(), new FunctionFinder(key, sig.operator(), implList));
100144
}
101145
}
102146
}
103147

104-
for (Entry<String, F> entry : alm.entries()) {
148+
for (Entry<String, F> entry : nameToFn.entries()) {
105149
String key = entry.getKey();
106150
F func = entry.getValue();
107151
for (FunctionMappings.Sig sig : calciteOperators.get(key)) {
@@ -112,6 +156,17 @@ public FunctionConverter(
112156
this.signatures = matcherMap;
113157
}
114158

159+
/**
160+
* Converts a Substrait function to a Calcite {@link SqlOperator} (Substrait → Calcite direction).
161+
*
162+
* <p>Given a Substrait function key (e.g., "concat:str_str") and output type, this method finds
163+
* the corresponding Calcite {@link SqlOperator}. When multiple operators match, the output type
164+
* is used to disambiguate.
165+
*
166+
* @param key the Substrait function key (function name with type signature)
167+
* @param outputType the expected output type
168+
* @return the matching {@link SqlOperator}, or empty if no match found
169+
*/
115170
public Optional<SqlOperator> getSqlOperatorFromSubstraitFunc(String key, Type outputType) {
116171
Map<SqlOperator, TypeBasedResolver> resolver = getTypeBasedResolver();
117172
Collection<SqlOperator> operators = substraitFuncKeyToSqlOperatorMap.get(key);
@@ -155,7 +210,7 @@ protected class FunctionFinder {
155210
private final String substraitName;
156211
private final SqlOperator operator;
157212
private final List<F> functions;
158-
private final Map<String, F> directMap;
213+
private final ListMultimap<String, F> directMap;
159214
private final Optional<SingularArgumentMatcher<F>> singularInputType;
160215
private final Util.IntRange argRange;
161216

@@ -168,7 +223,7 @@ public FunctionFinder(String substraitName, SqlOperator operator, List<F> functi
168223
functions.stream().mapToInt(t -> t.getRange().getStartInclusive()).min().getAsInt(),
169224
functions.stream().mapToInt(t -> t.getRange().getEndExclusive()).max().getAsInt());
170225
this.singularInputType = getSingularInputType(functions);
171-
ImmutableMap.Builder<String, F> directMap = ImmutableMap.builder();
226+
ImmutableListMultimap.Builder<String, F> directMap = ImmutableListMultimap.builder();
172227
for (F func : functions) {
173228
String key = func.key();
174229
directMap.put(key, func);
@@ -342,13 +397,29 @@ private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes
342397
}
343398
}
344399

400+
/**
401+
* Converts a Calcite call to a Substrait function invocation (Calcite → Substrait direction).
402+
*
403+
* <p>This method tries to find a matching Substrait function for the given Calcite call using
404+
* direct signature matching, type coercion, and least-restrictive type resolution.
405+
*
406+
* <p>If multiple registered function extensions have the same name and signature, the last one
407+
* in the list passed into the constructor will be matched.
408+
*
409+
* @param call the Calcite call to match
410+
* @param topLevelConverter function to convert RexNode operands to Substrait Expressions
411+
* @return the matched Substrait function binding, or empty if no match found
412+
*/
345413
public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelConverter) {
346414

347415
/*
348416
* Here the RexLiteral with an Enum value is mapped to String Literal.
349417
* Not enough context here to construct a substrait EnumArg.
350418
* Once a FunctionVariant is resolved we can map the String Literal
351419
* to a EnumArg.
420+
*
421+
* Note that if there are multiple registered function extensions which can match a particular Call,
422+
* the last one added to the extension collection will be matched.
352423
*/
353424
List<RexNode> operandsList = call.getOperands().collect(Collectors.toList());
354425
List<Expression> operands =
@@ -369,7 +440,13 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
369440
.findFirst();
370441

371442
if (directMatchKey.isPresent()) {
372-
F variant = directMap.get(directMatchKey.get());
443+
List<F> variants = directMap.get(directMatchKey.get());
444+
if (variants.isEmpty()) {
445+
446+
return Optional.empty();
447+
}
448+
449+
F variant = variants.get(variants.size() - 1);
373450
variant.validateOutputType(operands, outputType);
374451
List<FunctionArg> funcArgs =
375452
IntStream.range(0, operandsList.size())
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package io.substrait.isthmus;
2+
3+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
4+
import static org.junit.jupiter.api.Assertions.assertEquals;
5+
6+
import io.substrait.expression.Expression;
7+
import io.substrait.extension.SimpleExtension;
8+
import io.substrait.isthmus.expression.AggregateFunctionConverter;
9+
import io.substrait.isthmus.expression.ScalarFunctionConverter;
10+
import io.substrait.isthmus.expression.WindowFunctionConverter;
11+
import java.io.IOException;
12+
import java.io.UncheckedIOException;
13+
import java.util.List;
14+
import java.util.Optional;
15+
import org.apache.calcite.rex.RexBuilder;
16+
import org.apache.calcite.rex.RexCall;
17+
import org.apache.calcite.rex.RexNode;
18+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
19+
import org.apache.calcite.sql.fun.SqlTrimFunction.Flag;
20+
import org.junit.jupiter.api.Test;
21+
22+
/** Tests to reproduce #562 */
23+
class DuplicateFunctionUrnTest extends PlanTestBase {
24+
25+
static final SimpleExtension.ExtensionCollection collection1;
26+
static final SimpleExtension.ExtensionCollection collection2;
27+
static final SimpleExtension.ExtensionCollection collection;
28+
29+
static {
30+
try {
31+
String extensions1 = asString("extensions/functions_duplicate_urn1.yaml");
32+
String extensions2 = asString("extensions/functions_duplicate_urn2.yaml");
33+
collection1 =
34+
SimpleExtension.load("urn:extension:io.substrait:functions_string", extensions1);
35+
collection2 = SimpleExtension.load("urn:extension:com.domain:string", extensions2);
36+
collection = collection1.merge(collection2);
37+
38+
// Verify that the merged collection contains duplicate concat functions with different URNs
39+
// This is a precondition for the tests - if this fails, the tests don't make sense
40+
List<SimpleExtension.ScalarFunctionVariant> concatFunctions =
41+
collection.scalarFunctions().stream().filter(f -> f.name().equals("concat")).toList();
42+
43+
if (concatFunctions.size() != 2) {
44+
throw new IllegalStateException(
45+
"Expected 2 concat functions in merged collection, but found: "
46+
+ concatFunctions.size());
47+
}
48+
49+
String urn1 = concatFunctions.get(0).getAnchor().urn();
50+
String urn2 = concatFunctions.get(1).getAnchor().urn();
51+
if (urn1.equals(urn2)) {
52+
throw new IllegalStateException(
53+
"Expected different URNs for the two concat functions, but both were: " + urn1);
54+
}
55+
} catch (IOException e) {
56+
throw new UncheckedIOException(e);
57+
}
58+
}
59+
60+
@Test
61+
void testDuplicateFunctionWithDifferentUrns() {
62+
assertDoesNotThrow(
63+
() -> new ScalarFunctionConverter(collection.scalarFunctions(), typeFactory));
64+
}
65+
66+
@Test
67+
void testDuplicateAggregateFunctionWithDifferentUrns() {
68+
assertDoesNotThrow(
69+
() -> new AggregateFunctionConverter(collection.aggregateFunctions(), typeFactory));
70+
}
71+
72+
@Test
73+
void testDuplicateWindowFunctionWithDifferentUrns() {
74+
assertDoesNotThrow(
75+
() -> new WindowFunctionConverter(collection.windowFunctions(), typeFactory));
76+
}
77+
78+
@Test
79+
void testMergeOrderDeterminesFunctionPrecedence() {
80+
// This test verifies that when multiple extension collections contain functions with
81+
// the same name and signature but different URNs, the merge order determines precedence.
82+
// The FunctionConverter uses a "last-wins" strategy: the last function added to the
83+
// extension collection will be matched when converting from Calcite to Substrait.
84+
85+
SimpleExtension.ExtensionCollection reverseCollection = collection2.merge(collection1);
86+
ScalarFunctionConverter converterA =
87+
new ScalarFunctionConverter(collection.scalarFunctions(), typeFactory);
88+
ScalarFunctionConverter converterB =
89+
new ScalarFunctionConverter(reverseCollection.scalarFunctions(), typeFactory);
90+
91+
RexBuilder rexBuilder = new RexBuilder(typeFactory);
92+
RexCall concatCall =
93+
(RexCall)
94+
rexBuilder.makeCall(
95+
SqlStdOperatorTable.CONCAT,
96+
rexBuilder.makeLiteral("hello"),
97+
rexBuilder.makeLiteral("world"));
98+
99+
// Create a simple topLevelConverter that converts literals to Substrait expressions
100+
java.util.function.Function<RexNode, Expression> topLevelConverter =
101+
rexNode -> {
102+
org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode;
103+
return Expression.StrLiteral.builder()
104+
.value(lit.getValueAs(String.class))
105+
.nullable(false)
106+
.build();
107+
};
108+
109+
Optional<Expression> exprA = converterA.convert(concatCall, topLevelConverter);
110+
Optional<Expression> exprB = converterB.convert(concatCall, topLevelConverter);
111+
112+
Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get();
113+
Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get();
114+
115+
assertEquals(
116+
"extension:com.domain:string",
117+
funcA.declaration().getAnchor().urn(),
118+
"converterA should use last concat function (from collection2)");
119+
120+
assertEquals(
121+
"extension:io.substrait:functions_string",
122+
funcB.declaration().getAnchor().urn(),
123+
"converterB should use last concat function (from collection1)");
124+
}
125+
126+
@Test
127+
void testLtrimMergeOrderWithDefaultExtensions() {
128+
// This test verifies precedence between a custom ltrim (from collection2 with
129+
// extension:com.domain:string) and the default extension catalog's ltrim
130+
// (extension:io.substrait:functions_string).
131+
// The FunctionConverter uses a "last-wins" strategy.
132+
133+
// Merge default extensions with collection2 - collection2's ltrim should be last
134+
SimpleExtension.ExtensionCollection defaultWithCustom = extensions.merge(collection2);
135+
136+
// Merge collection2 with default extensions - default ltrim should be last
137+
SimpleExtension.ExtensionCollection customWithDefault = collection2.merge(extensions);
138+
139+
ScalarFunctionConverter converterA =
140+
new ScalarFunctionConverter(defaultWithCustom.scalarFunctions(), typeFactory);
141+
ScalarFunctionConverter converterB =
142+
new ScalarFunctionConverter(customWithDefault.scalarFunctions(), typeFactory);
143+
144+
// Create a TRIM(LEADING ' ' FROM 'test') call which uses TrimFunctionMapper to map to ltrim
145+
RexBuilder rexBuilder = new RexBuilder(typeFactory);
146+
RexCall trimCall =
147+
(RexCall)
148+
rexBuilder.makeCall(
149+
SqlStdOperatorTable.TRIM,
150+
rexBuilder.makeFlag(Flag.LEADING),
151+
rexBuilder.makeLiteral(" "),
152+
rexBuilder.makeLiteral("test"));
153+
154+
java.util.function.Function<RexNode, Expression> topLevelConverter =
155+
rexNode -> {
156+
org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode;
157+
Object value = lit.getValue();
158+
if (value == null) {
159+
return Expression.StrLiteral.builder().value("").nullable(true).build();
160+
}
161+
// Convert any literal value to string
162+
return Expression.StrLiteral.builder().value(value.toString()).nullable(false).build();
163+
};
164+
165+
Optional<Expression> exprA = converterA.convert(trimCall, topLevelConverter);
166+
Optional<Expression> exprB = converterB.convert(trimCall, topLevelConverter);
167+
168+
Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get();
169+
// converterA should use collection2's custom ltrim (last)
170+
assertEquals(
171+
"extension:com.domain:string",
172+
funcA.declaration().getAnchor().urn(),
173+
"converterA should use last ltrim (custom from collection2)");
174+
175+
Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get();
176+
// converterB should use default extensions' ltrim (last)
177+
assertEquals(
178+
"extension:io.substrait:functions_string",
179+
funcB.declaration().getAnchor().urn(),
180+
"converterB should use last ltrim (from default extensions)");
181+
}
182+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
%YAML 1.2
2+
---
3+
urn: extension:io.substrait:functions_string
4+
5+
scalar_functions:
6+
- name: "concat"
7+
description: "concatenate strings"
8+
impls:
9+
- args:
10+
- name: str1
11+
value: string
12+
- name: str2
13+
value: string
14+
variadic:
15+
min: 0
16+
return: string

0 commit comments

Comments
 (0)