1#include "duckdb/function/scalar/generic_functions.hpp"
2#include "duckdb/common/operator/comparison_operators.hpp"
3#include "duckdb/execution/expression_executor.hpp"
4#include "duckdb/planner/expression/bound_function_expression.hpp"
5
6namespace duckdb {
7
8struct ConstantOrNullBindData : public FunctionData {
9 explicit ConstantOrNullBindData(Value val) : value(std::move(val)) {
10 }
11
12 Value value;
13
14public:
15 unique_ptr<FunctionData> Copy() const override {
16 return make_uniq<ConstantOrNullBindData>(args: value);
17 }
18
19 bool Equals(const FunctionData &other_p) const override {
20 auto &other = other_p.Cast<ConstantOrNullBindData>();
21 return value == other.value;
22 }
23};
24
25static void ConstantOrNullFunction(DataChunk &args, ExpressionState &state, Vector &result) {
26 auto &func_expr = state.expr.Cast<BoundFunctionExpression>();
27 auto &info = func_expr.bind_info->Cast<ConstantOrNullBindData>();
28 result.Reference(value: info.value);
29 for (idx_t idx = 1; idx < args.ColumnCount(); idx++) {
30 switch (args.data[idx].GetVectorType()) {
31 case VectorType::FLAT_VECTOR: {
32 auto &input_mask = FlatVector::Validity(vector&: args.data[idx]);
33 if (!input_mask.AllValid()) {
34 // there are null values: need to merge them into the result
35 result.Flatten(count: args.size());
36 auto &result_mask = FlatVector::Validity(vector&: result);
37 result_mask.Combine(other: input_mask, count: args.size());
38 }
39 break;
40 }
41 case VectorType::CONSTANT_VECTOR: {
42 if (ConstantVector::IsNull(vector: args.data[idx])) {
43 // input is constant null, return constant null
44 result.Reference(value: info.value);
45 ConstantVector::SetNull(vector&: result, is_null: true);
46 return;
47 }
48 break;
49 }
50 default: {
51 UnifiedVectorFormat vdata;
52 args.data[idx].ToUnifiedFormat(count: args.size(), data&: vdata);
53 if (!vdata.validity.AllValid()) {
54 result.Flatten(count: args.size());
55 auto &result_mask = FlatVector::Validity(vector&: result);
56 for (idx_t i = 0; i < args.size(); i++) {
57 if (!vdata.validity.RowIsValid(row_idx: vdata.sel->get_index(idx: i))) {
58 result_mask.SetInvalid(i);
59 }
60 }
61 }
62 break;
63 }
64 }
65 }
66}
67
68ScalarFunction ConstantOrNull::GetFunction(const LogicalType &return_type) {
69 return ScalarFunction("constant_or_null", {return_type, LogicalType::ANY}, return_type, ConstantOrNullFunction);
70}
71
72unique_ptr<FunctionData> ConstantOrNull::Bind(Value value) {
73 return make_uniq<ConstantOrNullBindData>(args: std::move(value));
74}
75
76bool ConstantOrNull::IsConstantOrNull(BoundFunctionExpression &expr, const Value &val) {
77 if (expr.function.name != "constant_or_null") {
78 return false;
79 }
80 D_ASSERT(expr.bind_info);
81 auto &bind_data = expr.bind_info->Cast<ConstantOrNullBindData>();
82 D_ASSERT(bind_data.value.type() == val.type());
83 return bind_data.value == val;
84}
85
86unique_ptr<FunctionData> ConstantOrNullBind(ClientContext &context, ScalarFunction &bound_function,
87 vector<unique_ptr<Expression>> &arguments) {
88 if (arguments[0]->HasParameter()) {
89 throw ParameterNotResolvedException();
90 }
91 if (!arguments[0]->IsFoldable()) {
92 throw BinderException("ConstantOrNull requires a constant input");
93 }
94 D_ASSERT(arguments.size() >= 2);
95 auto value = ExpressionExecutor::EvaluateScalar(context, expr: *arguments[0]);
96 bound_function.return_type = arguments[0]->return_type;
97 return make_uniq<ConstantOrNullBindData>(args: std::move(value));
98}
99
100void ConstantOrNull::RegisterFunction(BuiltinFunctions &set) {
101 auto fun = ConstantOrNull::GetFunction(return_type: LogicalType::ANY);
102 fun.bind = ConstantOrNullBind;
103 fun.varargs = LogicalType::ANY;
104 set.AddFunction(function: fun);
105}
106
107} // namespace duckdb
108