1#ifndef __PRNG_EXTENSION_HPP__
2#define __PRNG_EXTENSION_HPP__
4#include "Kokkos_Macros.hpp"
5#include "common/maths.hpp"
7#include <Kokkos_Core.hpp>
8#include <Kokkos_MathematicalConstants.hpp>
9#include <Kokkos_Random.hpp>
11#include <common/traits.hpp>
50 template <
typename T,
typename F,
class DeviceType>
53 &&
requires(
const T& obj, Kokkos::Random_XorShift1024<DeviceType>& gen)
88 template <FloatingPo
intType F>
89 KOKKOS_INLINE_FUNCTION F
96 constexpr F a = 0.147;
97 constexpr F inv_a = 1. / a;
98 constexpr F tmp = (2 / (M_PI * a));
99 const double ln1mx2 = Kokkos::log((1. - x) * (1. + x));
100 const F term1 = tmp + (0.5 * ln1mx2);
101 const F term2 = inv_a * ln1mx2;
102 return Kokkos::copysign(
103 Kokkos::sqrt(Kokkos::sqrt(term1 * term1 - term2) - term1), x);
144 template <FloatingPo
intType F>
145 KOKKOS_INLINE_FUNCTION F
148 constexpr F erfinv_lb = -5;
149 constexpr F erfinv_up = 5;
151 = Kokkos::clamp(
erfinv(2 * p - 1), erfinv_lb, erfinv_up);
153 Kokkos::isfinite(stddev * Kokkos::numbers::sqrt2 * clamped_inverse));
154 return mean + stddev * Kokkos::numbers::sqrt2 * clamped_inverse;
171 template <FloatingPo
intType F>
172 KOKKOS_INLINE_FUNCTION F
176 constexpr double inv_sqrt_2_pi = 0.3989422804014327;
177 return inv_sqrt_2_pi * Kokkos::exp(-0.5 * x * x);
195 template <FloatingPo
intType F>
196 KOKKOS_INLINE_FUNCTION F
200 return 0.5 * (1 + Kokkos::erf(x / Kokkos::numbers::sqrt2));
217 template <
class DeviceType>
218 KOKKOS_INLINE_FUNCTION F
219 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
232 template <
class DeviceType>
233 static KOKKOS_INLINE_FUNCTION F
236 return gen.drand(
min,
max);
243 KOKKOS_INLINE_FUNCTION F
246 return (
min +
max) / F(2);
253 KOKKOS_INLINE_FUNCTION F
263 KOKKOS_INLINE_FUNCTION F
278 template <FloatingPo
intType F>
struct Normal
289 template <
class DeviceType>
290 KOKKOS_INLINE_FUNCTION F
291 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
304 template <
class DeviceType>
305 static KOKKOS_INLINE_FUNCTION F
315 KOKKOS_INLINE_FUNCTION F
325 KOKKOS_INLINE_FUNCTION F
335 KOKKOS_INLINE_FUNCTION F
345 KOKKOS_INLINE_FUNCTION F
371 KOKKOS_INLINE_FUNCTION
constexpr
380 template <
class DeviceType>
381 KOKKOS_INLINE_FUNCTION F
382 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
387 template <
class DeviceType>
388 static KOKKOS_INLINE_FUNCTION F
396 const F rand =
static_cast<F
>(gen.drand());
409 F pl = 0.5 * Kokkos::erfc(-zl / Kokkos::numbers::sqrt2);
412 &&
"Truncated normal draw leads to Nan of Inf with given parameters");
413 F pu = 0.5 * Kokkos::erfc(-zu / Kokkos::numbers::sqrt2);
416 &&
"Truncated normal draw leads to Nan of Inf with given parameters");
417 F p = rand * (pu - pl) + pl;
421 &&
"Truncated normal draw leads to Nan of Inf with given parameters");
427 KOKKOS_INLINE_FUNCTION F
436 KOKKOS_INLINE_FUNCTION F
444 const auto tmp = (phi_a - phi_b) / Z;
445 const auto tmp2 = (alpha * phi_a - beta * phi_b) / Z;
446 return sigma *
sigma * (1 - tmp2 - tmp * tmp);
448 KOKKOS_INLINE_FUNCTION F
484 template <
class DeviceType>
485 static KOKKOS_INLINE_FUNCTION F
501 template <
class DeviceType>
502 KOKKOS_INLINE_FUNCTION F
503 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
508 KOKKOS_INLINE_FUNCTION F
515 KOKKOS_INLINE_FUNCTION F
520 KOKKOS_INLINE_FUNCTION F
539 template <
class DeviceType>
540 KOKKOS_INLINE_FUNCTION F
541 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
543 return Kokkos::exp(gen.normal(
mu,
sigma));
546 KOKKOS_INLINE_FUNCTION F
551 KOKKOS_INLINE_FUNCTION F
555 return (Kokkos::exp(sigma2) - 1.) * Kokkos::exp(2. *
mu + sigma2);
557 KOKKOS_INLINE_FUNCTION F
561 return (Kokkos::exp(sigma2) + 2.)
562 * Kokkos::sqrt(Kokkos::exp(sigma2) - 1.);
580 template <
class DeviceType>
581 KOKKOS_INLINE_FUNCTION F
582 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
585 const double Z0 = gen.normal();
586 const double Z1 = gen.normal();
588 const double scale_factor = Kokkos::sqrt(1. + delta * delta);
589 const double X = (Z0 + delta * Kokkos::abs(Z1)) / scale_factor;
592 KOKKOS_INLINE_FUNCTION F
597 * Kokkos::sqrt(2 / M_PI);
599 KOKKOS_INLINE_FUNCTION F
603 return omega *
omega * (1 - 2 * delta * delta / M_PI);
605 KOKKOS_INLINE_FUNCTION F
609 return ((4 - M_PI) * Kokkos::pow(delta * std::sqrt(2 / M_PI), 3))
610 / Kokkos::pow(1 - 2 * delta * delta / M_PI, 1.5);
614 static_assert(ProbabilityLaw<SkewNormal<float>, float,
ComputeSpace>);
629 template <
class DeviceType>
630 KOKKOS_INLINE_FUNCTION F
631 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
633 const float rnd = gen.frand();
637 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
642 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
648 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
655 static_assert(ProbabilityLaw<Exponential<float>, float,
ComputeSpace>);
Concept for probability distribution laws.
Definition prng_extension.hpp:52
KOKKOS_INLINE_FUNCTION float _ln(float x)
Definition maths.hpp:11
Kokkos compatible method to draw from specific probability distribution.
Definition prng_extension.hpp:18
KOKKOS_INLINE_FUNCTION F erfinv(F x)
Computes an approximation of the inverse error function.
Definition prng_extension.hpp:90
KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
Computes the inverse CDF (probit function) of a normal distribution.
Definition prng_extension.hpp:146
KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
Computes the standard normal probability density function (PDF).
Definition prng_extension.hpp:173
KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
Computes the standard normal cumulative distribution function (CDF).
Definition prng_extension.hpp:197
Kokkos::DefaultExecutionSpace ComputeSpace
Definition alias.hpp:22
Represents a Exponential probability distribution.
Definition prng_extension.hpp:624
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:649
static constexpr bool use_kokkos_log
Definition prng_extension.hpp:627
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:638
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:643
F lambda
Definition prng_extension.hpp:625
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:631
Represents a LogNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:535
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:541
F sigma
Definition prng_extension.hpp:537
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:547
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:558
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:552
F mu
Definition prng_extension.hpp:536
Represents a normal (Gaussian) probability distribution.
Definition prng_extension.hpp:279
F mu
Mean.
Definition prng_extension.hpp:280
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:291
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:306
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:316
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:346
KOKKOS_INLINE_FUNCTION F stddev() const
Returns the standard deviation of the distribution.
Definition prng_extension.hpp:336
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:326
F sigma
Standard deviation.
Definition prng_extension.hpp:281
F inverse_factor
Definition prng_extension.hpp:471
F scale_factor
Definition prng_extension.hpp:470
TruncatedNormal< F > dist
Definition prng_extension.hpp:472
constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
Definition prng_extension.hpp:475
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:521
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:516
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F factor, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:486
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:503
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:509
Represents a SkewNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:575
F alpha
Definition prng_extension.hpp:578
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:606
F omega
Definition prng_extension.hpp:577
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:600
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:593
F xi
Definition prng_extension.hpp:576
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:582
Represents a TruncatedNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:362
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:428
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:449
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:382
KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
Definition prng_extension.hpp:372
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:437
F upper
Definition prng_extension.hpp:367
F lower
Definition prng_extension.hpp:366
F sigma
Definition prng_extension.hpp:365
F mu
Definition prng_extension.hpp:364
TruncatedNormal()=default
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:389