1 | #include "duckdb/function/table/range.hpp" |
2 | #include "duckdb/function/table/summary.hpp" |
3 | #include "duckdb/function/table_function.hpp" |
4 | #include "duckdb/function/function_set.hpp" |
5 | #include "duckdb/common/algorithm.hpp" |
6 | #include "duckdb/common/operator/add.hpp" |
7 | #include "duckdb/common/types/timestamp.hpp" |
8 | |
9 | namespace duckdb { |
10 | |
11 | //===--------------------------------------------------------------------===// |
12 | // Range (integers) |
13 | //===--------------------------------------------------------------------===// |
14 | struct RangeFunctionBindData : public TableFunctionData { |
15 | hugeint_t start; |
16 | hugeint_t end; |
17 | hugeint_t increment; |
18 | |
19 | public: |
20 | bool Equals(const FunctionData &other_p) const override { |
21 | auto &other = other_p.Cast<RangeFunctionBindData>(); |
22 | return other.start == start && other.end == end && other.increment == increment; |
23 | } |
24 | }; |
25 | |
26 | template <bool GENERATE_SERIES> |
27 | static void GenerateRangeParameters(const vector<Value> &inputs, RangeFunctionBindData &result) { |
28 | for (auto &input : inputs) { |
29 | if (input.IsNull()) { |
30 | result.start = GENERATE_SERIES ? 1 : 0; |
31 | result.end = 0; |
32 | result.increment = 1; |
33 | return; |
34 | } |
35 | } |
36 | if (inputs.size() < 2) { |
37 | // single argument: only the end is specified |
38 | result.start = 0; |
39 | result.end = inputs[0].GetValue<int64_t>(); |
40 | } else { |
41 | // two arguments: first two arguments are start and end |
42 | result.start = inputs[0].GetValue<int64_t>(); |
43 | result.end = inputs[1].GetValue<int64_t>(); |
44 | } |
45 | if (inputs.size() < 3) { |
46 | result.increment = 1; |
47 | } else { |
48 | result.increment = inputs[2].GetValue<int64_t>(); |
49 | } |
50 | if (result.increment == 0) { |
51 | throw BinderException("interval cannot be 0!" ); |
52 | } |
53 | if (result.start > result.end && result.increment > 0) { |
54 | throw BinderException("start is bigger than end, but increment is positive: cannot generate infinite series" ); |
55 | } else if (result.start < result.end && result.increment < 0) { |
56 | throw BinderException("start is smaller than end, but increment is negative: cannot generate infinite series" ); |
57 | } |
58 | } |
59 | |
60 | template <bool GENERATE_SERIES> |
61 | static unique_ptr<FunctionData> RangeFunctionBind(ClientContext &context, TableFunctionBindInput &input, |
62 | vector<LogicalType> &return_types, vector<string> &names) { |
63 | auto result = make_uniq<RangeFunctionBindData>(); |
64 | auto &inputs = input.inputs; |
65 | GenerateRangeParameters<GENERATE_SERIES>(inputs, *result); |
66 | |
67 | return_types.emplace_back(args: LogicalType::BIGINT); |
68 | if (GENERATE_SERIES) { |
69 | // generate_series has inclusive bounds on the RHS |
70 | if (result->increment < 0) { |
71 | result->end = result->end - 1; |
72 | } else { |
73 | result->end = result->end + 1; |
74 | } |
75 | names.emplace_back(args: "generate_series" ); |
76 | } else { |
77 | names.emplace_back(args: "range" ); |
78 | } |
79 | return std::move(result); |
80 | } |
81 | |
82 | struct RangeFunctionState : public GlobalTableFunctionState { |
83 | RangeFunctionState() : current_idx(0) { |
84 | } |
85 | |
86 | int64_t current_idx; |
87 | }; |
88 | |
89 | static unique_ptr<GlobalTableFunctionState> RangeFunctionInit(ClientContext &context, TableFunctionInitInput &input) { |
90 | return make_uniq<RangeFunctionState>(); |
91 | } |
92 | |
93 | static void RangeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { |
94 | auto &bind_data = data_p.bind_data->Cast<RangeFunctionBindData>(); |
95 | auto &state = data_p.global_state->Cast<RangeFunctionState>(); |
96 | |
97 | auto increment = bind_data.increment; |
98 | auto end = bind_data.end; |
99 | hugeint_t current_value = bind_data.start + increment * state.current_idx; |
100 | int64_t current_value_i64; |
101 | if (!Hugeint::TryCast<int64_t>(input: current_value, result&: current_value_i64)) { |
102 | return; |
103 | } |
104 | int64_t offset = increment < 0 ? 1 : -1; |
105 | idx_t remaining = MinValue<idx_t>(a: Hugeint::Cast<idx_t>(input: (end - current_value + (increment + offset)) / increment), |
106 | STANDARD_VECTOR_SIZE); |
107 | // set the result vector as a sequence vector |
108 | output.data[0].Sequence(start: current_value_i64, increment: Hugeint::Cast<int64_t>(input: increment), count: remaining); |
109 | // increment the index pointer by the remaining count |
110 | state.current_idx += remaining; |
111 | output.SetCardinality(remaining); |
112 | } |
113 | |
114 | unique_ptr<NodeStatistics> RangeCardinality(ClientContext &context, const FunctionData *bind_data_p) { |
115 | auto &bind_data = bind_data_p->Cast<RangeFunctionBindData>(); |
116 | idx_t cardinality = Hugeint::Cast<idx_t>(input: (bind_data.end - bind_data.start) / bind_data.increment); |
117 | return make_uniq<NodeStatistics>(args&: cardinality, args&: cardinality); |
118 | } |
119 | |
120 | //===--------------------------------------------------------------------===// |
121 | // Range (timestamp) |
122 | //===--------------------------------------------------------------------===// |
123 | struct RangeDateTimeBindData : public TableFunctionData { |
124 | timestamp_t start; |
125 | timestamp_t end; |
126 | interval_t increment; |
127 | bool inclusive_bound; |
128 | bool greater_than_check; |
129 | |
130 | public: |
131 | bool Equals(const FunctionData &other_p) const override { |
132 | auto &other = other_p.Cast<RangeDateTimeBindData>(); |
133 | return other.start == start && other.end == end && other.increment == increment && |
134 | other.inclusive_bound == inclusive_bound && other.greater_than_check == greater_than_check; |
135 | } |
136 | |
137 | bool Finished(timestamp_t current_value) const { |
138 | if (greater_than_check) { |
139 | if (inclusive_bound) { |
140 | return current_value > end; |
141 | } else { |
142 | return current_value >= end; |
143 | } |
144 | } else { |
145 | if (inclusive_bound) { |
146 | return current_value < end; |
147 | } else { |
148 | return current_value <= end; |
149 | } |
150 | } |
151 | } |
152 | }; |
153 | |
154 | template <bool GENERATE_SERIES> |
155 | static unique_ptr<FunctionData> RangeDateTimeBind(ClientContext &context, TableFunctionBindInput &input, |
156 | vector<LogicalType> &return_types, vector<string> &names) { |
157 | auto result = make_uniq<RangeDateTimeBindData>(); |
158 | auto &inputs = input.inputs; |
159 | D_ASSERT(inputs.size() == 3); |
160 | result->start = inputs[0].GetValue<timestamp_t>(); |
161 | result->end = inputs[1].GetValue<timestamp_t>(); |
162 | result->increment = inputs[2].GetValue<interval_t>(); |
163 | |
164 | // Infinities either cause errors or infinite loops, so just ban them |
165 | if (!Timestamp::IsFinite(timestamp: result->start) || !Timestamp::IsFinite(timestamp: result->end)) { |
166 | throw BinderException("RANGE with infinite bounds is not supported" ); |
167 | } |
168 | |
169 | if (result->increment.months == 0 && result->increment.days == 0 && result->increment.micros == 0) { |
170 | throw BinderException("interval cannot be 0!" ); |
171 | } |
172 | // all elements should point in the same direction |
173 | if (result->increment.months > 0 || result->increment.days > 0 || result->increment.micros > 0) { |
174 | if (result->increment.months < 0 || result->increment.days < 0 || result->increment.micros < 0) { |
175 | throw BinderException("RANGE with composite interval that has mixed signs is not supported" ); |
176 | } |
177 | result->greater_than_check = true; |
178 | if (result->start > result->end) { |
179 | throw BinderException( |
180 | "start is bigger than end, but increment is positive: cannot generate infinite series" ); |
181 | } |
182 | } else { |
183 | result->greater_than_check = false; |
184 | if (result->start < result->end) { |
185 | throw BinderException( |
186 | "start is smaller than end, but increment is negative: cannot generate infinite series" ); |
187 | } |
188 | } |
189 | return_types.push_back(x: inputs[0].type()); |
190 | if (GENERATE_SERIES) { |
191 | // generate_series has inclusive bounds on the RHS |
192 | result->inclusive_bound = true; |
193 | names.emplace_back(args: "generate_series" ); |
194 | } else { |
195 | result->inclusive_bound = false; |
196 | names.emplace_back(args: "range" ); |
197 | } |
198 | return std::move(result); |
199 | } |
200 | |
201 | struct RangeDateTimeState : public GlobalTableFunctionState { |
202 | explicit RangeDateTimeState(timestamp_t start_p) : current_state(start_p) { |
203 | } |
204 | |
205 | timestamp_t current_state; |
206 | bool finished = false; |
207 | }; |
208 | |
209 | static unique_ptr<GlobalTableFunctionState> RangeDateTimeInit(ClientContext &context, TableFunctionInitInput &input) { |
210 | auto &bind_data = input.bind_data->Cast<RangeDateTimeBindData>(); |
211 | return make_uniq<RangeDateTimeState>(args: bind_data.start); |
212 | } |
213 | |
214 | static void RangeDateTimeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { |
215 | auto &bind_data = data_p.bind_data->Cast<RangeDateTimeBindData>(); |
216 | auto &state = data_p.global_state->Cast<RangeDateTimeState>(); |
217 | if (state.finished) { |
218 | return; |
219 | } |
220 | |
221 | idx_t size = 0; |
222 | auto data = FlatVector::GetData<timestamp_t>(vector&: output.data[0]); |
223 | while (true) { |
224 | data[size++] = state.current_state; |
225 | state.current_state = |
226 | AddOperator::Operation<timestamp_t, interval_t, timestamp_t>(left: state.current_state, right: bind_data.increment); |
227 | if (bind_data.Finished(current_value: state.current_state)) { |
228 | state.finished = true; |
229 | break; |
230 | } |
231 | if (size >= STANDARD_VECTOR_SIZE) { |
232 | break; |
233 | } |
234 | } |
235 | output.SetCardinality(size); |
236 | } |
237 | |
238 | void RangeTableFunction::RegisterFunction(BuiltinFunctions &set) { |
239 | TableFunctionSet range("range" ); |
240 | |
241 | TableFunction range_function({LogicalType::BIGINT}, RangeFunction, RangeFunctionBind<false>, RangeFunctionInit); |
242 | range_function.cardinality = RangeCardinality; |
243 | |
244 | // single argument range: (end) - implicit start = 0 and increment = 1 |
245 | range.AddFunction(function: range_function); |
246 | // two arguments range: (start, end) - implicit increment = 1 |
247 | range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; |
248 | range.AddFunction(function: range_function); |
249 | // three arguments range: (start, end, increment) |
250 | range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; |
251 | range.AddFunction(function: range_function); |
252 | range.AddFunction(function: TableFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, |
253 | RangeDateTimeFunction, RangeDateTimeBind<false>, RangeDateTimeInit)); |
254 | set.AddFunction(set: range); |
255 | // generate_series: similar to range, but inclusive instead of exclusive bounds on the RHS |
256 | TableFunctionSet generate_series("generate_series" ); |
257 | range_function.bind = RangeFunctionBind<true>; |
258 | range_function.arguments = {LogicalType::BIGINT}; |
259 | generate_series.AddFunction(function: range_function); |
260 | range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; |
261 | generate_series.AddFunction(function: range_function); |
262 | range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; |
263 | generate_series.AddFunction(function: range_function); |
264 | generate_series.AddFunction(function: TableFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, |
265 | RangeDateTimeFunction, RangeDateTimeBind<true>, RangeDateTimeInit)); |
266 | set.AddFunction(set: generate_series); |
267 | } |
268 | |
269 | void BuiltinFunctions::RegisterTableFunctions() { |
270 | CheckpointFunction::RegisterFunction(set&: *this); |
271 | GlobTableFunction::RegisterFunction(set&: *this); |
272 | RangeTableFunction::RegisterFunction(set&: *this); |
273 | RepeatTableFunction::RegisterFunction(set&: *this); |
274 | SummaryTableFunction::RegisterFunction(set&: *this); |
275 | UnnestTableFunction::RegisterFunction(set&: *this); |
276 | RepeatRowTableFunction::RegisterFunction(set&: *this); |
277 | } |
278 | |
279 | } // namespace duckdb |
280 | |