1#include <Interpreters/InJoinSubqueriesPreprocessor.h>
2#include <Interpreters/Context.h>
3#include <Interpreters/DatabaseAndTableWithAlias.h>
4#include <Interpreters/IdentifierSemantic.h>
5#include <Interpreters/InDepthNodeVisitor.h>
6#include <Storages/StorageDistributed.h>
7#include <Parsers/ASTIdentifier.h>
8#include <Parsers/ASTSelectQuery.h>
9#include <Parsers/ASTTablesInSelectQuery.h>
10#include <Parsers/ASTFunction.h>
11#include <Common/typeid_cast.h>
12
13
14namespace DB
15{
16
17namespace ErrorCodes
18{
19 extern const int DISTRIBUTED_IN_JOIN_SUBQUERY_DENIED;
20 extern const int LOGICAL_ERROR;
21}
22
23
24namespace
25{
26
27StoragePtr tryGetTable(const ASTPtr & database_and_table, const Context & context)
28{
29 DatabaseAndTableWithAlias db_and_table(database_and_table);
30 return context.tryGetTable(db_and_table.database, db_and_table.table);
31}
32
33using CheckShardsAndTables = InJoinSubqueriesPreprocessor::CheckShardsAndTables;
34
35struct NonGlobalTableData
36{
37 using TypeToVisit = ASTTableExpression;
38
39 const CheckShardsAndTables & checker;
40 const Context & context;
41 ASTFunction * function = nullptr;
42 ASTTableJoin * table_join = nullptr;
43
44 void visit(ASTTableExpression & node, ASTPtr &)
45 {
46 ASTPtr & database_and_table = node.database_and_table_name;
47 if (database_and_table)
48 renameIfNeeded(database_and_table);
49 }
50
51private:
52 void renameIfNeeded(ASTPtr & database_and_table)
53 {
54 const SettingDistributedProductMode distributed_product_mode = context.getSettingsRef().distributed_product_mode;
55
56 StoragePtr storage = tryGetTable(database_and_table, context);
57 if (!storage || !checker.hasAtLeastTwoShards(*storage))
58 return;
59
60 if (distributed_product_mode == DistributedProductMode::DENY)
61 {
62 throw Exception("Double-distributed IN/JOIN subqueries is denied (distributed_product_mode = 'deny')."
63 " You may rewrite query to use local tables in subqueries, or use GLOBAL keyword, or set distributed_product_mode to suitable value.",
64 ErrorCodes::DISTRIBUTED_IN_JOIN_SUBQUERY_DENIED);
65 }
66 else if (distributed_product_mode == DistributedProductMode::GLOBAL)
67 {
68 if (function)
69 {
70 auto * concrete = function->as<ASTFunction>();
71
72 if (concrete->name == "in")
73 concrete->name = "globalIn";
74 else if (concrete->name == "notIn")
75 concrete->name = "globalNotIn";
76 else if (concrete->name == "globalIn" || concrete->name == "globalNotIn")
77 {
78 /// Already processed.
79 }
80 else
81 throw Exception("Logical error: unexpected function name " + concrete->name, ErrorCodes::LOGICAL_ERROR);
82 }
83 else if (table_join)
84 table_join->locality = ASTTableJoin::Locality::Global;
85 else
86 throw Exception("Logical error: unexpected AST node", ErrorCodes::LOGICAL_ERROR);
87 }
88 else if (distributed_product_mode == DistributedProductMode::LOCAL)
89 {
90 /// Convert distributed table to corresponding remote table.
91
92 std::string database;
93 std::string table;
94 std::tie(database, table) = checker.getRemoteDatabaseAndTableName(*storage);
95
96 String alias = database_and_table->tryGetAlias();
97 if (alias.empty())
98 throw Exception("Distributed table should have an alias when distributed_product_mode set to local.",
99 ErrorCodes::DISTRIBUTED_IN_JOIN_SUBQUERY_DENIED);
100
101 database_and_table = createTableIdentifier(database, table);
102 database_and_table->setAlias(alias);
103 }
104 else
105 throw Exception("InJoinSubqueriesPreprocessor: unexpected value of 'distributed_product_mode' setting",
106 ErrorCodes::LOGICAL_ERROR);
107 }
108};
109
110using NonGlobalTableMatcher = OneTypeMatcher<NonGlobalTableData>;
111using NonGlobalTableVisitor = InDepthNodeVisitor<NonGlobalTableMatcher, true>;
112
113
114class NonGlobalSubqueryMatcher
115{
116public:
117 struct Data
118 {
119 const CheckShardsAndTables & checker;
120 const Context & context;
121 };
122
123 static void visit(ASTPtr & node, Data & data)
124 {
125 if (auto * function = node->as<ASTFunction>())
126 visit(*function, node, data);
127 if (const auto * tables = node->as<ASTTablesInSelectQueryElement>())
128 visit(*tables, node, data);
129 }
130
131 static bool needChildVisit(ASTPtr & node, const ASTPtr & child)
132 {
133 if (auto * function = node->as<ASTFunction>())
134 if (function->name == "in" || function->name == "notIn")
135 return false; /// Processed, process others
136
137 if (const auto * t = node->as<ASTTablesInSelectQueryElement>())
138 if (t->table_join && t->table_expression)
139 return false; /// Processed, process others
140
141 /// Descent into all children, but not into subqueries of other kind (scalar subqueries), that are irrelevant to us.
142 if (child->as<ASTSelectQuery>())
143 return false;
144 return true;
145 }
146
147private:
148 static void visit(ASTFunction & node, ASTPtr &, Data & data)
149 {
150 if (node.name == "in" || node.name == "notIn")
151 {
152 auto & subquery = node.arguments->children.at(1);
153 NonGlobalTableVisitor::Data table_data{data.checker, data.context, &node, nullptr};
154 NonGlobalTableVisitor(table_data).visit(subquery);
155 }
156 }
157
158 static void visit(const ASTTablesInSelectQueryElement & node, ASTPtr &, Data & data)
159 {
160 if (!node.table_join || !node.table_expression)
161 return;
162
163 ASTTableJoin * table_join = node.table_join->as<ASTTableJoin>();
164 if (table_join->locality != ASTTableJoin::Locality::Global)
165 {
166 if (auto & subquery = node.table_expression->as<ASTTableExpression>()->subquery)
167 {
168 NonGlobalTableVisitor::Data table_data{data.checker, data.context, nullptr, table_join};
169 NonGlobalTableVisitor(table_data).visit(subquery);
170 }
171 }
172 }
173};
174
175using NonGlobalSubqueryVisitor = InDepthNodeVisitor<NonGlobalSubqueryMatcher, true>;
176
177}
178
179
180void InJoinSubqueriesPreprocessor::visit(ASTPtr & ast) const
181{
182 if (!ast)
183 return;
184
185 ASTSelectQuery * query = ast->as<ASTSelectQuery>();
186 if (!query || !query->tables())
187 return;
188
189 if (context.getSettingsRef().distributed_product_mode == DistributedProductMode::ALLOW)
190 return;
191
192 const auto & tables_in_select_query = query->tables()->as<ASTTablesInSelectQuery &>();
193 if (tables_in_select_query.children.empty())
194 return;
195
196 const auto & tables_element = tables_in_select_query.children[0]->as<ASTTablesInSelectQueryElement &>();
197 if (!tables_element.table_expression)
198 return;
199
200 const auto * table_expression = tables_element.table_expression->as<ASTTableExpression>();
201
202 /// If not ordinary table, skip it.
203 if (!table_expression->database_and_table_name)
204 return;
205
206 /// If not really distributed table, skip it.
207 {
208 StoragePtr storage = tryGetTable(table_expression->database_and_table_name, context);
209 if (!storage || !checker->hasAtLeastTwoShards(*storage))
210 return;
211 }
212
213 NonGlobalSubqueryVisitor::Data visitor_data{*checker, context};
214 NonGlobalSubqueryVisitor(visitor_data).visit(ast);
215}
216
217
218bool InJoinSubqueriesPreprocessor::CheckShardsAndTables::hasAtLeastTwoShards(const IStorage & table) const
219{
220 const StorageDistributed * distributed = dynamic_cast<const StorageDistributed *>(&table);
221 if (!distributed)
222 return false;
223
224 return distributed->getShardCount() >= 2;
225}
226
227
228std::pair<std::string, std::string>
229InJoinSubqueriesPreprocessor::CheckShardsAndTables::getRemoteDatabaseAndTableName(const IStorage & table) const
230{
231 const StorageDistributed & distributed = dynamic_cast<const StorageDistributed &>(table);
232 return { distributed.getRemoteDatabaseName(), distributed.getRemoteTableName() };
233}
234
235
236}
237