1/*
2 Copyright (c) 2005-2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15*/
16
17#include "harness.h"
18
19#if __TBB_CPF_BUILD
20#define TBB_DEPRECATED_FLOW_NODE_EXTRACTION 1
21#endif
22
23#include "tbb/flow_graph.h"
24#include "tbb/task.h"
25#include "tbb/atomic.h"
26
27const int N = 1000;
28const int R = 4;
29
30class int_convertable_type : private NoAssign {
31
32 int my_value;
33
34public:
35
36 int_convertable_type( int v ) : my_value(v) {}
37 operator int() const { return my_value; }
38
39};
40
41
42template< typename T >
43class counting_array_receiver : public tbb::flow::receiver<T> {
44
45 tbb::atomic<size_t> my_counters[N];
46 tbb::flow::graph& my_graph;
47
48public:
49
50 counting_array_receiver(tbb::flow::graph& g) : my_graph(g) {
51 for (int i = 0; i < N; ++i )
52 my_counters[i] = 0;
53 }
54
55 size_t operator[]( int i ) {
56 size_t v = my_counters[i];
57 return v;
58 }
59
60 tbb::task * try_put_task( const T &v ) __TBB_override {
61 ++my_counters[(int)v];
62 return const_cast<tbb::task *>(tbb::flow::internal::SUCCESSFULLY_ENQUEUED);
63 }
64
65 tbb::flow::graph& graph_reference() __TBB_override {
66 return my_graph;
67 }
68
69#if TBB_DEPRECATED_FLOW_NODE_EXTRACTION
70 typedef typename tbb::flow::receiver<T>::built_predecessors_type built_predecessors_type;
71 built_predecessors_type mbp;
72 built_predecessors_type &built_predecessors() __TBB_override { return mbp; }
73 typedef typename tbb::flow::receiver<T>::predecessor_list_type predecessor_list_type;
74 typedef typename tbb::flow::receiver<T>::predecessor_type predecessor_type;
75 void internal_add_built_predecessor(predecessor_type &) __TBB_override {}
76 void internal_delete_built_predecessor(predecessor_type &) __TBB_override {}
77 void copy_predecessors(predecessor_list_type &) __TBB_override {}
78 size_t predecessor_count() __TBB_override { return 0; }
79#endif
80 void reset_receiver(tbb::flow::reset_flags /*f*/) __TBB_override { }
81
82};
83
84template< typename T >
85void test_serial_broadcasts() {
86
87 tbb::flow::graph g;
88 tbb::flow::broadcast_node<T> b(g);
89
90 for ( int num_receivers = 1; num_receivers < R; ++num_receivers ) {
91 std::vector< counting_array_receiver<T> > receivers(num_receivers, counting_array_receiver<T>(g));
92#if TBB_DEPRECATED_FLOW_NODE_EXTRACTION
93 ASSERT(b.successor_count() == 0, NULL);
94 ASSERT(b.predecessor_count() == 0, NULL);
95 typename tbb::flow::broadcast_node<T>::successor_list_type my_succs;
96 b.copy_successors(my_succs);
97 ASSERT(my_succs.size() == 0, NULL);
98 typename tbb::flow::broadcast_node<T>::predecessor_list_type my_preds;
99 b.copy_predecessors(my_preds);
100 ASSERT(my_preds.size() == 0, NULL);
101#endif
102
103 for ( int r = 0; r < num_receivers; ++r ) {
104 tbb::flow::make_edge( b, receivers[r] );
105 }
106#if TBB_DEPRECATED_FLOW_NODE_EXTRACTION
107 ASSERT( b.successor_count() == (size_t)num_receivers, NULL);
108#endif
109
110 for (int n = 0; n < N; ++n ) {
111 ASSERT( b.try_put( (T)n ), NULL );
112 }
113
114 for ( int r = 0; r < num_receivers; ++r ) {
115 for (int n = 0; n < N; ++n ) {
116 ASSERT( receivers[r][n] == 1, NULL );
117 }
118 tbb::flow::remove_edge( b, receivers[r] );
119 }
120 ASSERT( b.try_put( (T)0 ), NULL );
121 for ( int r = 0; r < num_receivers; ++r )
122 ASSERT( receivers[0][0] == 1, NULL );
123 }
124
125}
126
127template< typename T >
128class native_body : private NoAssign {
129
130 tbb::flow::broadcast_node<T> &my_b;
131
132public:
133
134 native_body( tbb::flow::broadcast_node<T> &b ) : my_b(b) {}
135
136 void operator()(int) const {
137 for (int n = 0; n < N; ++n ) {
138 ASSERT( my_b.try_put( (T)n ), NULL );
139 }
140 }
141
142};
143
144template< typename T >
145void run_parallel_broadcasts(tbb::flow::graph& g, int p, tbb::flow::broadcast_node<T>& b) {
146 for ( int num_receivers = 1; num_receivers < R; ++num_receivers ) {
147 std::vector< counting_array_receiver<T> > receivers(num_receivers, counting_array_receiver<T>(g));
148
149 for ( int r = 0; r < num_receivers; ++r ) {
150 tbb::flow::make_edge( b, receivers[r] );
151 }
152
153 NativeParallelFor( p, native_body<T>( b ) );
154
155 for ( int r = 0; r < num_receivers; ++r ) {
156 for (int n = 0; n < N; ++n ) {
157 ASSERT( (int)receivers[r][n] == p, NULL );
158 }
159 tbb::flow::remove_edge( b, receivers[r] );
160 }
161 ASSERT( b.try_put( (T)0 ), NULL );
162 for ( int r = 0; r < num_receivers; ++r )
163 ASSERT( (int)receivers[r][0] == p, NULL );
164 }
165}
166
167template< typename T >
168void test_parallel_broadcasts(int p) {
169
170 tbb::flow::graph g;
171 tbb::flow::broadcast_node<T> b(g);
172 run_parallel_broadcasts(g, p, b);
173
174 // test copy constructor
175 tbb::flow::broadcast_node<T> b_copy(b);
176 run_parallel_broadcasts(g, p, b_copy);
177}
178
179// broadcast_node does not allow successors to try_get from it (it does not allow
180// the flow edge to switch) so we only need test the forward direction.
181template<typename T>
182void test_resets() {
183 tbb::flow::graph g;
184 tbb::flow::broadcast_node<T> b0(g);
185 tbb::flow::broadcast_node<T> b1(g);
186 tbb::flow::queue_node<T> q0(g);
187 tbb::flow::make_edge(b0,b1);
188 tbb::flow::make_edge(b1,q0);
189 T j;
190
191 // test standard reset
192 for(int testNo = 0; testNo < 2; ++testNo) {
193 for(T i= 0; i <= 3; i += 1) {
194 b0.try_put(i);
195 }
196 g.wait_for_all();
197 for(T i= 0; i <= 3; i += 1) {
198 ASSERT(q0.try_get(j) && j == i, "Bad value in queue");
199 }
200 ASSERT(!q0.try_get(j), "extra value in queue");
201
202 // reset the graph. It should work as before.
203 if (testNo == 0) g.reset();
204 }
205
206 g.reset(tbb::flow::rf_clear_edges);
207 for(T i= 0; i <= 3; i += 1) {
208 b0.try_put(i);
209 }
210 g.wait_for_all();
211 ASSERT(!q0.try_get(j), "edge between nodes not removed");
212 for(T i= 0; i <= 3; i += 1) {
213 b1.try_put(i);
214 }
215 g.wait_for_all();
216 ASSERT(!q0.try_get(j), "edge between nodes not removed");
217}
218
219#if TBB_DEPRECATED_FLOW_NODE_EXTRACTION
220void test_extract() {
221 int dont_care;
222 tbb::flow::graph g;
223 tbb::flow::broadcast_node<int> b0(g);
224 tbb::flow::broadcast_node<int> b1(g);
225 tbb::flow::broadcast_node<int> b2(g);
226 tbb::flow::broadcast_node<int> b3(g);
227 tbb::flow::broadcast_node<int> b4(g);
228 tbb::flow::broadcast_node<int> b5(g);
229 tbb::flow::queue_node<int> q0(g);
230 tbb::flow::make_edge(b0,b1);
231 tbb::flow::make_edge(b0,b2);
232 tbb::flow::make_edge(b1,b3);
233 tbb::flow::make_edge(b1,b4);
234 tbb::flow::make_edge(b2,b4);
235 tbb::flow::make_edge(b2,b5);
236 tbb::flow::make_edge(b3,q0);
237 tbb::flow::make_edge(b4,q0);
238 tbb::flow::make_edge(b5,q0);
239
240 /* b3 */
241 /* / \ */
242 /* b1 \ */
243 /* / \ \ */
244 /* b0 b4---q0 */
245 /* \ / / */
246 /* b2 / */
247 /* \ / */
248 /* b5 */
249
250 g.wait_for_all();
251 b0.try_put(1);
252 g.wait_for_all();
253 for( int i = 0; i < 4; ++i ) {
254 int j;
255 ASSERT(q0.try_get(j) && j == 1, "missing or incorrect message");
256 }
257 ASSERT(!q0.try_get(dont_care), "extra message in queue");
258 ASSERT(b0.predecessor_count() == 0 && b0.successor_count() == 2, "improper count for b0");
259 ASSERT(b1.predecessor_count() == 1 && b1.successor_count() == 2, "improper count for b1");
260 ASSERT(b2.predecessor_count() == 1 && b2.successor_count() == 2, "improper count for b2");
261 ASSERT(b3.predecessor_count() == 1 && b3.successor_count() == 1, "improper count for b3");
262 ASSERT(b4.predecessor_count() == 2 && b4.successor_count() == 1, "improper count before extract of b4");
263 ASSERT(b5.predecessor_count() == 1 && b5.successor_count() == 1, "improper count for b5");
264 b4.extract(); // remove from tree of nodes.
265 ASSERT(b0.predecessor_count() == 0 && b0.successor_count() == 2, "improper count for b0 after");
266 ASSERT(b1.predecessor_count() == 1 && b1.successor_count() == 1, "improper succ count for b1 after");
267 ASSERT(b2.predecessor_count() == 1 && b2.successor_count() == 1, "improper succ count for b2 after");
268 ASSERT(b3.predecessor_count() == 1 && b3.successor_count() == 1, "improper succ count for b3 after");
269 ASSERT(b4.predecessor_count() == 0 && b4.successor_count() == 0, "improper succ count after extract");
270 ASSERT(b5.predecessor_count() == 1 && b5.successor_count() == 1, "improper succ count for b5 after");
271
272 /* b3 */
273 /* / \ */
274 /* b1 \ */
275 /* / \ */
276 /* b0 q0 */
277 /* \ / */
278 /* b2 / */
279 /* \ / */
280 /* b5 */
281
282 b0.try_put(1);
283 g.wait_for_all();
284 for( int i = 0; i < 2; ++i ) {
285 int j;
286 ASSERT(q0.try_get(j) && j == 1, "missing or incorrect message");
287 }
288 ASSERT(!q0.try_get(dont_care), "extra message in queue");
289 tbb::flow::make_edge(b0,b4);
290 tbb::flow::make_edge(b4,q0);
291 g.wait_for_all();
292 ASSERT(b0.predecessor_count() == 0 && b0.successor_count() == 3, "improper count for b0 after");
293 ASSERT(b1.predecessor_count() == 1 && b1.successor_count() == 1, "improper succ count for b1 after");
294 ASSERT(b2.predecessor_count() == 1 && b2.successor_count() == 1, "improper succ count for b2 after");
295 ASSERT(b3.predecessor_count() == 1 && b3.successor_count() == 1, "improper succ count for b3 after");
296 ASSERT(b4.predecessor_count() == 1 && b4.successor_count() == 1, "improper succ count after extract");
297 ASSERT(b5.predecessor_count() == 1 && b5.successor_count() == 1, "improper succ count for b5 after");
298
299 /* b3 */
300 /* / \ */
301 /* b1 \ */
302 /* / \ */
303 /* b0---b4---q0 */
304 /* \ / */
305 /* b2 / */
306 /* \ / */
307 /* b5 */
308
309 b0.try_put(1);
310 g.wait_for_all();
311 for( int i = 0; i < 3; ++i ) {
312 int j;
313 ASSERT(q0.try_get(j) && j == 1, "missing or incorrect message");
314 }
315 ASSERT(!q0.try_get(dont_care), "extra message in queue");
316}
317#endif // TBB_DEPRECATED_FLOW_NODE_EXTRACTION
318
319int TestMain() {
320 if( MinThread<1 ) {
321 REPORT("number of threads must be positive\n");
322 exit(1);
323 }
324
325 test_serial_broadcasts<int>();
326 test_serial_broadcasts<float>();
327 test_serial_broadcasts<int_convertable_type>();
328
329 for( int p=MinThread; p<=MaxThread; ++p ) {
330 test_parallel_broadcasts<int>(p);
331 test_parallel_broadcasts<float>(p);
332 test_parallel_broadcasts<int_convertable_type>(p);
333 }
334
335 test_resets<int>();
336 test_resets<float>();
337#if TBB_DEPRECATED_FLOW_NODE_EXTRACTION
338 test_extract();
339#endif
340
341 return Harness::Done;
342}
343