1#include <Access/RowPolicyContextFactory.h>
2#include <Access/RowPolicyContext.h>
3#include <Access/AccessControlManager.h>
4#include <Parsers/ASTLiteral.h>
5#include <Parsers/ASTFunction.h>
6#include <Parsers/ExpressionListParsers.h>
7#include <Parsers/parseQuery.h>
8#include <Common/Exception.h>
9#include <Common/quoteString.h>
10#include <ext/range.h>
11#include <boost/range/algorithm/copy.hpp>
12#include <boost/range/algorithm_ext/erase.hpp>
13
14
15namespace DB
16{
17namespace
18{
19 bool tryGetLiteralBool(const IAST & ast, bool & value)
20 {
21 try
22 {
23 if (const ASTLiteral * literal = ast.as<ASTLiteral>())
24 {
25 value = !literal->value.isNull() && applyVisitor(FieldVisitorConvertToNumber<bool>(), literal->value);
26 return true;
27 }
28 return false;
29 }
30 catch (...)
31 {
32 return false;
33 }
34 }
35
36 ASTPtr applyFunctionAND(ASTs arguments)
37 {
38 bool const_arguments = true;
39 boost::range::remove_erase_if(arguments, [&](const ASTPtr & argument) -> bool
40 {
41 bool b;
42 if (!tryGetLiteralBool(*argument, b))
43 return false;
44 const_arguments &= b;
45 return true;
46 });
47
48 if (!const_arguments)
49 return std::make_shared<ASTLiteral>(Field{UInt8(0)});
50 if (arguments.empty())
51 return std::make_shared<ASTLiteral>(Field{UInt8(1)});
52 if (arguments.size() == 1)
53 return arguments[0];
54
55 auto function = std::make_shared<ASTFunction>();
56 auto exp_list = std::make_shared<ASTExpressionList>();
57 function->name = "and";
58 function->arguments = exp_list;
59 function->children.push_back(exp_list);
60 exp_list->children = std::move(arguments);
61 return function;
62 }
63
64
65 ASTPtr applyFunctionOR(ASTs arguments)
66 {
67 bool const_arguments = false;
68 boost::range::remove_erase_if(arguments, [&](const ASTPtr & argument) -> bool
69 {
70 bool b;
71 if (!tryGetLiteralBool(*argument, b))
72 return false;
73 const_arguments |= b;
74 return true;
75 });
76
77 if (const_arguments)
78 return std::make_shared<ASTLiteral>(Field{UInt8(1)});
79 if (arguments.empty())
80 return std::make_shared<ASTLiteral>(Field{UInt8(0)});
81 if (arguments.size() == 1)
82 return arguments[0];
83
84 auto function = std::make_shared<ASTFunction>();
85 auto exp_list = std::make_shared<ASTExpressionList>();
86 function->name = "or";
87 function->arguments = exp_list;
88 function->children.push_back(exp_list);
89 exp_list->children = std::move(arguments);
90 return function;
91 }
92
93
94 using ConditionIndex = RowPolicy::ConditionIndex;
95 static constexpr size_t MAX_CONDITION_INDEX = RowPolicy::MAX_CONDITION_INDEX;
96
97
98 /// Accumulates conditions from multiple row policies and joins them using the AND logical operation.
99 class ConditionsMixer
100 {
101 public:
102 void add(const ASTPtr & condition, bool is_restrictive)
103 {
104 if (!condition)
105 return;
106
107 if (is_restrictive)
108 restrictions.push_back(condition);
109 else
110 permissions.push_back(condition);
111 }
112
113 ASTPtr getResult() &&
114 {
115 /// Process permissive conditions.
116 if (!permissions.empty())
117 restrictions.push_back(applyFunctionOR(std::move(permissions)));
118
119 /// Process restrictive conditions.
120 if (!restrictions.empty())
121 return applyFunctionAND(std::move(restrictions));
122 return nullptr;
123 }
124
125 private:
126 ASTs permissions;
127 ASTs restrictions;
128 };
129}
130
131
132void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_)
133{
134 policy = policy_;
135
136 boost::range::copy(policy->roles, std::inserter(roles, roles.end()));
137 all_roles = policy->all_roles;
138 boost::range::copy(policy->except_roles, std::inserter(except_roles, except_roles.end()));
139
140 for (auto index : ext::range_with_static_cast<ConditionIndex>(0, MAX_CONDITION_INDEX))
141 {
142 const String & condition = policy->conditions[index];
143 auto previous_range = std::pair(std::begin(policy->conditions), std::begin(policy->conditions) + index);
144 auto previous_it = std::find(previous_range.first, previous_range.second, condition);
145 if (previous_it != previous_range.second)
146 {
147 /// The condition is already parsed before.
148 parsed_conditions[index] = parsed_conditions[previous_it - previous_range.first];
149 }
150 else
151 {
152 /// Try to parse the condition.
153 try
154 {
155 ParserExpression parser;
156 parsed_conditions[index] = parseQuery(parser, condition, 0);
157 }
158 catch (...)
159 {
160 tryLogCurrentException(
161 &Poco::Logger::get("RowPolicy"),
162 String("Could not parse the condition ") + RowPolicy::conditionIndexToString(index) + " of row policy "
163 + backQuote(policy->getFullName()));
164 }
165 }
166 }
167}
168
169
170bool RowPolicyContextFactory::PolicyInfo::canUseWithContext(const RowPolicyContext & context) const
171{
172 if (roles.count(context.user_name))
173 return true;
174
175 if (all_roles && !except_roles.count(context.user_name))
176 return true;
177
178 return false;
179}
180
181
182RowPolicyContextFactory::RowPolicyContextFactory(const AccessControlManager & access_control_manager_)
183 : access_control_manager(access_control_manager_)
184{
185}
186
187RowPolicyContextFactory::~RowPolicyContextFactory() = default;
188
189
190RowPolicyContextPtr RowPolicyContextFactory::createContext(const String & user_name)
191{
192 std::lock_guard lock{mutex};
193 ensureAllRowPoliciesRead();
194 auto context = ext::shared_ptr_helper<RowPolicyContext>::create(user_name);
195 contexts.push_back(context);
196 mixConditionsForContext(*context);
197 return context;
198}
199
200
201void RowPolicyContextFactory::ensureAllRowPoliciesRead()
202{
203 /// `mutex` is already locked.
204 if (all_policies_read)
205 return;
206 all_policies_read = true;
207
208 subscription = access_control_manager.subscribeForChanges<RowPolicy>(
209 [&](const UUID & id, const AccessEntityPtr & entity)
210 {
211 if (entity)
212 rowPolicyAddedOrChanged(id, typeid_cast<RowPolicyPtr>(entity));
213 else
214 rowPolicyRemoved(id);
215 });
216
217 for (const UUID & id : access_control_manager.findAll<RowPolicy>())
218 {
219 auto quota = access_control_manager.tryRead<RowPolicy>(id);
220 if (quota)
221 all_policies.emplace(id, PolicyInfo(quota));
222 }
223}
224
225
226void RowPolicyContextFactory::rowPolicyAddedOrChanged(const UUID & policy_id, const RowPolicyPtr & new_policy)
227{
228 std::lock_guard lock{mutex};
229 auto it = all_policies.find(policy_id);
230 if (it == all_policies.end())
231 {
232 it = all_policies.emplace(policy_id, PolicyInfo(new_policy)).first;
233 }
234 else
235 {
236 if (it->second.policy == new_policy)
237 return;
238 }
239
240 auto & info = it->second;
241 info.setPolicy(new_policy);
242 mixConditionsForAllContexts();
243}
244
245
246void RowPolicyContextFactory::rowPolicyRemoved(const UUID & policy_id)
247{
248 std::lock_guard lock{mutex};
249 all_policies.erase(policy_id);
250 mixConditionsForAllContexts();
251}
252
253
254void RowPolicyContextFactory::mixConditionsForAllContexts()
255{
256 /// `mutex` is already locked.
257 boost::range::remove_erase_if(
258 contexts,
259 [&](const std::weak_ptr<RowPolicyContext> & weak)
260 {
261 auto context = weak.lock();
262 if (!context)
263 return true; // remove from the `contexts` list.
264 mixConditionsForContext(*context);
265 return false; // keep in the `contexts` list.
266 });
267}
268
269
270void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context)
271{
272 /// `mutex` is already locked.
273 struct Mixers
274 {
275 ConditionsMixer mixers[MAX_CONDITION_INDEX];
276 std::vector<UUID> policy_ids;
277 };
278 using MapOfMixedConditions = RowPolicyContext::MapOfMixedConditions;
279 using DatabaseAndTableName = RowPolicyContext::DatabaseAndTableName;
280 using DatabaseAndTableNameRef = RowPolicyContext::DatabaseAndTableNameRef;
281 using Hash = RowPolicyContext::Hash;
282
283 std::unordered_map<DatabaseAndTableName, Mixers, Hash> map_of_mixers;
284
285 for (const auto & [policy_id, info] : all_policies)
286 {
287 if (info.canUseWithContext(context))
288 {
289 const auto & policy = *info.policy;
290 auto & mixers = map_of_mixers[std::pair{policy.getDatabase(), policy.getTableName()}];
291 mixers.policy_ids.push_back(policy_id);
292 for (auto index : ext::range(0, MAX_CONDITION_INDEX))
293 mixers.mixers[index].add(info.parsed_conditions[index], policy.isRestrictive());
294 }
295 }
296
297 auto map_of_mixed_conditions = std::make_shared<MapOfMixedConditions>();
298 for (auto & [database_and_table_name, mixers] : map_of_mixers)
299 {
300 auto database_and_table_name_keeper = std::make_unique<DatabaseAndTableName>();
301 database_and_table_name_keeper->first = database_and_table_name.first;
302 database_and_table_name_keeper->second = database_and_table_name.second;
303 auto & mixed_conditions = (*map_of_mixed_conditions)[DatabaseAndTableNameRef{database_and_table_name_keeper->first,
304 database_and_table_name_keeper->second}];
305 mixed_conditions.database_and_table_name_keeper = std::move(database_and_table_name_keeper);
306 mixed_conditions.policy_ids = std::move(mixers.policy_ids);
307 for (auto index : ext::range(0, MAX_CONDITION_INDEX))
308 mixed_conditions.mixed_conditions[index] = std::move(mixers.mixers[index]).getResult();
309 }
310
311 std::atomic_store(&context.atomic_map_of_mixed_conditions, std::shared_ptr<const MapOfMixedConditions>{map_of_mixed_conditions});
312}
313
314}
315