Skip to content

Commit 6abfcf6

Browse files
authored
Implement std::normal_distribution (#6585)
1 parent 10ad1db commit 6abfcf6

28 files changed

+647
-905
lines changed

libcudacxx/include/cuda/std/__random/bernoulli_distribution.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class bernoulli_distribution
3939

4040
class param_type
4141
{
42-
double __p_;
42+
double __p_{};
4343

4444
public:
4545
using distribution_type = bernoulli_distribution;
@@ -90,7 +90,7 @@ class bernoulli_distribution
9090
template <class _URng>
9191
[[nodiscard]] _CCCL_API constexpr result_type operator()(_URng& __g, const param_type& __p) noexcept
9292
{
93-
static_assert(__cccl_random_is_valid_urng<_URng>);
93+
static_assert(__cccl_random_is_valid_urng<_URng>, "URng must meet the UniformRandomBitGenerator requirements");
9494
return ::cuda::std::generate_canonical<double, numeric_limits<double>::digits>(__g) < __p.p();
9595
}
9696

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef _CUDA_STD___RANDOM_NORMAL_DISTRIBUTION_H
11+
#define _CUDA_STD___RANDOM_NORMAL_DISTRIBUTION_H
12+
13+
#include <cuda/std/detail/__config>
14+
15+
#include <cuda/std/__cmath/logarithms.h>
16+
#include <cuda/std/__cmath/roots.h>
17+
#include <cuda/std/__random/is_valid.h>
18+
#include <cuda/std/__random/uniform_real_distribution.h>
19+
#include <cuda/std/limits>
20+
21+
#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
22+
# pragma GCC system_header
23+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
24+
# pragma clang system_header
25+
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
26+
# pragma system_header
27+
#endif // no system header
28+
29+
#if !_CCCL_COMPILER(NVRTC)
30+
# include <ios>
31+
#endif // !_CCCL_COMPILER(NVRTC)
32+
33+
#include <cuda/std/__cccl/prologue.h>
34+
35+
_CCCL_BEGIN_NAMESPACE_CUDA_STD
36+
37+
template <class _RealType = double>
38+
class normal_distribution
39+
{
40+
static_assert(__libcpp_random_is_valid_realtype<_RealType>, "RealType must be a supported floating-point type");
41+
42+
public:
43+
// types
44+
using result_type = _RealType;
45+
46+
class param_type
47+
{
48+
result_type __mean_;
49+
result_type __stddev_;
50+
51+
public:
52+
using distribution_type = normal_distribution;
53+
54+
_CCCL_API constexpr explicit param_type(result_type __mean = 0, result_type __stddev = 1) noexcept
55+
: __mean_{__mean}
56+
, __stddev_{__stddev}
57+
{}
58+
59+
[[nodiscard]] _CCCL_API constexpr result_type mean() const noexcept
60+
{
61+
return __mean_;
62+
}
63+
[[nodiscard]] _CCCL_API constexpr result_type stddev() const noexcept
64+
{
65+
return __stddev_;
66+
}
67+
68+
[[nodiscard]] _CCCL_API friend constexpr bool operator==(const param_type& __x, const param_type& __y) noexcept
69+
{
70+
return __x.__mean_ == __y.__mean_ && __x.__stddev_ == __y.__stddev_;
71+
}
72+
73+
#if _CCCL_STD_VER <= 2017
74+
[[nodiscard]] _CCCL_API friend constexpr bool operator!=(const param_type& __x, const param_type& __y) noexcept
75+
{
76+
return !(__x == __y);
77+
}
78+
#endif // _CCCL_STD_VER <= 2017
79+
};
80+
81+
private:
82+
param_type __p_{};
83+
result_type __v_{};
84+
bool __v_hot_{false};
85+
86+
public:
87+
_CCCL_API constexpr normal_distribution() noexcept
88+
: normal_distribution{0}
89+
{}
90+
_CCCL_API constexpr explicit normal_distribution(result_type __mean, result_type __stddev = result_type{1}) noexcept
91+
: __p_{param_type(__mean, __stddev)}
92+
{}
93+
_CCCL_API constexpr explicit normal_distribution(const param_type& __p) noexcept
94+
: __p_{__p}
95+
{}
96+
_CCCL_API constexpr void reset() noexcept
97+
{
98+
__v_hot_ = false;
99+
}
100+
101+
// generating functions
102+
template <class _URng>
103+
_CCCL_API constexpr result_type operator()(_URng& __g)
104+
{
105+
return (*this)(__g, __p_);
106+
}
107+
template <class _URng>
108+
_CCCL_API constexpr result_type operator()(_URng& __g, const param_type& __p)
109+
{
110+
static_assert(__cccl_random_is_valid_urng<_URng>, "URng must meet the UniformRandomBitGenerator requirements");
111+
result_type __up = 0;
112+
if (__v_hot_)
113+
{
114+
__v_hot_ = false;
115+
__up = __v_;
116+
}
117+
else
118+
{
119+
uniform_real_distribution<result_type> __uni(-1, 1);
120+
result_type __u = __uni(__g);
121+
result_type __v = __uni(__g);
122+
result_type __s = __u * __u + __v * __v;
123+
while (__s > 1 || __s == 0)
124+
{
125+
__u = __uni(__g);
126+
__v = __uni(__g);
127+
__s = __u * __u + __v * __v;
128+
}
129+
const result_type __fp = ::cuda::std::sqrt(-2 * ::cuda::std::log(__s) / __s);
130+
__v_ = __v * __fp;
131+
__v_hot_ = true;
132+
__up = __u * __fp;
133+
}
134+
return __up * __p.stddev() + __p.mean();
135+
}
136+
137+
// property functions
138+
[[nodiscard]] _CCCL_API constexpr result_type mean() const noexcept
139+
{
140+
return __p_.mean();
141+
}
142+
[[nodiscard]] _CCCL_API constexpr result_type stddev() const noexcept
143+
{
144+
return __p_.stddev();
145+
}
146+
147+
[[nodiscard]] _CCCL_API constexpr param_type param() const noexcept
148+
{
149+
return __p_;
150+
}
151+
_CCCL_API constexpr void param(const param_type& __p) noexcept
152+
{
153+
__p_ = __p;
154+
}
155+
156+
[[nodiscard]] _CCCL_API constexpr result_type min() const noexcept
157+
{
158+
return -numeric_limits<result_type>::infinity();
159+
}
160+
[[nodiscard]] _CCCL_API constexpr result_type max() const noexcept
161+
{
162+
return numeric_limits<result_type>::infinity();
163+
}
164+
165+
[[nodiscard]] _CCCL_API friend constexpr bool
166+
operator==(const normal_distribution& __x, const normal_distribution& __y) noexcept
167+
{
168+
return __x.__p_ == __y.__p_ && __x.__v_hot_ == __y.__v_hot_ && (!__x.__v_hot_ || __x.__v_ == __y.__v_);
169+
}
170+
#if _CCCL_STD_VER <= 2017
171+
[[nodiscard]] _CCCL_API friend constexpr bool
172+
operator!=(const normal_distribution& __x, const normal_distribution& __y) noexcept
173+
{
174+
return !(__x == __y);
175+
}
176+
#endif // _CCCL_STD_VER <= 2017
177+
178+
#if !_CCCL_COMPILER(NVRTC)
179+
template <class _CharT, class _Traits>
180+
friend ::std::basic_ostream<_CharT, _Traits>&
181+
operator<<(::std::basic_ostream<_CharT, _Traits>& __os, const normal_distribution& __x)
182+
{
183+
_CharT __sp = __os.widen(' ');
184+
::std::ios_base::fmtflags __flags = __os.flags();
185+
__os.flags(::std::ios_base::dec | ::std::ios_base::left | ::std::ios_base::fixed);
186+
_CharT __fill = __os.fill(__sp);
187+
::std::streamsize __precision = __os.precision(17); // Max precision for double
188+
__os << __x.mean() << __sp << __x.stddev() << __sp << __x.__v_hot_;
189+
if (__x.__v_hot_)
190+
{
191+
__os << __sp << __x.__v_;
192+
}
193+
__os.precision(__precision);
194+
__os.fill(__fill);
195+
__os.flags(__flags);
196+
return __os;
197+
}
198+
199+
template <class _CharT, class _Traits>
200+
friend ::std::basic_istream<_CharT, _Traits>&
201+
operator>>(::std::basic_istream<_CharT, _Traits>& __is, normal_distribution& __x)
202+
{
203+
using _Istream = ::std::basic_istream<_CharT, _Traits>;
204+
auto __flags = __is.flags();
205+
__is.flags(_Istream::skipws);
206+
result_type __mean;
207+
result_type __stddev;
208+
result_type __vp = 0;
209+
int __v_hot_int = 0;
210+
__is >> __mean >> __stddev >> __v_hot_int;
211+
bool __v_hot = __v_hot_int != 0;
212+
if (__v_hot)
213+
{
214+
__is >> __vp;
215+
}
216+
if (!__is.fail())
217+
{
218+
__x.param(param_type(__mean, __stddev));
219+
__x.__v_hot_ = __v_hot;
220+
__x.__v_ = __vp;
221+
}
222+
__is.flags(__flags);
223+
return __is;
224+
}
225+
#endif // !_CCCL_COMPILER(NVRTC)
226+
};
227+
228+
_CCCL_END_NAMESPACE_CUDA_STD
229+
230+
#include <cuda/std/__cccl/epilogue.h>
231+
232+
#endif // _CUDA_STD___RANDOM_NORMAL_DISTRIBUTION_H

libcudacxx/include/cuda/std/__random/uniform_int_distribution.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ class uniform_int_distribution
226226
template <class _URng>
227227
[[nodiscard]] _CCCL_API constexpr result_type operator()(_URng& __g, const param_type& __p) noexcept
228228
{
229-
static_assert(__cccl_random_is_valid_urng<_URng>);
229+
static_assert(__cccl_random_is_valid_urng<_URng>, "URng must meet the UniformRandomBitGenerator requirements");
230230
using _UIntType = conditional_t<sizeof(result_type) <= sizeof(uint32_t), uint32_t, make_unsigned_t<result_type>>;
231231
const _UIntType __rp = _UIntType(__p.b()) - _UIntType(__p.a()) + _UIntType(1);
232232
if (__rp == 1)

libcudacxx/include/cuda/std/__random/uniform_real_distribution.h

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,28 +45,30 @@ class uniform_real_distribution
4545
public:
4646
using distribution_type = uniform_real_distribution;
4747

48-
_CCCL_API explicit param_type(result_type __a = 0, result_type __b = 1) noexcept
49-
: __a_(__a)
50-
, __b_(__b)
48+
_CCCL_API constexpr explicit param_type(result_type __a = 0, result_type __b = 1) noexcept
49+
: __a_{__a}
50+
, __b_{__b}
5151
{}
5252

53-
[[nodiscard]] _CCCL_API result_type a() const noexcept
53+
[[nodiscard]] _CCCL_API constexpr result_type a() const noexcept
5454
{
5555
return __a_;
5656
}
57-
[[nodiscard]] _CCCL_API result_type b() const noexcept
57+
[[nodiscard]] _CCCL_API constexpr result_type b() const noexcept
5858
{
5959
return __b_;
6060
}
6161

62-
[[nodiscard]] _CCCL_API friend bool operator==(const param_type& __x, const param_type& __y) noexcept
62+
[[nodiscard]] _CCCL_API friend constexpr bool operator==(const param_type& __x, const param_type& __y) noexcept
6363
{
6464
return __x.__a_ == __y.__a_ && __x.__b_ == __y.__b_;
6565
}
66-
[[nodiscard]] _CCCL_API friend bool operator!=(const param_type& __x, const param_type& __y) noexcept
66+
#if _CCCL_STD_VER <= 2017
67+
[[nodiscard]] _CCCL_API friend constexpr bool operator!=(const param_type& __x, const param_type& __y) noexcept
6768
{
6869
return !(__x == __y);
6970
}
71+
#endif // _CCCL_STD_VER <= 2017
7072
};
7173

7274
private:
@@ -75,71 +77,73 @@ class uniform_real_distribution
7577
public:
7678
// constructors and reset functions
7779

78-
_CCCL_API uniform_real_distribution() noexcept
80+
_CCCL_API constexpr uniform_real_distribution() noexcept
7981
: uniform_real_distribution(0)
8082
{}
81-
_CCCL_API explicit uniform_real_distribution(result_type __a, result_type __b = 1) noexcept
82-
: __p_(param_type(__a, __b))
83+
_CCCL_API constexpr explicit uniform_real_distribution(result_type __a, result_type __b = 1) noexcept
84+
: __p_{param_type(__a, __b)}
8385
{}
84-
_CCCL_API explicit uniform_real_distribution(const param_type& __p) noexcept
85-
: __p_(__p)
86+
_CCCL_API constexpr explicit uniform_real_distribution(const param_type& __p) noexcept
87+
: __p_{__p}
8688
{}
87-
_CCCL_API void reset() noexcept {}
89+
_CCCL_API constexpr void reset() noexcept {}
8890

8991
// generating functions
9092
template <class _URng>
91-
[[nodiscard]] _CCCL_API result_type operator()(_URng& __g) noexcept
93+
[[nodiscard]] _CCCL_API constexpr result_type operator()(_URng& __g) noexcept
9294
{
9395
return (*this)(__g, __p_);
9496
}
9597

9698
_CCCL_EXEC_CHECK_DISABLE
9799
template <class _URng>
98-
[[nodiscard]] _CCCL_API result_type operator()(_URng& __g, const param_type& __p) noexcept
100+
[[nodiscard]] _CCCL_API constexpr result_type operator()(_URng& __g, const param_type& __p) noexcept
99101
{
100-
static_assert(__cccl_random_is_valid_urng<_URng>, "");
102+
static_assert(__cccl_random_is_valid_urng<_URng>, "URng must meet the UniformRandomBitGenerator requirements");
101103
return (__p.b() - __p.a()) * ::cuda::std::generate_canonical<_RealType, numeric_limits<_RealType>::digits>(__g)
102104
+ __p.a();
103105
}
104106

105107
// property functions
106-
[[nodiscard]] _CCCL_API result_type a() const noexcept
108+
[[nodiscard]] _CCCL_API constexpr result_type a() const noexcept
107109
{
108110
return __p_.a();
109111
}
110-
[[nodiscard]] _CCCL_API result_type b() const noexcept
112+
[[nodiscard]] _CCCL_API constexpr result_type b() const noexcept
111113
{
112114
return __p_.b();
113115
}
114116

115-
[[nodiscard]] _CCCL_API param_type param() const noexcept
117+
[[nodiscard]] _CCCL_API constexpr param_type param() const noexcept
116118
{
117119
return __p_;
118120
}
119-
_CCCL_API void param(const param_type& __p) noexcept
121+
_CCCL_API constexpr void param(const param_type& __p) noexcept
120122
{
121123
__p_ = __p;
122124
}
123125

124-
[[nodiscard]] _CCCL_API result_type min() const noexcept
126+
[[nodiscard]] _CCCL_API constexpr result_type min() const noexcept
125127
{
126128
return a();
127129
}
128-
[[nodiscard]] _CCCL_API result_type max() const noexcept
130+
[[nodiscard]] _CCCL_API constexpr result_type max() const noexcept
129131
{
130132
return b();
131133
}
132134

133-
[[nodiscard]] _CCCL_API friend bool
135+
[[nodiscard]] _CCCL_API friend constexpr bool
134136
operator==(const uniform_real_distribution& __x, const uniform_real_distribution& __y) noexcept
135137
{
136138
return __x.__p_ == __y.__p_;
137139
}
138-
[[nodiscard]] _CCCL_API friend bool
140+
#if _CCCL_STD_VER <= 2017
141+
[[nodiscard]] _CCCL_API friend constexpr bool
139142
operator!=(const uniform_real_distribution& __x, const uniform_real_distribution& __y) noexcept
140143
{
141144
return !(__x == __y);
142145
}
146+
#endif // _CCCL_STD_VER <= 2017
143147
};
144148

145149
#if 0 // Implement streaming

libcudacxx/include/cuda/std/__random_

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include <cuda/std/__random/bernoulli_distribution.h>
2525
#include <cuda/std/__random/linear_congruential_engine.h>
26+
#include <cuda/std/__random/normal_distribution.h>
2627
#include <cuda/std/__random/philox_engine.h>
2728
#include <cuda/std/__random/seed_seq.h>
2829
#include <cuda/std/__random/uniform_int_distribution.h>

0 commit comments

Comments
 (0)