| 1 | #include <Common/typeid_cast.h> |
| 2 | #include <Functions/FunctionsComparison.h> |
| 3 | #include <Functions/FunctionsLogical.h> |
| 4 | #include <Interpreters/CrossToInnerJoinVisitor.h> |
| 5 | #include <Interpreters/DatabaseAndTableWithAlias.h> |
| 6 | #include <Interpreters/IdentifierSemantic.h> |
| 7 | #include <Interpreters/misc.h> |
| 8 | #include <Parsers/ASTSelectQuery.h> |
| 9 | #include <Parsers/ASTTablesInSelectQuery.h> |
| 10 | #include <Parsers/ASTIdentifier.h> |
| 11 | #include <Parsers/ASTFunction.h> |
| 12 | #include <Parsers/ASTExpressionList.h> |
| 13 | #include <Parsers/ParserTablesInSelectQuery.h> |
| 14 | #include <Parsers/ExpressionListParsers.h> |
| 15 | #include <Parsers/parseQuery.h> |
| 16 | #include <IO/WriteHelpers.h> |
| 17 | |
| 18 | namespace DB |
| 19 | { |
| 20 | |
| 21 | namespace ErrorCodes |
| 22 | { |
| 23 | extern const int LOGICAL_ERROR; |
| 24 | extern const int NOT_IMPLEMENTED; |
| 25 | } |
| 26 | |
| 27 | namespace |
| 28 | { |
| 29 | |
| 30 | struct JoinedTable |
| 31 | { |
| 32 | DatabaseAndTableWithAlias table; |
| 33 | ASTTablesInSelectQueryElement * element = nullptr; |
| 34 | ASTTableJoin * join = nullptr; |
| 35 | ASTPtr array_join = nullptr; |
| 36 | bool has_using = false; |
| 37 | |
| 38 | JoinedTable(ASTPtr table_element) |
| 39 | { |
| 40 | element = table_element->as<ASTTablesInSelectQueryElement>(); |
| 41 | if (!element) |
| 42 | throw Exception("Logical error: TablesInSelectQueryElement expected" , ErrorCodes::LOGICAL_ERROR); |
| 43 | |
| 44 | if (element->table_join) |
| 45 | { |
| 46 | join = element->table_join->as<ASTTableJoin>(); |
| 47 | if (join->kind == ASTTableJoin::Kind::Cross || |
| 48 | join->kind == ASTTableJoin::Kind::Comma) |
| 49 | { |
| 50 | if (!join->children.empty()) |
| 51 | throw Exception("Logical error: CROSS JOIN has expressions" , ErrorCodes::LOGICAL_ERROR); |
| 52 | } |
| 53 | |
| 54 | if (join->using_expression_list) |
| 55 | has_using = true; |
| 56 | } |
| 57 | |
| 58 | if (element->table_expression) |
| 59 | { |
| 60 | const auto & expr = element->table_expression->as<ASTTableExpression &>(); |
| 61 | table = DatabaseAndTableWithAlias(expr); |
| 62 | } |
| 63 | |
| 64 | array_join = element->array_join; |
| 65 | } |
| 66 | |
| 67 | void rewriteCommaToCross() |
| 68 | { |
| 69 | if (join) |
| 70 | join->kind = ASTTableJoin::Kind::Cross; |
| 71 | } |
| 72 | |
| 73 | bool canAttachOnExpression() const { return join && !join->on_expression; } |
| 74 | }; |
| 75 | |
| 76 | bool isComparison(const String & name) |
| 77 | { |
| 78 | return name == NameEquals::name || |
| 79 | name == NameNotEquals::name || |
| 80 | name == NameLess::name || |
| 81 | name == NameGreater::name || |
| 82 | name == NameLessOrEquals::name || |
| 83 | name == NameGreaterOrEquals::name; |
| 84 | } |
| 85 | |
| 86 | /// It checks if where expression could be moved to JOIN ON expression partially or entirely. |
| 87 | class CheckExpressionVisitorData |
| 88 | { |
| 89 | public: |
| 90 | using TypeToVisit = const ASTFunction; |
| 91 | |
| 92 | CheckExpressionVisitorData(const std::vector<JoinedTable> & tables_) |
| 93 | : joined_tables(tables_) |
| 94 | , ands_only(true) |
| 95 | { |
| 96 | for (auto & joined : joined_tables) |
| 97 | tables.push_back(joined.table); |
| 98 | } |
| 99 | |
| 100 | void visit(const ASTFunction & node, const ASTPtr & ast) |
| 101 | { |
| 102 | if (!ands_only) |
| 103 | return; |
| 104 | |
| 105 | if (node.name == NameAnd::name) |
| 106 | { |
| 107 | if (!node.arguments || node.arguments->children.empty()) |
| 108 | throw Exception("Logical error: function requires argiment" , ErrorCodes::LOGICAL_ERROR); |
| 109 | |
| 110 | for (auto & child : node.arguments->children) |
| 111 | { |
| 112 | if (const auto * func = child->as<ASTFunction>()) |
| 113 | visit(*func, child); |
| 114 | else |
| 115 | ands_only = false; |
| 116 | } |
| 117 | } |
| 118 | else if (node.name == NameEquals::name) |
| 119 | { |
| 120 | if (size_t min_table = canMoveEqualsToJoinOn(node)) |
| 121 | asts_to_join_on[min_table].push_back(ast); |
| 122 | } |
| 123 | else if (isComparison(node.name)) |
| 124 | { |
| 125 | /// leave other comparisons as is |
| 126 | } |
| 127 | else if (functionIsInOperator(node.name)) /// IN, NOT IN |
| 128 | { |
| 129 | if (auto ident = node.arguments->children.at(0)->as<ASTIdentifier>()) |
| 130 | if (size_t min_table = checkIdentifier(*ident)) |
| 131 | asts_to_join_on[min_table].push_back(ast); |
| 132 | } |
| 133 | else |
| 134 | { |
| 135 | ands_only = false; |
| 136 | asts_to_join_on.clear(); |
| 137 | } |
| 138 | } |
| 139 | |
| 140 | bool complex() const { return !ands_only; } |
| 141 | bool matchAny(size_t t) const { return asts_to_join_on.count(t); } |
| 142 | |
| 143 | ASTPtr makeOnExpression(size_t table_pos) |
| 144 | { |
| 145 | if (!asts_to_join_on.count(table_pos)) |
| 146 | return {}; |
| 147 | |
| 148 | std::vector<ASTPtr> & expressions = asts_to_join_on[table_pos]; |
| 149 | |
| 150 | if (expressions.size() == 1) |
| 151 | return expressions[0]->clone(); |
| 152 | |
| 153 | std::vector<ASTPtr> arguments; |
| 154 | arguments.reserve(expressions.size()); |
| 155 | for (auto & ast : expressions) |
| 156 | arguments.emplace_back(ast->clone()); |
| 157 | |
| 158 | return makeASTFunction(NameAnd::name, std::move(arguments)); |
| 159 | } |
| 160 | |
| 161 | private: |
| 162 | const std::vector<JoinedTable> & joined_tables; |
| 163 | std::vector<DatabaseAndTableWithAlias> tables; |
| 164 | std::map<size_t, std::vector<ASTPtr>> asts_to_join_on; |
| 165 | bool ands_only; |
| 166 | |
| 167 | size_t canMoveEqualsToJoinOn(const ASTFunction & node) |
| 168 | { |
| 169 | if (!node.arguments) |
| 170 | throw Exception("Logical error: function requires arguments" , ErrorCodes::LOGICAL_ERROR); |
| 171 | if (node.arguments->children.size() != 2) |
| 172 | return false; |
| 173 | |
| 174 | const auto * left = node.arguments->children[0]->as<ASTIdentifier>(); |
| 175 | const auto * right = node.arguments->children[1]->as<ASTIdentifier>(); |
| 176 | if (!left || !right) |
| 177 | return false; |
| 178 | |
| 179 | return checkIdentifiers(*left, *right); |
| 180 | } |
| 181 | |
| 182 | /// Check if the identifiers are from different joined tables. If it's a self joint, tables should have aliases. |
| 183 | /// select * from t1 a cross join t2 b where a.x = b.x |
| 184 | /// @return table position to attach expression to or 0. |
| 185 | size_t checkIdentifiers(const ASTIdentifier & left, const ASTIdentifier & right) |
| 186 | { |
| 187 | size_t left_table_pos = 0; |
| 188 | bool left_match = IdentifierSemantic::chooseTable(left, tables, left_table_pos); |
| 189 | |
| 190 | size_t right_table_pos = 0; |
| 191 | bool right_match = IdentifierSemantic::chooseTable(right, tables, right_table_pos); |
| 192 | |
| 193 | if (left_match && right_match && (left_table_pos != right_table_pos)) |
| 194 | { |
| 195 | size_t table_pos = std::max(left_table_pos, right_table_pos); |
| 196 | if (joined_tables[table_pos].canAttachOnExpression()) |
| 197 | return table_pos; |
| 198 | } |
| 199 | return 0; |
| 200 | } |
| 201 | |
| 202 | size_t checkIdentifier(const ASTIdentifier & identifier) |
| 203 | { |
| 204 | size_t best_table_pos = 0; |
| 205 | bool match = IdentifierSemantic::chooseTable(identifier, tables, best_table_pos); |
| 206 | |
| 207 | if (match && joined_tables[best_table_pos].canAttachOnExpression()) |
| 208 | return best_table_pos; |
| 209 | return 0; |
| 210 | } |
| 211 | }; |
| 212 | |
| 213 | using CheckExpressionMatcher = ConstOneTypeMatcher<CheckExpressionVisitorData, false>; |
| 214 | using CheckExpressionVisitor = ConstInDepthNodeVisitor<CheckExpressionMatcher, true>; |
| 215 | |
| 216 | |
| 217 | bool getTables(ASTSelectQuery & select, std::vector<JoinedTable> & joined_tables, size_t & num_comma) |
| 218 | { |
| 219 | if (!select.tables()) |
| 220 | return false; |
| 221 | |
| 222 | const auto * tables = select.tables()->as<ASTTablesInSelectQuery>(); |
| 223 | if (!tables) |
| 224 | return false; |
| 225 | |
| 226 | size_t num_tables = tables->children.size(); |
| 227 | if (num_tables < 2) |
| 228 | return false; |
| 229 | |
| 230 | joined_tables.reserve(num_tables); |
| 231 | size_t num_array_join = 0; |
| 232 | size_t num_using = 0; |
| 233 | |
| 234 | for (auto & child : tables->children) |
| 235 | { |
| 236 | joined_tables.emplace_back(JoinedTable(child)); |
| 237 | JoinedTable & t = joined_tables.back(); |
| 238 | if (t.array_join) |
| 239 | { |
| 240 | ++num_array_join; |
| 241 | continue; |
| 242 | } |
| 243 | |
| 244 | if (t.has_using) |
| 245 | { |
| 246 | ++num_using; |
| 247 | continue; |
| 248 | } |
| 249 | |
| 250 | if (auto * join = t.join) |
| 251 | if (join->kind == ASTTableJoin::Kind::Comma) |
| 252 | ++num_comma; |
| 253 | } |
| 254 | |
| 255 | if (num_using && (num_tables - num_array_join) > 2) |
| 256 | throw Exception("Multiple CROSS/COMMA JOIN do not support USING" , ErrorCodes::NOT_IMPLEMENTED); |
| 257 | |
| 258 | if (num_comma && (num_comma != (joined_tables.size() - 1))) |
| 259 | throw Exception("Mix of COMMA and other JOINS is not supported" , ErrorCodes::NOT_IMPLEMENTED); |
| 260 | |
| 261 | if (num_array_join || num_using) |
| 262 | return false; |
| 263 | return true; |
| 264 | } |
| 265 | |
| 266 | } |
| 267 | |
| 268 | |
| 269 | void CrossToInnerJoinMatcher::visit(ASTPtr & ast, Data & data) |
| 270 | { |
| 271 | if (auto * t = ast->as<ASTSelectQuery>()) |
| 272 | visit(*t, ast, data); |
| 273 | } |
| 274 | |
| 275 | void CrossToInnerJoinMatcher::visit(ASTSelectQuery & select, ASTPtr &, Data & data) |
| 276 | { |
| 277 | size_t num_comma = 0; |
| 278 | std::vector<JoinedTable> joined_tables; |
| 279 | if (!getTables(select, joined_tables, num_comma)) |
| 280 | return; |
| 281 | |
| 282 | /// COMMA to CROSS |
| 283 | |
| 284 | if (num_comma) |
| 285 | { |
| 286 | for (auto & table : joined_tables) |
| 287 | table.rewriteCommaToCross(); |
| 288 | } |
| 289 | |
| 290 | /// CROSS to INNER |
| 291 | |
| 292 | if (!select.where()) |
| 293 | return; |
| 294 | |
| 295 | CheckExpressionVisitor::Data visitor_data{joined_tables}; |
| 296 | CheckExpressionVisitor(visitor_data).visit(select.where()); |
| 297 | |
| 298 | if (visitor_data.complex()) |
| 299 | return; |
| 300 | |
| 301 | for (size_t i = 1; i < joined_tables.size(); ++i) |
| 302 | { |
| 303 | if (visitor_data.matchAny(i)) |
| 304 | { |
| 305 | ASTTableJoin & join = *joined_tables[i].join; |
| 306 | join.kind = ASTTableJoin::Kind::Inner; |
| 307 | join.strictness = ASTTableJoin::Strictness::All; |
| 308 | |
| 309 | join.on_expression = visitor_data.makeOnExpression(i); |
| 310 | join.children.push_back(join.on_expression); |
| 311 | data.done = true; |
| 312 | } |
| 313 | } |
| 314 | } |
| 315 | |
| 316 | } |
| 317 | |