1 | #include "duckdb/parser/expression/comparison_expression.hpp" |
2 | #include "duckdb/planner/expression/bound_cast_expression.hpp" |
3 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
4 | #include "duckdb/planner/expression/bound_comparison_expression.hpp" |
5 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
6 | #include "duckdb/planner/expression/bound_parameter_expression.hpp" |
7 | #include "duckdb/planner/expression_binder.hpp" |
8 | #include "duckdb/catalog/catalog_entry/collate_catalog_entry.hpp" |
9 | #include "duckdb/common/string_util.hpp" |
10 | |
11 | #include "duckdb/function/scalar/string_functions.hpp" |
12 | |
13 | #include "duckdb/common/types/decimal.hpp" |
14 | |
15 | #include "duckdb/main/config.hpp" |
16 | #include "duckdb/catalog/catalog.hpp" |
17 | #include "duckdb/function/function_binder.hpp" |
18 | |
19 | namespace duckdb { |
20 | |
21 | unique_ptr<Expression> ExpressionBinder::PushCollation(ClientContext &context, unique_ptr<Expression> source, |
22 | const string &collation_p, bool equality_only) { |
23 | // replace default collation with system collation |
24 | string collation; |
25 | if (collation_p.empty()) { |
26 | collation = DBConfig::GetConfig(context).options.collation; |
27 | } else { |
28 | collation = collation_p; |
29 | } |
30 | collation = StringUtil::Lower(str: collation); |
31 | // bind the collation |
32 | if (collation.empty() || collation == "binary" || collation == "c" || collation == "posix" ) { |
33 | // binary collation: just skip |
34 | return source; |
35 | } |
36 | auto &catalog = Catalog::GetSystemCatalog(context); |
37 | auto splits = StringUtil::Split(input: StringUtil::Lower(str: collation), split: "." ); |
38 | vector<reference<CollateCatalogEntry>> entries; |
39 | for (auto &collation_argument : splits) { |
40 | auto &collation_entry = catalog.GetEntry<CollateCatalogEntry>(context, DEFAULT_SCHEMA, name: collation_argument); |
41 | if (collation_entry.combinable) { |
42 | entries.insert(position: entries.begin(), x: collation_entry); |
43 | } else { |
44 | if (!entries.empty() && !entries.back().get().combinable) { |
45 | throw BinderException("Cannot combine collation types \"%s\" and \"%s\"" , entries.back().get().name, |
46 | collation_entry.name); |
47 | } |
48 | entries.push_back(x: collation_entry); |
49 | } |
50 | } |
51 | for (auto &entry : entries) { |
52 | auto &collation_entry = entry.get(); |
53 | if (equality_only && collation_entry.not_required_for_equality) { |
54 | continue; |
55 | } |
56 | vector<unique_ptr<Expression>> children; |
57 | children.push_back(x: std::move(source)); |
58 | |
59 | FunctionBinder function_binder(context); |
60 | auto function = function_binder.BindScalarFunction(bound_function: collation_entry.function, children: std::move(children)); |
61 | source = std::move(function); |
62 | } |
63 | return source; |
64 | } |
65 | |
66 | void ExpressionBinder::TestCollation(ClientContext &context, const string &collation) { |
67 | PushCollation(context, source: make_uniq<BoundConstantExpression>(args: Value("" )), collation_p: collation); |
68 | } |
69 | |
70 | LogicalType BoundComparisonExpression::BindComparison(LogicalType left_type, LogicalType right_type) { |
71 | auto result_type = LogicalType::MaxLogicalType(left: left_type, right: right_type); |
72 | switch (result_type.id()) { |
73 | case LogicalTypeId::DECIMAL: { |
74 | // result is a decimal: we need the maximum width and the maximum scale over width |
75 | vector<LogicalType> argument_types = {left_type, right_type}; |
76 | uint8_t max_width = 0, max_scale = 0, max_width_over_scale = 0; |
77 | for (idx_t i = 0; i < argument_types.size(); i++) { |
78 | uint8_t width, scale; |
79 | auto can_convert = argument_types[i].GetDecimalProperties(width, scale); |
80 | if (!can_convert) { |
81 | return result_type; |
82 | } |
83 | max_width = MaxValue<uint8_t>(a: width, b: max_width); |
84 | max_scale = MaxValue<uint8_t>(a: scale, b: max_scale); |
85 | max_width_over_scale = MaxValue<uint8_t>(a: width - scale, b: max_width_over_scale); |
86 | } |
87 | max_width = MaxValue<uint8_t>(a: max_scale + max_width_over_scale, b: max_width); |
88 | if (max_width > Decimal::MAX_WIDTH_DECIMAL) { |
89 | // target width does not fit in decimal: truncate the scale (if possible) to try and make it fit |
90 | max_width = Decimal::MAX_WIDTH_DECIMAL; |
91 | } |
92 | return LogicalType::DECIMAL(width: max_width, scale: max_scale); |
93 | } |
94 | case LogicalTypeId::VARCHAR: |
95 | // for comparison with strings, we prefer to bind to the numeric types |
96 | if (left_type.IsNumeric() || left_type.id() == LogicalTypeId::BOOLEAN) { |
97 | return left_type; |
98 | } else if (right_type.IsNumeric() || right_type.id() == LogicalTypeId::BOOLEAN) { |
99 | return right_type; |
100 | } else { |
101 | // else: check if collations are compatible |
102 | auto left_collation = StringType::GetCollation(type: left_type); |
103 | auto right_collation = StringType::GetCollation(type: right_type); |
104 | if (!left_collation.empty() && !right_collation.empty() && left_collation != right_collation) { |
105 | throw BinderException("Cannot combine types with different collation!" ); |
106 | } |
107 | } |
108 | return result_type; |
109 | default: |
110 | return result_type; |
111 | } |
112 | } |
113 | |
114 | BindResult ExpressionBinder::BindExpression(ComparisonExpression &expr, idx_t depth) { |
115 | // first try to bind the children of the case expression |
116 | string error; |
117 | BindChild(expr&: expr.left, depth, error); |
118 | BindChild(expr&: expr.right, depth, error); |
119 | if (!error.empty()) { |
120 | return BindResult(error); |
121 | } |
122 | // the children have been successfully resolved |
123 | auto &left = BoundExpression::GetExpression(expr&: *expr.left); |
124 | auto &right = BoundExpression::GetExpression(expr&: *expr.right); |
125 | auto left_sql_type = left->return_type; |
126 | auto right_sql_type = right->return_type; |
127 | // cast the input types to the same type |
128 | // now obtain the result type of the input types |
129 | auto input_type = BoundComparisonExpression::BindComparison(left_type: left_sql_type, right_type: right_sql_type); |
130 | // add casts (if necessary) |
131 | left = BoundCastExpression::AddCastToType(context, expr: std::move(left), target_type: input_type, |
132 | try_cast: input_type.id() == LogicalTypeId::ENUM); |
133 | right = BoundCastExpression::AddCastToType(context, expr: std::move(right), target_type: input_type, |
134 | try_cast: input_type.id() == LogicalTypeId::ENUM); |
135 | |
136 | if (input_type.id() == LogicalTypeId::VARCHAR) { |
137 | // handle collation |
138 | auto collation = StringType::GetCollation(type: input_type); |
139 | left = PushCollation(context, source: std::move(left), collation_p: collation, equality_only: expr.type == ExpressionType::COMPARE_EQUAL); |
140 | right = PushCollation(context, source: std::move(right), collation_p: collation, equality_only: expr.type == ExpressionType::COMPARE_EQUAL); |
141 | } |
142 | // now create the bound comparison expression |
143 | return BindResult(make_uniq<BoundComparisonExpression>(args&: expr.type, args: std::move(left), args: std::move(right))); |
144 | } |
145 | |
146 | } // namespace duckdb |
147 | |