Skip to content
This repository was archived by the owner on Oct 10, 2025. It is now read-only.

Commit 70abac0

Browse files
committed
implemented list_where, and debug a typecheck error for list_select
1 parent 31218d0 commit 70abac0

File tree

5 files changed

+88
-2
lines changed

5 files changed

+88
-2
lines changed

src/function/function_collection.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ FunctionCollection* FunctionCollection::getFunctions() {
131131
SCALAR_FUNCTION(ListCosineSimilarityFunction), SCALAR_FUNCTION(ListCosineDistanceFunction),
132132
SCALAR_FUNCTION(ListDistanceFunction), SCALAR_FUNCTION(ListHasAnyFunction),
133133
SCALAR_FUNCTION(ListIntersectFunction), SCALAR_FUNCTION(ListSelectFunction),
134+
SCALAR_FUNCTION(ListWhereFunction),
134135

135136
// Cast functions
136137
SCALAR_FUNCTION(CastToDateFunction), SCALAR_FUNCTION_ALIAS(DateFunction),

src/function/list/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ add_library(kuzu_list_function
3030
list_has_all.cpp
3131
list_has_any.cpp
3232
list_intersect.cpp
33-
list_select_function.cpp)
33+
list_select_function.cpp
34+
list_where_function.cpp)
3435

3536
set(ALL_OBJECT_FILES
3637
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_list_function>

src/function/list/list_select_function.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& inp
3939
if (types[1].getPhysicalType()!=PhysicalTypeID::LIST) {
4040
throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType(
4141
ListIntersectFunction::name, types[0].toString(), types[1].toString()));
42+
} else {
4243
auto thisExtraTypeInfo=types[1].getExtraTypeInfo();
4344
auto thisListTypeInfo=ku_dynamic_cast<const ListTypeInfo*>(thisExtraTypeInfo);
44-
if (thisListTypeInfo->getChildType().getLogicalTypeID()!=LogicalTypeID::INT64) {
45+
if (thisListTypeInfo->getChildType().getPhysicalType()!=PhysicalTypeID::INT64) {
4546
throw BinderException("LIST_SELECT expecting argument type: LIST of ANY, LIST of INT");
4647
}
4748
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include "common/exception/binder.h"
2+
#include "common/exception/message.h"
3+
#include "common/type_utils.h"
4+
#include "common/types/types.h"
5+
#include "function/list/functions/list_function_utils.h"
6+
#include "function/list/functions/list_position_function.h"
7+
#include "function/list/functions/list_unique_function.h"
8+
#include "function/list/vector_list_functions.h"
9+
#include "function/scalar_function.h"
10+
11+
using namespace kuzu::common;
12+
13+
namespace kuzu {
14+
namespace function {
15+
16+
struct ListWhere {
17+
static void operation(common::list_entry_t& left, common::list_entry_t& right,
18+
common::list_entry_t& result, common::ValueVector& leftVector,
19+
common::ValueVector& rightVector, common::ValueVector& resultVector) {
20+
if (right.size!=left.size) {
21+
throw BinderException(stringFormat("LIST_WHERE expecting lists of same size, receiving size {} and size {}", left.size, left.size));
22+
}
23+
auto leftDataVector = common::ListVector::getDataVector(&leftVector);
24+
auto leftPos = left.offset;
25+
auto rightDataVector = common::ListVector::getDataVector(&rightVector);
26+
auto rightPos = right.offset;
27+
list_size_t resultSize=0;
28+
std::vector<bool> maskListBools;
29+
for (auto i=0u; i < right.size; i++) {
30+
auto maskBool=rightDataVector->getValue<bool>(rightPos+i);
31+
maskListBools.push_back(maskBool);
32+
if (maskBool) {
33+
resultSize++;
34+
}
35+
}
36+
result = common::ListVector::addList(&resultVector, resultSize);
37+
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
38+
auto resultPos = result.offset;
39+
for (auto i=0u; i < right.size; i++) {
40+
auto maskBool=maskListBools.at(i);
41+
if (maskBool) {
42+
resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftPos+i);
43+
}
44+
}
45+
}
46+
};
47+
static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) {
48+
std::vector<LogicalType> types;
49+
types.push_back(input.arguments[0]->getDataType().copy());
50+
types.push_back(input.arguments[1]->getDataType().copy());
51+
if (types[1].getPhysicalType()!=PhysicalTypeID::LIST) {
52+
throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType(
53+
ListIntersectFunction::name, types[0].toString(), types[1].toString()));
54+
} else {
55+
auto thisExtraTypeInfo=types[1].getExtraTypeInfo();
56+
auto thisListTypeInfo=ku_dynamic_cast<const ListTypeInfo*>(thisExtraTypeInfo);
57+
if (thisListTypeInfo->getChildType().getPhysicalType()!=PhysicalTypeID::BOOL) {
58+
throw BinderException("LIST_SELECT expecting argument type: LIST of ANY, LIST of BOOL");
59+
}
60+
}
61+
return std::make_unique<FunctionBindData>(std::move(types), types[0].copy());
62+
}
63+
64+
function_set ListWhereFunction::getFunctionSet() {
65+
function_set result;
66+
auto execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
67+
list_entry_t, ListWhere>;
68+
auto function = std::make_unique<ScalarFunction>(name,
69+
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::LIST,
70+
execFunc);
71+
function->bindFunc = bindFunc;
72+
result.push_back(std::move(function));
73+
return result;
74+
}
75+
76+
} // namespace function
77+
} // namespace kuzu

src/include/function/list/vector_list_functions.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,5 +248,11 @@ struct ListSelectFunction {
248248
static function_set getFunctionSet();
249249
};
250250

251+
struct ListWhereFunction {
252+
static constexpr const char* name = "LIST_WHERE";
253+
254+
static function_set getFunctionSet();
255+
};
256+
251257
} // namespace function
252258
} // namespace kuzu

0 commit comments

Comments
 (0)