1/*******************************************************************************
2* Copyright 2018 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#ifndef PRIMITIVE_ITERATOR_HPP
17#define PRIMITIVE_ITERATOR_HPP
18
19#include "mkldnn.h"
20
21#include "c_types_map.hpp"
22#include "engine.hpp"
23#include "primitive_desc.hpp"
24#include "type_helpers.hpp"
25
26struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible {
27 using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
28
29 mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc,
30 const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd)
31 : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc)
32 , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd)
33 , impl_list_(engine_->get_implementation_list()), last_idx_(0)
34 {
35 while (impl_list_[last_idx_] != nullptr) ++last_idx_;
36 }
37 ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; }
38
39 bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
40 { return idx_ == rhs.idx_ && engine_ == rhs.engine_; }
41 bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
42 { return !operator==(rhs); }
43
44 mkldnn::impl::primitive_desc_iterator_t end() const
45 { return mkldnn_primitive_desc_iterator(engine_, last_idx_); }
46
47 mkldnn::impl::primitive_desc_iterator_t &operator++() {
48 if (pd_) { delete pd_; pd_ = nullptr; }
49 while (++idx_ != last_idx_) {
50 auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_,
51 hint_fwd_pd_);
52 if (s == mkldnn::impl::status::success) break;
53 }
54 return *this;
55 }
56
57 mkldnn::impl::primitive_desc_t *operator*() const {
58 if (*this == end() || pd_ == nullptr) return nullptr;
59 return pd_->clone();
60 }
61
62protected:
63 int idx_;
64 mkldnn::impl::engine_t *engine_;
65 mkldnn::impl::primitive_desc_t *pd_;
66 const mkldnn::impl::op_desc_t *op_desc_;
67 const mkldnn::impl::primitive_attr_t attr_;
68 const mkldnn::impl::primitive_desc_t *hint_fwd_pd_;
69 const pd_create_f *impl_list_;
70 int last_idx_;
71
72private:
73 mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx)
74 : idx_(last_idx), engine_(engine), pd_(nullptr)
75 , op_desc_(nullptr), hint_fwd_pd_(nullptr)
76 , impl_list_(nullptr), last_idx_(last_idx) {}
77};
78
79#endif
80