| 1 | #include "duckdb/planner/expression_iterator.hpp" |
| 2 | |
| 3 | #include "duckdb/planner/bound_query_node.hpp" |
| 4 | #include "duckdb/planner/expression/list.hpp" |
| 5 | #include "duckdb/planner/query_node/bound_select_node.hpp" |
| 6 | #include "duckdb/planner/query_node/bound_set_operation_node.hpp" |
| 7 | #include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" |
| 8 | #include "duckdb/planner/tableref/list.hpp" |
| 9 | |
| 10 | namespace duckdb { |
| 11 | |
| 12 | void ExpressionIterator::EnumerateChildren(const Expression &expr, |
| 13 | const std::function<void(const Expression &child)> &callback) { |
| 14 | EnumerateChildren(expression&: (Expression &)expr, callback: [&](unique_ptr<Expression> &child) { callback(*child); }); |
| 15 | } |
| 16 | |
| 17 | void ExpressionIterator::EnumerateChildren(Expression &expr, const std::function<void(Expression &child)> &callback) { |
| 18 | EnumerateChildren(expression&: expr, callback: [&](unique_ptr<Expression> &child) { callback(*child); }); |
| 19 | } |
| 20 | |
| 21 | void ExpressionIterator::EnumerateChildren(Expression &expr, |
| 22 | const std::function<void(unique_ptr<Expression> &child)> &callback) { |
| 23 | switch (expr.expression_class) { |
| 24 | case ExpressionClass::BOUND_AGGREGATE: { |
| 25 | auto &aggr_expr = expr.Cast<BoundAggregateExpression>(); |
| 26 | for (auto &child : aggr_expr.children) { |
| 27 | callback(child); |
| 28 | } |
| 29 | if (aggr_expr.filter) { |
| 30 | callback(aggr_expr.filter); |
| 31 | } |
| 32 | if (aggr_expr.order_bys) { |
| 33 | for (auto &order : aggr_expr.order_bys->orders) { |
| 34 | callback(order.expression); |
| 35 | } |
| 36 | } |
| 37 | break; |
| 38 | } |
| 39 | case ExpressionClass::BOUND_BETWEEN: { |
| 40 | auto &between_expr = expr.Cast<BoundBetweenExpression>(); |
| 41 | callback(between_expr.input); |
| 42 | callback(between_expr.lower); |
| 43 | callback(between_expr.upper); |
| 44 | break; |
| 45 | } |
| 46 | case ExpressionClass::BOUND_CASE: { |
| 47 | auto &case_expr = expr.Cast<BoundCaseExpression>(); |
| 48 | for (auto &case_check : case_expr.case_checks) { |
| 49 | callback(case_check.when_expr); |
| 50 | callback(case_check.then_expr); |
| 51 | } |
| 52 | callback(case_expr.else_expr); |
| 53 | break; |
| 54 | } |
| 55 | case ExpressionClass::BOUND_CAST: { |
| 56 | auto &cast_expr = expr.Cast<BoundCastExpression>(); |
| 57 | callback(cast_expr.child); |
| 58 | break; |
| 59 | } |
| 60 | case ExpressionClass::BOUND_COMPARISON: { |
| 61 | auto &comp_expr = expr.Cast<BoundComparisonExpression>(); |
| 62 | callback(comp_expr.left); |
| 63 | callback(comp_expr.right); |
| 64 | break; |
| 65 | } |
| 66 | case ExpressionClass::BOUND_CONJUNCTION: { |
| 67 | auto &conj_expr = expr.Cast<BoundConjunctionExpression>(); |
| 68 | for (auto &child : conj_expr.children) { |
| 69 | callback(child); |
| 70 | } |
| 71 | break; |
| 72 | } |
| 73 | case ExpressionClass::BOUND_FUNCTION: { |
| 74 | auto &func_expr = expr.Cast<BoundFunctionExpression>(); |
| 75 | for (auto &child : func_expr.children) { |
| 76 | callback(child); |
| 77 | } |
| 78 | break; |
| 79 | } |
| 80 | case ExpressionClass::BOUND_OPERATOR: { |
| 81 | auto &op_expr = expr.Cast<BoundOperatorExpression>(); |
| 82 | for (auto &child : op_expr.children) { |
| 83 | callback(child); |
| 84 | } |
| 85 | break; |
| 86 | } |
| 87 | case ExpressionClass::BOUND_SUBQUERY: { |
| 88 | auto &subquery_expr = expr.Cast<BoundSubqueryExpression>(); |
| 89 | if (subquery_expr.child) { |
| 90 | callback(subquery_expr.child); |
| 91 | } |
| 92 | break; |
| 93 | } |
| 94 | case ExpressionClass::BOUND_WINDOW: { |
| 95 | auto &window_expr = expr.Cast<BoundWindowExpression>(); |
| 96 | for (auto &partition : window_expr.partitions) { |
| 97 | callback(partition); |
| 98 | } |
| 99 | for (auto &order : window_expr.orders) { |
| 100 | callback(order.expression); |
| 101 | } |
| 102 | for (auto &child : window_expr.children) { |
| 103 | callback(child); |
| 104 | } |
| 105 | if (window_expr.filter_expr) { |
| 106 | callback(window_expr.filter_expr); |
| 107 | } |
| 108 | if (window_expr.start_expr) { |
| 109 | callback(window_expr.start_expr); |
| 110 | } |
| 111 | if (window_expr.end_expr) { |
| 112 | callback(window_expr.end_expr); |
| 113 | } |
| 114 | if (window_expr.offset_expr) { |
| 115 | callback(window_expr.offset_expr); |
| 116 | } |
| 117 | if (window_expr.default_expr) { |
| 118 | callback(window_expr.default_expr); |
| 119 | } |
| 120 | break; |
| 121 | } |
| 122 | case ExpressionClass::BOUND_UNNEST: { |
| 123 | auto &unnest_expr = expr.Cast<BoundUnnestExpression>(); |
| 124 | callback(unnest_expr.child); |
| 125 | break; |
| 126 | } |
| 127 | case ExpressionClass::BOUND_COLUMN_REF: |
| 128 | case ExpressionClass::BOUND_LAMBDA_REF: |
| 129 | case ExpressionClass::BOUND_CONSTANT: |
| 130 | case ExpressionClass::BOUND_DEFAULT: |
| 131 | case ExpressionClass::BOUND_PARAMETER: |
| 132 | case ExpressionClass::BOUND_REF: |
| 133 | // these node types have no children |
| 134 | break; |
| 135 | default: |
| 136 | throw InternalException("ExpressionIterator used on unbound expression" ); |
| 137 | } |
| 138 | } |
| 139 | |
| 140 | void ExpressionIterator::EnumerateExpression(unique_ptr<Expression> &expr, |
| 141 | const std::function<void(Expression &child)> &callback) { |
| 142 | if (!expr) { |
| 143 | return; |
| 144 | } |
| 145 | callback(*expr); |
| 146 | ExpressionIterator::EnumerateChildren(expr&: *expr, |
| 147 | callback: [&](unique_ptr<Expression> &child) { EnumerateExpression(expr&: child, callback); }); |
| 148 | } |
| 149 | |
| 150 | void ExpressionIterator::EnumerateTableRefChildren(BoundTableRef &ref, |
| 151 | const std::function<void(Expression &child)> &callback) { |
| 152 | switch (ref.type) { |
| 153 | case TableReferenceType::EXPRESSION_LIST: { |
| 154 | auto &bound_expr_list = ref.Cast<BoundExpressionListRef>(); |
| 155 | for (auto &expr_list : bound_expr_list.values) { |
| 156 | for (auto &expr : expr_list) { |
| 157 | EnumerateExpression(expr, callback); |
| 158 | } |
| 159 | } |
| 160 | break; |
| 161 | } |
| 162 | case TableReferenceType::JOIN: { |
| 163 | auto &bound_join = ref.Cast<BoundJoinRef>(); |
| 164 | if (bound_join.condition) { |
| 165 | EnumerateExpression(expr&: bound_join.condition, callback); |
| 166 | } |
| 167 | EnumerateTableRefChildren(ref&: *bound_join.left, callback); |
| 168 | EnumerateTableRefChildren(ref&: *bound_join.right, callback); |
| 169 | break; |
| 170 | } |
| 171 | case TableReferenceType::SUBQUERY: { |
| 172 | auto &bound_subquery = ref.Cast<BoundSubqueryRef>(); |
| 173 | EnumerateQueryNodeChildren(node&: *bound_subquery.subquery, callback); |
| 174 | break; |
| 175 | } |
| 176 | case TableReferenceType::TABLE_FUNCTION: |
| 177 | case TableReferenceType::EMPTY: |
| 178 | case TableReferenceType::BASE_TABLE: |
| 179 | case TableReferenceType::CTE: |
| 180 | break; |
| 181 | default: |
| 182 | throw NotImplementedException("Unimplemented table reference type in ExpressionIterator" ); |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | void ExpressionIterator::EnumerateQueryNodeChildren(BoundQueryNode &node, |
| 187 | const std::function<void(Expression &child)> &callback) { |
| 188 | switch (node.type) { |
| 189 | case QueryNodeType::SET_OPERATION_NODE: { |
| 190 | auto &bound_setop = node.Cast<BoundSetOperationNode>(); |
| 191 | EnumerateQueryNodeChildren(node&: *bound_setop.left, callback); |
| 192 | EnumerateQueryNodeChildren(node&: *bound_setop.right, callback); |
| 193 | break; |
| 194 | } |
| 195 | case QueryNodeType::RECURSIVE_CTE_NODE: { |
| 196 | auto &cte_node = node.Cast<BoundRecursiveCTENode>(); |
| 197 | EnumerateQueryNodeChildren(node&: *cte_node.left, callback); |
| 198 | EnumerateQueryNodeChildren(node&: *cte_node.right, callback); |
| 199 | break; |
| 200 | } |
| 201 | case QueryNodeType::SELECT_NODE: { |
| 202 | auto &bound_select = node.Cast<BoundSelectNode>(); |
| 203 | for (auto &expr : bound_select.select_list) { |
| 204 | EnumerateExpression(expr, callback); |
| 205 | } |
| 206 | EnumerateExpression(expr&: bound_select.where_clause, callback); |
| 207 | for (auto &expr : bound_select.groups.group_expressions) { |
| 208 | EnumerateExpression(expr, callback); |
| 209 | } |
| 210 | EnumerateExpression(expr&: bound_select.having, callback); |
| 211 | for (auto &expr : bound_select.aggregates) { |
| 212 | EnumerateExpression(expr, callback); |
| 213 | } |
| 214 | for (auto &entry : bound_select.unnests) { |
| 215 | for (auto &expr : entry.second.expressions) { |
| 216 | EnumerateExpression(expr, callback); |
| 217 | } |
| 218 | } |
| 219 | for (auto &expr : bound_select.windows) { |
| 220 | EnumerateExpression(expr, callback); |
| 221 | } |
| 222 | if (bound_select.from_table) { |
| 223 | EnumerateTableRefChildren(ref&: *bound_select.from_table, callback); |
| 224 | } |
| 225 | break; |
| 226 | } |
| 227 | default: |
| 228 | throw NotImplementedException("Unimplemented query node in ExpressionIterator" ); |
| 229 | } |
| 230 | for (idx_t i = 0; i < node.modifiers.size(); i++) { |
| 231 | switch (node.modifiers[i]->type) { |
| 232 | case ResultModifierType::DISTINCT_MODIFIER: |
| 233 | for (auto &expr : node.modifiers[i]->Cast<BoundDistinctModifier>().target_distincts) { |
| 234 | EnumerateExpression(expr, callback); |
| 235 | } |
| 236 | break; |
| 237 | case ResultModifierType::ORDER_MODIFIER: |
| 238 | for (auto &order : node.modifiers[i]->Cast<BoundOrderModifier>().orders) { |
| 239 | EnumerateExpression(expr&: order.expression, callback); |
| 240 | } |
| 241 | break; |
| 242 | default: |
| 243 | break; |
| 244 | } |
| 245 | } |
| 246 | } |
| 247 | |
| 248 | } // namespace duckdb |
| 249 | |