1 | /* boost random/binomial_distribution.hpp header file |
2 | * |
3 | * Copyright Steven Watanabe 2010 |
4 | * Distributed under the Boost Software License, Version 1.0. (See |
5 | * accompanying file LICENSE_1_0.txt or copy at |
6 | * http://www.boost.org/LICENSE_1_0.txt) |
7 | * |
8 | * See http://www.boost.org for most recent version including documentation. |
9 | * |
10 | * $Id$ |
11 | */ |
12 | |
13 | #ifndef BOOST_RANDOM_BINOMIAL_DISTRIBUTION_HPP_INCLUDED |
14 | #define BOOST_RANDOM_BINOMIAL_DISTRIBUTION_HPP_INCLUDED |
15 | |
16 | #include <boost/config/no_tr1/cmath.hpp> |
17 | #include <cstdlib> |
18 | #include <iosfwd> |
19 | |
20 | #include <boost/random/detail/config.hpp> |
21 | #include <boost/random/uniform_01.hpp> |
22 | |
23 | #include <boost/random/detail/disable_warnings.hpp> |
24 | |
25 | namespace boost { |
26 | namespace random { |
27 | |
28 | namespace detail { |
29 | |
30 | template<class RealType> |
31 | struct binomial_table { |
32 | static const RealType table[10]; |
33 | }; |
34 | |
35 | template<class RealType> |
36 | const RealType binomial_table<RealType>::table[10] = { |
37 | 0.08106146679532726, |
38 | 0.04134069595540929, |
39 | 0.02767792568499834, |
40 | 0.02079067210376509, |
41 | 0.01664469118982119, |
42 | 0.01387612882307075, |
43 | 0.01189670994589177, |
44 | 0.01041126526197209, |
45 | 0.009255462182712733, |
46 | 0.008330563433362871 |
47 | }; |
48 | |
49 | } |
50 | |
51 | /** |
52 | * The binomial distribution is an integer valued distribution with |
53 | * two parameters, @c t and @c p. The values of the distribution |
54 | * are within the range [0,t]. |
55 | * |
56 | * The distribution function is |
57 | * \f$\displaystyle P(k) = {t \choose k}p^k(1-p)^{t-k}\f$. |
58 | * |
59 | * The algorithm used is the BTRD algorithm described in |
60 | * |
61 | * @blockquote |
62 | * "The generation of binomial random variates", Wolfgang Hormann, |
63 | * Journal of Statistical Computation and Simulation, Volume 46, |
64 | * Issue 1 & 2 April 1993 , pages 101 - 110 |
65 | * @endblockquote |
66 | */ |
67 | template<class IntType = int, class RealType = double> |
68 | class binomial_distribution { |
69 | public: |
70 | typedef IntType result_type; |
71 | typedef RealType input_type; |
72 | |
73 | class param_type { |
74 | public: |
75 | typedef binomial_distribution distribution_type; |
76 | /** |
77 | * Construct a param_type object. @c t and @c p |
78 | * are the parameters of the distribution. |
79 | * |
80 | * Requires: t >=0 && 0 <= p <= 1 |
81 | */ |
82 | explicit param_type(IntType t_arg = 1, RealType p_arg = RealType (0.5)) |
83 | : _t(t_arg), _p(p_arg) |
84 | {} |
85 | /** Returns the @c t parameter of the distribution. */ |
86 | IntType t() const { return _t; } |
87 | /** Returns the @c p parameter of the distribution. */ |
88 | RealType p() const { return _p; } |
89 | #ifndef BOOST_RANDOM_NO_STREAM_OPERATORS |
90 | /** Writes the parameters of the distribution to a @c std::ostream. */ |
91 | template<class CharT, class Traits> |
92 | friend std::basic_ostream<CharT,Traits>& |
93 | operator<<(std::basic_ostream<CharT,Traits>& os, |
94 | const param_type& parm) |
95 | { |
96 | os << parm._p << " " << parm._t; |
97 | return os; |
98 | } |
99 | |
100 | /** Reads the parameters of the distribution from a @c std::istream. */ |
101 | template<class CharT, class Traits> |
102 | friend std::basic_istream<CharT,Traits>& |
103 | operator>>(std::basic_istream<CharT,Traits>& is, param_type& parm) |
104 | { |
105 | is >> parm._p >> std::ws >> parm._t; |
106 | return is; |
107 | } |
108 | #endif |
109 | /** Returns true if the parameters have the same values. */ |
110 | friend bool operator==(const param_type& lhs, const param_type& rhs) |
111 | { |
112 | return lhs._t == rhs._t && lhs._p == rhs._p; |
113 | } |
114 | /** Returns true if the parameters have different values. */ |
115 | friend bool operator!=(const param_type& lhs, const param_type& rhs) |
116 | { |
117 | return !(lhs == rhs); |
118 | } |
119 | private: |
120 | IntType _t; |
121 | RealType _p; |
122 | }; |
123 | |
124 | /** |
125 | * Construct a @c binomial_distribution object. @c t and @c p |
126 | * are the parameters of the distribution. |
127 | * |
128 | * Requires: t >=0 && 0 <= p <= 1 |
129 | */ |
130 | explicit binomial_distribution(IntType t_arg = 1, |
131 | RealType p_arg = RealType(0.5)) |
132 | : _t(t_arg), _p(p_arg) |
133 | { |
134 | init(); |
135 | } |
136 | |
137 | /** |
138 | * Construct an @c binomial_distribution object from the |
139 | * parameters. |
140 | */ |
141 | explicit binomial_distribution(const param_type& parm) |
142 | : _t(parm.t()), _p(parm.p()) |
143 | { |
144 | init(); |
145 | } |
146 | |
147 | /** |
148 | * Returns a random variate distributed according to the |
149 | * binomial distribution. |
150 | */ |
151 | template<class URNG> |
152 | IntType operator()(URNG& urng) const |
153 | { |
154 | if(use_inversion()) { |
155 | if(0.5 < _p) { |
156 | return _t - invert(_t, 1-_p, urng); |
157 | } else { |
158 | return invert(_t, _p, urng); |
159 | } |
160 | } else if(0.5 < _p) { |
161 | return _t - generate(urng); |
162 | } else { |
163 | return generate(urng); |
164 | } |
165 | } |
166 | |
167 | /** |
168 | * Returns a random variate distributed according to the |
169 | * binomial distribution with parameters specified by @c param. |
170 | */ |
171 | template<class URNG> |
172 | IntType operator()(URNG& urng, const param_type& parm) const |
173 | { |
174 | return binomial_distribution(parm)(urng); |
175 | } |
176 | |
177 | /** Returns the @c t parameter of the distribution. */ |
178 | IntType t() const { return _t; } |
179 | /** Returns the @c p parameter of the distribution. */ |
180 | RealType p() const { return _p; } |
181 | |
182 | /** Returns the smallest value that the distribution can produce. */ |
183 | IntType min BOOST_PREVENT_MACRO_SUBSTITUTION() const { return 0; } |
184 | /** Returns the largest value that the distribution can produce. */ |
185 | IntType max BOOST_PREVENT_MACRO_SUBSTITUTION() const { return _t; } |
186 | |
187 | /** Returns the parameters of the distribution. */ |
188 | param_type param() const { return param_type(_t, _p); } |
189 | /** Sets parameters of the distribution. */ |
190 | void param(const param_type& parm) |
191 | { |
192 | _t = parm.t(); |
193 | _p = parm.p(); |
194 | init(); |
195 | } |
196 | |
197 | /** |
198 | * Effects: Subsequent uses of the distribution do not depend |
199 | * on values produced by any engine prior to invoking reset. |
200 | */ |
201 | void reset() { } |
202 | |
203 | #ifndef BOOST_RANDOM_NO_STREAM_OPERATORS |
204 | /** Writes the parameters of the distribution to a @c std::ostream. */ |
205 | template<class CharT, class Traits> |
206 | friend std::basic_ostream<CharT,Traits>& |
207 | operator<<(std::basic_ostream<CharT,Traits>& os, |
208 | const binomial_distribution& bd) |
209 | { |
210 | os << bd.param(); |
211 | return os; |
212 | } |
213 | |
214 | /** Reads the parameters of the distribution from a @c std::istream. */ |
215 | template<class CharT, class Traits> |
216 | friend std::basic_istream<CharT,Traits>& |
217 | operator>>(std::basic_istream<CharT,Traits>& is, binomial_distribution& bd) |
218 | { |
219 | bd.read(is); |
220 | return is; |
221 | } |
222 | #endif |
223 | |
224 | /** Returns true if the two distributions will produce the same |
225 | sequence of values, given equal generators. */ |
226 | friend bool operator==(const binomial_distribution& lhs, |
227 | const binomial_distribution& rhs) |
228 | { |
229 | return lhs._t == rhs._t && lhs._p == rhs._p; |
230 | } |
231 | /** Returns true if the two distributions could produce different |
232 | sequences of values, given equal generators. */ |
233 | friend bool operator!=(const binomial_distribution& lhs, |
234 | const binomial_distribution& rhs) |
235 | { |
236 | return !(lhs == rhs); |
237 | } |
238 | |
239 | private: |
240 | |
241 | /// @cond show_private |
242 | |
243 | template<class CharT, class Traits> |
244 | void read(std::basic_istream<CharT, Traits>& is) { |
245 | param_type parm; |
246 | if(is >> parm) { |
247 | param(parm); |
248 | } |
249 | } |
250 | |
251 | bool use_inversion() const |
252 | { |
253 | // BTRD is safe when np >= 10 |
254 | return m < 11; |
255 | } |
256 | |
257 | // computes the correction factor for the Stirling approximation |
258 | // for log(k!) |
259 | static RealType fc(IntType k) |
260 | { |
261 | if(k < 10) return detail::binomial_table<RealType>::table[k]; |
262 | else { |
263 | RealType ikp1 = RealType(1) / (k + 1); |
264 | return (RealType(1)/12 |
265 | - (RealType(1)/360 |
266 | - (RealType(1)/1260)*(ikp1*ikp1))*(ikp1*ikp1))*ikp1; |
267 | } |
268 | } |
269 | |
270 | void init() |
271 | { |
272 | using std::sqrt; |
273 | using std::pow; |
274 | |
275 | RealType p = (0.5 < _p)? (1 - _p) : _p; |
276 | IntType t = _t; |
277 | |
278 | m = static_cast<IntType>((t+1)*p); |
279 | |
280 | if(use_inversion()) { |
281 | q_n = pow((1 - p), static_cast<RealType>(t)); |
282 | } else { |
283 | btrd.r = p/(1-p); |
284 | btrd.nr = (t+1)*btrd.r; |
285 | btrd.npq = t*p*(1-p); |
286 | RealType sqrt_npq = sqrt(btrd.npq); |
287 | btrd.b = 1.15 + 2.53 * sqrt_npq; |
288 | btrd.a = -0.0873 + 0.0248*btrd.b + 0.01*p; |
289 | btrd.c = t*p + 0.5; |
290 | btrd.alpha = (2.83 + 5.1/btrd.b) * sqrt_npq; |
291 | btrd.v_r = 0.92 - 4.2/btrd.b; |
292 | btrd.u_rv_r = 0.86*btrd.v_r; |
293 | } |
294 | } |
295 | |
296 | template<class URNG> |
297 | result_type generate(URNG& urng) const |
298 | { |
299 | using std::floor; |
300 | using std::abs; |
301 | using std::log; |
302 | |
303 | while(true) { |
304 | RealType u; |
305 | RealType v = uniform_01<RealType>()(urng); |
306 | if(v <= btrd.u_rv_r) { |
307 | u = v/btrd.v_r - 0.43; |
308 | return static_cast<IntType>(floor( |
309 | (2*btrd.a/(0.5 - abs(u)) + btrd.b)*u + btrd.c)); |
310 | } |
311 | |
312 | if(v >= btrd.v_r) { |
313 | u = uniform_01<RealType>()(urng) - 0.5; |
314 | } else { |
315 | u = v/btrd.v_r - 0.93; |
316 | u = ((u < 0)? -0.5 : 0.5) - u; |
317 | v = uniform_01<RealType>()(urng) * btrd.v_r; |
318 | } |
319 | |
320 | RealType us = 0.5 - abs(u); |
321 | IntType k = static_cast<IntType>(floor((2*btrd.a/us + btrd.b)*u + btrd.c)); |
322 | if(k < 0 || k > _t) continue; |
323 | v = v*btrd.alpha/(btrd.a/(us*us) + btrd.b); |
324 | RealType km = abs(k - m); |
325 | if(km <= 15) { |
326 | RealType f = 1; |
327 | if(m < k) { |
328 | IntType i = m; |
329 | do { |
330 | ++i; |
331 | f = f*(btrd.nr/i - btrd.r); |
332 | } while(i != k); |
333 | } else if(m > k) { |
334 | IntType i = k; |
335 | do { |
336 | ++i; |
337 | v = v*(btrd.nr/i - btrd.r); |
338 | } while(i != m); |
339 | } |
340 | if(v <= f) return k; |
341 | else continue; |
342 | } else { |
343 | // final acceptance/rejection |
344 | v = log(v); |
345 | RealType rho = |
346 | (km/btrd.npq)*(((km/3. + 0.625)*km + 1./6)/btrd.npq + 0.5); |
347 | RealType t = -km*km/(2*btrd.npq); |
348 | if(v < t - rho) return k; |
349 | if(v > t + rho) continue; |
350 | |
351 | IntType nm = _t - m + 1; |
352 | RealType h = (m + 0.5)*log((m + 1)/(btrd.r*nm)) |
353 | + fc(m) + fc(_t - m); |
354 | |
355 | IntType nk = _t - k + 1; |
356 | if(v <= h + (_t+1)*log(static_cast<RealType>(nm)/nk) |
357 | + (k + 0.5)*log(nk*btrd.r/(k+1)) |
358 | - fc(k) |
359 | - fc(_t - k)) |
360 | { |
361 | return k; |
362 | } else { |
363 | continue; |
364 | } |
365 | } |
366 | } |
367 | } |
368 | |
369 | template<class URNG> |
370 | IntType invert(IntType t, RealType p, URNG& urng) const |
371 | { |
372 | RealType q = 1 - p; |
373 | RealType s = p / q; |
374 | RealType a = (t + 1) * s; |
375 | RealType r = q_n; |
376 | RealType u = uniform_01<RealType>()(urng); |
377 | IntType x = 0; |
378 | while(u > r) { |
379 | u = u - r; |
380 | ++x; |
381 | RealType r1 = ((a/x) - s) * r; |
382 | // If r gets too small then the round-off error |
383 | // becomes a problem. At this point, p(i) is |
384 | // decreasing exponentially, so if we just call |
385 | // it 0, it's close enough. Note that the |
386 | // minimum value of q_n is about 1e-7, so we |
387 | // may need to be a little careful to make sure that |
388 | // we don't terminate the first time through the loop |
389 | // for float. (Hence the test that r is decreasing) |
390 | if(r1 < std::numeric_limits<RealType>::epsilon() && r1 < r) { |
391 | break; |
392 | } |
393 | r = r1; |
394 | } |
395 | return x; |
396 | } |
397 | |
398 | // parameters |
399 | IntType _t; |
400 | RealType _p; |
401 | |
402 | // common data |
403 | IntType m; |
404 | |
405 | union { |
406 | // for btrd |
407 | struct { |
408 | RealType r; |
409 | RealType nr; |
410 | RealType npq; |
411 | RealType b; |
412 | RealType a; |
413 | RealType c; |
414 | RealType alpha; |
415 | RealType v_r; |
416 | RealType u_rv_r; |
417 | } btrd; |
418 | // for inversion |
419 | RealType q_n; |
420 | }; |
421 | |
422 | /// @endcond |
423 | }; |
424 | |
425 | } |
426 | |
427 | // backwards compatibility |
428 | using random::binomial_distribution; |
429 | |
430 | } |
431 | |
432 | #include <boost/random/detail/enable_warnings.hpp> |
433 | |
434 | #endif |
435 | |