1#include "duckdb/parser/expression/operator_expression.hpp"
2#include "duckdb/planner/expression/bound_cast_expression.hpp"
3#include "duckdb/planner/expression/bound_operator_expression.hpp"
4#include "duckdb/planner/expression/bound_case_expression.hpp"
5#include "duckdb/planner/expression/bound_parameter_expression.hpp"
6#include "duckdb/parser/expression/function_expression.hpp"
7#include "duckdb/planner/expression_binder.hpp"
8
9namespace duckdb {
10
11static LogicalType ResolveNotType(OperatorExpression &op, vector<unique_ptr<Expression>> &children) {
12 // NOT expression, cast child to BOOLEAN
13 D_ASSERT(children.size() == 1);
14 children[0] = BoundCastExpression::AddDefaultCastToType(expr: std::move(children[0]), target_type: LogicalType::BOOLEAN);
15 return LogicalType(LogicalTypeId::BOOLEAN);
16}
17
18static LogicalType ResolveInType(OperatorExpression &op, vector<unique_ptr<Expression>> &children) {
19 if (children.empty()) {
20 throw InternalException("IN requires at least a single child node");
21 }
22 // get the maximum type from the children
23 LogicalType max_type = children[0]->return_type;
24 bool any_varchar = children[0]->return_type == LogicalType::VARCHAR;
25 bool any_enum = children[0]->return_type.id() == LogicalTypeId::ENUM;
26 for (idx_t i = 1; i < children.size(); i++) {
27 max_type = LogicalType::MaxLogicalType(left: max_type, right: children[i]->return_type);
28 if (children[i]->return_type == LogicalType::VARCHAR) {
29 any_varchar = true;
30 }
31 if (children[i]->return_type.id() == LogicalTypeId::ENUM) {
32 any_enum = true;
33 }
34 }
35 if (any_varchar && any_enum) {
36 // For the coalesce function, we must be sure we always upcast the parameters to VARCHAR, if there are at least
37 // one enum and one varchar
38 max_type = LogicalType::VARCHAR;
39 }
40
41 // cast all children to the same type
42 for (idx_t i = 0; i < children.size(); i++) {
43 children[i] = BoundCastExpression::AddDefaultCastToType(expr: std::move(children[i]), target_type: max_type);
44 }
45 // (NOT) IN always returns a boolean
46 return LogicalType::BOOLEAN;
47}
48
49static LogicalType ResolveOperatorType(OperatorExpression &op, vector<unique_ptr<Expression>> &children) {
50 switch (op.type) {
51 case ExpressionType::OPERATOR_IS_NULL:
52 case ExpressionType::OPERATOR_IS_NOT_NULL:
53 // IS (NOT) NULL always returns a boolean, and does not cast its children
54 if (!children[0]->return_type.IsValid()) {
55 throw ParameterNotResolvedException();
56 }
57 return LogicalType::BOOLEAN;
58 case ExpressionType::COMPARE_IN:
59 case ExpressionType::COMPARE_NOT_IN:
60 return ResolveInType(op, children);
61 case ExpressionType::OPERATOR_COALESCE: {
62 ResolveInType(op, children);
63 return children[0]->return_type;
64 }
65 case ExpressionType::OPERATOR_NOT:
66 return ResolveNotType(op, children);
67 default:
68 throw InternalException("Unrecognized expression type for ResolveOperatorType");
69 }
70}
71
72BindResult ExpressionBinder::BindGroupingFunction(OperatorExpression &op, idx_t depth) {
73 return BindResult("GROUPING function is not supported here");
74}
75
76BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) {
77 if (op.type == ExpressionType::GROUPING_FUNCTION) {
78 return BindGroupingFunction(op, depth);
79 }
80 // bind the children of the operator expression
81 string error;
82 for (idx_t i = 0; i < op.children.size(); i++) {
83 BindChild(expr&: op.children[i], depth, error);
84 }
85 if (!error.empty()) {
86 return BindResult(error);
87 }
88 // all children bound successfully
89 string function_name;
90 switch (op.type) {
91 case ExpressionType::ARRAY_EXTRACT: {
92 D_ASSERT(op.children[0]->expression_class == ExpressionClass::BOUND_EXPRESSION);
93 auto &b_exp = BoundExpression::GetExpression(expr&: *op.children[0]);
94 if (b_exp->return_type.id() == LogicalTypeId::MAP) {
95 function_name = "map_extract";
96 } else {
97 function_name = "array_extract";
98 }
99 break;
100 }
101 case ExpressionType::ARRAY_SLICE:
102 function_name = "array_slice";
103 break;
104 case ExpressionType::STRUCT_EXTRACT: {
105 D_ASSERT(op.children.size() == 2);
106 D_ASSERT(op.children[0]->expression_class == ExpressionClass::BOUND_EXPRESSION);
107 D_ASSERT(op.children[1]->expression_class == ExpressionClass::BOUND_EXPRESSION);
108 auto &extract_exp = BoundExpression::GetExpression(expr&: *op.children[0]);
109 auto &name_exp = BoundExpression::GetExpression(expr&: *op.children[1]);
110 auto extract_expr_type = extract_exp->return_type.id();
111 if (extract_expr_type != LogicalTypeId::STRUCT && extract_expr_type != LogicalTypeId::UNION &&
112 extract_expr_type != LogicalTypeId::SQLNULL) {
113 return BindResult(StringUtil::Format(
114 fmt_str: "Cannot extract field %s from expression \"%s\" because it is not a struct or a union",
115 params: name_exp->ToString(), params: extract_exp->ToString()));
116 }
117 function_name = extract_expr_type == LogicalTypeId::UNION ? "union_extract" : "struct_extract";
118 break;
119 }
120 case ExpressionType::ARRAY_CONSTRUCTOR:
121 function_name = "list_value";
122 break;
123 case ExpressionType::ARROW:
124 function_name = "json_extract";
125 break;
126 default:
127 break;
128 }
129 if (!function_name.empty()) {
130 auto function = make_uniq_base<ParsedExpression, FunctionExpression>(args&: function_name, args: std::move(op.children));
131 return BindExpression(expr_ptr&: function, depth, root_expression: false);
132 }
133
134 vector<unique_ptr<Expression>> children;
135 for (idx_t i = 0; i < op.children.size(); i++) {
136 D_ASSERT(op.children[i]->expression_class == ExpressionClass::BOUND_EXPRESSION);
137 children.push_back(x: std::move(BoundExpression::GetExpression(expr&: *op.children[i])));
138 }
139 // now resolve the types
140 LogicalType result_type = ResolveOperatorType(op, children);
141 if (op.type == ExpressionType::OPERATOR_COALESCE) {
142 if (children.empty()) {
143 throw BinderException("COALESCE needs at least one child");
144 }
145 if (children.size() == 1) {
146 return BindResult(std::move(children[0]));
147 }
148 }
149
150 auto result = make_uniq<BoundOperatorExpression>(args&: op.type, args&: result_type);
151 for (auto &child : children) {
152 result->children.push_back(x: std::move(child));
153 }
154 return BindResult(std::move(result));
155}
156
157} // namespace duckdb
158