1#pragma once
2
3#include <Interpreters/InDepthNodeVisitor.h>
4#include <AggregateFunctions/AggregateFunctionFactory.h>
5
6namespace DB
7{
8
9namespace ErrorCodes
10{
11 extern const int ILLEGAL_AGGREGATION;
12}
13
14class GetAggregatesMatcher
15{
16public:
17 using Visitor = ConstInDepthNodeVisitor<GetAggregatesMatcher, true>;
18
19 struct Data
20 {
21 const char * assert_no_aggregates = nullptr;
22 std::unordered_set<String> uniq_names;
23 std::vector<const ASTFunction *> aggregates;
24 };
25
26 static bool needChildVisit(const ASTPtr & node, const ASTPtr & child)
27 {
28 if (child->as<ASTSubquery>() || child->as<ASTSelectQuery>())
29 return false;
30 if (auto * func = node->as<ASTFunction>())
31 if (isAggregateFunction(func->name))
32 return false;
33 return true;
34 }
35
36 static void visit(const ASTPtr & ast, Data & data)
37 {
38 if (auto * func = ast->as<ASTFunction>())
39 visit(*func, ast, data);
40 }
41
42private:
43 static void visit(const ASTFunction & node, const ASTPtr &, Data & data)
44 {
45 if (!isAggregateFunction(node.name))
46 return;
47
48 if (data.assert_no_aggregates)
49 throw Exception("Aggregate function " + node.getColumnName() + " is found " + String(data.assert_no_aggregates) + " in query",
50 ErrorCodes::ILLEGAL_AGGREGATION);
51
52 String column_name = node.getColumnName();
53 if (data.uniq_names.count(column_name))
54 return;
55
56 data.uniq_names.insert(column_name);
57 data.aggregates.push_back(&node);
58 }
59
60 static bool isAggregateFunction(const String & name)
61 {
62 return AggregateFunctionFactory::instance().isAggregateFunctionName(name);
63 }
64};
65
66using GetAggregatesVisitor = GetAggregatesMatcher::Visitor;
67
68
69inline void assertNoAggregates(const ASTPtr & ast, const char * description)
70{
71 GetAggregatesVisitor::Data data{description, {}, {}};
72 GetAggregatesVisitor(data).visit(ast);
73}
74
75}
76