BioCMAMC-ST
prng_extension.hpp
1#ifndef __PRNG_EXTENSION_HPP__
2#define __PRNG_EXTENSION_HPP__
3
4#include "Kokkos_Macros.hpp"
5#include <Kokkos_Core.hpp>
6#include <Kokkos_MathematicalConstants.hpp>
7#include <Kokkos_Random.hpp>
8#include <cmath>
9#include <common/traits.hpp>
10
15{
46 template <typename T, typename F, class DeviceType>
48 FloatingPointType<F> && requires(const T& obj, Kokkos::Random_XorShift1024<DeviceType>& gen) {
49 { obj.draw(gen) } -> std::same_as<F>;
50 { obj.mean() } -> std::same_as<F>;
51 { obj.var() } -> std::same_as<F>;
52 { obj.skewness() } -> std::same_as<F>;
53 };
54
73 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F erfinv(F x)
74 {
75
76 // Use the Winitzki’s method to calculate get an approached expression of erf(x) and inverse it
77
78 constexpr F a = 0.147;
79 constexpr F inv_a = 1. / a;
80 constexpr F tmp = (2 / (M_PI * a));
81 const double ln1mx2 = Kokkos::log((1. - x) * (1. + x));
82 const F term1 = tmp + (0.5 * ln1mx2);
83 const F term2 = inv_a * ln1mx2;
84 return Kokkos::copysign(Kokkos::sqrt(Kokkos::sqrt(term1 * term1 - term2) - term1), x);
85 }
86
87 // Inverse error function approximation (Using the rational approximation)
88 // template <FloatingPointType F> KOKKOS_INLINE_FUNCTION constexpr F erf_inv(F x)
89 // {
90 // constexpr F a[4] = {0.147, 0.147, 0.147, 0.147};
91 // constexpr F b[4] = {-1.0, 0.5, -0.5, 1.0};
92 // KOKKOS_ASSERT(x <= -1.0 || x >= 1.0);
93
94 // F z = (x < 0.0) ? -x : x;
95
96 // F t = 2.0 / (Kokkos::numbers::pi * 0.147) + 0.5 * Kokkos::log(1.0 - z);
97 // F result = a[0] * Kokkos::pow(t, b[0]) +
98 // a[1] * Kokkos::pow(t, b[1]) +
99 // a[2] * Kokkos::pow(t, b[2]) +
100 // a[3] * Kokkos::pow(t, b[3]);
101 // KOKKOS_ASSERT(Kokkos::isfinite(result));
102 // return result;
103 // }
104
122 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
123 {
124 constexpr F erfinv_lb = -5;
125 constexpr F erfinv_up = 5;
126 auto clamped_inverse = Kokkos::clamp(erfinv(2 * p - 1), erfinv_lb, erfinv_up);
127 KOKKOS_ASSERT(Kokkos::isfinite(stddev * Kokkos::numbers::sqrt2 * clamped_inverse));
128 return mean + stddev * Kokkos::numbers::sqrt2 * clamped_inverse;
129 }
130
145 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
146 {
147 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
148 constexpr double inv_sqrt_2_pi = 0.3989422804014327; // 1/sqrt(2pi)
149 return inv_sqrt_2_pi * Kokkos::exp(-0.5 * x * x);
150 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
151 }
152
165 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
166 {
167 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
168 return 0.5 * (1 + Kokkos::erf(x / Kokkos::numbers::sqrt2));
169 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
170 }
171
172 /*DISTRIBUTIONS*/
173
182 template <FloatingPointType F> struct Normal
183 {
184 F mu;
186
193 template <class DeviceType>
194 KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
195 {
196 return draw_from(gen, mu, sigma);
197 }
198
207 template <class DeviceType>
208 static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen,
209 F mu,
210 F sigma)
211 {
212 return gen.normal(mu, sigma);
213 }
214
219 KOKKOS_INLINE_FUNCTION F mean() const
220 {
221 return mu;
222 }
223
228 KOKKOS_INLINE_FUNCTION F var() const
229 {
230 return sigma * sigma;
231 }
232
237 KOKKOS_INLINE_FUNCTION F stddev() const
238 {
239 return sigma;
240 }
241
246 KOKKOS_INLINE_FUNCTION F skewness() const
247 {
248 return F(0);
249 }
250 };
251
252
261 template <FloatingPointType F> struct TruncatedNormal
262 {
263
264 F mu; // Mean
265 F sigma; // Standard deviation
266 F lower; // Standard deviation
267 F upper; // Standard deviation
268
269 KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
270 : mu(m), sigma(s), lower(l), upper(u)
271 {
272 X_ASSERT(mu > lower);
273 X_ASSERT(mu < upper);
274 KOKKOS_ASSERT(mu > lower && mu < upper);
275 }
276
277 template <class DeviceType>
278 KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
279 {
280 return draw_from(gen, mu, sigma, lower, upper);
281 }
282
283 template <class DeviceType>
284 static KOKKOS_INLINE_FUNCTION F
285 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen, F mu, F sigma, F lower, F upper)
286 {
287 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
288 const F rand = static_cast<F>(gen.drand());
289
290 // Max bounded because if sigma <<1 z -> Inf not wanted because erf/erfc/erfinv are not stable
291 // for extrem value Min bounded if |mu-bound| <<1 z -> 0 which is also not wanted for error
292 // function
293
294 F zl = Kokkos::clamp((lower - mu) / sigma, F(-5e3), F(0)); // upper-mu is by defintion <0
295 F zu = Kokkos::clamp((upper - mu) / sigma, F(0), F(5e3)); // upper-mu is by defintion >0
296
297 F pl = 0.5 * Kokkos::erfc(-zl / Kokkos::numbers::sqrt2);
298 KOKKOS_ASSERT(Kokkos::isfinite(pl) &&
299 "Truncated normal draw leads is Nan of Inf with given parameters");
300 F pu = 0.5 * Kokkos::erfc(-zu / Kokkos::numbers::sqrt2);
301 KOKKOS_ASSERT(Kokkos::isfinite(pu) &&
302 "Truncated normal draw leads is Nan of Inf with given parameters");
303 F p = rand * (pu - pl) + pl;
304 F x = norminv(p, mu, sigma);
305 KOKKOS_ASSERT(Kokkos::isfinite(x) &&
306 "Truncated normal draw leads is Nan of Inf with given parameters");
307 return x;
308
309 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
310 }
311
312 KOKKOS_INLINE_FUNCTION F mean() const
313 {
314 const F alpha = (lower - mu) / sigma;
315 const F beta = (upper - mu) / sigma;
316 F Z = std_normal_cdf(beta) - std_normal_cdf(alpha);
317 return mu + sigma * (std_normal_pdf(alpha) - std_normal_pdf(beta)) / Z;
318 }
319
320 KOKKOS_INLINE_FUNCTION F var() const
321 {
322 const F alpha = (lower - mu) / sigma;
323 const F beta = (upper - mu) / sigma;
324 F Z = std_normal_cdf(beta) - std_normal_cdf(alpha);
325 const auto phi_b = std_normal_pdf(beta);
326 const auto phi_a = std_normal_pdf(alpha);
327 const auto tmp = (phi_a - phi_b) / Z;
328 const auto tmp2 = (alpha * phi_a - beta * phi_b) / Z;
329 return sigma * sigma * (1 - tmp2 - tmp * tmp);
330 }
331 KOKKOS_INLINE_FUNCTION F skewness() const
332 {
333 return 0.;
334 }
335 };
336
337
348 template <FloatingPointType F> struct ScaledTruncatedNormal
349 {
350
354
355 constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
358 {
359 }
360
361 template <class DeviceType>
362 static KOKKOS_INLINE_FUNCTION F draw_from(
363 Kokkos::Random_XorShift1024<DeviceType>& gen, F factor, F mu, F sigma, F lower, F upper)
364 {
365 return 1. / factor *
367 gen, factor * mu, factor * sigma, factor * lower, factor * upper);
368 }
369
370 template <class DeviceType>
371 KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
372 {
373 return inverse_factor * dist.draw(gen);
374 }
375
376 KOKKOS_INLINE_FUNCTION F mean() const
377 {
378
379 return inverse_factor * dist.mean();
380 }
381
382 KOKKOS_INLINE_FUNCTION F var() const
383 {
384 return (inverse_factor * inverse_factor) * dist.var();
385 }
386 KOKKOS_INLINE_FUNCTION F skewness() const
387 {
388 return 0.;
389 }
390 };
391
399 template <FloatingPointType F> struct LogNormal
400 {
401 F mu; // Mean
402 F sigma; // Standard deviation
403
404 template <class DeviceType>
405 KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
406 {
407 return Kokkos::exp(gen.normal(mu, sigma));
408 }
409
410 KOKKOS_INLINE_FUNCTION F mean() const
411 {
412 return Kokkos::exp(mu + sigma * sigma / 2);
413 }
414 KOKKOS_INLINE_FUNCTION F var() const
415 {
416 const auto sigma2 = sigma * sigma;
417 return (Kokkos::exp(sigma2) - 1) * Kokkos::exp(2 * mu + sigma2);
418 }
419 KOKKOS_INLINE_FUNCTION F skewness() const
420 {
421 const auto sigma2 = sigma * sigma;
422 return (Kokkos::exp(sigma2) + 2) * Kokkos::sqrt(Kokkos::exp(sigma2) - 1);
423 }
424 };
425
433 template <FloatingPointType F> struct SkewNormal
434 {
435 F xi; // Mean
436 F omega; // Standard deviation
438
439 template <class DeviceType>
440 KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
441 {
442
443 const double Z0 = gen.normal();
444 const double Z1 = gen.normal();
445 const double delta = alpha / Kokkos::sqrt(1. + alpha * alpha);
446 const double scale_factor = Kokkos::sqrt(1. + delta * delta);
447 const double X = (Z0 + delta * Kokkos::abs(Z1)) / scale_factor;
448 return xi + omega * X;
449 }
450 KOKKOS_INLINE_FUNCTION F mean() const
451 {
452 return xi + omega * (alpha / (Kokkos::sqrt(1 + alpha * alpha))) * Kokkos::sqrt(2 / M_PI);
453 }
454 KOKKOS_INLINE_FUNCTION F var() const
455 {
456 const auto delta = alpha / (Kokkos::sqrt(1 + alpha * alpha));
457 return omega * omega * (1 - 2 * delta * delta / M_PI);
458 }
459 KOKKOS_INLINE_FUNCTION F skewness() const
460 {
461 const auto delta = alpha / (Kokkos::sqrt(1 + alpha * alpha));
462 return ((4 - M_PI) * Kokkos::pow(delta * std::sqrt(2 / M_PI), 3)) /
463 Kokkos::pow(1 - 2 * delta * delta / M_PI, 1.5);
464 }
465 };
466
467} // namespace MC::Distributions
468
469#endif
Definition traits.hpp:20
Concept for probability distribution laws.
Definition prng_extension.hpp:47
Kokkos compatible method to draw from specific probability distribution.
Definition prng_extension.hpp:15
KOKKOS_INLINE_FUNCTION F erfinv(F x)
Computes an approximation of the inverse error function.
Definition prng_extension.hpp:73
KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
Computes the inverse CDF (probit function) of a normal distribution.
Definition prng_extension.hpp:122
KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
Computes the standard normal probability density function (PDF).
Definition prng_extension.hpp:145
KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
Computes the standard normal cumulative distribution function (CDF).
Definition prng_extension.hpp:165
Represents a LogNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:400
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:405
F sigma
Definition prng_extension.hpp:402
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:410
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:419
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:414
F mu
Definition prng_extension.hpp:401
Represents a normal (Gaussian) probability distribution.
Definition prng_extension.hpp:183
F mu
Mean.
Definition prng_extension.hpp:184
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:194
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:208
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:219
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:246
KOKKOS_INLINE_FUNCTION F stddev() const
Returns the standard deviation of the distribution.
Definition prng_extension.hpp:237
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:228
F sigma
Standard deviation.
Definition prng_extension.hpp:185
Represents a Scaled TruncatedNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:349
F inverse_factor
Definition prng_extension.hpp:352
F scale_factor
Definition prng_extension.hpp:351
TruncatedNormal< F > dist
Definition prng_extension.hpp:353
constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
Definition prng_extension.hpp:355
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:386
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:382
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F factor, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:362
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:371
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:376
Represents a SkewNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:434
F alpha
Definition prng_extension.hpp:437
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:459
F omega
Definition prng_extension.hpp:436
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:454
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:450
F xi
Definition prng_extension.hpp:435
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:440
Represents a TruncatedNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:262
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:312
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:331
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:278
KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
Definition prng_extension.hpp:269
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:320
F upper
Definition prng_extension.hpp:267
F lower
Definition prng_extension.hpp:266
F sigma
Definition prng_extension.hpp:265
F mu
Definition prng_extension.hpp:264
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:285