1+ #include " math.h"
2+
13#include " common/exception/binder.h"
24#include " common/exception/message.h"
35#include " common/type_utils.h"
6+ #include " common/vector/value_vector.h"
47#include " function/list/functions/list_function_utils.h"
58#include " function/list/vector_list_functions.h"
69#include " function/scalar_function.h"
7- #include " math.h"
8- #include " common/vector/value_vector.h"
910#include < simsimd.h>
1011
1112using namespace kuzu ::common;
@@ -14,9 +15,10 @@ namespace kuzu {
1415namespace function {
1516
1617struct ListCosineSimilarity {
17- template <std::floating_point T>
18- static void operation (common::list_entry_t & left, common::list_entry_t & right, T& result, common::ValueVector& leftVector,
19- common::ValueVector& rightVector, common::ValueVector& /* resultVector*/ ) {
18+ template <std::floating_point T>
19+ static void operation (common::list_entry_t & left, common::list_entry_t & right, T& result,
20+ common::ValueVector& leftVector, common::ValueVector& rightVector,
21+ common::ValueVector& /* resultVector*/ ) {
2022 auto leftElements = (T*)common::ListVector::getListValues (&leftVector, left);
2123 auto rightElements = (T*)common::ListVector::getListValues (&rightVector, right);
2224 KU_ASSERT (left.size == right.size );
@@ -54,8 +56,7 @@ static LogicalType validateListFunctionParameters(const LogicalType& leftType,
5456 return rightType.copy ();
5557 }
5658 throw BinderException (
57- stringFormat (" {} requires at least one argument to be LIST." ,
58- functionName));
59+ stringFormat (" {} requires at least one argument to be LIST." , functionName));
5960}
6061
6162template <typename OPERATION, typename RESULT>
@@ -83,12 +84,13 @@ scalar_func_exec_t getScalarExecFunc(LogicalType type) {
8384
8485static std::unique_ptr<FunctionBindData> bindFunc (const ScalarBindFuncInput& input) {
8586 std::vector<LogicalType> types;
86- // auto scalarFunction = input.definition->ptrCast<ScalarFunction>();
87+ // auto scalarFunction = input.definition->ptrCast<ScalarFunction>();
8788 types.push_back (input.arguments [0 ]->getDataType ().copy ());
8889 types.push_back (input.arguments [1 ]->getDataType ().copy ());
8990 auto paramType = validateListFunctionParameters (types[0 ], types[1 ], input.definition ->name );
90- // const auto& resultType = ListType::getChildType(input.arguments[0]->dataType);
91- input.definition ->ptrCast <ScalarFunction>()->execFunc = std::move (getScalarExecFunc<ListCosineSimilarity>(paramType.copy ()));
91+ // const auto& resultType = ListType::getChildType(input.arguments[0]->dataType);
92+ input.definition ->ptrCast <ScalarFunction>()->execFunc =
93+ std::move (getScalarExecFunc<ListCosineSimilarity>(paramType.copy ()));
9294 auto bindData = std::make_unique<FunctionBindData>(ListType::getChildType (paramType).copy ());
9395 std::vector<LogicalType> paramTypes;
9496 for (auto & _ : input.arguments ) {
@@ -100,8 +102,9 @@ static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& inp
100102
101103function_set ListConsineSimilarityFunction::getFunctionSet () {
102104 function_set result;
103- // auto execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, float, ListCosineSimilarity>;
104- auto function = std::make_unique<ScalarFunction>(name,
105+ // auto execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
106+ // float, ListCosineSimilarity>;
107+ auto function = std::make_unique<ScalarFunction>(name,
105108 std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::ANY);
106109 function->bindFunc = bindFunc;
107110 result.push_back (std::move (function));
0 commit comments