1#include "duckdb/parser/expression/comparison_expression.hpp"
2#include "duckdb/planner/expression/bound_cast_expression.hpp"
3#include "duckdb/planner/expression/bound_constant_expression.hpp"
4#include "duckdb/planner/expression/bound_comparison_expression.hpp"
5#include "duckdb/planner/expression/bound_function_expression.hpp"
6#include "duckdb/planner/expression/bound_parameter_expression.hpp"
7#include "duckdb/planner/expression_binder.hpp"
8#include "duckdb/catalog/catalog_entry/collate_catalog_entry.hpp"
9#include "duckdb/common/string_util.hpp"
10
11#include "duckdb/function/scalar/string_functions.hpp"
12
13#include "duckdb/common/types/decimal.hpp"
14
15#include "duckdb/main/config.hpp"
16#include "duckdb/catalog/catalog.hpp"
17#include "duckdb/function/function_binder.hpp"
18
19namespace duckdb {
20
21unique_ptr<Expression> ExpressionBinder::PushCollation(ClientContext &context, unique_ptr<Expression> source,
22 const string &collation_p, bool equality_only) {
23 // replace default collation with system collation
24 string collation;
25 if (collation_p.empty()) {
26 collation = DBConfig::GetConfig(context).options.collation;
27 } else {
28 collation = collation_p;
29 }
30 collation = StringUtil::Lower(str: collation);
31 // bind the collation
32 if (collation.empty() || collation == "binary" || collation == "c" || collation == "posix") {
33 // binary collation: just skip
34 return source;
35 }
36 auto &catalog = Catalog::GetSystemCatalog(context);
37 auto splits = StringUtil::Split(input: StringUtil::Lower(str: collation), split: ".");
38 vector<reference<CollateCatalogEntry>> entries;
39 for (auto &collation_argument : splits) {
40 auto &collation_entry = catalog.GetEntry<CollateCatalogEntry>(context, DEFAULT_SCHEMA, name: collation_argument);
41 if (collation_entry.combinable) {
42 entries.insert(position: entries.begin(), x: collation_entry);
43 } else {
44 if (!entries.empty() && !entries.back().get().combinable) {
45 throw BinderException("Cannot combine collation types \"%s\" and \"%s\"", entries.back().get().name,
46 collation_entry.name);
47 }
48 entries.push_back(x: collation_entry);
49 }
50 }
51 for (auto &entry : entries) {
52 auto &collation_entry = entry.get();
53 if (equality_only && collation_entry.not_required_for_equality) {
54 continue;
55 }
56 vector<unique_ptr<Expression>> children;
57 children.push_back(x: std::move(source));
58
59 FunctionBinder function_binder(context);
60 auto function = function_binder.BindScalarFunction(bound_function: collation_entry.function, children: std::move(children));
61 source = std::move(function);
62 }
63 return source;
64}
65
66void ExpressionBinder::TestCollation(ClientContext &context, const string &collation) {
67 PushCollation(context, source: make_uniq<BoundConstantExpression>(args: Value("")), collation_p: collation);
68}
69
70LogicalType BoundComparisonExpression::BindComparison(LogicalType left_type, LogicalType right_type) {
71 auto result_type = LogicalType::MaxLogicalType(left: left_type, right: right_type);
72 switch (result_type.id()) {
73 case LogicalTypeId::DECIMAL: {
74 // result is a decimal: we need the maximum width and the maximum scale over width
75 vector<LogicalType> argument_types = {left_type, right_type};
76 uint8_t max_width = 0, max_scale = 0, max_width_over_scale = 0;
77 for (idx_t i = 0; i < argument_types.size(); i++) {
78 uint8_t width, scale;
79 auto can_convert = argument_types[i].GetDecimalProperties(width, scale);
80 if (!can_convert) {
81 return result_type;
82 }
83 max_width = MaxValue<uint8_t>(a: width, b: max_width);
84 max_scale = MaxValue<uint8_t>(a: scale, b: max_scale);
85 max_width_over_scale = MaxValue<uint8_t>(a: width - scale, b: max_width_over_scale);
86 }
87 max_width = MaxValue<uint8_t>(a: max_scale + max_width_over_scale, b: max_width);
88 if (max_width > Decimal::MAX_WIDTH_DECIMAL) {
89 // target width does not fit in decimal: truncate the scale (if possible) to try and make it fit
90 max_width = Decimal::MAX_WIDTH_DECIMAL;
91 }
92 return LogicalType::DECIMAL(width: max_width, scale: max_scale);
93 }
94 case LogicalTypeId::VARCHAR:
95 // for comparison with strings, we prefer to bind to the numeric types
96 if (left_type.IsNumeric() || left_type.id() == LogicalTypeId::BOOLEAN) {
97 return left_type;
98 } else if (right_type.IsNumeric() || right_type.id() == LogicalTypeId::BOOLEAN) {
99 return right_type;
100 } else {
101 // else: check if collations are compatible
102 auto left_collation = StringType::GetCollation(type: left_type);
103 auto right_collation = StringType::GetCollation(type: right_type);
104 if (!left_collation.empty() && !right_collation.empty() && left_collation != right_collation) {
105 throw BinderException("Cannot combine types with different collation!");
106 }
107 }
108 return result_type;
109 default:
110 return result_type;
111 }
112}
113
114BindResult ExpressionBinder::BindExpression(ComparisonExpression &expr, idx_t depth) {
115 // first try to bind the children of the case expression
116 string error;
117 BindChild(expr&: expr.left, depth, error);
118 BindChild(expr&: expr.right, depth, error);
119 if (!error.empty()) {
120 return BindResult(error);
121 }
122 // the children have been successfully resolved
123 auto &left = BoundExpression::GetExpression(expr&: *expr.left);
124 auto &right = BoundExpression::GetExpression(expr&: *expr.right);
125 auto left_sql_type = left->return_type;
126 auto right_sql_type = right->return_type;
127 // cast the input types to the same type
128 // now obtain the result type of the input types
129 auto input_type = BoundComparisonExpression::BindComparison(left_type: left_sql_type, right_type: right_sql_type);
130 // add casts (if necessary)
131 left = BoundCastExpression::AddCastToType(context, expr: std::move(left), target_type: input_type,
132 try_cast: input_type.id() == LogicalTypeId::ENUM);
133 right = BoundCastExpression::AddCastToType(context, expr: std::move(right), target_type: input_type,
134 try_cast: input_type.id() == LogicalTypeId::ENUM);
135
136 if (input_type.id() == LogicalTypeId::VARCHAR) {
137 // handle collation
138 auto collation = StringType::GetCollation(type: input_type);
139 left = PushCollation(context, source: std::move(left), collation_p: collation, equality_only: expr.type == ExpressionType::COMPARE_EQUAL);
140 right = PushCollation(context, source: std::move(right), collation_p: collation, equality_only: expr.type == ExpressionType::COMPARE_EQUAL);
141 }
142 // now create the bound comparison expression
143 return BindResult(make_uniq<BoundComparisonExpression>(args&: expr.type, args: std::move(left), args: std::move(right)));
144}
145
146} // namespace duckdb
147