1#ifndef __PRNG_EXTENSION_HPP__
2#define __PRNG_EXTENSION_HPP__
4#include "Kokkos_Macros.hpp"
5#include "common/maths.hpp"
7#include <Kokkos_Core.hpp>
8#include <Kokkos_MathematicalConstants.hpp>
9#include <Kokkos_Random.hpp>
11#include <common/traits.hpp>
50 template <
typename T,
typename F,
class DeviceType>
53 &&
requires(
const T& obj,
54 Kokkos::Random_XorShift1024<DeviceType>& gen) {
55 { obj.draw(gen) } -> std::same_as<F>;
56 { obj.mean() } -> std::same_as<F>;
57 { obj.var() } -> std::same_as<F>;
58 { obj.skewness() } -> std::same_as<F>;
80 template <FloatingPo
intType F>
81 KOKKOS_INLINE_FUNCTION F
88 constexpr F a = 0.147;
89 constexpr F inv_a = 1. / a;
90 constexpr F tmp = (2 / (M_PI * a));
91 const double ln1mx2 = Kokkos::log((1. - x) * (1. + x));
92 const F term1 = tmp + (0.5 * ln1mx2);
93 const F term2 = inv_a * ln1mx2;
94 return Kokkos::copysign(
95 Kokkos::sqrt(Kokkos::sqrt(term1 * term1 - term2) - term1), x);
136 template <FloatingPo
intType F>
137 KOKKOS_INLINE_FUNCTION F
140 constexpr F erfinv_lb = -5;
141 constexpr F erfinv_up = 5;
143 = Kokkos::clamp(
erfinv(2 * p - 1), erfinv_lb, erfinv_up);
145 Kokkos::isfinite(stddev * Kokkos::numbers::sqrt2 * clamped_inverse));
146 return mean + stddev * Kokkos::numbers::sqrt2 * clamped_inverse;
163 template <FloatingPo
intType F>
164 KOKKOS_INLINE_FUNCTION F
168 constexpr double inv_sqrt_2_pi = 0.3989422804014327;
169 return inv_sqrt_2_pi * Kokkos::exp(-0.5 * x * x);
187 template <FloatingPo
intType F>
188 KOKKOS_INLINE_FUNCTION F
192 return 0.5 * (1 + Kokkos::erf(x / Kokkos::numbers::sqrt2));
209 template <
class DeviceType>
210 KOKKOS_INLINE_FUNCTION F
211 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
224 template <
class DeviceType>
225 static KOKKOS_INLINE_FUNCTION F
228 return gen.drand(
min,
max);
235 KOKKOS_INLINE_FUNCTION F
238 return (
min +
max) / F(2);
245 KOKKOS_INLINE_FUNCTION F
255 KOKKOS_INLINE_FUNCTION F
270 template <FloatingPo
intType F>
struct Normal
281 template <
class DeviceType>
282 KOKKOS_INLINE_FUNCTION F
283 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
296 template <
class DeviceType>
297 static KOKKOS_INLINE_FUNCTION F
307 KOKKOS_INLINE_FUNCTION F
317 KOKKOS_INLINE_FUNCTION F
327 KOKKOS_INLINE_FUNCTION F
337 KOKKOS_INLINE_FUNCTION F
371 template <
class DeviceType>
372 KOKKOS_INLINE_FUNCTION F
373 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
378 template <
class DeviceType>
379 static KOKKOS_INLINE_FUNCTION F
387 const F rand =
static_cast<F
>(gen.drand());
400 F pl = 0.5 * Kokkos::erfc(-zl / Kokkos::numbers::sqrt2);
403 &&
"Truncated normal draw leads to Nan of Inf with given parameters");
404 F pu = 0.5 * Kokkos::erfc(-zu / Kokkos::numbers::sqrt2);
407 &&
"Truncated normal draw leads to Nan of Inf with given parameters");
408 F p = rand * (pu - pl) + pl;
412 &&
"Truncated normal draw leads to Nan of Inf with given parameters");
418 KOKKOS_INLINE_FUNCTION F
427 KOKKOS_INLINE_FUNCTION F
435 const auto tmp = (phi_a - phi_b) / Z;
436 const auto tmp2 = (alpha * phi_a - beta * phi_b) / Z;
437 return sigma *
sigma * (1 - tmp2 - tmp * tmp);
439 KOKKOS_INLINE_FUNCTION F
474 template <
class DeviceType>
475 static KOKKOS_INLINE_FUNCTION F
491 template <
class DeviceType>
492 KOKKOS_INLINE_FUNCTION F
493 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
498 KOKKOS_INLINE_FUNCTION F
505 KOKKOS_INLINE_FUNCTION F
510 KOKKOS_INLINE_FUNCTION F
529 template <
class DeviceType>
530 KOKKOS_INLINE_FUNCTION F
531 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
533 return Kokkos::exp(gen.normal(
mu,
sigma));
536 KOKKOS_INLINE_FUNCTION F
541 KOKKOS_INLINE_FUNCTION F
545 return (Kokkos::exp(sigma2) - 1.) * Kokkos::exp(2. *
mu + sigma2);
547 KOKKOS_INLINE_FUNCTION F
551 return (Kokkos::exp(sigma2) + 2.)
552 * Kokkos::sqrt(Kokkos::exp(sigma2) - 1.);
570 template <
class DeviceType>
571 KOKKOS_INLINE_FUNCTION F
572 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
575 const double Z0 = gen.normal();
576 const double Z1 = gen.normal();
578 const double scale_factor = Kokkos::sqrt(1. + delta * delta);
579 const double X = (Z0 + delta * Kokkos::abs(Z1)) / scale_factor;
582 KOKKOS_INLINE_FUNCTION F
587 * Kokkos::sqrt(2 / M_PI);
589 KOKKOS_INLINE_FUNCTION F
593 return omega *
omega * (1 - 2 * delta * delta / M_PI);
595 KOKKOS_INLINE_FUNCTION F
599 return ((4 - M_PI) * Kokkos::pow(delta * std::sqrt(2 / M_PI), 3))
600 / Kokkos::pow(1 - 2 * delta * delta / M_PI, 1.5);
604 static_assert(ProbabilityLaw<SkewNormal<float>, float,
ComputeSpace>);
619 template <
class DeviceType>
620 KOKKOS_INLINE_FUNCTION F
621 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
623 const float rnd = gen.frand();
627 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
632 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
638 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
645 static_assert(ProbabilityLaw<Exponential<float>, float,
ComputeSpace>);
Concept for probability distribution laws.
Definition prng_extension.hpp:52
KOKKOS_INLINE_FUNCTION float _ln(float x)
Definition maths.hpp:11
Kokkos compatible method to draw from specific probability distribution.
Definition prng_extension.hpp:18
KOKKOS_INLINE_FUNCTION F erfinv(F x)
Computes an approximation of the inverse error function.
Definition prng_extension.hpp:82
KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
Computes the inverse CDF (probit function) of a normal distribution.
Definition prng_extension.hpp:138
KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
Computes the standard normal probability density function (PDF).
Definition prng_extension.hpp:165
KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
Computes the standard normal cumulative distribution function (CDF).
Definition prng_extension.hpp:189
Kokkos::DefaultExecutionSpace ComputeSpace
Definition alias.hpp:22
Represents a Exponential probability distribution.
Definition prng_extension.hpp:614
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:639
static constexpr bool use_kokkos_log
Definition prng_extension.hpp:617
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:628
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:633
F lambda
Definition prng_extension.hpp:615
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:621
Represents a LogNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:525
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:531
F sigma
Definition prng_extension.hpp:527
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:537
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:548
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:542
F mu
Definition prng_extension.hpp:526
Represents a normal (Gaussian) probability distribution.
Definition prng_extension.hpp:271
F mu
Mean.
Definition prng_extension.hpp:272
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:283
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:298
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:308
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:338
KOKKOS_INLINE_FUNCTION F stddev() const
Returns the standard deviation of the distribution.
Definition prng_extension.hpp:328
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:318
F sigma
Standard deviation.
Definition prng_extension.hpp:273
F inverse_factor
Definition prng_extension.hpp:462
F scale_factor
Definition prng_extension.hpp:461
TruncatedNormal< F > dist
Definition prng_extension.hpp:463
constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
Definition prng_extension.hpp:465
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:511
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:506
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F factor, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:476
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:493
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:499
Represents a SkewNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:565
F alpha
Definition prng_extension.hpp:568
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:596
F omega
Definition prng_extension.hpp:567
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:590
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:583
F xi
Definition prng_extension.hpp:566
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:572
Represents a TruncatedNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:354
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:419
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:440
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:373
KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
Definition prng_extension.hpp:363
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:428
F upper
Definition prng_extension.hpp:359
F lower
Definition prng_extension.hpp:358
F sigma
Definition prng_extension.hpp:357
F mu
Definition prng_extension.hpp:356
TruncatedNormal()=default
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:380