1 | #include <Common/typeid_cast.h> |
---|---|
2 | #include <Parsers/ASTLiteral.h> |
3 | #include <Parsers/ASTFunction.h> |
4 | #include <Parsers/ASTExpressionList.h> |
5 | #include <Interpreters/OptimizeIfWithConstantConditionVisitor.h> |
6 | #include <IO/WriteHelpers.h> |
7 | |
8 | namespace DB |
9 | { |
10 | |
11 | namespace ErrorCodes |
12 | { |
13 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
14 | } |
15 | |
16 | static bool tryExtractConstValueFromCondition(const ASTPtr & condition, bool & value) |
17 | { |
18 | /// numeric constant in condition |
19 | if (const auto * literal = condition->as<ASTLiteral>()) |
20 | { |
21 | if (literal->value.getType() == Field::Types::Int64 || |
22 | literal->value.getType() == Field::Types::UInt64) |
23 | { |
24 | value = literal->value.get<Int64>(); |
25 | return true; |
26 | } |
27 | } |
28 | |
29 | /// cast of numeric constant in condition to UInt8 |
30 | if (const auto * function = condition->as<ASTFunction>()) |
31 | { |
32 | if (function->name == "CAST") |
33 | { |
34 | if (const auto * expr_list = function->arguments->as<ASTExpressionList>()) |
35 | { |
36 | const ASTPtr & type_ast = expr_list->children.at(1); |
37 | if (const auto * type_literal = type_ast->as<ASTLiteral>()) |
38 | { |
39 | if (type_literal->value.getType() == Field::Types::String && |
40 | type_literal->value.get<std::string>() == "UInt8") |
41 | return tryExtractConstValueFromCondition(expr_list->children.at(0), value); |
42 | } |
43 | } |
44 | } |
45 | } |
46 | |
47 | return false; |
48 | } |
49 | |
50 | void OptimizeIfWithConstantConditionVisitor::visit(ASTPtr & current_ast) |
51 | { |
52 | if (!current_ast) |
53 | return; |
54 | |
55 | for (ASTPtr & child : current_ast->children) |
56 | { |
57 | auto * function_node = child->as<ASTFunction>(); |
58 | if (!function_node || function_node->name != "if") |
59 | { |
60 | visit(child); |
61 | continue; |
62 | } |
63 | |
64 | visit(function_node->arguments); |
65 | const auto * args = function_node->arguments->as<ASTExpressionList>(); |
66 | |
67 | if (args->children.size() != 3) |
68 | throw Exception("Wrong number of arguments for function 'if' ("+ toString(args->children.size()) + " instead of 3)", |
69 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
70 | |
71 | ASTPtr condition_expr = args->children[0]; |
72 | ASTPtr then_expr = args->children[1]; |
73 | ASTPtr else_expr = args->children[2]; |
74 | |
75 | bool condition; |
76 | if (tryExtractConstValueFromCondition(condition_expr, condition)) |
77 | { |
78 | ASTPtr replace_ast = condition ? then_expr : else_expr; |
79 | ASTPtr child_copy = child; |
80 | String replace_alias = replace_ast->tryGetAlias(); |
81 | String if_alias = child->tryGetAlias(); |
82 | |
83 | if (replace_alias.empty()) |
84 | { |
85 | replace_ast->setAlias(if_alias); |
86 | child = replace_ast; |
87 | } |
88 | else |
89 | { |
90 | /// Only copy of one node is required here. |
91 | /// But IAST has only method for deep copy of subtree. |
92 | /// This can be a reason of performance degradation in case of deep queries. |
93 | ASTPtr replace_ast_deep_copy = replace_ast->clone(); |
94 | replace_ast_deep_copy->setAlias(if_alias); |
95 | child = replace_ast_deep_copy; |
96 | } |
97 | |
98 | if (!if_alias.empty()) |
99 | { |
100 | auto alias_it = aliases.find(if_alias); |
101 | if (alias_it != aliases.end() && alias_it->second.get() == child_copy.get()) |
102 | alias_it->second = child; |
103 | } |
104 | } |
105 | } |
106 | } |
107 | |
108 | } |
109 |