1#include "duckdb/function/scalar/sequence_functions.hpp"
2
3#include "duckdb/catalog/catalog.hpp"
4#include "duckdb/catalog/dependency_list.hpp"
5#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp"
6#include "duckdb/common/exception.hpp"
7#include "duckdb/common/vector_operations/vector_operations.hpp"
8#include "duckdb/execution/expression_executor.hpp"
9#include "duckdb/planner/expression/bound_function_expression.hpp"
10#include "duckdb/transaction/duck_transaction.hpp"
11#include "duckdb/common/vector_operations/unary_executor.hpp"
12#include "duckdb/common/operator/add.hpp"
13#include "duckdb/planner/binder.hpp"
14
15namespace duckdb {
16
17struct NextvalBindData : public FunctionData {
18 explicit NextvalBindData(optional_ptr<SequenceCatalogEntry> sequence) : sequence(sequence) {
19 }
20
21 //! The sequence to use for the nextval computation; only if the sequence is a constant
22 optional_ptr<SequenceCatalogEntry> sequence;
23
24 unique_ptr<FunctionData> Copy() const override {
25 return make_uniq<NextvalBindData>(args: sequence);
26 }
27
28 bool Equals(const FunctionData &other_p) const override {
29 auto &other = other_p.Cast<NextvalBindData>();
30 return sequence == other.sequence;
31 }
32};
33
34struct CurrentSequenceValueOperator {
35 static int64_t Operation(DuckTransaction &transaction, SequenceCatalogEntry &seq) {
36 lock_guard<mutex> seqlock(seq.lock);
37 int64_t result;
38 if (seq.usage_count == 0u) {
39 throw SequenceException("currval: sequence is not yet defined in this session");
40 }
41 result = seq.last_value;
42 return result;
43 }
44};
45
46struct NextSequenceValueOperator {
47 static int64_t Operation(DuckTransaction &transaction, SequenceCatalogEntry &seq) {
48 lock_guard<mutex> seqlock(seq.lock);
49 int64_t result;
50 result = seq.counter;
51 bool overflow = !TryAddOperator::Operation(left: seq.counter, right: seq.increment, result&: seq.counter);
52 if (seq.cycle) {
53 if (overflow) {
54 seq.counter = seq.increment < 0 ? seq.max_value : seq.min_value;
55 } else if (seq.counter < seq.min_value) {
56 seq.counter = seq.max_value;
57 } else if (seq.counter > seq.max_value) {
58 seq.counter = seq.min_value;
59 }
60 } else {
61 if (result < seq.min_value || (overflow && seq.increment < 0)) {
62 throw SequenceException("nextval: reached minimum value of sequence \"%s\" (%lld)", seq.name,
63 seq.min_value);
64 }
65 if (result > seq.max_value || overflow) {
66 throw SequenceException("nextval: reached maximum value of sequence \"%s\" (%lld)", seq.name,
67 seq.max_value);
68 }
69 }
70 seq.last_value = result;
71 seq.usage_count++;
72 if (!seq.temporary) {
73 transaction.sequence_usage[&seq] = SequenceValue(seq.usage_count, seq.counter);
74 }
75 return result;
76 }
77};
78
79SequenceCatalogEntry &BindSequence(ClientContext &context, const string &name) {
80 auto qname = QualifiedName::Parse(input: name);
81 // fetch the sequence from the catalog
82 Binder::BindSchemaOrCatalog(context, catalog&: qname.catalog, schema&: qname.schema);
83 return Catalog::GetEntry<SequenceCatalogEntry>(context, catalog_name: qname.catalog, schema_name: qname.schema, name: qname.name);
84}
85
86template <class OP>
87static void NextValFunction(DataChunk &args, ExpressionState &state, Vector &result) {
88 auto &func_expr = state.expr.Cast<BoundFunctionExpression>();
89 auto &info = func_expr.bind_info->Cast<NextvalBindData>();
90 auto &input = args.data[0];
91
92 auto &context = state.GetContext();
93 if (info.sequence) {
94 auto &sequence = *info.sequence;
95 auto &transaction = DuckTransaction::Get(context, catalog&: sequence.catalog);
96 // sequence to use is hard coded
97 // increment the sequence
98 result.SetVectorType(VectorType::FLAT_VECTOR);
99 auto result_data = FlatVector::GetData<int64_t>(vector&: result);
100 for (idx_t i = 0; i < args.size(); i++) {
101 // get the next value from the sequence
102 result_data[i] = OP::Operation(transaction, sequence);
103 }
104 } else {
105 // sequence to use comes from the input
106 UnaryExecutor::Execute<string_t, int64_t>(input, result, args.size(), [&](string_t value) {
107 // fetch the sequence from the catalog
108 auto &sequence = BindSequence(context, name: value.GetString());
109 // finally get the next value from the sequence
110 auto &transaction = DuckTransaction::Get(context, catalog&: sequence.catalog);
111 return OP::Operation(transaction, sequence);
112 });
113 }
114}
115
116static unique_ptr<FunctionData> NextValBind(ClientContext &context, ScalarFunction &bound_function,
117 vector<unique_ptr<Expression>> &arguments) {
118 optional_ptr<SequenceCatalogEntry> sequence;
119 if (arguments[0]->IsFoldable()) {
120 // parameter to nextval function is a foldable constant
121 // evaluate the constant and perform the catalog lookup already
122 auto seqname = ExpressionExecutor::EvaluateScalar(context, expr: *arguments[0]);
123 if (!seqname.IsNull()) {
124 sequence = &BindSequence(context, name: seqname.ToString());
125 }
126 }
127 return make_uniq<NextvalBindData>(args&: sequence);
128}
129
130static void NextValDependency(BoundFunctionExpression &expr, DependencyList &dependencies) {
131 auto &info = expr.bind_info->Cast<NextvalBindData>();
132 if (info.sequence) {
133 dependencies.AddDependency(entry&: *info.sequence);
134 }
135}
136
137void NextvalFun::RegisterFunction(BuiltinFunctions &set) {
138 ScalarFunction next_val("nextval", {LogicalType::VARCHAR}, LogicalType::BIGINT,
139 NextValFunction<NextSequenceValueOperator>, NextValBind, NextValDependency);
140 next_val.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS;
141 set.AddFunction(function: next_val);
142}
143
144void CurrvalFun::RegisterFunction(BuiltinFunctions &set) {
145 ScalarFunction curr_val("currval", {LogicalType::VARCHAR}, LogicalType::BIGINT,
146 NextValFunction<CurrentSequenceValueOperator>, NextValBind, NextValDependency);
147 curr_val.side_effects = FunctionSideEffects::HAS_SIDE_EFFECTS;
148 set.AddFunction(function: curr_val);
149}
150
151} // namespace duckdb
152