1#pragma once
2
3#include <Functions/IFunctionAdaptors.h>
4#include <Interpreters/ExpressionActions.h>
5#include <DataTypes/DataTypeFunction.h>
6#include <IO/WriteBufferFromString.h>
7#include <IO/Operators.h>
8#include <Columns/ColumnFunction.h>
9#include <DataTypes/DataTypesNumber.h>
10
11namespace DB
12{
13
14class ExecutableFunctionExpression : public IExecutableFunctionImpl
15{
16public:
17 struct Signature
18 {
19 Names argument_names;
20 String return_name;
21 };
22
23 using SignaturePtr = std::shared_ptr<Signature>;
24
25 ExecutableFunctionExpression(ExpressionActionsPtr expression_actions_, SignaturePtr signature_)
26 : expression_actions(std::move(expression_actions_))
27 , signature(std::move(signature_))
28 {}
29
30 String getName() const override { return "FunctionExpression"; }
31
32 void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
33 {
34 Block expr_block;
35 for (size_t i = 0; i < arguments.size(); ++i)
36 {
37 const auto & argument = block.getByPosition(arguments[i]);
38 /// Replace column name with value from argument_names.
39 expr_block.insert({argument.column, argument.type, signature->argument_names[i]});
40 }
41
42 expression_actions->execute(expr_block);
43
44 block.getByPosition(result).column = expr_block.getByName(signature->return_name).column;
45 }
46
47bool useDefaultImplementationForNulls() const override { return false; }
48
49private:
50 ExpressionActionsPtr expression_actions;
51 SignaturePtr signature;
52};
53
54/// Executes expression. Uses for lambda functions implementation. Can't be created from factory.
55class FunctionExpression : public IFunctionBaseImpl
56{
57public:
58 using Signature = ExecutableFunctionExpression::Signature;
59 using SignaturePtr = ExecutableFunctionExpression::SignaturePtr;
60
61 FunctionExpression(ExpressionActionsPtr expression_actions_,
62 DataTypes argument_types_, const Names & argument_names_,
63 DataTypePtr return_type_, const std::string & return_name_)
64 : expression_actions(std::move(expression_actions_))
65 , signature(std::make_shared<Signature>(Signature{argument_names_, return_name_}))
66 , argument_types(std::move(argument_types_)), return_type(std::move(return_type_))
67 {
68 }
69
70 String getName() const override { return "FunctionExpression"; }
71
72 bool isDeterministic() const override { return true; }
73 bool isDeterministicInScopeOfQuery() const override { return true; }
74
75 const DataTypes & getArgumentTypes() const override { return argument_types; }
76 const DataTypePtr & getReturnType() const override { return return_type; }
77
78 ExecutableFunctionImplPtr prepare(const Block &, const ColumnNumbers &, size_t) const override
79 {
80 return std::make_unique<ExecutableFunctionExpression>(expression_actions, signature);
81 }
82
83private:
84 ExpressionActionsPtr expression_actions;
85 SignaturePtr signature;
86 DataTypes argument_types;
87 DataTypePtr return_type;
88};
89
90/// Captures columns which are used by lambda function but not in argument list.
91/// Returns ColumnFunction with captured columns.
92/// For lambda(x, x + y) x is in lambda_arguments, y is in captured arguments, expression_actions is 'x + y'.
93/// execute(y) returns ColumnFunction(FunctionExpression(x + y), y) with type Function(x) -> function_return_type.
94class ExecutableFunctionCapture : public IExecutableFunctionImpl
95{
96public:
97 struct Capture
98 {
99 Names captured_names;
100 DataTypes captured_types;
101 NamesAndTypesList lambda_arguments;
102 String return_name;
103 DataTypePtr return_type;
104 };
105
106 using CapturePtr = std::shared_ptr<Capture>;
107
108 ExecutableFunctionCapture(ExpressionActionsPtr expression_actions_, CapturePtr capture_)
109 : expression_actions(std::move(expression_actions_)), capture(std::move(capture_)) {}
110
111 String getName() const override { return "FunctionCapture"; }
112
113 bool useDefaultImplementationForNulls() const override { return false; }
114
115 void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
116 {
117 ColumnsWithTypeAndName columns;
118 columns.reserve(arguments.size());
119
120 Names names;
121 DataTypes types;
122
123 names.reserve(capture->captured_names.size() + capture->lambda_arguments.size());
124 names.insert(names.end(), capture->captured_names.begin(), capture->captured_names.end());
125
126 types.reserve(capture->captured_types.size() + capture->lambda_arguments.size());
127 types.insert(types.end(), capture->captured_types.begin(), capture->captured_types.end());
128
129 for (const auto & lambda_argument : capture->lambda_arguments)
130 {
131 names.push_back(lambda_argument.name);
132 types.push_back(lambda_argument.type);
133 }
134
135 for (const auto & argument : arguments)
136 columns.push_back(block.getByPosition(argument));
137
138 auto function = std::make_unique<FunctionExpression>(expression_actions, types, names,
139 capture->return_type, capture->return_name);
140 auto function_adaptor = std::make_shared<FunctionBaseAdaptor>(std::move(function));
141 block.getByPosition(result).column = ColumnFunction::create(input_rows_count, std::move(function_adaptor), columns);
142 }
143
144private:
145 ExpressionActionsPtr expression_actions;
146 CapturePtr capture;
147};
148
149class FunctionCapture : public IFunctionBaseImpl
150{
151public:
152 using Capture = ExecutableFunctionCapture::Capture;
153 using CapturePtr = ExecutableFunctionCapture::CapturePtr;
154
155 FunctionCapture(
156 ExpressionActionsPtr expression_actions_,
157 CapturePtr capture_,
158 DataTypePtr return_type_,
159 String name_)
160 : expression_actions(std::move(expression_actions_))
161 , capture(std::move(capture_))
162 , return_type(std::move(return_type_))
163 , name(std::move(name_))
164 {
165 }
166
167 String getName() const override { return name; }
168
169 bool isDeterministic() const override { return true; }
170 bool isDeterministicInScopeOfQuery() const override { return true; }
171
172 const DataTypes & getArgumentTypes() const override { return capture->captured_types; }
173 const DataTypePtr & getReturnType() const override { return return_type; }
174
175 ExecutableFunctionImplPtr prepare(const Block &, const ColumnNumbers &, size_t) const override
176 {
177 return std::make_unique<ExecutableFunctionCapture>(expression_actions, capture);
178 }
179
180private:
181 ExpressionActionsPtr expression_actions;
182 CapturePtr capture;
183 DataTypePtr return_type;
184 String name;
185};
186
187class FunctionCaptureOverloadResolver : public IFunctionOverloadResolverImpl
188{
189public:
190 using Capture = ExecutableFunctionCapture::Capture;
191 using CapturePtr = ExecutableFunctionCapture::CapturePtr;
192
193 FunctionCaptureOverloadResolver(
194 ExpressionActionsPtr expression_actions_,
195 const Names & captured_names_,
196 const NamesAndTypesList & lambda_arguments_,
197 const DataTypePtr & function_return_type_,
198 const String & expression_return_name_)
199 : expression_actions(std::move(expression_actions_))
200 {
201 std::unordered_map<std::string, DataTypePtr> arguments_map;
202
203 const auto & all_arguments = expression_actions->getRequiredColumnsWithTypes();
204 for (const auto & arg : all_arguments)
205 arguments_map[arg.name] = arg.type;
206
207 DataTypes captured_types;
208 captured_types.reserve(captured_names_.size());
209
210 for (const auto & captured_name : captured_names_)
211 {
212 auto it = arguments_map.find(captured_name);
213 if (it == arguments_map.end())
214 throw Exception("Lambda captured argument " + captured_name + " not found in required columns.",
215 ErrorCodes::LOGICAL_ERROR);
216
217 captured_types.push_back(it->second);
218 arguments_map.erase(it);
219 }
220
221 DataTypes argument_types;
222 argument_types.reserve(lambda_arguments_.size());
223 for (const auto & lambda_argument : lambda_arguments_)
224 argument_types.push_back(lambda_argument.type);
225
226 return_type = std::make_shared<DataTypeFunction>(argument_types, function_return_type_);
227
228 name = "Capture[" + toString(captured_types) + "](" + toString(argument_types) + ") -> "
229 + function_return_type_->getName();
230
231 capture = std::make_shared<Capture>(Capture{
232 .captured_names = captured_names_,
233 .captured_types = std::move(captured_types),
234 .lambda_arguments = lambda_arguments_,
235 .return_name = expression_return_name_,
236 .return_type = function_return_type_,
237 });
238 }
239
240 String getName() const override { return name; }
241 bool useDefaultImplementationForNulls() const override { return false; }
242 DataTypePtr getReturnType(const ColumnsWithTypeAndName &) const override { return return_type; }
243 size_t getNumberOfArguments() const override { return capture->captured_types.size(); }
244
245 FunctionBaseImplPtr build(const ColumnsWithTypeAndName &, const DataTypePtr &) const override
246 {
247 return std::make_unique<FunctionCapture>(expression_actions, capture, return_type, name);
248 }
249
250private:
251 ExpressionActionsPtr expression_actions;
252 CapturePtr capture;
253 DataTypePtr return_type;
254 String name;
255
256 static String toString(const DataTypes & data_types)
257 {
258 std::string result;
259 {
260 WriteBufferFromString buffer(result);
261 bool first = true;
262 for (const auto & type : data_types)
263 {
264 if (!first)
265 buffer << ", ";
266
267 first = false;
268 buffer << type->getName();
269 }
270 }
271
272 return result;
273 }
274};
275
276}
277