1#include <Common/typeid_cast.h>
2#include <Parsers/ASTLiteral.h>
3#include <Parsers/ASTFunction.h>
4#include <Parsers/ASTExpressionList.h>
5#include <Interpreters/OptimizeIfChains.h>
6#include <IO/WriteHelpers.h>
7
8namespace DB
9{
10
11namespace ErrorCodes
12{
13 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
14 extern const int UNEXPECTED_AST_STRUCTURE;
15}
16
17void OptimizeIfChainsVisitor::visit(ASTPtr & current_ast)
18{
19 if (!current_ast)
20 return;
21
22 for (ASTPtr & child : current_ast->children)
23 {
24 /// Fallthrough cases
25
26 const auto * function_node = child->as<ASTFunction>();
27 if (!function_node || function_node->name != "if" || !function_node->arguments)
28 {
29 visit(child);
30 continue;
31 }
32
33 const auto * function_args = function_node->arguments->as<ASTExpressionList>();
34 if (!function_args || function_args->children.size() != 3 || !function_args->children[2])
35 {
36 visit(child);
37 continue;
38 }
39
40 const auto * else_arg = function_args->children[2]->as<ASTFunction>();
41 if (!else_arg || else_arg->name != "if")
42 {
43 visit(child);
44 continue;
45 }
46
47 /// The case of:
48 /// if(cond, a, if(...))
49
50 auto chain = ifChain(child);
51 std::reverse(chain.begin(), chain.end());
52 child->as<ASTFunction>()->name = "multiIf";
53 child->as<ASTFunction>()->arguments->children = std::move(chain);
54 }
55}
56
57ASTs OptimizeIfChainsVisitor::ifChain(const ASTPtr & child)
58{
59 const auto * function_node = child->as<ASTFunction>();
60 if (!function_node || !function_node->arguments)
61 throw Exception("Unexpected AST for function 'if'", ErrorCodes::UNEXPECTED_AST_STRUCTURE);
62
63 const auto * function_args = function_node->arguments->as<ASTExpressionList>();
64
65 if (!function_args || function_args->children.size() != 3)
66 throw Exception("Wrong number of arguments for function 'if' (" + toString(function_args->children.size()) + " instead of 3)",
67 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
68
69 const auto * else_arg = function_args->children[2]->as<ASTFunction>();
70
71 /// Recursively collect arguments from the innermost if ("head-resursion").
72 /// Arguments will be returned in reverse order.
73
74 if (else_arg && else_arg->name == "if")
75 {
76 auto cur = ifChain(function_node->arguments->children[2]);
77 cur.push_back(function_node->arguments->children[1]);
78 cur.push_back(function_node->arguments->children[0]);
79 return cur;
80 }
81 else
82 {
83 ASTs end;
84 end.reserve(3);
85 end.push_back(function_node->arguments->children[2]);
86 end.push_back(function_node->arguments->children[1]);
87 end.push_back(function_node->arguments->children[0]);
88 return end;
89 }
90}
91
92}
93