1 | #include <Functions/FunctionJoinGet.h> |
2 | |
3 | #include <Functions/FunctionFactory.h> |
4 | #include <Functions/FunctionHelpers.h> |
5 | #include <Interpreters/Context.h> |
6 | #include <Interpreters/Join.h> |
7 | #include <Storages/StorageJoin.h> |
8 | |
9 | |
10 | namespace DB |
11 | { |
12 | namespace ErrorCodes |
13 | { |
14 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
15 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
16 | } |
17 | |
18 | static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & context) |
19 | { |
20 | if (arguments.size() != 3) |
21 | throw Exception{"Function joinGet takes 3 arguments" , ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; |
22 | |
23 | String join_name; |
24 | if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get())) |
25 | { |
26 | join_name = name_col->getValue<String>(); |
27 | } |
28 | else |
29 | throw Exception{"Illegal type " + arguments[0].type->getName() + " of first argument of function joinGet, expected a const string." , |
30 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
31 | |
32 | size_t dot = join_name.find('.'); |
33 | String database_name; |
34 | if (dot == String::npos) |
35 | { |
36 | database_name = context.getCurrentDatabase(); |
37 | dot = 0; |
38 | } |
39 | else |
40 | { |
41 | database_name = join_name.substr(0, dot); |
42 | ++dot; |
43 | } |
44 | String table_name = join_name.substr(dot); |
45 | auto table = context.getTable(database_name, table_name); |
46 | auto storage_join = std::dynamic_pointer_cast<StorageJoin>(table); |
47 | if (!storage_join) |
48 | throw Exception{"Table " + join_name + " should have engine StorageJoin" , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
49 | |
50 | String attr_name; |
51 | if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get())) |
52 | { |
53 | attr_name = name_col->getValue<String>(); |
54 | } |
55 | else |
56 | throw Exception{"Illegal type " + arguments[1].type->getName() |
57 | + " of second argument of function joinGet, expected a const string." , |
58 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
59 | return std::make_pair(storage_join, attr_name); |
60 | } |
61 | |
62 | FunctionBaseImplPtr JoinGetOverloadResolver::build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const |
63 | { |
64 | auto [storage_join, attr_name] = getJoin(arguments, context); |
65 | auto join = storage_join->getJoin(); |
66 | DataTypes data_types(arguments.size()); |
67 | |
68 | auto table_lock = storage_join->lockStructureForShare(false, context.getInitialQueryId()); |
69 | for (size_t i = 0; i < arguments.size(); ++i) |
70 | data_types[i] = arguments[i].type; |
71 | |
72 | auto return_type = join->joinGetReturnType(attr_name); |
73 | return std::make_unique<FunctionJoinGet>(table_lock, storage_join, join, attr_name, data_types, return_type); |
74 | } |
75 | |
76 | DataTypePtr JoinGetOverloadResolver::getReturnType(const ColumnsWithTypeAndName & arguments) const |
77 | { |
78 | auto [storage_join, attr_name] = getJoin(arguments, context); |
79 | auto join = storage_join->getJoin(); |
80 | return join->joinGetReturnType(attr_name); |
81 | } |
82 | |
83 | |
84 | void ExecutableFunctionJoinGet::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) |
85 | { |
86 | auto ctn = block.getByPosition(arguments[2]); |
87 | if (isColumnConst(*ctn.column)) |
88 | ctn.column = ctn.column->cloneResized(1); |
89 | ctn.name = "" ; // make sure the key name never collide with the join columns |
90 | Block key_block = {ctn}; |
91 | join->joinGet(key_block, attr_name); |
92 | auto & result_ctn = key_block.getByPosition(1); |
93 | if (isColumnConst(*ctn.column)) |
94 | result_ctn.column = ColumnConst::create(result_ctn.column, input_rows_count); |
95 | block.getByPosition(result) = result_ctn; |
96 | } |
97 | |
98 | ExecutableFunctionImplPtr FunctionJoinGet::prepare(const Block &, const ColumnNumbers &, size_t) const |
99 | { |
100 | return std::make_unique<ExecutableFunctionJoinGet>(join, attr_name); |
101 | } |
102 | |
103 | void registerFunctionJoinGet(FunctionFactory & factory) |
104 | { |
105 | factory.registerFunction<JoinGetOverloadResolver>(); |
106 | } |
107 | |
108 | } |
109 | |