1#include <Functions/FunctionJoinGet.h>
2
3#include <Functions/FunctionFactory.h>
4#include <Functions/FunctionHelpers.h>
5#include <Interpreters/Context.h>
6#include <Interpreters/Join.h>
7#include <Storages/StorageJoin.h>
8
9
10namespace DB
11{
12namespace ErrorCodes
13{
14 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
15 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
16}
17
18static auto getJoin(const ColumnsWithTypeAndName & arguments, const Context & context)
19{
20 if (arguments.size() != 3)
21 throw Exception{"Function joinGet takes 3 arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
22
23 String join_name;
24 if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()))
25 {
26 join_name = name_col->getValue<String>();
27 }
28 else
29 throw Exception{"Illegal type " + arguments[0].type->getName() + " of first argument of function joinGet, expected a const string.",
30 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
31
32 size_t dot = join_name.find('.');
33 String database_name;
34 if (dot == String::npos)
35 {
36 database_name = context.getCurrentDatabase();
37 dot = 0;
38 }
39 else
40 {
41 database_name = join_name.substr(0, dot);
42 ++dot;
43 }
44 String table_name = join_name.substr(dot);
45 auto table = context.getTable(database_name, table_name);
46 auto storage_join = std::dynamic_pointer_cast<StorageJoin>(table);
47 if (!storage_join)
48 throw Exception{"Table " + join_name + " should have engine StorageJoin", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
49
50 String attr_name;
51 if (auto name_col = checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
52 {
53 attr_name = name_col->getValue<String>();
54 }
55 else
56 throw Exception{"Illegal type " + arguments[1].type->getName()
57 + " of second argument of function joinGet, expected a const string.",
58 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
59 return std::make_pair(storage_join, attr_name);
60}
61
62FunctionBaseImplPtr JoinGetOverloadResolver::build(const ColumnsWithTypeAndName & arguments, const DataTypePtr &) const
63{
64 auto [storage_join, attr_name] = getJoin(arguments, context);
65 auto join = storage_join->getJoin();
66 DataTypes data_types(arguments.size());
67
68 auto table_lock = storage_join->lockStructureForShare(false, context.getInitialQueryId());
69 for (size_t i = 0; i < arguments.size(); ++i)
70 data_types[i] = arguments[i].type;
71
72 auto return_type = join->joinGetReturnType(attr_name);
73 return std::make_unique<FunctionJoinGet>(table_lock, storage_join, join, attr_name, data_types, return_type);
74}
75
76DataTypePtr JoinGetOverloadResolver::getReturnType(const ColumnsWithTypeAndName & arguments) const
77{
78 auto [storage_join, attr_name] = getJoin(arguments, context);
79 auto join = storage_join->getJoin();
80 return join->joinGetReturnType(attr_name);
81}
82
83
84void ExecutableFunctionJoinGet::execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
85{
86 auto ctn = block.getByPosition(arguments[2]);
87 if (isColumnConst(*ctn.column))
88 ctn.column = ctn.column->cloneResized(1);
89 ctn.name = ""; // make sure the key name never collide with the join columns
90 Block key_block = {ctn};
91 join->joinGet(key_block, attr_name);
92 auto & result_ctn = key_block.getByPosition(1);
93 if (isColumnConst(*ctn.column))
94 result_ctn.column = ColumnConst::create(result_ctn.column, input_rows_count);
95 block.getByPosition(result) = result_ctn;
96}
97
98ExecutableFunctionImplPtr FunctionJoinGet::prepare(const Block &, const ColumnNumbers &, size_t) const
99{
100 return std::make_unique<ExecutableFunctionJoinGet>(join, attr_name);
101}
102
103void registerFunctionJoinGet(FunctionFactory & factory)
104{
105 factory.registerFunction<JoinGetOverloadResolver>();
106}
107
108}
109