1#include <Functions/IFunctionImpl.h>
2#include <Functions/FunctionFactory.h>
3#include <Functions/FunctionHelpers.h>
4#include <Columns/ColumnString.h>
5#include <Columns/ColumnsNumber.h>
6#include <Columns/ColumnsCommon.h>
7#include <DataTypes/DataTypesNumber.h>
8#include <IO/WriteHelpers.h>
9
10
11namespace DB
12{
13
14namespace ErrorCodes
15{
16 extern const int ILLEGAL_COLUMN;
17 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
18 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
19 extern const int FUNCTION_THROW_IF_VALUE_IS_NON_ZERO;
20}
21
22
23/// Throw an exception if the argument is non zero.
24class FunctionThrowIf : public IFunction
25{
26public:
27 static constexpr auto name = "throwIf";
28 static FunctionPtr create(const Context &)
29 {
30 return std::make_shared<FunctionThrowIf>();
31 }
32
33 String getName() const override
34 {
35 return name;
36 }
37
38 bool isVariadic() const override { return true; }
39 size_t getNumberOfArguments() const override
40 {
41 return 0;
42 }
43
44 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
45 {
46 const size_t number_of_arguments = arguments.size();
47
48 if (number_of_arguments < 1 || number_of_arguments > 2)
49 throw Exception{"Number of arguments for function " + getName() + " doesn't match: passed "
50 + toString(number_of_arguments) + ", should be 1 or 2",
51 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
52
53 if (!isNativeNumber(arguments[0]))
54 throw Exception{"Argument for function " + getName() + " must be number", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
55
56 if (number_of_arguments > 1 && !isString(arguments[1]))
57 throw Exception{"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(),
58 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
59
60
61 return std::make_shared<DataTypeUInt8>();
62 }
63
64 bool useDefaultImplementationForConstants() const override { return true; }
65 ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
66
67 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
68 {
69 std::optional<String> custom_message;
70 if (arguments.size() == 2)
71 {
72 auto * msg_column = checkAndGetColumnConst<ColumnString>(block.getByPosition(arguments[1]).column.get());
73 if (!msg_column)
74 throw Exception{"Second argument for function " + getName() + " must be constant String", ErrorCodes::ILLEGAL_COLUMN};
75 custom_message = msg_column->getValue<String>();
76 }
77
78 const auto in = block.getByPosition(arguments.front()).column.get();
79
80 if ( !execute<UInt8>(block, in, result, custom_message)
81 && !execute<UInt16>(block, in, result, custom_message)
82 && !execute<UInt32>(block, in, result, custom_message)
83 && !execute<UInt64>(block, in, result, custom_message)
84 && !execute<Int8>(block, in, result, custom_message)
85 && !execute<Int16>(block, in, result, custom_message)
86 && !execute<Int32>(block, in, result, custom_message)
87 && !execute<Int64>(block, in, result, custom_message)
88 && !execute<Float32>(block, in, result, custom_message)
89 && !execute<Float64>(block, in, result, custom_message))
90 throw Exception{"Illegal column " + in->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN};
91 }
92
93 template <typename T>
94 bool execute(Block & block, const IColumn * in_untyped, const size_t result, const std::optional<String> & message)
95 {
96 if (const auto in = checkAndGetColumn<ColumnVector<T>>(in_untyped))
97 {
98 const auto & in_data = in->getData();
99 if (!memoryIsZero(in_data.data(), in_data.size() * sizeof(in_data[0])))
100 throw Exception{message.value_or("Value passed to '" + getName() + "' function is non zero"),
101 ErrorCodes::FUNCTION_THROW_IF_VALUE_IS_NON_ZERO};
102
103 /// We return non constant to avoid constant folding.
104 block.getByPosition(result).column = ColumnUInt8::create(in_data.size(), 0);
105 return true;
106 }
107
108 return false;
109 }
110};
111
112
113void registerFunctionThrowIf(FunctionFactory & factory)
114{
115 factory.registerFunction<FunctionThrowIf>();
116}
117
118}
119