| 1 | #include "duckdb/function/table/range.hpp" |
| 2 | #include "duckdb/function/table/summary.hpp" |
| 3 | #include "duckdb/function/table_function.hpp" |
| 4 | #include "duckdb/function/function_set.hpp" |
| 5 | #include "duckdb/common/algorithm.hpp" |
| 6 | #include "duckdb/common/operator/add.hpp" |
| 7 | #include "duckdb/common/types/timestamp.hpp" |
| 8 | |
| 9 | namespace duckdb { |
| 10 | |
| 11 | //===--------------------------------------------------------------------===// |
| 12 | // Range (integers) |
| 13 | //===--------------------------------------------------------------------===// |
| 14 | struct RangeFunctionBindData : public TableFunctionData { |
| 15 | hugeint_t start; |
| 16 | hugeint_t end; |
| 17 | hugeint_t increment; |
| 18 | |
| 19 | public: |
| 20 | bool Equals(const FunctionData &other_p) const override { |
| 21 | auto &other = other_p.Cast<RangeFunctionBindData>(); |
| 22 | return other.start == start && other.end == end && other.increment == increment; |
| 23 | } |
| 24 | }; |
| 25 | |
| 26 | template <bool GENERATE_SERIES> |
| 27 | static void GenerateRangeParameters(const vector<Value> &inputs, RangeFunctionBindData &result) { |
| 28 | for (auto &input : inputs) { |
| 29 | if (input.IsNull()) { |
| 30 | result.start = GENERATE_SERIES ? 1 : 0; |
| 31 | result.end = 0; |
| 32 | result.increment = 1; |
| 33 | return; |
| 34 | } |
| 35 | } |
| 36 | if (inputs.size() < 2) { |
| 37 | // single argument: only the end is specified |
| 38 | result.start = 0; |
| 39 | result.end = inputs[0].GetValue<int64_t>(); |
| 40 | } else { |
| 41 | // two arguments: first two arguments are start and end |
| 42 | result.start = inputs[0].GetValue<int64_t>(); |
| 43 | result.end = inputs[1].GetValue<int64_t>(); |
| 44 | } |
| 45 | if (inputs.size() < 3) { |
| 46 | result.increment = 1; |
| 47 | } else { |
| 48 | result.increment = inputs[2].GetValue<int64_t>(); |
| 49 | } |
| 50 | if (result.increment == 0) { |
| 51 | throw BinderException("interval cannot be 0!" ); |
| 52 | } |
| 53 | if (result.start > result.end && result.increment > 0) { |
| 54 | throw BinderException("start is bigger than end, but increment is positive: cannot generate infinite series" ); |
| 55 | } else if (result.start < result.end && result.increment < 0) { |
| 56 | throw BinderException("start is smaller than end, but increment is negative: cannot generate infinite series" ); |
| 57 | } |
| 58 | } |
| 59 | |
| 60 | template <bool GENERATE_SERIES> |
| 61 | static unique_ptr<FunctionData> RangeFunctionBind(ClientContext &context, TableFunctionBindInput &input, |
| 62 | vector<LogicalType> &return_types, vector<string> &names) { |
| 63 | auto result = make_uniq<RangeFunctionBindData>(); |
| 64 | auto &inputs = input.inputs; |
| 65 | GenerateRangeParameters<GENERATE_SERIES>(inputs, *result); |
| 66 | |
| 67 | return_types.emplace_back(args: LogicalType::BIGINT); |
| 68 | if (GENERATE_SERIES) { |
| 69 | // generate_series has inclusive bounds on the RHS |
| 70 | if (result->increment < 0) { |
| 71 | result->end = result->end - 1; |
| 72 | } else { |
| 73 | result->end = result->end + 1; |
| 74 | } |
| 75 | names.emplace_back(args: "generate_series" ); |
| 76 | } else { |
| 77 | names.emplace_back(args: "range" ); |
| 78 | } |
| 79 | return std::move(result); |
| 80 | } |
| 81 | |
| 82 | struct RangeFunctionState : public GlobalTableFunctionState { |
| 83 | RangeFunctionState() : current_idx(0) { |
| 84 | } |
| 85 | |
| 86 | int64_t current_idx; |
| 87 | }; |
| 88 | |
| 89 | static unique_ptr<GlobalTableFunctionState> RangeFunctionInit(ClientContext &context, TableFunctionInitInput &input) { |
| 90 | return make_uniq<RangeFunctionState>(); |
| 91 | } |
| 92 | |
| 93 | static void RangeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { |
| 94 | auto &bind_data = data_p.bind_data->Cast<RangeFunctionBindData>(); |
| 95 | auto &state = data_p.global_state->Cast<RangeFunctionState>(); |
| 96 | |
| 97 | auto increment = bind_data.increment; |
| 98 | auto end = bind_data.end; |
| 99 | hugeint_t current_value = bind_data.start + increment * state.current_idx; |
| 100 | int64_t current_value_i64; |
| 101 | if (!Hugeint::TryCast<int64_t>(input: current_value, result&: current_value_i64)) { |
| 102 | return; |
| 103 | } |
| 104 | int64_t offset = increment < 0 ? 1 : -1; |
| 105 | idx_t remaining = MinValue<idx_t>(a: Hugeint::Cast<idx_t>(input: (end - current_value + (increment + offset)) / increment), |
| 106 | STANDARD_VECTOR_SIZE); |
| 107 | // set the result vector as a sequence vector |
| 108 | output.data[0].Sequence(start: current_value_i64, increment: Hugeint::Cast<int64_t>(input: increment), count: remaining); |
| 109 | // increment the index pointer by the remaining count |
| 110 | state.current_idx += remaining; |
| 111 | output.SetCardinality(remaining); |
| 112 | } |
| 113 | |
| 114 | unique_ptr<NodeStatistics> RangeCardinality(ClientContext &context, const FunctionData *bind_data_p) { |
| 115 | auto &bind_data = bind_data_p->Cast<RangeFunctionBindData>(); |
| 116 | idx_t cardinality = Hugeint::Cast<idx_t>(input: (bind_data.end - bind_data.start) / bind_data.increment); |
| 117 | return make_uniq<NodeStatistics>(args&: cardinality, args&: cardinality); |
| 118 | } |
| 119 | |
| 120 | //===--------------------------------------------------------------------===// |
| 121 | // Range (timestamp) |
| 122 | //===--------------------------------------------------------------------===// |
| 123 | struct RangeDateTimeBindData : public TableFunctionData { |
| 124 | timestamp_t start; |
| 125 | timestamp_t end; |
| 126 | interval_t increment; |
| 127 | bool inclusive_bound; |
| 128 | bool greater_than_check; |
| 129 | |
| 130 | public: |
| 131 | bool Equals(const FunctionData &other_p) const override { |
| 132 | auto &other = other_p.Cast<RangeDateTimeBindData>(); |
| 133 | return other.start == start && other.end == end && other.increment == increment && |
| 134 | other.inclusive_bound == inclusive_bound && other.greater_than_check == greater_than_check; |
| 135 | } |
| 136 | |
| 137 | bool Finished(timestamp_t current_value) const { |
| 138 | if (greater_than_check) { |
| 139 | if (inclusive_bound) { |
| 140 | return current_value > end; |
| 141 | } else { |
| 142 | return current_value >= end; |
| 143 | } |
| 144 | } else { |
| 145 | if (inclusive_bound) { |
| 146 | return current_value < end; |
| 147 | } else { |
| 148 | return current_value <= end; |
| 149 | } |
| 150 | } |
| 151 | } |
| 152 | }; |
| 153 | |
| 154 | template <bool GENERATE_SERIES> |
| 155 | static unique_ptr<FunctionData> RangeDateTimeBind(ClientContext &context, TableFunctionBindInput &input, |
| 156 | vector<LogicalType> &return_types, vector<string> &names) { |
| 157 | auto result = make_uniq<RangeDateTimeBindData>(); |
| 158 | auto &inputs = input.inputs; |
| 159 | D_ASSERT(inputs.size() == 3); |
| 160 | result->start = inputs[0].GetValue<timestamp_t>(); |
| 161 | result->end = inputs[1].GetValue<timestamp_t>(); |
| 162 | result->increment = inputs[2].GetValue<interval_t>(); |
| 163 | |
| 164 | // Infinities either cause errors or infinite loops, so just ban them |
| 165 | if (!Timestamp::IsFinite(timestamp: result->start) || !Timestamp::IsFinite(timestamp: result->end)) { |
| 166 | throw BinderException("RANGE with infinite bounds is not supported" ); |
| 167 | } |
| 168 | |
| 169 | if (result->increment.months == 0 && result->increment.days == 0 && result->increment.micros == 0) { |
| 170 | throw BinderException("interval cannot be 0!" ); |
| 171 | } |
| 172 | // all elements should point in the same direction |
| 173 | if (result->increment.months > 0 || result->increment.days > 0 || result->increment.micros > 0) { |
| 174 | if (result->increment.months < 0 || result->increment.days < 0 || result->increment.micros < 0) { |
| 175 | throw BinderException("RANGE with composite interval that has mixed signs is not supported" ); |
| 176 | } |
| 177 | result->greater_than_check = true; |
| 178 | if (result->start > result->end) { |
| 179 | throw BinderException( |
| 180 | "start is bigger than end, but increment is positive: cannot generate infinite series" ); |
| 181 | } |
| 182 | } else { |
| 183 | result->greater_than_check = false; |
| 184 | if (result->start < result->end) { |
| 185 | throw BinderException( |
| 186 | "start is smaller than end, but increment is negative: cannot generate infinite series" ); |
| 187 | } |
| 188 | } |
| 189 | return_types.push_back(x: inputs[0].type()); |
| 190 | if (GENERATE_SERIES) { |
| 191 | // generate_series has inclusive bounds on the RHS |
| 192 | result->inclusive_bound = true; |
| 193 | names.emplace_back(args: "generate_series" ); |
| 194 | } else { |
| 195 | result->inclusive_bound = false; |
| 196 | names.emplace_back(args: "range" ); |
| 197 | } |
| 198 | return std::move(result); |
| 199 | } |
| 200 | |
| 201 | struct RangeDateTimeState : public GlobalTableFunctionState { |
| 202 | explicit RangeDateTimeState(timestamp_t start_p) : current_state(start_p) { |
| 203 | } |
| 204 | |
| 205 | timestamp_t current_state; |
| 206 | bool finished = false; |
| 207 | }; |
| 208 | |
| 209 | static unique_ptr<GlobalTableFunctionState> RangeDateTimeInit(ClientContext &context, TableFunctionInitInput &input) { |
| 210 | auto &bind_data = input.bind_data->Cast<RangeDateTimeBindData>(); |
| 211 | return make_uniq<RangeDateTimeState>(args: bind_data.start); |
| 212 | } |
| 213 | |
| 214 | static void RangeDateTimeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { |
| 215 | auto &bind_data = data_p.bind_data->Cast<RangeDateTimeBindData>(); |
| 216 | auto &state = data_p.global_state->Cast<RangeDateTimeState>(); |
| 217 | if (state.finished) { |
| 218 | return; |
| 219 | } |
| 220 | |
| 221 | idx_t size = 0; |
| 222 | auto data = FlatVector::GetData<timestamp_t>(vector&: output.data[0]); |
| 223 | while (true) { |
| 224 | data[size++] = state.current_state; |
| 225 | state.current_state = |
| 226 | AddOperator::Operation<timestamp_t, interval_t, timestamp_t>(left: state.current_state, right: bind_data.increment); |
| 227 | if (bind_data.Finished(current_value: state.current_state)) { |
| 228 | state.finished = true; |
| 229 | break; |
| 230 | } |
| 231 | if (size >= STANDARD_VECTOR_SIZE) { |
| 232 | break; |
| 233 | } |
| 234 | } |
| 235 | output.SetCardinality(size); |
| 236 | } |
| 237 | |
| 238 | void RangeTableFunction::RegisterFunction(BuiltinFunctions &set) { |
| 239 | TableFunctionSet range("range" ); |
| 240 | |
| 241 | TableFunction range_function({LogicalType::BIGINT}, RangeFunction, RangeFunctionBind<false>, RangeFunctionInit); |
| 242 | range_function.cardinality = RangeCardinality; |
| 243 | |
| 244 | // single argument range: (end) - implicit start = 0 and increment = 1 |
| 245 | range.AddFunction(function: range_function); |
| 246 | // two arguments range: (start, end) - implicit increment = 1 |
| 247 | range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; |
| 248 | range.AddFunction(function: range_function); |
| 249 | // three arguments range: (start, end, increment) |
| 250 | range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; |
| 251 | range.AddFunction(function: range_function); |
| 252 | range.AddFunction(function: TableFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, |
| 253 | RangeDateTimeFunction, RangeDateTimeBind<false>, RangeDateTimeInit)); |
| 254 | set.AddFunction(set: range); |
| 255 | // generate_series: similar to range, but inclusive instead of exclusive bounds on the RHS |
| 256 | TableFunctionSet generate_series("generate_series" ); |
| 257 | range_function.bind = RangeFunctionBind<true>; |
| 258 | range_function.arguments = {LogicalType::BIGINT}; |
| 259 | generate_series.AddFunction(function: range_function); |
| 260 | range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; |
| 261 | generate_series.AddFunction(function: range_function); |
| 262 | range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; |
| 263 | generate_series.AddFunction(function: range_function); |
| 264 | generate_series.AddFunction(function: TableFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, |
| 265 | RangeDateTimeFunction, RangeDateTimeBind<true>, RangeDateTimeInit)); |
| 266 | set.AddFunction(set: generate_series); |
| 267 | } |
| 268 | |
| 269 | void BuiltinFunctions::RegisterTableFunctions() { |
| 270 | CheckpointFunction::RegisterFunction(set&: *this); |
| 271 | GlobTableFunction::RegisterFunction(set&: *this); |
| 272 | RangeTableFunction::RegisterFunction(set&: *this); |
| 273 | RepeatTableFunction::RegisterFunction(set&: *this); |
| 274 | SummaryTableFunction::RegisterFunction(set&: *this); |
| 275 | UnnestTableFunction::RegisterFunction(set&: *this); |
| 276 | RepeatRowTableFunction::RegisterFunction(set&: *this); |
| 277 | } |
| 278 | |
| 279 | } // namespace duckdb |
| 280 | |