1 | #include <Access/RowPolicyContextFactory.h> |
2 | #include <Access/RowPolicyContext.h> |
3 | #include <Access/AccessControlManager.h> |
4 | #include <Parsers/ASTLiteral.h> |
5 | #include <Parsers/ASTFunction.h> |
6 | #include <Parsers/ExpressionListParsers.h> |
7 | #include <Parsers/parseQuery.h> |
8 | #include <Common/Exception.h> |
9 | #include <Common/quoteString.h> |
10 | #include <ext/range.h> |
11 | #include <boost/range/algorithm/copy.hpp> |
12 | #include <boost/range/algorithm_ext/erase.hpp> |
13 | |
14 | |
15 | namespace DB |
16 | { |
17 | namespace |
18 | { |
19 | bool tryGetLiteralBool(const IAST & ast, bool & value) |
20 | { |
21 | try |
22 | { |
23 | if (const ASTLiteral * literal = ast.as<ASTLiteral>()) |
24 | { |
25 | value = !literal->value.isNull() && applyVisitor(FieldVisitorConvertToNumber<bool>(), literal->value); |
26 | return true; |
27 | } |
28 | return false; |
29 | } |
30 | catch (...) |
31 | { |
32 | return false; |
33 | } |
34 | } |
35 | |
36 | ASTPtr applyFunctionAND(ASTs arguments) |
37 | { |
38 | bool const_arguments = true; |
39 | boost::range::remove_erase_if(arguments, [&](const ASTPtr & argument) -> bool |
40 | { |
41 | bool b; |
42 | if (!tryGetLiteralBool(*argument, b)) |
43 | return false; |
44 | const_arguments &= b; |
45 | return true; |
46 | }); |
47 | |
48 | if (!const_arguments) |
49 | return std::make_shared<ASTLiteral>(Field{UInt8(0)}); |
50 | if (arguments.empty()) |
51 | return std::make_shared<ASTLiteral>(Field{UInt8(1)}); |
52 | if (arguments.size() == 1) |
53 | return arguments[0]; |
54 | |
55 | auto function = std::make_shared<ASTFunction>(); |
56 | auto exp_list = std::make_shared<ASTExpressionList>(); |
57 | function->name = "and" ; |
58 | function->arguments = exp_list; |
59 | function->children.push_back(exp_list); |
60 | exp_list->children = std::move(arguments); |
61 | return function; |
62 | } |
63 | |
64 | |
65 | ASTPtr applyFunctionOR(ASTs arguments) |
66 | { |
67 | bool const_arguments = false; |
68 | boost::range::remove_erase_if(arguments, [&](const ASTPtr & argument) -> bool |
69 | { |
70 | bool b; |
71 | if (!tryGetLiteralBool(*argument, b)) |
72 | return false; |
73 | const_arguments |= b; |
74 | return true; |
75 | }); |
76 | |
77 | if (const_arguments) |
78 | return std::make_shared<ASTLiteral>(Field{UInt8(1)}); |
79 | if (arguments.empty()) |
80 | return std::make_shared<ASTLiteral>(Field{UInt8(0)}); |
81 | if (arguments.size() == 1) |
82 | return arguments[0]; |
83 | |
84 | auto function = std::make_shared<ASTFunction>(); |
85 | auto exp_list = std::make_shared<ASTExpressionList>(); |
86 | function->name = "or" ; |
87 | function->arguments = exp_list; |
88 | function->children.push_back(exp_list); |
89 | exp_list->children = std::move(arguments); |
90 | return function; |
91 | } |
92 | |
93 | |
94 | using ConditionIndex = RowPolicy::ConditionIndex; |
95 | static constexpr size_t MAX_CONDITION_INDEX = RowPolicy::MAX_CONDITION_INDEX; |
96 | |
97 | |
98 | /// Accumulates conditions from multiple row policies and joins them using the AND logical operation. |
99 | class ConditionsMixer |
100 | { |
101 | public: |
102 | void add(const ASTPtr & condition, bool is_restrictive) |
103 | { |
104 | if (!condition) |
105 | return; |
106 | |
107 | if (is_restrictive) |
108 | restrictions.push_back(condition); |
109 | else |
110 | permissions.push_back(condition); |
111 | } |
112 | |
113 | ASTPtr getResult() && |
114 | { |
115 | /// Process permissive conditions. |
116 | if (!permissions.empty()) |
117 | restrictions.push_back(applyFunctionOR(std::move(permissions))); |
118 | |
119 | /// Process restrictive conditions. |
120 | if (!restrictions.empty()) |
121 | return applyFunctionAND(std::move(restrictions)); |
122 | return nullptr; |
123 | } |
124 | |
125 | private: |
126 | ASTs permissions; |
127 | ASTs restrictions; |
128 | }; |
129 | } |
130 | |
131 | |
132 | void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_) |
133 | { |
134 | policy = policy_; |
135 | |
136 | boost::range::copy(policy->roles, std::inserter(roles, roles.end())); |
137 | all_roles = policy->all_roles; |
138 | boost::range::copy(policy->except_roles, std::inserter(except_roles, except_roles.end())); |
139 | |
140 | for (auto index : ext::range_with_static_cast<ConditionIndex>(0, MAX_CONDITION_INDEX)) |
141 | { |
142 | const String & condition = policy->conditions[index]; |
143 | auto previous_range = std::pair(std::begin(policy->conditions), std::begin(policy->conditions) + index); |
144 | auto previous_it = std::find(previous_range.first, previous_range.second, condition); |
145 | if (previous_it != previous_range.second) |
146 | { |
147 | /// The condition is already parsed before. |
148 | parsed_conditions[index] = parsed_conditions[previous_it - previous_range.first]; |
149 | } |
150 | else |
151 | { |
152 | /// Try to parse the condition. |
153 | try |
154 | { |
155 | ParserExpression parser; |
156 | parsed_conditions[index] = parseQuery(parser, condition, 0); |
157 | } |
158 | catch (...) |
159 | { |
160 | tryLogCurrentException( |
161 | &Poco::Logger::get("RowPolicy" ), |
162 | String("Could not parse the condition " ) + RowPolicy::conditionIndexToString(index) + " of row policy " |
163 | + backQuote(policy->getFullName())); |
164 | } |
165 | } |
166 | } |
167 | } |
168 | |
169 | |
170 | bool RowPolicyContextFactory::PolicyInfo::canUseWithContext(const RowPolicyContext & context) const |
171 | { |
172 | if (roles.count(context.user_name)) |
173 | return true; |
174 | |
175 | if (all_roles && !except_roles.count(context.user_name)) |
176 | return true; |
177 | |
178 | return false; |
179 | } |
180 | |
181 | |
182 | RowPolicyContextFactory::RowPolicyContextFactory(const AccessControlManager & access_control_manager_) |
183 | : access_control_manager(access_control_manager_) |
184 | { |
185 | } |
186 | |
187 | RowPolicyContextFactory::~RowPolicyContextFactory() = default; |
188 | |
189 | |
190 | RowPolicyContextPtr RowPolicyContextFactory::createContext(const String & user_name) |
191 | { |
192 | std::lock_guard lock{mutex}; |
193 | ensureAllRowPoliciesRead(); |
194 | auto context = ext::shared_ptr_helper<RowPolicyContext>::create(user_name); |
195 | contexts.push_back(context); |
196 | mixConditionsForContext(*context); |
197 | return context; |
198 | } |
199 | |
200 | |
201 | void RowPolicyContextFactory::ensureAllRowPoliciesRead() |
202 | { |
203 | /// `mutex` is already locked. |
204 | if (all_policies_read) |
205 | return; |
206 | all_policies_read = true; |
207 | |
208 | subscription = access_control_manager.subscribeForChanges<RowPolicy>( |
209 | [&](const UUID & id, const AccessEntityPtr & entity) |
210 | { |
211 | if (entity) |
212 | rowPolicyAddedOrChanged(id, typeid_cast<RowPolicyPtr>(entity)); |
213 | else |
214 | rowPolicyRemoved(id); |
215 | }); |
216 | |
217 | for (const UUID & id : access_control_manager.findAll<RowPolicy>()) |
218 | { |
219 | auto quota = access_control_manager.tryRead<RowPolicy>(id); |
220 | if (quota) |
221 | all_policies.emplace(id, PolicyInfo(quota)); |
222 | } |
223 | } |
224 | |
225 | |
226 | void RowPolicyContextFactory::rowPolicyAddedOrChanged(const UUID & policy_id, const RowPolicyPtr & new_policy) |
227 | { |
228 | std::lock_guard lock{mutex}; |
229 | auto it = all_policies.find(policy_id); |
230 | if (it == all_policies.end()) |
231 | { |
232 | it = all_policies.emplace(policy_id, PolicyInfo(new_policy)).first; |
233 | } |
234 | else |
235 | { |
236 | if (it->second.policy == new_policy) |
237 | return; |
238 | } |
239 | |
240 | auto & info = it->second; |
241 | info.setPolicy(new_policy); |
242 | mixConditionsForAllContexts(); |
243 | } |
244 | |
245 | |
246 | void RowPolicyContextFactory::rowPolicyRemoved(const UUID & policy_id) |
247 | { |
248 | std::lock_guard lock{mutex}; |
249 | all_policies.erase(policy_id); |
250 | mixConditionsForAllContexts(); |
251 | } |
252 | |
253 | |
254 | void RowPolicyContextFactory::mixConditionsForAllContexts() |
255 | { |
256 | /// `mutex` is already locked. |
257 | boost::range::remove_erase_if( |
258 | contexts, |
259 | [&](const std::weak_ptr<RowPolicyContext> & weak) |
260 | { |
261 | auto context = weak.lock(); |
262 | if (!context) |
263 | return true; // remove from the `contexts` list. |
264 | mixConditionsForContext(*context); |
265 | return false; // keep in the `contexts` list. |
266 | }); |
267 | } |
268 | |
269 | |
270 | void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context) |
271 | { |
272 | /// `mutex` is already locked. |
273 | struct Mixers |
274 | { |
275 | ConditionsMixer mixers[MAX_CONDITION_INDEX]; |
276 | std::vector<UUID> policy_ids; |
277 | }; |
278 | using MapOfMixedConditions = RowPolicyContext::MapOfMixedConditions; |
279 | using DatabaseAndTableName = RowPolicyContext::DatabaseAndTableName; |
280 | using DatabaseAndTableNameRef = RowPolicyContext::DatabaseAndTableNameRef; |
281 | using Hash = RowPolicyContext::Hash; |
282 | |
283 | std::unordered_map<DatabaseAndTableName, Mixers, Hash> map_of_mixers; |
284 | |
285 | for (const auto & [policy_id, info] : all_policies) |
286 | { |
287 | if (info.canUseWithContext(context)) |
288 | { |
289 | const auto & policy = *info.policy; |
290 | auto & mixers = map_of_mixers[std::pair{policy.getDatabase(), policy.getTableName()}]; |
291 | mixers.policy_ids.push_back(policy_id); |
292 | for (auto index : ext::range(0, MAX_CONDITION_INDEX)) |
293 | mixers.mixers[index].add(info.parsed_conditions[index], policy.isRestrictive()); |
294 | } |
295 | } |
296 | |
297 | auto map_of_mixed_conditions = std::make_shared<MapOfMixedConditions>(); |
298 | for (auto & [database_and_table_name, mixers] : map_of_mixers) |
299 | { |
300 | auto database_and_table_name_keeper = std::make_unique<DatabaseAndTableName>(); |
301 | database_and_table_name_keeper->first = database_and_table_name.first; |
302 | database_and_table_name_keeper->second = database_and_table_name.second; |
303 | auto & mixed_conditions = (*map_of_mixed_conditions)[DatabaseAndTableNameRef{database_and_table_name_keeper->first, |
304 | database_and_table_name_keeper->second}]; |
305 | mixed_conditions.database_and_table_name_keeper = std::move(database_and_table_name_keeper); |
306 | mixed_conditions.policy_ids = std::move(mixers.policy_ids); |
307 | for (auto index : ext::range(0, MAX_CONDITION_INDEX)) |
308 | mixed_conditions.mixed_conditions[index] = std::move(mixers.mixers[index]).getResult(); |
309 | } |
310 | |
311 | std::atomic_store(&context.atomic_map_of_mixed_conditions, std::shared_ptr<const MapOfMixedConditions>{map_of_mixed_conditions}); |
312 | } |
313 | |
314 | } |
315 | |