| 1 | /******************************************************************************* |
| 2 | * Copyright 2018 Intel Corporation |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | *******************************************************************************/ |
| 16 | #ifndef PRIMITIVE_ITERATOR_HPP |
| 17 | #define PRIMITIVE_ITERATOR_HPP |
| 18 | |
| 19 | #include "mkldnn.h" |
| 20 | |
| 21 | #include "c_types_map.hpp" |
| 22 | #include "engine.hpp" |
| 23 | #include "primitive_desc.hpp" |
| 24 | #include "type_helpers.hpp" |
| 25 | |
| 26 | struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible { |
| 27 | using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f; |
| 28 | |
| 29 | mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc, |
| 30 | const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd) |
| 31 | : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc) |
| 32 | , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd) |
| 33 | , impl_list_(engine_->get_implementation_list()), last_idx_(0) |
| 34 | { |
| 35 | while (impl_list_[last_idx_] != nullptr) ++last_idx_; |
| 36 | } |
| 37 | ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; } |
| 38 | |
| 39 | bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const |
| 40 | { return idx_ == rhs.idx_ && engine_ == rhs.engine_; } |
| 41 | bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const |
| 42 | { return !operator==(rhs); } |
| 43 | |
| 44 | mkldnn::impl::primitive_desc_iterator_t end() const |
| 45 | { return mkldnn_primitive_desc_iterator(engine_, last_idx_); } |
| 46 | |
| 47 | mkldnn::impl::primitive_desc_iterator_t &operator++() { |
| 48 | if (pd_) { delete pd_; pd_ = nullptr; } |
| 49 | while (++idx_ != last_idx_) { |
| 50 | auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_, |
| 51 | hint_fwd_pd_); |
| 52 | if (s == mkldnn::impl::status::success) break; |
| 53 | } |
| 54 | return *this; |
| 55 | } |
| 56 | |
| 57 | mkldnn::impl::primitive_desc_t *operator*() const { |
| 58 | if (*this == end() || pd_ == nullptr) return nullptr; |
| 59 | return pd_->clone(); |
| 60 | } |
| 61 | |
| 62 | protected: |
| 63 | int idx_; |
| 64 | mkldnn::impl::engine_t *engine_; |
| 65 | mkldnn::impl::primitive_desc_t *pd_; |
| 66 | const mkldnn::impl::op_desc_t *op_desc_; |
| 67 | const mkldnn::impl::primitive_attr_t attr_; |
| 68 | const mkldnn::impl::primitive_desc_t *hint_fwd_pd_; |
| 69 | const pd_create_f *impl_list_; |
| 70 | int last_idx_; |
| 71 | |
| 72 | private: |
| 73 | mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx) |
| 74 | : idx_(last_idx), engine_(engine), pd_(nullptr) |
| 75 | , op_desc_(nullptr), hint_fwd_pd_(nullptr) |
| 76 | , impl_list_(nullptr), last_idx_(last_idx) {} |
| 77 | }; |
| 78 | |
| 79 | #endif |
| 80 | |