1#include <Parsers/ASTFunction.h>
2#include <Parsers/ASTLiteral.h>
3#include <Parsers/ASTSubquery.h>
4#include <Parsers/ASTSelectQuery.h>
5#include <Parsers/ASTTablesInSelectQuery.h>
6#include <Parsers/ASTExpressionList.h>
7
8#include <Interpreters/Context.h>
9#include <Interpreters/misc.h>
10#include <Interpreters/InterpreterSelectWithUnionQuery.h>
11#include <Interpreters/ExecuteScalarSubqueriesVisitor.h>
12#include <Interpreters/addTypeConversionToAST.h>
13
14#include <DataStreams/IBlockInputStream.h>
15#include <DataStreams/materializeBlock.h>
16#include <DataTypes/DataTypeAggregateFunction.h>
17#include <DataTypes/DataTypeTuple.h>
18
19#include <Columns/ColumnTuple.h>
20
21namespace DB
22{
23
24namespace ErrorCodes
25{
26 extern const int INCORRECT_RESULT_OF_SCALAR_SUBQUERY;
27 extern const int TOO_MANY_ROWS;
28}
29
30
31bool ExecuteScalarSubqueriesMatcher::needChildVisit(ASTPtr & node, const ASTPtr & child)
32{
33 /// Processed
34 if (node->as<ASTSubquery>() || node->as<ASTFunction>())
35 return false;
36
37 /// Don't descend into subqueries in FROM section
38 if (node->as<ASTTableExpression>())
39 return false;
40
41 if (node->as<ASTSelectQuery>())
42 {
43 /// Do not go to FROM, JOIN, UNION.
44 if (child->as<ASTTableExpression>() || child->as<ASTSelectQuery>())
45 return false;
46 }
47
48 return true;
49}
50
51void ExecuteScalarSubqueriesMatcher::visit(ASTPtr & ast, Data & data)
52{
53 if (const auto * t = ast->as<ASTSubquery>())
54 visit(*t, ast, data);
55 if (const auto * t = ast->as<ASTFunction>())
56 visit(*t, ast, data);
57}
58
59/// Converting to literal values might take a fair amount of overhead when the value is large, (e.g.
60/// Array, BitMap, etc.), This conversion is required for constant folding, index lookup, branch
61/// elimination. However, these optimizations should never be related to large values, thus we
62/// blacklist them here.
63static bool worthConvertingToLiteral(const Block & scalar)
64{
65 auto scalar_type_name = scalar.safeGetByPosition(0).type->getFamilyName();
66 std::set<String> useless_literal_types = {"Array", "Tuple", "AggregateFunction", "Function", "Set", "LowCardinality"};
67 return !useless_literal_types.count(scalar_type_name);
68}
69
70void ExecuteScalarSubqueriesMatcher::visit(const ASTSubquery & subquery, ASTPtr & ast, Data & data)
71{
72 auto hash = subquery.getTreeHash();
73 auto scalar_query_hash_str = toString(hash.first) + "_" + toString(hash.second);
74
75 Block scalar;
76 if (data.context.hasQueryContext() && data.context.getQueryContext().hasScalar(scalar_query_hash_str))
77 scalar = data.context.getQueryContext().getScalar(scalar_query_hash_str);
78 else if (data.scalars.count(scalar_query_hash_str))
79 scalar = data.scalars[scalar_query_hash_str];
80 else
81 {
82 Context subquery_context = data.context;
83 Settings subquery_settings = data.context.getSettings();
84 subquery_settings.max_result_rows = 1;
85 subquery_settings.extremes = 0;
86 subquery_context.setSettings(subquery_settings);
87
88 ASTPtr subquery_select = subquery.children.at(0);
89 BlockIO res = InterpreterSelectWithUnionQuery(
90 subquery_select, subquery_context, SelectQueryOptions(QueryProcessingStage::Complete, data.subquery_depth + 1)).execute();
91
92 Block block;
93 try
94 {
95 block = res.in->read();
96
97 if (!block)
98 {
99 /// Interpret subquery with empty result as Null literal
100 auto ast_new = std::make_unique<ASTLiteral>(Null());
101 ast_new->setAlias(ast->tryGetAlias());
102 ast = std::move(ast_new);
103 return;
104 }
105
106 if (block.rows() != 1 || res.in->read())
107 throw Exception("Scalar subquery returned more than one row", ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY);
108 }
109 catch (const Exception & e)
110 {
111 if (e.code() == ErrorCodes::TOO_MANY_ROWS)
112 throw Exception("Scalar subquery returned more than one row", ErrorCodes::INCORRECT_RESULT_OF_SCALAR_SUBQUERY);
113 else
114 throw;
115 }
116
117 block = materializeBlock(block);
118 size_t columns = block.columns();
119
120 if (columns == 1)
121 scalar = block;
122 else
123 {
124
125 ColumnWithTypeAndName ctn;
126 ctn.type = std::make_shared<DataTypeTuple>(block.getDataTypes());
127 ctn.column = ColumnTuple::create(block.getColumns());
128 scalar.insert(ctn);
129 }
130 }
131
132 const Settings & settings = data.context.getSettingsRef();
133
134 // Always convert to literals when there is no query context.
135 if (!settings.enable_scalar_subquery_optimization || worthConvertingToLiteral(scalar) || !data.context.hasQueryContext())
136 {
137 auto lit = std::make_unique<ASTLiteral>((*scalar.safeGetByPosition(0).column)[0]);
138 lit->alias = subquery.alias;
139 lit->prefer_alias_to_column_name = subquery.prefer_alias_to_column_name;
140 ast = addTypeConversionToAST(std::move(lit), scalar.safeGetByPosition(0).type->getName());
141 }
142 else
143 {
144 auto func = makeASTFunction("__getScalar", std::make_shared<ASTLiteral>(scalar_query_hash_str));
145 func->alias = subquery.alias;
146 func->prefer_alias_to_column_name = subquery.prefer_alias_to_column_name;
147 ast = std::move(func);
148 }
149
150 data.scalars[scalar_query_hash_str] = std::move(scalar);
151}
152
153void ExecuteScalarSubqueriesMatcher::visit(const ASTFunction & func, ASTPtr & ast, Data & data)
154{
155 /// Don't descend into subqueries in arguments of IN operator.
156 /// But if an argument is not subquery, than deeper may be scalar subqueries and we need to descend in them.
157
158 std::vector<ASTPtr *> out;
159 if (functionIsInOrGlobalInOperator(func.name))
160 {
161 for (auto & child : ast->children)
162 {
163 if (child != func.arguments)
164 out.push_back(&child);
165 else
166 for (size_t i = 0, size = func.arguments->children.size(); i < size; ++i)
167 if (i != 1 || !func.arguments->children[i]->as<ASTSubquery>())
168 out.push_back(&func.arguments->children[i]);
169 }
170 }
171 else
172 for (auto & child : ast->children)
173 out.push_back(&child);
174
175 for (ASTPtr * add_node : out)
176 Visitor(data).visit(*add_node);
177}
178
179}
180