From 122e0004b2cbe88d98431eabd14d78d30d7aadb0 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Mon, 18 Mar 2024 13:39:18 -0600 Subject: [PATCH] Fix bugs and old test cases --- .../table/impl/lang/QueryLanguageParser.java | 35 ++++++++++++- .../engine/util/PyCallableWrapperJpyImpl.java | 35 +++++++++---- py/server/tests/test_udf_numpy_args.py | 50 ++++++++++++++----- .../tests/test_udf_return_java_values.py | 2 +- py/server/tests/test_vectorization.py | 2 +- 5 files changed, 98 insertions(+), 26 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index e3e5fc50851..88faecd3d46 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -153,7 +153,7 @@ public final class QueryLanguageParser extends GenericVisitorAdapter, Q private final Map> staticImportLookupCache = new HashMap<>(); // We need some class to represent null. We know for certain that this one won't be used... - private static final Class NULL_CLASS = QueryLanguageParser.class; + public static final Class NULL_CLASS = QueryLanguageParser.class; /** * The result of the QueryLanguageParser for the expression passed given to the constructor. @@ -1968,6 +1968,39 @@ public static boolean isWideningPrimitiveConversion(Class original, Class return false; } + public static boolean isLosslessWideningPrimitiveConversion(Class original, Class target) { + if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive() + || original.equals(void.class) || target.equals(void.class)) { + throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!"); + } + + if (original.equals(target)) { + return true; + } + + LanguageParserPrimitiveType originalEnum = LanguageParserPrimitiveType.getPrimitiveType(original); + + switch (originalEnum) { + case BytePrimitive: + if (target == short.class) + return true; + case ShortPrimitive: + case CharPrimitive: + if (target == int.class) + return true; + case IntPrimitive: + if (target == long.class) + return true; + break; + case FloatPrimitive: + if (target == double.class) + return true; + break; + } + + return false; + } + private enum LanguageParserPrimitiveType { // Including "Enum" (or really, any differentiating string) in these names is important. They're used // in a switch() statement, which apparently does not support qualified names. And we can't use diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index c3171176cc7..aadce5e6422 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -3,6 +3,7 @@ // package io.deephaven.engine.util; +import io.deephaven.engine.table.impl.lang.QueryLanguageParser; import io.deephaven.engine.table.impl.select.python.ArgumentsChunked; import io.deephaven.internal.log.LoggerFactory; import io.deephaven.io.logger.Logger; @@ -12,7 +13,9 @@ import java.time.Instant; import java.util.*; -import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isWideningPrimitiveConversion; +import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.NULL_CLASS; +import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isLosslessWideningPrimitiveConversion; +import static io.deephaven.util.type.TypeUtils.getUnboxedType; /** * When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs @@ -228,6 +231,11 @@ public void parseSignature() { // skip the array type code ti++; possibleTypes.add(numpyType2JavaArrayClass.get(paramTypeCodes.charAt(ti))); + if (paramTypeCodes.charAt(ti) == '?') { + possibleTypes.add(Boolean[].class); + } + } else if (typeCode == 'N') { + possibleTypes.add(NULL_CLASS); } else { possibleTypes.add(numpyType2JavaClass.get(typeCode)); } @@ -252,7 +260,7 @@ private boolean isSafelyCastable(Set> types, Class type) { if (t.isAssignableFrom(type)) { return true; } - if (t.isPrimitive() && type.isPrimitive() && isWideningPrimitiveConversion(type, t)) { + if (t.isPrimitive() && type.isPrimitive() && isLosslessWideningPrimitiveConversion(type, t)) { return true; } } @@ -263,18 +271,25 @@ private boolean isSafelyCastable(Set> types, Class type) { public void verifyArguments(Class[] argTypes) { String callableName = pyCallable.getAttribute("__name__").toString(); - if (argTypes.length != parameters.size()) { - throw new IllegalArgumentException( - callableName + ": " + "Expected " + parameters.size() + " arguments, got " + argTypes.length); - } + // if (argTypes.length > parameters.size()) { + // throw new IllegalArgumentException( + // callableName + ": " + "Expected " + parameters.size() + " or fewer arguments, got " + argTypes.length); + // } for (int i = 0; i < argTypes.length; i++) { - Set> types = parameters.get(i).getPossibleTypes(); - if (!types.contains(argTypes[i]) && !types.contains(Object.class) - && !isSafelyCastable(types, argTypes[i])) { + Set> types = + parameters.get(i > parameters.size() - 1 ? parameters.size() - 1 : i).getPossibleTypes(); + // Object is a catch-all type, so we don't need to check for it + if (argTypes[i] == Object.class) { + continue; + } + + Class t = getUnboxedType(argTypes[i]) == null ? argTypes[i] : getUnboxedType(argTypes[i]); + if (!types.contains(t) && !types.contains(Object.class) + && !isSafelyCastable(types, t)) { throw new IllegalArgumentException( callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of " + parameters.get(i).getPossibleTypes() + ", got " - + argTypes[i]); + + (argTypes[i].equals(NULL_CLASS) ? "null" : argTypes[i])); } } }; diff --git a/py/server/tests/test_udf_numpy_args.py b/py/server/tests/test_udf_numpy_args.py index 55c688b4a8f..fbb913520ec 100644 --- a/py/server/tests/test_udf_numpy_args.py +++ b/py/server/tests/test_udf_numpy_args.py @@ -218,7 +218,7 @@ def f11(p1: Union[float, np.float32]) -> bool: with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X1 = f11(i)"]) - def f2(p1: Union[np.int16, np.float64]) -> Union[Optional[bool]]: + def f2(p1: Union[np.int32, np.float64]) -> Union[Optional[bool]]: return bool(p1) t = empty_table(10).update(["X1 = f2(i)"]) @@ -231,7 +231,7 @@ def f21(p1: Union[np.int16, np.float64]) -> Union[Optional[bool], int]: with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X1 = f21(i)"]) - def f3(p1: Union[np.int16, np.float64], p2=None) -> bool: + def f3(p1: Union[np.int32, np.float64], p2=None) -> bool: return bool(p1) t = empty_table(10).update(["X1 = f3(i)"]) @@ -244,7 +244,7 @@ def f4(p1: Union[np.int16, np.float64], p2=None) -> bool: self.assertEqual(t.columns[0].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X1 = f4(now())"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f4: Expected .* got .*Instant") def f41(p1: Union[np.int16, np.float64, Union[Any]], p2=None) -> bool: return bool(p1) @@ -266,7 +266,7 @@ def f5(col1, col2: np.ndarray[np.int32]) -> bool: t = t.update(["X1 = f5(X, Y)"]) with self.assertRaises(DHError) as cm: t = t.update(["X1 = f5(X, null)"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f5: Expected .* got null") def f51(col1, col2: Optional[np.ndarray[np.int32]]) -> bool: return np.nanmean(col2) == np.mean(col2) @@ -287,7 +287,7 @@ def f6(*args: np.int32, col2: np.ndarray[np.int32]) -> bool: with self.assertRaises(DHError) as cm: t1 = t.update(["X1 = f6(X, Y=null)"]) - self.assertIn("not compatible with annotation", str(cm.exception)) + self.assertIn("f6: Expected argument (col2) to be one of [class [I], got boolean", str(cm.exception)) def test_str_bool_datetime_array(self): with self.subTest("str"): @@ -299,7 +299,7 @@ def f1(p1: np.ndarray[str], p2=None) -> bool: self.assertEqual(t1.columns[2].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t2 = t.update(["X1 = f1(null, Y )"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f1: Expected .* got null") def f11(p1: Union[np.ndarray[str], None], p2=None) -> bool: return bool(len(p1)) if p1 is not None else False @@ -315,7 +315,7 @@ def f2(p1: np.ndarray[np.datetime64], p2=None) -> bool: self.assertEqual(t1.columns[2].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t2 = t.update(["X1 = f2(null, Y )"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f2: Expected .* got null") def f21(p1: Union[np.ndarray[np.datetime64], None], p2=None) -> bool: return bool(len(p1)) if p1 is not None else False @@ -337,7 +337,7 @@ def f3(p1: np.ndarray[np.bool_], p2=None) -> bool: self.assertEqual(t1.columns[2].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t2 = t.update(["X1 = f3(null, Y )"]) - self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation") + self.assertRegex(str(cm.exception), "f3: Expected .* got null") def f31(p1: Optional[np.ndarray[bool]], p2=None) -> bool: return bool(len(p1)) if p1 is not None else False @@ -405,9 +405,9 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f(({p_type})X)"]) self.assertEqual(1, t.to_string(cols="Y").count("true")) - - np_int_types = {"np.int8", "np.int16", "np.int32", "np.int64"} - for p_type in np_int_types: + def test_np_typehints(self): + widening_np_int_types = {"np.int32", "np.int64"} + for p_type in widening_np_int_types: with self.subTest(p_type): func_str = f""" def f(x: {p_type}) -> bool: # note typing @@ -417,8 +417,20 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f(X)"]) self.assertEqual(1, t.to_string(cols="Y").count("true")) - np_floating_types = {"np.float32", "np.float64"} - for p_type in np_floating_types: + narrowing_np_int_types = {"np.int8", "np.int16"} + for p_type in narrowing_np_int_types: + with self.subTest(p_type): + func_str = f""" +def f(x: {p_type}) -> bool: # note typing + return type(x) == {p_type} +""" + exec(func_str, globals()) + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", f"Y = f(X)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + widening_np_floating_types = {"np.float32", "np.float64"} + for p_type in widening_np_floating_types: with self.subTest(p_type): func_str = f""" def f(x: {p_type}) -> bool: # note typing @@ -428,6 +440,18 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f((float)X)"]) self.assertEqual(1, t.to_string(cols="Y").count("true")) + int_to_floating_types = {"np.float32", "np.float64"} + for p_type in int_to_floating_types: + with self.subTest(p_type): + func_str = f""" +def f(x: {p_type}) -> bool: # note typing + return type(x) == {p_type} +""" + exec(func_str, globals()) + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", f"Y = f(X)"]) + self.assertRegex(str(cm.exception), "f: Expect") + def test_np_typehints_mismatch(self): def f(x: float) -> bool: return True diff --git a/py/server/tests/test_udf_return_java_values.py b/py/server/tests/test_udf_return_java_values.py index 7c1f55c5827..129105d5698 100644 --- a/py/server/tests/test_udf_return_java_values.py +++ b/py/server/tests/test_udf_return_java_values.py @@ -287,7 +287,7 @@ def test_np_ufunc(self): nbsin = numba.vectorize([numba.float64(numba.float64)])(np.sin) # this is the workaround that utilizes vectorization and type inference - @numba.vectorize([numba.float64(numba.float64)], nopython=True) + @numba.vectorize([numba.float64(numba.int64)], nopython=True) def nbsin(x): return np.sin(x) t3 = empty_table(10).update(["X3 = nbsin(i)"]) diff --git a/py/server/tests/test_vectorization.py b/py/server/tests/test_vectorization.py index d7532647640..2685fa843d5 100644 --- a/py/server/tests/test_vectorization.py +++ b/py/server/tests/test_vectorization.py @@ -269,7 +269,7 @@ def sinc2(x): self.assertEqual(t.columns[1].data_type, dtypes.PyObject) def test_optional_annotations(self): - def pyfunc(p1: np.int32, p2: np.int32, p3: Optional[np.int32]) -> Optional[int]: + def pyfunc(p1: np.int32, p2: np.int64, p3: Optional[np.int32]) -> Optional[int]: total = p1 + p2 + p3 return None if total % 3 == 0 else total