Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2854,9 +2854,39 @@ public RelNode visitReplace(Replace node, CalcitePlanContext context) {
for (ReplacePair pair : node.getReplacePairs()) {
RexNode patternNode = rexVisitor.analyze(pair.getPattern(), context);
RexNode replacementNode = rexVisitor.analyze(pair.getReplacement(), context);
fieldRef =
context.relBuilder.call(
SqlStdOperatorTable.REPLACE, fieldRef, patternNode, replacementNode);

String patternStr = pair.getPattern().getValue().toString();
String replacementStr = pair.getReplacement().getValue().toString();

if (patternStr.contains("*")) {
WildcardUtils.validateWildcardSymmetry(patternStr, replacementStr);

String regexPattern = WildcardUtils.convertWildcardPatternToRegex(patternStr);
String regexReplacement =
WildcardUtils.convertWildcardReplacementToRegex(replacementStr);

RexNode regexPatternNode =
context.rexBuilder.makeLiteral(
regexPattern,
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR),
true);
RexNode regexReplacementNode =
context.rexBuilder.makeLiteral(
regexReplacement,
context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR),
true);

fieldRef =
context.rexBuilder.makeCall(
org.apache.calcite.sql.fun.SqlLibraryOperators.REGEXP_REPLACE_3,
fieldRef,
regexPatternNode,
regexReplacementNode);
} else {
fieldRef =
context.relBuilder.call(
SqlStdOperatorTable.REPLACE, fieldRef, patternNode, replacementNode);
}
}

projectList.add(fieldRef);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.calcite.utils;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -92,4 +93,141 @@ private static boolean matchesCompiledPattern(String[] parts, String fieldName)
public static boolean containsWildcard(String str) {
return str != null && str.contains(WILDCARD);
}

/**
* Converts a wildcard pattern to a regex pattern.
*
* <p>Example: "*ada" → "^(.*?)ada$"
*
* @param wildcardPattern wildcard pattern with '*' and escape sequences (\*, \\)
* @return regex pattern with capture groups
*/
public static String convertWildcardPatternToRegex(String wildcardPattern) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add comprehensive unit tests for the methods in WildcardUtils.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added unit tests in WildcardUtilsTest

String[] parts = splitWildcards(wildcardPattern);
StringBuilder regexBuilder = new StringBuilder("^");

for (int i = 0; i < parts.length; i++) {
regexBuilder.append(java.util.regex.Pattern.quote(parts[i]));
if (i < parts.length - 1) {
regexBuilder.append("(.*?)"); // Non-greedy capture group for wildcard
}
}
regexBuilder.append("$");

return regexBuilder.toString();
}

/**
* Converts a wildcard replacement string to a regex replacement string.
*
* <p>Example: "*_*" → "$1_$2"
*
* @param wildcardReplacement replacement string with '*' and escape sequences (\*, \\)
* @return regex replacement string with capture group references
*/
public static String convertWildcardReplacementToRegex(String wildcardReplacement) {
if (!wildcardReplacement.contains("*")) {
return wildcardReplacement; // No wildcards = literal replacement
}

StringBuilder result = new StringBuilder();
int captureIndex = 1; // Regex capture groups start at $1
boolean escaped = false;

for (char c : wildcardReplacement.toCharArray()) {
if (escaped) {
// Handle escape sequences: \* or \\
result.append(c);
escaped = false;
} else if (c == '\\') {
escaped = true;
} else if (c == '*') {
// Replace wildcard with $1, $2, etc.
result.append('$').append(captureIndex++);
} else {
result.append(c);
}
}

return result.toString();
}

/**
* Splits a wildcard pattern into parts separated by unescaped wildcards.
*
* <p>Example: "a*b*c" → ["a", "b", "c"]
*
* @param pattern wildcard pattern with escape sequences
* @return array of pattern parts
*/
private static String[] splitWildcards(String pattern) {
List<String> parts = new ArrayList<>();
StringBuilder current = new StringBuilder();
boolean escaped = false;

for (char c : pattern.toCharArray()) {
if (escaped) {
current.append(c);
escaped = false;
} else if (c == '\\') {
escaped = true;
} else if (c == '*') {
parts.add(current.toString());
current = new StringBuilder();
} else {
current.append(c);
}
}

if (escaped) {
throw new IllegalArgumentException(
"Invalid escape sequence: pattern ends with unescaped backslash");
}

parts.add(current.toString());
return parts.toArray(new String[0]);
}

/**
* Counts the number of unescaped wildcards in a string.
*
* @param str string to count wildcards in
* @return number of unescaped wildcards
*/
private static int countWildcards(String str) {
int count = 0;
boolean escaped = false;
for (char c : str.toCharArray()) {
if (escaped) {
escaped = false;
} else if (c == '\\') {
escaped = true;
} else if (c == '*') {
count++;
}
}
return count;
}

/**
* Validates that wildcard count is symmetric between pattern and replacement.
*
* <p>Replacement must have either the same number of wildcards as the pattern, or zero wildcards.
*
* @param pattern wildcard pattern
* @param replacement wildcard replacement
* @throws IllegalArgumentException if wildcard counts are mismatched
*/
public static void validateWildcardSymmetry(String pattern, String replacement) {
int patternWildcards = countWildcards(pattern);
int replacementWildcards = countWildcards(replacement);

if (replacementWildcards != 0 && replacementWildcards != patternWildcards) {
throw new IllegalArgumentException(
String.format(
"Error in 'replace' command: Wildcard count mismatch - pattern has %d wildcard(s), "
+ "replacement has %d. Replacement must have same number of wildcards or none.",
patternWildcards, replacementWildcards));
}
}
}
Loading
Loading