1#include "duckdb/function/scalar/math_functions.hpp"
2#include "duckdb/common/exception.hpp"
3#include "duckdb/common/vector_operations/vector_operations.hpp"
4#include "duckdb/execution/expression_executor.hpp"
5#include "duckdb/main/client_context.hpp"
6#include "duckdb/planner/expression/bound_function_expression.hpp"
7
8using namespace duckdb;
9using namespace std;
10
11struct SetseedBindData : public FunctionData {
12 //! The client context for the function call
13 ClientContext &context;
14
15 SetseedBindData(ClientContext &context) : context(context) {
16 }
17
18 unique_ptr<FunctionData> Copy() override {
19 return make_unique<SetseedBindData>(context);
20 }
21};
22
23static void setseed_function(DataChunk &args, ExpressionState &state, Vector &result) {
24 auto &func_expr = (BoundFunctionExpression &)state.expr;
25 auto &info = (SetseedBindData &)*func_expr.bind_info;
26 auto &input = args.data[0];
27 input.Normalify(args.size());
28
29 auto input_seeds = FlatVector::GetData<double>(input);
30 uint32_t half_max = numeric_limits<uint32_t>::max() / 2;
31
32 for (idx_t i = 0; i < args.size(); i++) {
33 if (input_seeds[i] < -1.0 || input_seeds[i] > 1.0) {
34 throw Exception("SETSEED accepts seed values between -1.0 and 1.0, inclusive");
35 }
36 uint32_t norm_seed = (input_seeds[i] + 1.0) * half_max;
37 info.context.random_engine.seed(norm_seed);
38 }
39
40 result.vector_type = VectorType::CONSTANT_VECTOR;
41 ConstantVector::SetNull(result, true);
42}
43
44unique_ptr<FunctionData> setseed_bind(BoundFunctionExpression &expr, ClientContext &context) {
45 return make_unique<SetseedBindData>(context);
46}
47
48void SetseedFun::RegisterFunction(BuiltinFunctions &set) {
49 set.AddFunction(
50 ScalarFunction("setseed", {SQLType::DOUBLE}, SQLType::SQLNULL, setseed_function, true, setseed_bind));
51}
52