| 1 | #include <Common/typeid_cast.h> | 
|---|
| 2 | #include <Functions/FunctionsComparison.h> | 
|---|
| 3 | #include <Functions/FunctionsLogical.h> | 
|---|
| 4 | #include <Interpreters/CrossToInnerJoinVisitor.h> | 
|---|
| 5 | #include <Interpreters/DatabaseAndTableWithAlias.h> | 
|---|
| 6 | #include <Interpreters/IdentifierSemantic.h> | 
|---|
| 7 | #include <Interpreters/misc.h> | 
|---|
| 8 | #include <Parsers/ASTSelectQuery.h> | 
|---|
| 9 | #include <Parsers/ASTTablesInSelectQuery.h> | 
|---|
| 10 | #include <Parsers/ASTIdentifier.h> | 
|---|
| 11 | #include <Parsers/ASTFunction.h> | 
|---|
| 12 | #include <Parsers/ASTExpressionList.h> | 
|---|
| 13 | #include <Parsers/ParserTablesInSelectQuery.h> | 
|---|
| 14 | #include <Parsers/ExpressionListParsers.h> | 
|---|
| 15 | #include <Parsers/parseQuery.h> | 
|---|
| 16 | #include <IO/WriteHelpers.h> | 
|---|
| 17 |  | 
|---|
| 18 | namespace DB | 
|---|
| 19 | { | 
|---|
| 20 |  | 
|---|
| 21 | namespace ErrorCodes | 
|---|
| 22 | { | 
|---|
| 23 | extern const int LOGICAL_ERROR; | 
|---|
| 24 | extern const int NOT_IMPLEMENTED; | 
|---|
| 25 | } | 
|---|
| 26 |  | 
|---|
| 27 | namespace | 
|---|
| 28 | { | 
|---|
| 29 |  | 
|---|
| 30 | struct JoinedTable | 
|---|
| 31 | { | 
|---|
| 32 | DatabaseAndTableWithAlias table; | 
|---|
| 33 | ASTTablesInSelectQueryElement * element = nullptr; | 
|---|
| 34 | ASTTableJoin * join = nullptr; | 
|---|
| 35 | ASTPtr array_join = nullptr; | 
|---|
| 36 | bool has_using = false; | 
|---|
| 37 |  | 
|---|
| 38 | JoinedTable(ASTPtr table_element) | 
|---|
| 39 | { | 
|---|
| 40 | element = table_element->as<ASTTablesInSelectQueryElement>(); | 
|---|
| 41 | if (!element) | 
|---|
| 42 | throw Exception( "Logical error: TablesInSelectQueryElement expected", ErrorCodes::LOGICAL_ERROR); | 
|---|
| 43 |  | 
|---|
| 44 | if (element->table_join) | 
|---|
| 45 | { | 
|---|
| 46 | join = element->table_join->as<ASTTableJoin>(); | 
|---|
| 47 | if (join->kind == ASTTableJoin::Kind::Cross || | 
|---|
| 48 | join->kind == ASTTableJoin::Kind::Comma) | 
|---|
| 49 | { | 
|---|
| 50 | if (!join->children.empty()) | 
|---|
| 51 | throw Exception( "Logical error: CROSS JOIN has expressions", ErrorCodes::LOGICAL_ERROR); | 
|---|
| 52 | } | 
|---|
| 53 |  | 
|---|
| 54 | if (join->using_expression_list) | 
|---|
| 55 | has_using = true; | 
|---|
| 56 | } | 
|---|
| 57 |  | 
|---|
| 58 | if (element->table_expression) | 
|---|
| 59 | { | 
|---|
| 60 | const auto & expr = element->table_expression->as<ASTTableExpression &>(); | 
|---|
| 61 | table = DatabaseAndTableWithAlias(expr); | 
|---|
| 62 | } | 
|---|
| 63 |  | 
|---|
| 64 | array_join = element->array_join; | 
|---|
| 65 | } | 
|---|
| 66 |  | 
|---|
| 67 | void rewriteCommaToCross() | 
|---|
| 68 | { | 
|---|
| 69 | if (join) | 
|---|
| 70 | join->kind = ASTTableJoin::Kind::Cross; | 
|---|
| 71 | } | 
|---|
| 72 |  | 
|---|
| 73 | bool canAttachOnExpression() const { return join && !join->on_expression; } | 
|---|
| 74 | }; | 
|---|
| 75 |  | 
|---|
| 76 | bool isComparison(const String & name) | 
|---|
| 77 | { | 
|---|
| 78 | return name == NameEquals::name || | 
|---|
| 79 | name == NameNotEquals::name || | 
|---|
| 80 | name == NameLess::name || | 
|---|
| 81 | name == NameGreater::name || | 
|---|
| 82 | name == NameLessOrEquals::name || | 
|---|
| 83 | name == NameGreaterOrEquals::name; | 
|---|
| 84 | } | 
|---|
| 85 |  | 
|---|
| 86 | /// It checks if where expression could be moved to JOIN ON expression partially or entirely. | 
|---|
| 87 | class CheckExpressionVisitorData | 
|---|
| 88 | { | 
|---|
| 89 | public: | 
|---|
| 90 | using TypeToVisit = const ASTFunction; | 
|---|
| 91 |  | 
|---|
| 92 | CheckExpressionVisitorData(const std::vector<JoinedTable> & tables_) | 
|---|
| 93 | : joined_tables(tables_) | 
|---|
| 94 | , ands_only(true) | 
|---|
| 95 | { | 
|---|
| 96 | for (auto & joined : joined_tables) | 
|---|
| 97 | tables.push_back(joined.table); | 
|---|
| 98 | } | 
|---|
| 99 |  | 
|---|
| 100 | void visit(const ASTFunction & node, const ASTPtr & ast) | 
|---|
| 101 | { | 
|---|
| 102 | if (!ands_only) | 
|---|
| 103 | return; | 
|---|
| 104 |  | 
|---|
| 105 | if (node.name == NameAnd::name) | 
|---|
| 106 | { | 
|---|
| 107 | if (!node.arguments || node.arguments->children.empty()) | 
|---|
| 108 | throw Exception( "Logical error: function requires argiment", ErrorCodes::LOGICAL_ERROR); | 
|---|
| 109 |  | 
|---|
| 110 | for (auto & child : node.arguments->children) | 
|---|
| 111 | { | 
|---|
| 112 | if (const auto * func = child->as<ASTFunction>()) | 
|---|
| 113 | visit(*func, child); | 
|---|
| 114 | else | 
|---|
| 115 | ands_only = false; | 
|---|
| 116 | } | 
|---|
| 117 | } | 
|---|
| 118 | else if (node.name == NameEquals::name) | 
|---|
| 119 | { | 
|---|
| 120 | if (size_t min_table = canMoveEqualsToJoinOn(node)) | 
|---|
| 121 | asts_to_join_on[min_table].push_back(ast); | 
|---|
| 122 | } | 
|---|
| 123 | else if (isComparison(node.name)) | 
|---|
| 124 | { | 
|---|
| 125 | /// leave other comparisons as is | 
|---|
| 126 | } | 
|---|
| 127 | else if (functionIsInOperator(node.name)) /// IN, NOT IN | 
|---|
| 128 | { | 
|---|
| 129 | if (auto ident = node.arguments->children.at(0)->as<ASTIdentifier>()) | 
|---|
| 130 | if (size_t min_table = checkIdentifier(*ident)) | 
|---|
| 131 | asts_to_join_on[min_table].push_back(ast); | 
|---|
| 132 | } | 
|---|
| 133 | else | 
|---|
| 134 | { | 
|---|
| 135 | ands_only = false; | 
|---|
| 136 | asts_to_join_on.clear(); | 
|---|
| 137 | } | 
|---|
| 138 | } | 
|---|
| 139 |  | 
|---|
| 140 | bool complex() const { return !ands_only; } | 
|---|
| 141 | bool matchAny(size_t t) const { return asts_to_join_on.count(t); } | 
|---|
| 142 |  | 
|---|
| 143 | ASTPtr makeOnExpression(size_t table_pos) | 
|---|
| 144 | { | 
|---|
| 145 | if (!asts_to_join_on.count(table_pos)) | 
|---|
| 146 | return {}; | 
|---|
| 147 |  | 
|---|
| 148 | std::vector<ASTPtr> & expressions = asts_to_join_on[table_pos]; | 
|---|
| 149 |  | 
|---|
| 150 | if (expressions.size() == 1) | 
|---|
| 151 | return expressions[0]->clone(); | 
|---|
| 152 |  | 
|---|
| 153 | std::vector<ASTPtr> arguments; | 
|---|
| 154 | arguments.reserve(expressions.size()); | 
|---|
| 155 | for (auto & ast : expressions) | 
|---|
| 156 | arguments.emplace_back(ast->clone()); | 
|---|
| 157 |  | 
|---|
| 158 | return makeASTFunction(NameAnd::name, std::move(arguments)); | 
|---|
| 159 | } | 
|---|
| 160 |  | 
|---|
| 161 | private: | 
|---|
| 162 | const std::vector<JoinedTable> & joined_tables; | 
|---|
| 163 | std::vector<DatabaseAndTableWithAlias> tables; | 
|---|
| 164 | std::map<size_t, std::vector<ASTPtr>> asts_to_join_on; | 
|---|
| 165 | bool ands_only; | 
|---|
| 166 |  | 
|---|
| 167 | size_t canMoveEqualsToJoinOn(const ASTFunction & node) | 
|---|
| 168 | { | 
|---|
| 169 | if (!node.arguments) | 
|---|
| 170 | throw Exception( "Logical error: function requires arguments", ErrorCodes::LOGICAL_ERROR); | 
|---|
| 171 | if (node.arguments->children.size() != 2) | 
|---|
| 172 | return false; | 
|---|
| 173 |  | 
|---|
| 174 | const auto * left = node.arguments->children[0]->as<ASTIdentifier>(); | 
|---|
| 175 | const auto * right = node.arguments->children[1]->as<ASTIdentifier>(); | 
|---|
| 176 | if (!left || !right) | 
|---|
| 177 | return false; | 
|---|
| 178 |  | 
|---|
| 179 | return checkIdentifiers(*left, *right); | 
|---|
| 180 | } | 
|---|
| 181 |  | 
|---|
| 182 | /// Check if the identifiers are from different joined tables. If it's a self joint, tables should have aliases. | 
|---|
| 183 | /// select * from t1 a cross join t2 b where a.x = b.x | 
|---|
| 184 | /// @return table position to attach expression to or 0. | 
|---|
| 185 | size_t checkIdentifiers(const ASTIdentifier & left, const ASTIdentifier & right) | 
|---|
| 186 | { | 
|---|
| 187 | size_t left_table_pos = 0; | 
|---|
| 188 | bool left_match = IdentifierSemantic::chooseTable(left, tables, left_table_pos); | 
|---|
| 189 |  | 
|---|
| 190 | size_t right_table_pos = 0; | 
|---|
| 191 | bool right_match = IdentifierSemantic::chooseTable(right, tables, right_table_pos); | 
|---|
| 192 |  | 
|---|
| 193 | if (left_match && right_match && (left_table_pos != right_table_pos)) | 
|---|
| 194 | { | 
|---|
| 195 | size_t table_pos = std::max(left_table_pos, right_table_pos); | 
|---|
| 196 | if (joined_tables[table_pos].canAttachOnExpression()) | 
|---|
| 197 | return table_pos; | 
|---|
| 198 | } | 
|---|
| 199 | return 0; | 
|---|
| 200 | } | 
|---|
| 201 |  | 
|---|
| 202 | size_t checkIdentifier(const ASTIdentifier & identifier) | 
|---|
| 203 | { | 
|---|
| 204 | size_t best_table_pos = 0; | 
|---|
| 205 | bool match = IdentifierSemantic::chooseTable(identifier, tables, best_table_pos); | 
|---|
| 206 |  | 
|---|
| 207 | if (match && joined_tables[best_table_pos].canAttachOnExpression()) | 
|---|
| 208 | return best_table_pos; | 
|---|
| 209 | return 0; | 
|---|
| 210 | } | 
|---|
| 211 | }; | 
|---|
| 212 |  | 
|---|
| 213 | using CheckExpressionMatcher = ConstOneTypeMatcher<CheckExpressionVisitorData, false>; | 
|---|
| 214 | using CheckExpressionVisitor = ConstInDepthNodeVisitor<CheckExpressionMatcher, true>; | 
|---|
| 215 |  | 
|---|
| 216 |  | 
|---|
| 217 | bool getTables(ASTSelectQuery & select, std::vector<JoinedTable> & joined_tables, size_t & num_comma) | 
|---|
| 218 | { | 
|---|
| 219 | if (!select.tables()) | 
|---|
| 220 | return false; | 
|---|
| 221 |  | 
|---|
| 222 | const auto * tables = select.tables()->as<ASTTablesInSelectQuery>(); | 
|---|
| 223 | if (!tables) | 
|---|
| 224 | return false; | 
|---|
| 225 |  | 
|---|
| 226 | size_t num_tables = tables->children.size(); | 
|---|
| 227 | if (num_tables < 2) | 
|---|
| 228 | return false; | 
|---|
| 229 |  | 
|---|
| 230 | joined_tables.reserve(num_tables); | 
|---|
| 231 | size_t num_array_join = 0; | 
|---|
| 232 | size_t num_using = 0; | 
|---|
| 233 |  | 
|---|
| 234 | for (auto & child : tables->children) | 
|---|
| 235 | { | 
|---|
| 236 | joined_tables.emplace_back(JoinedTable(child)); | 
|---|
| 237 | JoinedTable & t = joined_tables.back(); | 
|---|
| 238 | if (t.array_join) | 
|---|
| 239 | { | 
|---|
| 240 | ++num_array_join; | 
|---|
| 241 | continue; | 
|---|
| 242 | } | 
|---|
| 243 |  | 
|---|
| 244 | if (t.has_using) | 
|---|
| 245 | { | 
|---|
| 246 | ++num_using; | 
|---|
| 247 | continue; | 
|---|
| 248 | } | 
|---|
| 249 |  | 
|---|
| 250 | if (auto * join = t.join) | 
|---|
| 251 | if (join->kind == ASTTableJoin::Kind::Comma) | 
|---|
| 252 | ++num_comma; | 
|---|
| 253 | } | 
|---|
| 254 |  | 
|---|
| 255 | if (num_using && (num_tables - num_array_join) > 2) | 
|---|
| 256 | throw Exception( "Multiple CROSS/COMMA JOIN do not support USING", ErrorCodes::NOT_IMPLEMENTED); | 
|---|
| 257 |  | 
|---|
| 258 | if (num_comma && (num_comma != (joined_tables.size() - 1))) | 
|---|
| 259 | throw Exception( "Mix of COMMA and other JOINS is not supported", ErrorCodes::NOT_IMPLEMENTED); | 
|---|
| 260 |  | 
|---|
| 261 | if (num_array_join || num_using) | 
|---|
| 262 | return false; | 
|---|
| 263 | return true; | 
|---|
| 264 | } | 
|---|
| 265 |  | 
|---|
| 266 | } | 
|---|
| 267 |  | 
|---|
| 268 |  | 
|---|
| 269 | void CrossToInnerJoinMatcher::visit(ASTPtr & ast, Data & data) | 
|---|
| 270 | { | 
|---|
| 271 | if (auto * t = ast->as<ASTSelectQuery>()) | 
|---|
| 272 | visit(*t, ast, data); | 
|---|
| 273 | } | 
|---|
| 274 |  | 
|---|
| 275 | void CrossToInnerJoinMatcher::visit(ASTSelectQuery & select, ASTPtr &, Data & data) | 
|---|
| 276 | { | 
|---|
| 277 | size_t num_comma = 0; | 
|---|
| 278 | std::vector<JoinedTable> joined_tables; | 
|---|
| 279 | if (!getTables(select, joined_tables, num_comma)) | 
|---|
| 280 | return; | 
|---|
| 281 |  | 
|---|
| 282 | /// COMMA to CROSS | 
|---|
| 283 |  | 
|---|
| 284 | if (num_comma) | 
|---|
| 285 | { | 
|---|
| 286 | for (auto & table : joined_tables) | 
|---|
| 287 | table.rewriteCommaToCross(); | 
|---|
| 288 | } | 
|---|
| 289 |  | 
|---|
| 290 | /// CROSS to INNER | 
|---|
| 291 |  | 
|---|
| 292 | if (!select.where()) | 
|---|
| 293 | return; | 
|---|
| 294 |  | 
|---|
| 295 | CheckExpressionVisitor::Data visitor_data{joined_tables}; | 
|---|
| 296 | CheckExpressionVisitor(visitor_data).visit(select.where()); | 
|---|
| 297 |  | 
|---|
| 298 | if (visitor_data.complex()) | 
|---|
| 299 | return; | 
|---|
| 300 |  | 
|---|
| 301 | for (size_t i = 1; i < joined_tables.size(); ++i) | 
|---|
| 302 | { | 
|---|
| 303 | if (visitor_data.matchAny(i)) | 
|---|
| 304 | { | 
|---|
| 305 | ASTTableJoin & join = *joined_tables[i].join; | 
|---|
| 306 | join.kind = ASTTableJoin::Kind::Inner; | 
|---|
| 307 | join.strictness = ASTTableJoin::Strictness::All; | 
|---|
| 308 |  | 
|---|
| 309 | join.on_expression = visitor_data.makeOnExpression(i); | 
|---|
| 310 | join.children.push_back(join.on_expression); | 
|---|
| 311 | data.done = true; | 
|---|
| 312 | } | 
|---|
| 313 | } | 
|---|
| 314 | } | 
|---|
| 315 |  | 
|---|
| 316 | } | 
|---|
| 317 |  | 
|---|