|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | +// This file is copied from |
| 18 | +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp |
| 19 | +// and modified by Doris |
| 20 | + |
| 21 | +#include "vec/aggregate_functions/aggregate_function_group_array_intersect.h" |
| 22 | + |
| 23 | +namespace doris::vectorized { |
| 24 | + |
| 25 | +IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type, |
| 26 | + const DataTypes& argument_types) { |
| 27 | + WhichDataType which(nested_type); |
| 28 | + if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateTime) { |
| 29 | + throw Exception(ErrorCode::INVALID_ARGUMENT, |
| 30 | + "We don't support array<date> or array<datetime> for " |
| 31 | + "group_array_intersect(), please use array<datev2> or array<datetimev2>."); |
| 32 | + } else if (which.idx == TypeIndex::DateV2) { |
| 33 | + return new AggregateFunctionGroupArrayIntersect<DateV2>(argument_types); |
| 34 | + } else if (which.idx == TypeIndex::DateTimeV2) { |
| 35 | + return new AggregateFunctionGroupArrayIntersect<DateTimeV2>(argument_types); |
| 36 | + } else { |
| 37 | + /// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric |
| 38 | + if (nested_type->is_value_unambiguously_represented_in_contiguous_memory_region()) |
| 39 | + return new AggregateFunctionGroupArrayIntersectGeneric<true>(argument_types); |
| 40 | + else |
| 41 | + return new AggregateFunctionGroupArrayIntersectGeneric<false>(argument_types); |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl( |
| 46 | + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { |
| 47 | + const auto& nested_type = remove_nullable( |
| 48 | + dynamic_cast<const DataTypeArray&>(*(argument_types[0])).get_nested_type()); |
| 49 | + AggregateFunctionPtr res = nullptr; |
| 50 | + |
| 51 | + WhichDataType which(nested_type); |
| 52 | +#define DISPATCH(TYPE) \ |
| 53 | + if (which.idx == TypeIndex::TYPE) \ |
| 54 | + res = creator_without_type::create<AggregateFunctionGroupArrayIntersect<TYPE>>( \ |
| 55 | + argument_types, result_is_nullable); |
| 56 | + FOR_NUMERIC_TYPES(DISPATCH) |
| 57 | +#undef DISPATCH |
| 58 | + |
| 59 | + if (!res) { |
| 60 | + res = AggregateFunctionPtr(create_with_extra_types(nested_type, argument_types)); |
| 61 | + } |
| 62 | + |
| 63 | + if (!res) { |
| 64 | + throw Exception(ErrorCode::INVALID_ARGUMENT, |
| 65 | + "Illegal type {} of argument for aggregate function {}", |
| 66 | + argument_types[0]->get_name(), name); |
| 67 | + } |
| 68 | + |
| 69 | + return res; |
| 70 | +} |
| 71 | + |
| 72 | +AggregateFunctionPtr create_aggregate_function_group_array_intersect( |
| 73 | + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { |
| 74 | + assert_unary(name, argument_types); |
| 75 | + const DataTypePtr& argument_type = remove_nullable(argument_types[0]); |
| 76 | + |
| 77 | + if (!WhichDataType(argument_type).is_array()) |
| 78 | + throw Exception(ErrorCode::INVALID_ARGUMENT, |
| 79 | + "Aggregate function groupArrayIntersect accepts only array type argument. " |
| 80 | + "Provided argument type: " + |
| 81 | + argument_type->get_name()); |
| 82 | + return create_aggregate_function_group_array_intersect_impl(name, {argument_type}, |
| 83 | + result_is_nullable); |
| 84 | +} |
| 85 | + |
| 86 | +void register_aggregate_function_group_array_intersect(AggregateFunctionSimpleFactory& factory) { |
| 87 | + factory.register_function_both("group_array_intersect", |
| 88 | + create_aggregate_function_group_array_intersect); |
| 89 | +} |
| 90 | +} // namespace doris::vectorized |
0 commit comments