diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 65f50020..4b3bdb1d 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -3882,7 +3882,120 @@ public static Column listagg(Column col) { } /** - * Returns a Column expression with values sorted in descending order. + * + * Signature - snowflake.snowpark.functions.regexp_extract (value: Union[Column, str], regexp: + * Union[Column, str], idx: int) Column Extract a specific group matched by a regex, from the + * specified string column. If the regex did not match, or the specified group did not match, an + * empty string is returned. + * Example: + *
{@code
+ * from snowflake.snowpark.functions import regexp_extract
+ * df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"])
+ * df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
+ * ---------
+ * |"RES" |
+ * ---------
+ * |20 |
+ * |40 |
+ * ---------
+ * }
+ *
+ * @since 1.12.1
+ * @return Column object.
+ */
+ public static Column regexp_extract(
+ Column col, String exp, Integer position, Integer Occurences, Integer grpIdx) {
+ return new Column(
+ com.snowflake.snowpark.functions.regexp_extract(
+ col.toScalaColumn(), exp, position, Occurences, grpIdx));
+ }
+
+ /**
+ * Returns the sign of its argument:
+ *
+ * - -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
+ *
+ *
Args: col: The column to evaluate its sign
+ * Example::
+ * *
{@code df =
+ * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
+ * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+ * sign("c").alias("c_sign")).show()
+ * ----------------------------------
+ * |"A_SIGN" |"B_SIGN" |"C_SIGN" |
+ * ----------------------------------
+ * |-1 |1 |0 |
+ * ----------------------------------
+ * }
+ *
+ * @since 1.12.1
+ * @param e Column to calculate the sign.
+ * @return Column object.
+ */
+ public static Column signum(Column col) {
+ return new Column(com.snowflake.snowpark.functions.signum(col.toScalaColumn()));
+ }
+
+ /**
+ * Returns the sign of its argument:
+ *
+ * - -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0.
+ *
+ *
Args: col: The column to evaluate its sign
+ * Example::
+ *
{@code df =
+ * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
+ * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+ * sign("c").alias("c_sign")).show()
+ * ----------------------------------
+ * |"A_SIGN" |"B_SIGN" |"C_SIGN" |
+ * ----------------------------------
+ * |-1 |1 |0 |
+ * ----------------------------------
+ * }
+ *
+ * @since 1.12.1
+ * @param e Column to calculate the sign.
+ * @return Column object.
+ */
+ public static Column sign(Column col) {
+ return new Column(com.snowflake.snowpark.functions.sign(col.toScalaColumn()));
+ }
+
+ /**
+ * Returns the substring from string str before count occurrences of the delimiter delim. If count
+ * is positive, everything the left of the final delimiter (counting from left) is returned. If
+ * count is negative, every to the right of the final delimiter (counting from the right) is
+ * returned. substring_index performs a case-sensitive match when searching for delim.
+ *
+ * @since 1.12.1
+ */
+ public static Column substring_index(Column col, String delim, Integer count) {
+ return new Column(
+ com.snowflake.snowpark.functions.substring_index(col.toScalaColumn(), delim, count));
+ }
+
+ /**
+ * Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is
+ * returned.
+ *
+ * Example::
+ *
+ *
{@code
+ * df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
+ * df.select(array_agg("a", True).alias("result")).show()
+ * "RESULT" [ 1, 2, 3 ]
+ * }
+ *
+ * @since 1.10.0
+ * @param c Column to be collect.
+ * @return The array.
+ */
+ public static Column collect_list(Column col) {
+ return new Column(com.snowflake.snowpark.functions.collect_list(col.toScalaColumn()));
+ }
+
+ /* Returns a Column expression with values sorted in descending order.
*
* Example: order column values in descending
*
diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala
index 5c6f599f..588d8290 100644
--- a/src/main/scala/com/snowflake/snowpark/functions.scala
+++ b/src/main/scala/com/snowflake/snowpark/functions.scala
@@ -3142,7 +3142,193 @@ object functions {
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)
/**
- * Returns a Column expression with values sorted in descending order.
+
+ * Signature - snowflake.snowpark.functions.regexp_extract
+ * (value: Union[Column, str], regexp: Union[Column, str], idx: int)
+ * Column
+ * Extract a specific group matched by a regex, from the specified string
+ * column. If the regex did not match, or the specified group did not match,
+ * an empty string is returned.
+ * Example:
+ * from snowflake.snowpark.functions import regexp_extract
+ * df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]],
+ * ["id", "age"])
+ * df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
+ *
+ *
+ * ---------
+ * |"RES" |
+ * ---------
+ * |20 |
+ * |40 |
+ * ---------
+ *
+ * Note: non-greedy tokens such as are not supported
+ * @since 1.12.1
+ * @return Column object.
+ */
+ def regexp_extract(
+ colName: Column,
+ exp: String,
+ position: Int,
+ Occurences: Int,
+ grpIdx: Int): Column = {
+ when(colName.is_null, lit(null))
+ .otherwise(
+ coalesce(
+ builtin("REGEX_SUBSTR")(
+ colName,
+ lit(exp),
+ lit(position),
+ lit(Occurences),
+ lit("ce"),
+ lit(grpIdx)),
+ lit("")))
+ }
+
+ /**
+ * Returns the sign of its argument:
+ *
+ * - -1 if the argument is negative.
+ * - 1 if it is positive.
+ * - 0 if it is 0.
+ *
+ * Args:
+ * col: The column to evaluate its sign
+ *
+ * Example::
+ * >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
+ * >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+ * sign("c").alias("c_sign")).show()
+ * ----------------------------------
+ * |"A_SIGN" |"B_SIGN" |"C_SIGN" |
+ * ----------------------------------
+ * |-1 |1 |0 |
+ * ----------------------------------
+ *
+ * @since 1.12.1
+ * @param e Column to calculate the sign.
+ * @return Column object.
+ */
+ def sign(colName: Column): Column = {
+ builtin("SIGN")(colName)
+ }
+
+ /**
+ * Returns the sign of its argument:
+ *
+ * - -1 if the argument is negative.
+ * - 1 if it is positive.
+ * - 0 if it is 0.
+ *
+ * Args:
+ * col: The column to evaluate its sign
+ *
+ * Example::
+ * >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
+ * >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+ * sign("c").alias("c_sign")).show()
+ * ----------------------------------
+ * |"A_SIGN" |"B_SIGN" |"C_SIGN" |
+ * ----------------------------------
+ * |-1 |1 |0 |
+ * ----------------------------------
+ *
+ * @since 1.12.1
+ * @param e Column to calculate the sign.
+ * @return Column object.
+ */
+ def signum(colName: Column): Column = {
+ builtin("SIGN")(colName)
+ }
+
+ /**
+ * Returns the sign of the given column. Returns either 1 for positive,
+ * 0 for 0 or
+ * NaN, -1 for negative and null for null.
+ * NOTE: if string values are provided snowflake will attempts to cast.
+ * If it casts correctly, returns the calculation,
+ * if not an error will be thrown
+ * @since 1.12.1
+ * @param columnName Name of the column to calculate the sign.
+ * @return Column object.
+ */
+ def signum(columnName: String): Column = {
+ signum(col(columnName))
+ }
+
+ /**
+ * Returns the substring from string str before count occurrences
+ * of the delimiter delim. If count is positive,
+ * everything the left of the final delimiter (counting from left)
+ * is returned. If count is negative, every to the right of the
+ * final delimiter (counting from the right) is returned.
+ * substring_index performs a case-sensitive match when searching for delim.
+ * @since 1.12.1
+ */
+ def substring_index(str: Column, delim: String, count: Int): Column = {
+ when(
+ lit(count) < lit(0),
+ callBuiltin(
+ "substring",
+ lit(str),
+ callBuiltin("regexp_instr", sqlExpr(s"reverse(${str}, ${delim}, 1, abs(${count}), 0"))))
+ .otherwise(
+ callBuiltin(
+ "substring",
+ lit(str),
+ 1,
+ callBuiltin("regexp_instr", col("str"), lit(delim), 1, lit(count), 1)))
+ }
+
+ /**
+ *
+ * Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
+ * ARRAY is returned.
+ *
+ * Example::
+ * >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
+ * >>> df.select(array_agg("a", True).alias("result")).show()
+ * ------------
+ * |"RESULT" |
+ * ------------
+ * |[ |
+ * | 1, |
+ * | 2, |
+ * | 3 |
+ * |] |
+ * ------------
+ *
+ * @since 1.10.0
+ * @param c Column to be collect.
+ * @return The array.
+ */
+ def collect_list(c: Column): Column = array_agg(c)
+
+ /**
+ *
+ * Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
+ * ARRAY is returned.
+ *
+ * Example::
+ * >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
+ * >>> df.select(array_agg("a", True).alias("result")).show()
+ * ------------
+ * |"RESULT" |
+ * ------------
+ * |[ |
+ * | 1, |
+ * | 2, |
+ * | 3 |
+ * |] |
+ * ------------
+ * @since 1.10.0
+ * @param s Column name to be collected.
+ * @return The array.
+ */
+ def collect_list(s: String): Column = array_agg(col(s))
+
+ /* Returns a Column expression with values sorted in descending order.
* Example:
* {{{
* val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id")
diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
index 624ea481..33f2904a 100644
--- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
+++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
@@ -2765,6 +2765,56 @@ public void any_value() {
assert result[0].getInt(0) == 1 || result[0].getInt(0) == 2 || result[0].getInt(0) == 3;
}
+ @Test
+ public void regexp_extract() {
+ DataFrame df = getSession().sql("select * from values('A MAN A PLAN A CANAL') as T(a)");
+ Row[] expected = {Row.create("MAN")};
+ checkAnswer(
+ df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 1, 1)), expected, false);
+ Row[] expected2 = {Row.create("PLAN")};
+ checkAnswer(
+ df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected2, false);
+ Row[] expected3 = {Row.create("CANAL")};
+ checkAnswer(
+ df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected3, false);
+ Row[] expected4 = {Row.create(null)};
+ checkAnswer(
+ df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 3, 1)), expected4, false);
+ }
+
+ @Test
+ public void signum() {
+ DataFrame df = getSession().sql("select * from values(1,-2,0) as T(a)");
+ checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1, -1, 0)}, false);
+ }
+
+ @Test
+ public void sign() {
+ DataFrame df = getSession().sql("select * from values(1,-2,0) as T(a)");
+ checkAnswer(df.select(Functions.sign(df.col("a"))), new Row[] {Row.create(1, -1, 0)}, false);
+ }
+
+ @Test
+ public void collect_list() {
+ DataFrame df = getSession().sql("select * from values(10000,400,450) as T(a)");
+ checkAnswer(
+ df.select(Functions.collect_list(df.col("a"))),
+ new Row[] {Row.create("[\n \"10000,400,450\"\n]")},
+ false);
+ }
+
+ @Test
+ public void substring_index() {
+ DataFrame df =
+ getSession()
+ .sql(
+ "select * from values ('It was the best of times,it was the worst of times') as T(a)");
+ checkAnswer(
+ df.select(Functions.substring_index(df.col("a"), "was", 1)),
+ new Row[] {Row.create(7)},
+ false);
+ }
+
@Test
public void test_asc() {
DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
index 8a89d87b..1420bb10 100644
--- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
@@ -2177,6 +2177,49 @@ trait FunctionSuite extends TestData {
expected,
sort = false)
}
+ test("regexp_extract") {
+ val data = Seq("A MAN A PLAN A CANAL").toDF("a")
+ var expected = Seq(Row("MAN"))
+ checkAnswer(
+ data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 1, 1)),
+ expected,
+ sort = false)
+ expected = Seq(Row("PLAN"))
+ checkAnswer(
+ data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 2, 1)),
+ expected,
+ sort = false)
+ expected = Seq(Row("CANAL"))
+ checkAnswer(
+ data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 3, 1)),
+ expected,
+ sort = false)
+
+ expected = Seq(Row(null))
+ checkAnswer(
+ data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 4, 1)),
+ expected,
+ sort = false)
+ }
+ test("signum") {
+ val df = Seq(1, -2, 0).toDF("a")
+ checkAnswer(df.select(signum(col("a"))), Seq(Row(1), Row(-1), Row(0)), sort = false)
+ }
+ test("sign") {
+ val df = Seq(1, -2, 0).toDF("a")
+ checkAnswer(df.select(sign(col("a"))), Seq(Row(1), Row(-1), Row(0)), sort = false)
+ }
+
+ test("collect_list") {
+ assert(monthlySales.select(collect_list(col("amount"))).collect()(0).get(0).toString ==
+ "[\n 10000,\n 400,\n 4500,\n 35000,\n 5000,\n 3000,\n 200,\n 90500,\n 6000,\n " +
+ "5000,\n 2500,\n 9500,\n 8000,\n 10000,\n 800,\n 4500\n]")
+
+ }
+ test("substring_index") {
+ val df = Seq("It was the best of times, it was the worst of times").toDF("a")
+ checkAnswer(df.select(substring_index(col("a"), "was", 1)), Seq(Row(7)), sort = false)
+ }
test("desc column order") {
val input = Seq(1, 2, 3).toDF("data")