* libstdc++-v3/include/ext/random: Add __gnu_cxx::beta_distribution<>
[official-gcc.git] / libstdc++-v3 / include / ext / random
blob9563e6a05001940b654fb6f865922beb00d2b5cb
1 // Random number extensions -*- C++ -*-
3 // Copyright (C) 2012 Free Software Foundation, Inc.
4 //
5 // This file is part of the GNU ISO C++ Library.  This library is free
6 // software; you can redistribute it and/or modify it under the
7 // terms of the GNU General Public License as published by the
8 // Free Software Foundation; either version 3, or (at your option)
9 // any later version.
11 // This library is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 // GNU General Public License for more details.
16 // Under Section 7 of GPL version 3, you are granted additional
17 // permissions described in the GCC Runtime Library Exception, version
18 // 3.1, as published by the Free Software Foundation.
20 // You should have received a copy of the GNU General Public License and
21 // a copy of the GCC Runtime Library Exception along with this program;
22 // see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
23 // <http://www.gnu.org/licenses/>.
25 /** @file ext/random
26  *  This file is a GNU extension to the Standard C++ Library.
27  */
29 #ifndef _EXT_RANDOM
30 #define _EXT_RANDOM 1
32 #pragma GCC system_header
34 #include <random>
35 #ifdef __SSE2__
36 # include <x86intrin.h>
37 #endif
40 namespace __gnu_cxx _GLIBCXX_VISIBILITY(default)
42 _GLIBCXX_BEGIN_NAMESPACE_VERSION
44   /* Mersenne twister implementation optimized for vector operations.
45    *
46    * Reference: http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/SFMT/
47    */
48   template<typename _UIntType, size_t __m,
49            size_t __pos1, size_t __sl1, size_t __sl2,
50            size_t __sr1, size_t __sr2,
51            uint32_t __msk1, uint32_t __msk2,
52            uint32_t __msk3, uint32_t __msk4,
53            uint32_t __parity1, uint32_t __parity2,
54            uint32_t __parity3, uint32_t __parity4>
55     class simd_fast_mersenne_twister_engine
56     {
57       static_assert(std::is_unsigned<_UIntType>::value, "template argument "
58                     "substituting _UIntType not an unsigned integral type");
59       static_assert(__sr1 < 32, "first right shift too large");
60       static_assert(__sr2 < 16, "second right shift too large");
61       static_assert(__sl1 < 32, "first left shift too large");
62       static_assert(__sl2 < 16, "second left shift too large");
64     public:
65       typedef _UIntType result_type;
67     private:
68       static constexpr size_t m_w = sizeof(result_type) * 8;
69       static constexpr size_t _M_nstate = __m / 128 + 1;
70       static constexpr size_t _M_nstate32 = _M_nstate * 4;
72       static_assert(std::is_unsigned<_UIntType>::value, "template argument "
73                     "substituting _UIntType not an unsigned integral type");
74       static_assert(__pos1 < _M_nstate, "POS1 not smaller than state size");
75       static_assert(16 % sizeof(_UIntType) == 0,
76                     "UIntType size must divide 16");
78     public:
79       static constexpr size_t state_size = _M_nstate * (16
80                                                         / sizeof(result_type));
81       static constexpr result_type default_seed = 5489u;
83       // constructors and member function
84       explicit
85       simd_fast_mersenne_twister_engine(result_type __sd = default_seed)
86       { seed(__sd); }
88       template<typename _Sseq, typename = typename
89         std::enable_if<!std::is_same<_Sseq, simd_fast_mersenne_twister_engine>::value>
90                ::type>
91         explicit
92         simd_fast_mersenne_twister_engine(_Sseq& __q)
93         { seed(__q); }
95       void
96       seed(result_type __sd = default_seed);
98       template<typename _Sseq>
99         typename std::enable_if<std::is_class<_Sseq>::value>::type
100         seed(_Sseq& __q);
102       static constexpr result_type
103       min()
104       { return 0; };
106       static constexpr result_type
107       max()
108       { return std::numeric_limits<result_type>::max(); }
110       void
111       discard(unsigned long long __z);
113       result_type
114       operator()()
115       {
116         if (__builtin_expect(_M_pos >= state_size, 0))
117           _M_gen_rand();
119         return _M_stateT[_M_pos++];
120       }
122 #ifdef __SSE2__
123       friend bool
124       operator==(const simd_fast_mersenne_twister_engine& __lhs,
125                  const simd_fast_mersenne_twister_engine& __rhs)
126       { __m128i __res = _mm_cmpeq_epi8(__lhs._M_state[0], __rhs._M_state[0]);
127         for (size_t __i = 1; __i < __lhs._M_nstate; ++__i)
128           __res = _mm_and_si128(__res, _mm_cmpeq_epi8(__lhs._M_state[__i],
129                                                       __rhs._M_state[__i]));
130         return (_mm_movemask_epi8(__res) == 0xffff
131                 && __lhs._M_pos == __rhs._M_pos); }
132 #else
133       friend bool
134       operator==(const simd_fast_mersenne_twister_engine& __lhs,
135                  const simd_fast_mersenne_twister_engine& __rhs)
136       { return (std::equal(__lhs._M_stateT, __lhs._M_stateT + state_size,
137                            __rhs._M_stateT)
138                 && __lhs._M_pos == __rhs._M_pos); }
139 #endif
141       template<typename _UIntType_2, size_t __m_2,
142                size_t __pos1_2, size_t __sl1_2, size_t __sl2_2,
143                size_t __sr1_2, size_t __sr2_2,
144                uint32_t __msk1_2, uint32_t __msk2_2,
145                uint32_t __msk3_2, uint32_t __msk4_2,
146                uint32_t __parity1_2, uint32_t __parity2_2,
147                uint32_t __parity3_2, uint32_t __parity4_2,
148                typename _CharT, typename _Traits>
149         friend std::basic_ostream<_CharT, _Traits>&
150         operator<<(std::basic_ostream<_CharT, _Traits>& __os,
151                    const __gnu_cxx::simd_fast_mersenne_twister_engine<_UIntType_2,
152                    __m_2, __pos1_2, __sl1_2, __sl2_2, __sr1_2, __sr2_2,
153                    __msk1_2, __msk2_2, __msk3_2, __msk4_2,
154                    __parity1_2, __parity2_2, __parity3_2, __parity4_2>& __x);
156       template<typename _UIntType_2, size_t __m_2,
157                size_t __pos1_2, size_t __sl1_2, size_t __sl2_2,
158                size_t __sr1_2, size_t __sr2_2,
159                uint32_t __msk1_2, uint32_t __msk2_2,
160                uint32_t __msk3_2, uint32_t __msk4_2,
161                uint32_t __parity1_2, uint32_t __parity2_2,
162                uint32_t __parity3_2, uint32_t __parity4_2,
163                typename _CharT, typename _Traits>
164         friend std::basic_istream<_CharT, _Traits>&
165         operator>>(std::basic_istream<_CharT, _Traits>& __is,
166                    __gnu_cxx::simd_fast_mersenne_twister_engine<_UIntType_2,
167                    __m_2, __pos1_2, __sl1_2, __sl2_2, __sr1_2, __sr2_2,
168                    __msk1_2, __msk2_2, __msk3_2, __msk4_2,
169                    __parity1_2, __parity2_2, __parity3_2, __parity4_2>& __x);
171     private:
172       union
173       {
174 #ifdef __SSE2__
175         __m128i _M_state[_M_nstate];
176 #endif
177         uint32_t _M_state32[_M_nstate32];
178         result_type _M_stateT[state_size];
179       } __attribute__ ((__aligned__ (16)));
180       size_t _M_pos;
182       void _M_gen_rand(void);
183       void _M_period_certification();
184   };
187   template<typename _UIntType, size_t __m,
188            size_t __pos1, size_t __sl1, size_t __sl2,
189            size_t __sr1, size_t __sr2,
190            uint32_t __msk1, uint32_t __msk2,
191            uint32_t __msk3, uint32_t __msk4,
192            uint32_t __parity1, uint32_t __parity2,
193            uint32_t __parity3, uint32_t __parity4>
194     inline bool
195     operator!=(const __gnu_cxx::simd_fast_mersenne_twister_engine<_UIntType,
196                __m, __pos1, __sl1, __sl2, __sr1, __sr2, __msk1, __msk2, __msk3,
197                __msk4, __parity1, __parity2, __parity3, __parity4>& __lhs,
198                const __gnu_cxx::simd_fast_mersenne_twister_engine<_UIntType,
199                __m, __pos1, __sl1, __sl2, __sr1, __sr2, __msk1, __msk2, __msk3,
200                __msk4, __parity1, __parity2, __parity3, __parity4>& __rhs)
201     { return !(__lhs == __rhs); }
204   /* Definitions for the SIMD-oriented Fast Mersenne Twister as defined
205    * in the C implementation by Daito and Matsumoto, as both a 32-bit
206    * and 64-bit version.
207    */
208   typedef simd_fast_mersenne_twister_engine<uint32_t, 607, 2,
209                                             15, 3, 13, 3,
210                                             0xfdff37ffU, 0xef7f3f7dU,
211                                             0xff777b7dU, 0x7ff7fb2fU,
212                                             0x00000001U, 0x00000000U,
213                                             0x00000000U, 0x5986f054U>
214     sfmt607;
216   typedef simd_fast_mersenne_twister_engine<uint64_t, 607, 2,
217                                             15, 3, 13, 3,
218                                             0xfdff37ffU, 0xef7f3f7dU,
219                                             0xff777b7dU, 0x7ff7fb2fU,
220                                             0x00000001U, 0x00000000U,
221                                             0x00000000U, 0x5986f054U>
222     sfmt607_64;
225   typedef simd_fast_mersenne_twister_engine<uint32_t, 1279, 7,
226                                             14, 3, 5, 1,
227                                             0xf7fefffdU, 0x7fefcfffU,
228                                             0xaff3ef3fU, 0xb5ffff7fU,
229                                             0x00000001U, 0x00000000U,
230                                             0x00000000U, 0x20000000U>
231     sfmt1279;
233   typedef simd_fast_mersenne_twister_engine<uint64_t, 1279, 7,
234                                             14, 3, 5, 1,
235                                             0xf7fefffdU, 0x7fefcfffU,
236                                             0xaff3ef3fU, 0xb5ffff7fU,
237                                             0x00000001U, 0x00000000U,
238                                             0x00000000U, 0x20000000U>
239     sfmt1279_64;
242   typedef simd_fast_mersenne_twister_engine<uint32_t, 2281, 12,
243                                             19, 1, 5, 1,
244                                             0xbff7ffbfU, 0xfdfffffeU,
245                                             0xf7ffef7fU, 0xf2f7cbbfU,
246                                             0x00000001U, 0x00000000U,
247                                             0x00000000U, 0x41dfa600U>
248     sfmt2281;
250   typedef simd_fast_mersenne_twister_engine<uint64_t, 2281, 12,
251                                             19, 1, 5, 1,
252                                             0xbff7ffbfU, 0xfdfffffeU,
253                                             0xf7ffef7fU, 0xf2f7cbbfU,
254                                             0x00000001U, 0x00000000U,
255                                             0x00000000U, 0x41dfa600U>
256     sfmt2281_64;
259   typedef simd_fast_mersenne_twister_engine<uint32_t, 4253, 17,
260                                             20, 1, 7, 1,
261                                             0x9f7bffffU, 0x9fffff5fU,
262                                             0x3efffffbU, 0xfffff7bbU,
263                                             0xa8000001U, 0xaf5390a3U,
264                                             0xb740b3f8U, 0x6c11486dU>
265     sfmt4253;
267   typedef simd_fast_mersenne_twister_engine<uint64_t, 4253, 17,
268                                             20, 1, 7, 1,
269                                             0x9f7bffffU, 0x9fffff5fU,
270                                             0x3efffffbU, 0xfffff7bbU,
271                                             0xa8000001U, 0xaf5390a3U,
272                                             0xb740b3f8U, 0x6c11486dU>
273     sfmt4253_64;
276   typedef simd_fast_mersenne_twister_engine<uint32_t, 11213, 68,
277                                             14, 3, 7, 3,
278                                             0xeffff7fbU, 0xffffffefU,
279                                             0xdfdfbfffU, 0x7fffdbfdU,
280                                             0x00000001U, 0x00000000U,
281                                             0xe8148000U, 0xd0c7afa3U>
282     sfmt11213;
284   typedef simd_fast_mersenne_twister_engine<uint64_t, 11213, 68,
285                                             14, 3, 7, 3,
286                                             0xeffff7fbU, 0xffffffefU,
287                                             0xdfdfbfffU, 0x7fffdbfdU,
288                                             0x00000001U, 0x00000000U,
289                                             0xe8148000U, 0xd0c7afa3U>
290     sfmt11213_64;
293   typedef simd_fast_mersenne_twister_engine<uint32_t, 19937, 122,
294                                             18, 1, 11, 1,
295                                             0xdfffffefU, 0xddfecb7fU,
296                                             0xbffaffffU, 0xbffffff6U,
297                                             0x00000001U, 0x00000000U,
298                                             0x00000000U, 0x13c9e684U>
299     sfmt19937;
301   typedef simd_fast_mersenne_twister_engine<uint64_t, 19937, 122,
302                                             18, 1, 11, 1,
303                                             0xdfffffefU, 0xddfecb7fU,
304                                             0xbffaffffU, 0xbffffff6U,
305                                             0x00000001U, 0x00000000U,
306                                             0x00000000U, 0x13c9e684U>
307     sfmt19937_64;
310   typedef simd_fast_mersenne_twister_engine<uint32_t, 44497, 330,
311                                             5, 3, 9, 3,
312                                             0xeffffffbU, 0xdfbebfffU,
313                                             0xbfbf7befU, 0x9ffd7bffU,
314                                             0x00000001U, 0x00000000U,
315                                             0xa3ac4000U, 0xecc1327aU>
316     sfmt44497;
318   typedef simd_fast_mersenne_twister_engine<uint64_t, 44497, 330,
319                                             5, 3, 9, 3,
320                                             0xeffffffbU, 0xdfbebfffU,
321                                             0xbfbf7befU, 0x9ffd7bffU,
322                                             0x00000001U, 0x00000000U,
323                                             0xa3ac4000U, 0xecc1327aU>
324     sfmt44497_64;
327   typedef simd_fast_mersenne_twister_engine<uint32_t, 86243, 366,
328                                             6, 7, 19, 1,
329                                             0xfdbffbffU, 0xbff7ff3fU,
330                                             0xfd77efffU, 0xbf9ff3ffU,
331                                             0x00000001U, 0x00000000U,
332                                             0x00000000U, 0xe9528d85U>
333     sfmt86243;
335   typedef simd_fast_mersenne_twister_engine<uint64_t, 86243, 366,
336                                             6, 7, 19, 1,
337                                             0xfdbffbffU, 0xbff7ff3fU,
338                                             0xfd77efffU, 0xbf9ff3ffU,
339                                             0x00000001U, 0x00000000U,
340                                             0x00000000U, 0xe9528d85U>
341     sfmt86243_64;
344   typedef simd_fast_mersenne_twister_engine<uint32_t, 132049, 110,
345                                             19, 1, 21, 1,
346                                             0xffffbb5fU, 0xfb6ebf95U,
347                                             0xfffefffaU, 0xcff77fffU,
348                                             0x00000001U, 0x00000000U,
349                                             0xcb520000U, 0xc7e91c7dU>
350     sfmt132049;
352   typedef simd_fast_mersenne_twister_engine<uint64_t, 132049, 110,
353                                             19, 1, 21, 1,
354                                             0xffffbb5fU, 0xfb6ebf95U,
355                                             0xfffefffaU, 0xcff77fffU,
356                                             0x00000001U, 0x00000000U,
357                                             0xcb520000U, 0xc7e91c7dU>
358     sfmt132049_64;
361   typedef simd_fast_mersenne_twister_engine<uint32_t, 216091, 627,
362                                             11, 3, 10, 1,
363                                             0xbff7bff7U, 0xbfffffffU,
364                                             0xbffffa7fU, 0xffddfbfbU,
365                                             0xf8000001U, 0x89e80709U,
366                                             0x3bd2b64bU, 0x0c64b1e4U>
367     sfmt216091;
369   typedef simd_fast_mersenne_twister_engine<uint64_t, 216091, 627,
370                                             11, 3, 10, 1,
371                                             0xbff7bff7U, 0xbfffffffU,
372                                             0xbffffa7fU, 0xffddfbfbU,
373                                             0xf8000001U, 0x89e80709U,
374                                             0x3bd2b64bU, 0x0c64b1e4U>
375     sfmt216091_64;
378   /**
379    * @brief A beta continuous distribution for random numbers.
380    *
381    * The formula for the beta probability density function is:
382    * @f[
383    *     p(x|\alpha,\beta) = \frac{1}{B(\alpha,\beta)}
384    *                         x^{\alpha - 1} (1 - x)^{\beta - 1}
385    * @f]
386    */
387   template<typename _RealType = double>
388     class beta_distribution
389     {
390       static_assert(std::is_floating_point<_RealType>::value,
391                     "template argument not a floating point type");
393     public:
394       /** The type of the range of the distribution. */
395       typedef _RealType result_type;
396       /** Parameter type. */
397       struct param_type
398       {
399         typedef beta_distribution<_RealType> distribution_type;
400         friend class beta_distribution<_RealType>;
402         explicit
403         param_type(_RealType __alpha_val = _RealType(1),
404                    _RealType __beta_val = _RealType(1))
405         : _M_alpha(__alpha_val), _M_beta(__beta_val)
406         {
407           _GLIBCXX_DEBUG_ASSERT(_M_alpha > _RealType(0));
408           _GLIBCXX_DEBUG_ASSERT(_M_beta > _RealType(0));
409         }
411         _RealType
412         alpha() const
413         { return _M_alpha; }
415         _RealType
416         beta() const
417         { return _M_beta; }
419         friend bool
420         operator==(const param_type& __p1, const param_type& __p2)
421         { return (__p1._M_alpha == __p2._M_alpha
422                   && __p1._M_beta == __p2._M_beta); }
424       private:
425         void
426         _M_initialize();
428         _RealType _M_alpha;
429         _RealType _M_beta;
430       };
432     public:
433       /**
434        * @brief Constructs a beta distribution with parameters
435        * @f$\alpha@f$ and @f$\beta@f$.
436        */
437       explicit
438       beta_distribution(_RealType __alpha_val = _RealType(1),
439                         _RealType __beta_val = _RealType(1))
440       : _M_param(__alpha_val, __beta_val)
441       { }
443       explicit
444       beta_distribution(const param_type& __p)
445       : _M_param(__p)
446       { }
448       /**
449        * @brief Resets the distribution state.
450        */
451       void
452       reset()
453       { }
455       /**
456        * @brief Returns the @f$\alpha@f$ of the distribution.
457        */
458       _RealType
459       alpha() const
460       { return _M_param.alpha(); }
462       /**
463        * @brief Returns the @f$\beta@f$ of the distribution.
464        */
465       _RealType
466       beta() const
467       { return _M_param.beta(); }
469       /**
470        * @brief Returns the parameter set of the distribution.
471        */
472       param_type
473       param() const
474       { return _M_param; }
476       /**
477        * @brief Sets the parameter set of the distribution.
478        * @param __param The new parameter set of the distribution.
479        */
480       void
481       param(const param_type& __param)
482       { _M_param = __param; }
484       /**
485        * @brief Returns the greatest lower bound value of the distribution.
486        */
487       result_type
488       min() const
489       { return result_type(0); }
491       /**
492        * @brief Returns the least upper bound value of the distribution.
493        */
494       result_type
495       max() const
496       { return result_type(1); }
498       /**
499        * @brief Generating functions.
500        */
501       template<typename _UniformRandomNumberGenerator>
502         result_type
503         operator()(_UniformRandomNumberGenerator& __urng)
504         { return this->operator()(__urng, this->param()); }
506       template<typename _UniformRandomNumberGenerator>
507         result_type
508         operator()(_UniformRandomNumberGenerator& __urng,
509                    const param_type& __p);
511       template<typename _ForwardIterator,
512                typename _UniformRandomNumberGenerator>
513         void
514         __generate(_ForwardIterator __f, _ForwardIterator __t,
515                    _UniformRandomNumberGenerator& __urng)
516         { this->__generate(__f, __t, __urng, this->param()); }
518       template<typename _ForwardIterator,
519                typename _UniformRandomNumberGenerator>
520         void
521         __generate(_ForwardIterator __f, _ForwardIterator __t,
522                    _UniformRandomNumberGenerator& __urng,
523                    const param_type& __p)
524         { this->__generate_impl(__f, __t, __urng, __p); }
526       template<typename _UniformRandomNumberGenerator>
527         void
528         __generate(result_type* __f, result_type* __t,
529                    _UniformRandomNumberGenerator& __urng,
530                    const param_type& __p)
531         { this->__generate_impl(__f, __t, __urng, __p); }
533       /**
534        * @brief Return true if two beta distributions have the same
535        *        parameters and the sequences that would be generated
536        *        are equal.
537        */
538       friend bool
539       operator==(const beta_distribution& __d1,
540                  const beta_distribution& __d2)
541       { return __d1.param() == __d2.param(); }
543       /**
544        * @brief Inserts a %beta_distribution random number distribution
545        * @p __x into the output stream @p __os.
546        *
547        * @param __os An output stream.
548        * @param __x  A %beta_distribution random number distribution.
549        *
550        * @returns The output stream with the state of @p __x inserted or in
551        * an error state.
552        */
553       template<typename _RealType1, typename _CharT, typename _Traits>
554         friend std::basic_ostream<_CharT, _Traits>&
555         operator<<(std::basic_ostream<_CharT, _Traits>& __os,
556                    const __gnu_cxx::beta_distribution<_RealType1>& __x);
558       /**
559        * @brief Extracts a %beta_distribution random number distribution
560        * @p __x from the input stream @p __is.
561        *
562        * @param __is An input stream.
563        * @param __x  A %beta_distribution random number generator engine.
564        *
565        * @returns The input stream with @p __x extracted or in an error state.
566        */
567       template<typename _RealType1, typename _CharT, typename _Traits>
568         friend std::basic_istream<_CharT, _Traits>&
569         operator>>(std::basic_istream<_CharT, _Traits>& __is,
570                    __gnu_cxx::beta_distribution<_RealType1>& __x);
572     private:
573       template<typename _ForwardIterator,
574                typename _UniformRandomNumberGenerator>
575         void
576         __generate_impl(_ForwardIterator __f, _ForwardIterator __t,
577                         _UniformRandomNumberGenerator& __urng,
578                         const param_type& __p);
580       param_type _M_param;
581     };
583   /**
584    * @brief Return true if two beta distributions are different.
585    */
586    template<typename _RealType>
587      inline bool
588      operator!=(const __gnu_cxx::beta_distribution<_RealType>& __d1,
589                 const __gnu_cxx::beta_distribution<_RealType>& __d2)
590     { return !(__d1 == __d2); }
594 _GLIBCXX_END_NAMESPACE_VERSION
595 } // namespace std
597 #include "random.tcc"
599 #endif /* _EXT_RANDOM */