|
| 1 | +#include "common/exception/binder.h" |
| 2 | +#include "common/exception/message.h" |
| 3 | +#include "common/type_utils.h" |
| 4 | +#include "common/types/types.h" |
| 5 | +#include "function/list/functions/list_function_utils.h" |
| 6 | +#include "function/list/functions/list_position_function.h" |
| 7 | +#include "function/list/functions/list_unique_function.h" |
| 8 | +#include "function/list/vector_list_functions.h" |
| 9 | +#include "function/scalar_function.h" |
| 10 | + |
| 11 | +using namespace kuzu::common; |
| 12 | + |
| 13 | +namespace kuzu { |
| 14 | +namespace function { |
| 15 | + |
| 16 | +struct ListWhere { |
| 17 | + static void operation(common::list_entry_t& left, common::list_entry_t& right, |
| 18 | + common::list_entry_t& result, common::ValueVector& leftVector, |
| 19 | + common::ValueVector& rightVector, common::ValueVector& resultVector) { |
| 20 | + if (right.size!=left.size) { |
| 21 | + throw BinderException(stringFormat("LIST_WHERE expecting lists of same size, receiving size {} and size {}", left.size, left.size)); |
| 22 | + } |
| 23 | + auto leftDataVector = common::ListVector::getDataVector(&leftVector); |
| 24 | + auto leftPos = left.offset; |
| 25 | + auto rightDataVector = common::ListVector::getDataVector(&rightVector); |
| 26 | + auto rightPos = right.offset; |
| 27 | + list_size_t resultSize=0; |
| 28 | + std::vector<bool> maskListBools; |
| 29 | + for (auto i=0u; i < right.size; i++) { |
| 30 | + auto maskBool=rightDataVector->getValue<bool>(rightPos+i); |
| 31 | + maskListBools.push_back(maskBool); |
| 32 | + if (maskBool) { |
| 33 | + resultSize++; |
| 34 | + } |
| 35 | + } |
| 36 | + result = common::ListVector::addList(&resultVector, resultSize); |
| 37 | + auto resultDataVector = common::ListVector::getDataVector(&resultVector); |
| 38 | + auto resultPos = result.offset; |
| 39 | + for (auto i=0u; i < right.size; i++) { |
| 40 | + auto maskBool=maskListBools.at(i); |
| 41 | + if (maskBool) { |
| 42 | + resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftPos+i); |
| 43 | + } |
| 44 | + } |
| 45 | + } |
| 46 | +}; |
| 47 | +static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) { |
| 48 | + std::vector<LogicalType> types; |
| 49 | + types.push_back(input.arguments[0]->getDataType().copy()); |
| 50 | + types.push_back(input.arguments[1]->getDataType().copy()); |
| 51 | + if (types[1].getPhysicalType()!=PhysicalTypeID::LIST) { |
| 52 | + throw BinderException(ExceptionMessage::listFunctionIncompatibleChildrenType( |
| 53 | + ListIntersectFunction::name, types[0].toString(), types[1].toString())); |
| 54 | + } else { |
| 55 | + auto thisExtraTypeInfo=types[1].getExtraTypeInfo(); |
| 56 | + auto thisListTypeInfo=ku_dynamic_cast<const ListTypeInfo*>(thisExtraTypeInfo); |
| 57 | + if (thisListTypeInfo->getChildType().getPhysicalType()!=PhysicalTypeID::BOOL) { |
| 58 | + throw BinderException("LIST_SELECT expecting argument type: LIST of ANY, LIST of BOOL"); |
| 59 | + } |
| 60 | + } |
| 61 | + return std::make_unique<FunctionBindData>(std::move(types), types[0].copy()); |
| 62 | +} |
| 63 | + |
| 64 | +function_set ListWhereFunction::getFunctionSet() { |
| 65 | + function_set result; |
| 66 | + auto execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, |
| 67 | + list_entry_t, ListWhere>; |
| 68 | + auto function = std::make_unique<ScalarFunction>(name, |
| 69 | + std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::LIST}, LogicalTypeID::LIST, |
| 70 | + execFunc); |
| 71 | + function->bindFunc = bindFunc; |
| 72 | + result.push_back(std::move(function)); |
| 73 | + return result; |
| 74 | +} |
| 75 | + |
| 76 | +} // namespace function |
| 77 | +} // namespace kuzu |
0 commit comments