1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "parquet/schema.h" |
19 | #include "parquet/schema-internal.h" |
20 | |
21 | #include <algorithm> |
22 | #include <memory> |
23 | #include <sstream> |
24 | #include <string> |
25 | #include <utility> |
26 | |
27 | #include "parquet/exception.h" |
28 | #include "parquet/thrift.h" |
29 | |
30 | using parquet::format::SchemaElement; |
31 | |
32 | namespace parquet { |
33 | |
34 | namespace schema { |
35 | |
36 | // ---------------------------------------------------------------------- |
37 | // ColumnPath |
38 | |
39 | std::shared_ptr<ColumnPath> ColumnPath::FromDotString(const std::string& dotstring) { |
40 | std::stringstream ss(dotstring); |
41 | std::string item; |
42 | std::vector<std::string> path; |
43 | while (std::getline(ss, item, '.')) { |
44 | path.push_back(item); |
45 | } |
46 | return std::shared_ptr<ColumnPath>(new ColumnPath(std::move(path))); |
47 | } |
48 | |
49 | std::shared_ptr<ColumnPath> ColumnPath::FromNode(const Node& node) { |
50 | // Build the path in reverse order as we traverse the nodes to the top |
51 | std::vector<std::string> rpath_; |
52 | const Node* cursor = &node; |
53 | // The schema node is not part of the ColumnPath |
54 | while (cursor->parent()) { |
55 | rpath_.push_back(cursor->name()); |
56 | cursor = cursor->parent(); |
57 | } |
58 | |
59 | // Build ColumnPath in correct order |
60 | std::vector<std::string> path(rpath_.crbegin(), rpath_.crend()); |
61 | return std::make_shared<ColumnPath>(std::move(path)); |
62 | } |
63 | |
64 | std::shared_ptr<ColumnPath> ColumnPath::extend(const std::string& node_name) const { |
65 | std::vector<std::string> path; |
66 | path.reserve(path_.size() + 1); |
67 | path.resize(path_.size() + 1); |
68 | std::copy(path_.cbegin(), path_.cend(), path.begin()); |
69 | path[path_.size()] = node_name; |
70 | |
71 | return std::shared_ptr<ColumnPath>(new ColumnPath(std::move(path))); |
72 | } |
73 | |
74 | std::string ColumnPath::ToDotString() const { |
75 | std::stringstream ss; |
76 | for (auto it = path_.cbegin(); it != path_.cend(); ++it) { |
77 | if (it != path_.cbegin()) { |
78 | ss << "." ; |
79 | } |
80 | ss << *it; |
81 | } |
82 | return ss.str(); |
83 | } |
84 | |
85 | const std::vector<std::string>& ColumnPath::ToDotVector() const { return path_; } |
86 | |
87 | // ---------------------------------------------------------------------- |
88 | // Base node |
89 | |
90 | const std::shared_ptr<ColumnPath> Node::path() const { |
91 | // TODO(itaiin): Cache the result, or more precisely, cache ->ToDotString() |
92 | // since it is being used to access the leaf nodes |
93 | return ColumnPath::FromNode(*this); |
94 | } |
95 | |
96 | bool Node::EqualsInternal(const Node* other) const { |
97 | return type_ == other->type_ && name_ == other->name_ && |
98 | repetition_ == other->repetition_ && logical_type_ == other->logical_type_; |
99 | } |
100 | |
101 | void Node::SetParent(const Node* parent) { parent_ = parent; } |
102 | |
103 | // ---------------------------------------------------------------------- |
104 | // Primitive node |
105 | |
106 | PrimitiveNode::PrimitiveNode(const std::string& name, Repetition::type repetition, |
107 | Type::type type, LogicalType::type logical_type, int length, |
108 | int precision, int scale, int id) |
109 | : Node(Node::PRIMITIVE, name, repetition, logical_type, id), |
110 | physical_type_(type), |
111 | type_length_(length) { |
112 | std::stringstream ss; |
113 | |
114 | // PARQUET-842: In an earlier revision, decimal_metadata_.isset was being |
115 | // set to true, but Impala will raise an incompatible metadata in such cases |
116 | memset(&decimal_metadata_, 0, sizeof(decimal_metadata_)); |
117 | |
118 | // Check if the physical and logical types match |
119 | // Mapping referred from Apache parquet-mr as on 2016-02-22 |
120 | switch (logical_type) { |
121 | case LogicalType::NONE: |
122 | // Logical type not set |
123 | break; |
124 | case LogicalType::UTF8: |
125 | case LogicalType::JSON: |
126 | case LogicalType::BSON: |
127 | if (type != Type::BYTE_ARRAY) { |
128 | ss << LogicalTypeToString(logical_type); |
129 | ss << " can only annotate BYTE_ARRAY fields" ; |
130 | throw ParquetException(ss.str()); |
131 | } |
132 | break; |
133 | case LogicalType::DECIMAL: |
134 | if ((type != Type::INT32) && (type != Type::INT64) && (type != Type::BYTE_ARRAY) && |
135 | (type != Type::FIXED_LEN_BYTE_ARRAY)) { |
136 | ss << "DECIMAL can only annotate INT32, INT64, BYTE_ARRAY, and FIXED" ; |
137 | throw ParquetException(ss.str()); |
138 | } |
139 | if (precision <= 0) { |
140 | ss << "Invalid DECIMAL precision: " << precision |
141 | << ". Precision must be a number between 1 and 38 inclusive" ; |
142 | throw ParquetException(ss.str()); |
143 | } |
144 | if (scale < 0) { |
145 | ss << "Invalid DECIMAL scale: " << scale |
146 | << ". Scale must be a number between 0 and precision inclusive" ; |
147 | throw ParquetException(ss.str()); |
148 | } |
149 | if (scale > precision) { |
150 | ss << "Invalid DECIMAL scale " << scale; |
151 | ss << " cannot be greater than precision " << precision; |
152 | throw ParquetException(ss.str()); |
153 | } |
154 | decimal_metadata_.isset = true; |
155 | decimal_metadata_.precision = precision; |
156 | decimal_metadata_.scale = scale; |
157 | break; |
158 | case LogicalType::DATE: |
159 | case LogicalType::TIME_MILLIS: |
160 | case LogicalType::UINT_8: |
161 | case LogicalType::UINT_16: |
162 | case LogicalType::UINT_32: |
163 | case LogicalType::INT_8: |
164 | case LogicalType::INT_16: |
165 | case LogicalType::INT_32: |
166 | if (type != Type::INT32) { |
167 | ss << LogicalTypeToString(logical_type); |
168 | ss << " can only annotate INT32" ; |
169 | throw ParquetException(ss.str()); |
170 | } |
171 | break; |
172 | case LogicalType::TIME_MICROS: |
173 | case LogicalType::TIMESTAMP_MILLIS: |
174 | case LogicalType::TIMESTAMP_MICROS: |
175 | case LogicalType::UINT_64: |
176 | case LogicalType::INT_64: |
177 | if (type != Type::INT64) { |
178 | ss << LogicalTypeToString(logical_type); |
179 | ss << " can only annotate INT64" ; |
180 | throw ParquetException(ss.str()); |
181 | } |
182 | break; |
183 | case LogicalType::INTERVAL: |
184 | if ((type != Type::FIXED_LEN_BYTE_ARRAY) || (length != 12)) { |
185 | ss << "INTERVAL can only annotate FIXED_LEN_BYTE_ARRAY(12)" ; |
186 | throw ParquetException(ss.str()); |
187 | } |
188 | break; |
189 | case LogicalType::ENUM: |
190 | if (type != Type::BYTE_ARRAY) { |
191 | ss << "ENUM can only annotate BYTE_ARRAY fields" ; |
192 | throw ParquetException(ss.str()); |
193 | } |
194 | break; |
195 | case LogicalType::NA: |
196 | // NA can annotate any type |
197 | break; |
198 | default: |
199 | ss << LogicalTypeToString(logical_type); |
200 | ss << " can not be applied to a primitive type" ; |
201 | throw ParquetException(ss.str()); |
202 | } |
203 | if (type == Type::FIXED_LEN_BYTE_ARRAY) { |
204 | if (length <= 0) { |
205 | ss << "Invalid FIXED_LEN_BYTE_ARRAY length: " << length; |
206 | throw ParquetException(ss.str()); |
207 | } |
208 | type_length_ = length; |
209 | } |
210 | } |
211 | |
212 | bool PrimitiveNode::EqualsInternal(const PrimitiveNode* other) const { |
213 | bool is_equal = true; |
214 | if ((physical_type_ != other->physical_type_) || |
215 | (logical_type_ != other->logical_type_)) { |
216 | return false; |
217 | } |
218 | if (logical_type_ == LogicalType::DECIMAL) { |
219 | is_equal &= (decimal_metadata_.precision == other->decimal_metadata_.precision) && |
220 | (decimal_metadata_.scale == other->decimal_metadata_.scale); |
221 | } |
222 | if (physical_type_ == Type::FIXED_LEN_BYTE_ARRAY) { |
223 | is_equal &= (type_length_ == other->type_length_); |
224 | } |
225 | return is_equal; |
226 | } |
227 | |
228 | bool PrimitiveNode::Equals(const Node* other) const { |
229 | if (!Node::EqualsInternal(other)) { |
230 | return false; |
231 | } |
232 | return EqualsInternal(static_cast<const PrimitiveNode*>(other)); |
233 | } |
234 | |
235 | void PrimitiveNode::Visit(Node::Visitor* visitor) { visitor->Visit(this); } |
236 | |
237 | void PrimitiveNode::VisitConst(Node::ConstVisitor* visitor) const { |
238 | visitor->Visit(this); |
239 | } |
240 | |
241 | // ---------------------------------------------------------------------- |
242 | // Group node |
243 | |
244 | bool GroupNode::EqualsInternal(const GroupNode* other) const { |
245 | if (this == other) { |
246 | return true; |
247 | } |
248 | if (this->field_count() != other->field_count()) { |
249 | return false; |
250 | } |
251 | for (int i = 0; i < this->field_count(); ++i) { |
252 | if (!this->field(i)->Equals(other->field(i).get())) { |
253 | return false; |
254 | } |
255 | } |
256 | return true; |
257 | } |
258 | |
259 | bool GroupNode::Equals(const Node* other) const { |
260 | if (!Node::EqualsInternal(other)) { |
261 | return false; |
262 | } |
263 | return EqualsInternal(static_cast<const GroupNode*>(other)); |
264 | } |
265 | |
266 | int GroupNode::FieldIndex(const std::string& name) const { |
267 | auto search = field_name_to_idx_.find(name); |
268 | if (search == field_name_to_idx_.end()) { |
269 | // Not found |
270 | return -1; |
271 | } |
272 | return search->second; |
273 | } |
274 | |
275 | int GroupNode::FieldIndex(const Node& node) const { |
276 | auto search = field_name_to_idx_.equal_range(node.name()); |
277 | for (auto it = search.first; it != search.second; ++it) { |
278 | const int idx = it->second; |
279 | if (&node == field(idx).get()) { |
280 | return idx; |
281 | } |
282 | } |
283 | return -1; |
284 | } |
285 | |
286 | void GroupNode::Visit(Node::Visitor* visitor) { visitor->Visit(this); } |
287 | |
288 | void GroupNode::VisitConst(Node::ConstVisitor* visitor) const { visitor->Visit(this); } |
289 | |
290 | // ---------------------------------------------------------------------- |
291 | // Node construction from Parquet metadata |
292 | |
293 | struct NodeParams { |
294 | explicit NodeParams(const std::string& name) : name(name) {} |
295 | |
296 | const std::string& name; |
297 | Repetition::type repetition; |
298 | LogicalType::type logical_type; |
299 | }; |
300 | |
301 | static inline NodeParams GetNodeParams(const format::SchemaElement* element) { |
302 | NodeParams params(element->name); |
303 | |
304 | params.repetition = FromThrift(element->repetition_type); |
305 | if (element->__isset.converted_type) { |
306 | params.logical_type = FromThrift(element->converted_type); |
307 | } else { |
308 | params.logical_type = LogicalType::NONE; |
309 | } |
310 | return params; |
311 | } |
312 | |
313 | std::unique_ptr<Node> GroupNode::FromParquet(const void* opaque_element, int node_id, |
314 | const NodeVector& fields) { |
315 | const format::SchemaElement* element = |
316 | static_cast<const format::SchemaElement*>(opaque_element); |
317 | NodeParams params = GetNodeParams(element); |
318 | return std::unique_ptr<Node>(new GroupNode(params.name, params.repetition, fields, |
319 | params.logical_type, node_id)); |
320 | } |
321 | |
322 | std::unique_ptr<Node> PrimitiveNode::FromParquet(const void* opaque_element, |
323 | int node_id) { |
324 | const format::SchemaElement* element = |
325 | static_cast<const format::SchemaElement*>(opaque_element); |
326 | NodeParams params = GetNodeParams(element); |
327 | |
328 | std::unique_ptr<PrimitiveNode> result = |
329 | std::unique_ptr<PrimitiveNode>(new PrimitiveNode( |
330 | params.name, params.repetition, FromThrift(element->type), params.logical_type, |
331 | element->type_length, element->precision, element->scale, node_id)); |
332 | |
333 | // Return as unique_ptr to the base type |
334 | return std::unique_ptr<Node>(result.release()); |
335 | } |
336 | |
337 | void GroupNode::ToParquet(void* opaque_element) const { |
338 | format::SchemaElement* element = static_cast<format::SchemaElement*>(opaque_element); |
339 | element->__set_name(name_); |
340 | element->__set_num_children(field_count()); |
341 | element->__set_repetition_type(ToThrift(repetition_)); |
342 | if (logical_type_ != LogicalType::NONE) { |
343 | element->__set_converted_type(ToThrift(logical_type_)); |
344 | } |
345 | } |
346 | |
347 | void PrimitiveNode::ToParquet(void* opaque_element) const { |
348 | format::SchemaElement* element = static_cast<format::SchemaElement*>(opaque_element); |
349 | |
350 | element->__set_name(name_); |
351 | element->__set_repetition_type(ToThrift(repetition_)); |
352 | if (logical_type_ != LogicalType::NONE) { |
353 | element->__set_converted_type(ToThrift(logical_type_)); |
354 | } |
355 | element->__set_type(ToThrift(physical_type_)); |
356 | if (physical_type_ == Type::FIXED_LEN_BYTE_ARRAY) { |
357 | element->__set_type_length(type_length_); |
358 | } |
359 | if (decimal_metadata_.isset) { |
360 | element->__set_precision(decimal_metadata_.precision); |
361 | element->__set_scale(decimal_metadata_.scale); |
362 | } |
363 | } |
364 | |
365 | // ---------------------------------------------------------------------- |
366 | // Schema converters |
367 | |
368 | std::unique_ptr<Node> FlatSchemaConverter::Convert() { |
369 | const SchemaElement& root = elements_[0]; |
370 | |
371 | // Validate the root node |
372 | if (root.num_children == 0) { |
373 | throw ParquetException("Root node did not have children" ); |
374 | } |
375 | |
376 | // Relaxing this restriction as some implementations don't set this |
377 | // if (root.repetition_type != FieldRepetitionType::REPEATED) { |
378 | // throw ParquetException("Root node was not FieldRepetitionType::REPEATED"); |
379 | // } |
380 | |
381 | return NextNode(); |
382 | } |
383 | |
384 | std::unique_ptr<Node> FlatSchemaConverter::NextNode() { |
385 | const SchemaElement& element = Next(); |
386 | |
387 | int node_id = next_id(); |
388 | |
389 | const void* opaque_element = static_cast<const void*>(&element); |
390 | |
391 | if (element.num_children == 0) { |
392 | // Leaf (primitive) node |
393 | return PrimitiveNode::FromParquet(opaque_element, node_id); |
394 | } else { |
395 | // Group |
396 | NodeVector fields; |
397 | for (int i = 0; i < element.num_children; ++i) { |
398 | std::unique_ptr<Node> field = NextNode(); |
399 | fields.push_back(NodePtr(field.release())); |
400 | } |
401 | return GroupNode::FromParquet(opaque_element, node_id, fields); |
402 | } |
403 | } |
404 | |
405 | const format::SchemaElement& FlatSchemaConverter::Next() { |
406 | if (pos_ == length_) { |
407 | throw ParquetException("Malformed schema: not enough SchemaElement values" ); |
408 | } |
409 | return elements_[pos_++]; |
410 | } |
411 | |
412 | std::shared_ptr<SchemaDescriptor> FromParquet(const std::vector<SchemaElement>& schema) { |
413 | FlatSchemaConverter converter(&schema[0], static_cast<int>(schema.size())); |
414 | std::unique_ptr<Node> root = converter.Convert(); |
415 | |
416 | std::shared_ptr<SchemaDescriptor> descr = std::make_shared<SchemaDescriptor>(); |
417 | descr->Init(std::shared_ptr<GroupNode>(static_cast<GroupNode*>(root.release()))); |
418 | |
419 | return descr; |
420 | } |
421 | |
422 | void ToParquet(const GroupNode* schema, std::vector<format::SchemaElement>* out) { |
423 | SchemaFlattener flattener(schema, out); |
424 | flattener.Flatten(); |
425 | } |
426 | |
427 | class SchemaVisitor : public Node::ConstVisitor { |
428 | public: |
429 | explicit SchemaVisitor(std::vector<format::SchemaElement>* elements) |
430 | : elements_(elements) {} |
431 | |
432 | void Visit(const Node* node) override { |
433 | format::SchemaElement element; |
434 | node->ToParquet(&element); |
435 | elements_->push_back(element); |
436 | |
437 | if (node->is_group()) { |
438 | const GroupNode* group_node = static_cast<const GroupNode*>(node); |
439 | for (int i = 0; i < group_node->field_count(); ++i) { |
440 | group_node->field(i)->VisitConst(this); |
441 | } |
442 | } |
443 | } |
444 | |
445 | private: |
446 | std::vector<format::SchemaElement>* elements_; |
447 | }; |
448 | |
449 | SchemaFlattener::SchemaFlattener(const GroupNode* schema, |
450 | std::vector<format::SchemaElement>* out) |
451 | : root_(schema), elements_(out) {} |
452 | |
453 | void SchemaFlattener::Flatten() { |
454 | SchemaVisitor visitor(elements_); |
455 | root_->VisitConst(&visitor); |
456 | } |
457 | |
458 | // ---------------------------------------------------------------------- |
459 | // Schema printing |
460 | |
461 | class SchemaPrinter : public Node::ConstVisitor { |
462 | public: |
463 | explicit SchemaPrinter(std::ostream& stream, int indent_width) |
464 | : stream_(stream), indent_(0), indent_width_(2) {} |
465 | |
466 | void Visit(const Node* node) override; |
467 | |
468 | private: |
469 | void Visit(const PrimitiveNode* node); |
470 | void Visit(const GroupNode* node); |
471 | |
472 | void Indent(); |
473 | |
474 | std::ostream& stream_; |
475 | |
476 | int indent_; |
477 | int indent_width_; |
478 | }; |
479 | |
480 | static void PrintRepLevel(Repetition::type repetition, std::ostream& stream) { |
481 | switch (repetition) { |
482 | case Repetition::REQUIRED: |
483 | stream << "required" ; |
484 | break; |
485 | case Repetition::OPTIONAL: |
486 | stream << "optional" ; |
487 | break; |
488 | case Repetition::REPEATED: |
489 | stream << "repeated" ; |
490 | break; |
491 | default: |
492 | break; |
493 | } |
494 | } |
495 | |
496 | static void PrintType(const PrimitiveNode* node, std::ostream& stream) { |
497 | switch (node->physical_type()) { |
498 | case Type::BOOLEAN: |
499 | stream << "boolean" ; |
500 | break; |
501 | case Type::INT32: |
502 | stream << "int32" ; |
503 | break; |
504 | case Type::INT64: |
505 | stream << "int64" ; |
506 | break; |
507 | case Type::INT96: |
508 | stream << "int96" ; |
509 | break; |
510 | case Type::FLOAT: |
511 | stream << "float" ; |
512 | break; |
513 | case Type::DOUBLE: |
514 | stream << "double" ; |
515 | break; |
516 | case Type::BYTE_ARRAY: |
517 | stream << "binary" ; |
518 | break; |
519 | case Type::FIXED_LEN_BYTE_ARRAY: |
520 | stream << "fixed_len_byte_array(" << node->type_length() << ")" ; |
521 | break; |
522 | default: |
523 | break; |
524 | } |
525 | } |
526 | |
527 | static void PrintLogicalType(const PrimitiveNode* node, std::ostream& stream) { |
528 | auto lt = node->logical_type(); |
529 | if (lt == LogicalType::DECIMAL) { |
530 | stream << " (" << LogicalTypeToString(lt) << "(" << node->decimal_metadata().precision |
531 | << "," << node->decimal_metadata().scale << "))" ; |
532 | } else if (lt != LogicalType::NONE) { |
533 | stream << " (" << LogicalTypeToString(lt) << ")" ; |
534 | } |
535 | } |
536 | |
537 | void SchemaPrinter::Visit(const PrimitiveNode* node) { |
538 | PrintRepLevel(node->repetition(), stream_); |
539 | stream_ << " " ; |
540 | PrintType(node, stream_); |
541 | stream_ << " " << node->name(); |
542 | PrintLogicalType(node, stream_); |
543 | stream_ << ";" << std::endl; |
544 | } |
545 | |
546 | void SchemaPrinter::Visit(const GroupNode* node) { |
547 | if (!node->parent()) { |
548 | stream_ << "message " << node->name() << " {" << std::endl; |
549 | } else { |
550 | PrintRepLevel(node->repetition(), stream_); |
551 | stream_ << " group " << node->name(); |
552 | auto lt = node->logical_type(); |
553 | if (lt != LogicalType::NONE) { |
554 | stream_ << " (" << LogicalTypeToString(lt) << ")" ; |
555 | } |
556 | stream_ << " {" << std::endl; |
557 | } |
558 | |
559 | indent_ += indent_width_; |
560 | for (int i = 0; i < node->field_count(); ++i) { |
561 | node->field(i)->VisitConst(this); |
562 | } |
563 | indent_ -= indent_width_; |
564 | Indent(); |
565 | stream_ << "}" << std::endl; |
566 | } |
567 | |
568 | void SchemaPrinter::Indent() { |
569 | if (indent_ > 0) { |
570 | std::string spaces(indent_, ' '); |
571 | stream_ << spaces; |
572 | } |
573 | } |
574 | |
575 | void SchemaPrinter::Visit(const Node* node) { |
576 | Indent(); |
577 | if (node->is_group()) { |
578 | Visit(static_cast<const GroupNode*>(node)); |
579 | } else { |
580 | // Primitive |
581 | Visit(static_cast<const PrimitiveNode*>(node)); |
582 | } |
583 | } |
584 | |
585 | void PrintSchema(const Node* schema, std::ostream& stream, int indent_width) { |
586 | SchemaPrinter printer(stream, indent_width); |
587 | printer.Visit(schema); |
588 | } |
589 | |
590 | } // namespace schema |
591 | |
592 | using schema::ColumnPath; |
593 | using schema::GroupNode; |
594 | using schema::Node; |
595 | using schema::NodePtr; |
596 | using schema::PrimitiveNode; |
597 | |
598 | void SchemaDescriptor::Init(std::unique_ptr<schema::Node> schema) { |
599 | Init(NodePtr(schema.release())); |
600 | } |
601 | |
602 | class SchemaUpdater : public Node::Visitor { |
603 | public: |
604 | explicit SchemaUpdater(const std::vector<ColumnOrder>& column_orders) |
605 | : column_orders_(column_orders), leaf_count_(0) {} |
606 | |
607 | void Visit(Node* node) override { |
608 | if (node->is_group()) { |
609 | GroupNode* group_node = static_cast<GroupNode*>(node); |
610 | for (int i = 0; i < group_node->field_count(); ++i) { |
611 | group_node->field(i)->Visit(this); |
612 | } |
613 | } else { // leaf node |
614 | PrimitiveNode* leaf_node = static_cast<PrimitiveNode*>(node); |
615 | leaf_node->SetColumnOrder(column_orders_[leaf_count_++]); |
616 | } |
617 | } |
618 | |
619 | private: |
620 | const std::vector<ColumnOrder>& column_orders_; |
621 | int leaf_count_; |
622 | }; |
623 | |
624 | void SchemaDescriptor::updateColumnOrders(const std::vector<ColumnOrder>& column_orders) { |
625 | if (static_cast<int>(column_orders.size()) != num_columns()) { |
626 | throw ParquetException("Malformed schema: not enough ColumnOrder values" ); |
627 | } |
628 | SchemaUpdater visitor(column_orders); |
629 | const_cast<GroupNode*>(group_node_)->Visit(&visitor); |
630 | } |
631 | |
632 | void SchemaDescriptor::Init(const NodePtr& schema) { |
633 | schema_ = schema; |
634 | |
635 | if (!schema_->is_group()) { |
636 | throw ParquetException("Must initialize with a schema group" ); |
637 | } |
638 | |
639 | group_node_ = static_cast<const GroupNode*>(schema_.get()); |
640 | leaves_.clear(); |
641 | |
642 | for (int i = 0; i < group_node_->field_count(); ++i) { |
643 | BuildTree(group_node_->field(i), 0, 0, group_node_->field(i)); |
644 | } |
645 | } |
646 | |
647 | bool SchemaDescriptor::Equals(const SchemaDescriptor& other) const { |
648 | if (this->num_columns() != other.num_columns()) { |
649 | return false; |
650 | } |
651 | |
652 | for (int i = 0; i < this->num_columns(); ++i) { |
653 | if (!this->Column(i)->Equals(*other.Column(i))) { |
654 | return false; |
655 | } |
656 | } |
657 | |
658 | return true; |
659 | } |
660 | |
661 | void SchemaDescriptor::BuildTree(const NodePtr& node, int16_t max_def_level, |
662 | int16_t max_rep_level, const NodePtr& base) { |
663 | if (node->is_optional()) { |
664 | ++max_def_level; |
665 | } else if (node->is_repeated()) { |
666 | // Repeated fields add a definition level. This is used to distinguish |
667 | // between an empty list and a list with an item in it. |
668 | ++max_rep_level; |
669 | ++max_def_level; |
670 | } |
671 | |
672 | // Now, walk the schema and create a ColumnDescriptor for each leaf node |
673 | if (node->is_group()) { |
674 | const GroupNode* group = static_cast<const GroupNode*>(node.get()); |
675 | for (int i = 0; i < group->field_count(); ++i) { |
676 | BuildTree(group->field(i), max_def_level, max_rep_level, base); |
677 | } |
678 | } else { |
679 | // Primitive node, append to leaves |
680 | leaves_.push_back(ColumnDescriptor(node, max_def_level, max_rep_level, this)); |
681 | leaf_to_base_.emplace(static_cast<int>(leaves_.size()) - 1, base); |
682 | leaf_to_idx_.emplace(node->path()->ToDotString(), |
683 | static_cast<int>(leaves_.size()) - 1); |
684 | } |
685 | } |
686 | |
687 | ColumnDescriptor::ColumnDescriptor(const schema::NodePtr& node, |
688 | int16_t max_definition_level, |
689 | int16_t max_repetition_level, |
690 | const SchemaDescriptor* schema_descr) |
691 | : node_(node), |
692 | max_definition_level_(max_definition_level), |
693 | max_repetition_level_(max_repetition_level) { |
694 | if (!node_->is_primitive()) { |
695 | throw ParquetException("Must be a primitive type" ); |
696 | } |
697 | primitive_node_ = static_cast<const PrimitiveNode*>(node_.get()); |
698 | } |
699 | |
700 | bool ColumnDescriptor::Equals(const ColumnDescriptor& other) const { |
701 | return primitive_node_->Equals(other.primitive_node_) && |
702 | max_repetition_level() == other.max_repetition_level() && |
703 | max_definition_level() == other.max_definition_level(); |
704 | } |
705 | |
706 | const ColumnDescriptor* SchemaDescriptor::Column(int i) const { |
707 | DCHECK(i >= 0 && i < static_cast<int>(leaves_.size())); |
708 | return &leaves_[i]; |
709 | } |
710 | |
711 | int SchemaDescriptor::ColumnIndex(const std::string& node_path) const { |
712 | auto search = leaf_to_idx_.find(node_path); |
713 | if (search == leaf_to_idx_.end()) { |
714 | // Not found |
715 | return -1; |
716 | } |
717 | return search->second; |
718 | } |
719 | |
720 | int SchemaDescriptor::ColumnIndex(const Node& node) const { |
721 | auto search = leaf_to_idx_.equal_range(node.path()->ToDotString()); |
722 | for (auto it = search.first; it != search.second; ++it) { |
723 | const int idx = it->second; |
724 | if (&node == Column(idx)->schema_node().get()) { |
725 | return idx; |
726 | } |
727 | } |
728 | return -1; |
729 | } |
730 | |
731 | const schema::Node* SchemaDescriptor::GetColumnRoot(int i) const { |
732 | DCHECK(i >= 0 && i < static_cast<int>(leaves_.size())); |
733 | return leaf_to_base_.find(i)->second.get(); |
734 | } |
735 | |
736 | std::string SchemaDescriptor::ToString() const { |
737 | std::ostringstream ss; |
738 | PrintSchema(schema_.get(), ss); |
739 | return ss.str(); |
740 | } |
741 | |
742 | int ColumnDescriptor::type_scale() const { |
743 | return primitive_node_->decimal_metadata().scale; |
744 | } |
745 | |
746 | int ColumnDescriptor::type_precision() const { |
747 | return primitive_node_->decimal_metadata().precision; |
748 | } |
749 | |
750 | int ColumnDescriptor::type_length() const { return primitive_node_->type_length(); } |
751 | |
752 | const std::shared_ptr<ColumnPath> ColumnDescriptor::path() const { |
753 | return primitive_node_->path(); |
754 | } |
755 | |
756 | } // namespace parquet |
757 | |