Skip to content

Commit 8d773a7

Browse files
[feature](agg) support aggregate function group_array_intersect (#33265)
1 parent 9e72044 commit 8d773a7

File tree

12 files changed

+1115
-2
lines changed

12 files changed

+1115
-2
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
// This file is copied from
18+
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp
19+
// and modified by Doris
20+
21+
#include "vec/aggregate_functions/aggregate_function_group_array_intersect.h"
22+
23+
namespace doris::vectorized {
24+
25+
IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type,
26+
const DataTypes& argument_types) {
27+
WhichDataType which(nested_type);
28+
if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateTime) {
29+
throw Exception(ErrorCode::INVALID_ARGUMENT,
30+
"We don't support array<date> or array<datetime> for "
31+
"group_array_intersect(), please use array<datev2> or array<datetimev2>.");
32+
} else if (which.idx == TypeIndex::DateV2) {
33+
return new AggregateFunctionGroupArrayIntersect<DateV2>(argument_types);
34+
} else if (which.idx == TypeIndex::DateTimeV2) {
35+
return new AggregateFunctionGroupArrayIntersect<DateTimeV2>(argument_types);
36+
} else {
37+
/// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric
38+
if (nested_type->is_value_unambiguously_represented_in_contiguous_memory_region())
39+
return new AggregateFunctionGroupArrayIntersectGeneric<true>(argument_types);
40+
else
41+
return new AggregateFunctionGroupArrayIntersectGeneric<false>(argument_types);
42+
}
43+
}
44+
45+
inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl(
46+
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
47+
const auto& nested_type = remove_nullable(
48+
dynamic_cast<const DataTypeArray&>(*(argument_types[0])).get_nested_type());
49+
AggregateFunctionPtr res = nullptr;
50+
51+
WhichDataType which(nested_type);
52+
#define DISPATCH(TYPE) \
53+
if (which.idx == TypeIndex::TYPE) \
54+
res = creator_without_type::create<AggregateFunctionGroupArrayIntersect<TYPE>>( \
55+
argument_types, result_is_nullable);
56+
FOR_NUMERIC_TYPES(DISPATCH)
57+
#undef DISPATCH
58+
59+
if (!res) {
60+
res = AggregateFunctionPtr(create_with_extra_types(nested_type, argument_types));
61+
}
62+
63+
if (!res) {
64+
throw Exception(ErrorCode::INVALID_ARGUMENT,
65+
"Illegal type {} of argument for aggregate function {}",
66+
argument_types[0]->get_name(), name);
67+
}
68+
69+
return res;
70+
}
71+
72+
AggregateFunctionPtr create_aggregate_function_group_array_intersect(
73+
const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) {
74+
assert_unary(name, argument_types);
75+
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
76+
77+
if (!WhichDataType(argument_type).is_array())
78+
throw Exception(ErrorCode::INVALID_ARGUMENT,
79+
"Aggregate function groupArrayIntersect accepts only array type argument. "
80+
"Provided argument type: " +
81+
argument_type->get_name());
82+
return create_aggregate_function_group_array_intersect_impl(name, {argument_type},
83+
result_is_nullable);
84+
}
85+
86+
void register_aggregate_function_group_array_intersect(AggregateFunctionSimpleFactory& factory) {
87+
factory.register_function_both("group_array_intersect",
88+
create_aggregate_function_group_array_intersect);
89+
}
90+
} // namespace doris::vectorized

0 commit comments

Comments
 (0)