1#ifndef __PRNG_EXTENSION_HPP__
2#define __PRNG_EXTENSION_HPP__
4#include "Kokkos_Macros.hpp"
5#include "common/maths.hpp"
7#include <Kokkos_Core.hpp>
8#include <Kokkos_MathematicalConstants.hpp>
9#include <Kokkos_Random.hpp>
11#include <common/traits.hpp>
49 template <
typename T,
typename F,
class DeviceType>
52 requires(
const T& obj, Kokkos::Random_XorShift1024<DeviceType>& gen) {
53 { obj.draw(gen) } -> std::same_as<F>;
54 { obj.mean() } -> std::same_as<F>;
55 { obj.var() } -> std::same_as<F>;
56 { obj.skewness() } -> std::same_as<F>;
78 template <FloatingPo
intType F> KOKKOS_INLINE_FUNCTION F
erfinv(F x)
84 constexpr F a = 0.147;
85 constexpr F inv_a = 1. / a;
86 constexpr F tmp = (2 / (M_PI * a));
87 const double ln1mx2 = Kokkos::log((1. - x) * (1. + x));
88 const F term1 = tmp + (0.5 * ln1mx2);
89 const F term2 = inv_a * ln1mx2;
90 return Kokkos::copysign(
91 Kokkos::sqrt(Kokkos::sqrt(term1 * term1 - term2) - term1), x);
132 template <FloatingPo
intType F>
133 KOKKOS_INLINE_FUNCTION F
norminv(F p, F mean, F stddev)
135 constexpr F erfinv_lb = -5;
136 constexpr F erfinv_up = 5;
137 auto clamped_inverse =
138 Kokkos::clamp(
erfinv(2 * p - 1), erfinv_lb, erfinv_up);
140 Kokkos::isfinite(stddev * Kokkos::numbers::sqrt2 * clamped_inverse));
141 return mean + stddev * Kokkos::numbers::sqrt2 * clamped_inverse;
161 constexpr double inv_sqrt_2_pi = 0.3989422804014327;
162 return inv_sqrt_2_pi * Kokkos::exp(-0.5 * x * x);
183 return 0.5 * (1 + Kokkos::erf(x / Kokkos::numbers::sqrt2));
200 template <
class DeviceType>
201 KOKKOS_INLINE_FUNCTION F
202 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
215 template <
class DeviceType>
216 static KOKKOS_INLINE_FUNCTION F
219 return gen.drand(
min,
max);
226 KOKKOS_INLINE_FUNCTION F
mean()
const
228 return (
min +
max) / F(2);
235 KOKKOS_INLINE_FUNCTION F
var()
const
258 template <FloatingPo
intType F>
struct Normal
269 template <
class DeviceType>
270 KOKKOS_INLINE_FUNCTION F
271 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
284 template <
class DeviceType>
285 static KOKKOS_INLINE_FUNCTION F
295 KOKKOS_INLINE_FUNCTION F
mean()
const
304 KOKKOS_INLINE_FUNCTION F
var()
const
355 template <
class DeviceType>
356 KOKKOS_INLINE_FUNCTION F
357 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
362 template <
class DeviceType>
363 static KOKKOS_INLINE_FUNCTION F
371 const F rand =
static_cast<F
>(gen.drand());
377 F zl = Kokkos::clamp(
379 F zu = Kokkos::clamp(
382 F pl = 0.5 * Kokkos::erfc(-zl / Kokkos::numbers::sqrt2);
384 Kokkos::isfinite(pl) &&
385 "Truncated normal draw leads to Nan of Inf with given parameters");
386 F pu = 0.5 * Kokkos::erfc(-zu / Kokkos::numbers::sqrt2);
388 Kokkos::isfinite(pu) &&
389 "Truncated normal draw leads to Nan of Inf with given parameters");
390 F p = rand * (pu - pl) + pl;
393 Kokkos::isfinite(x) &&
394 "Truncated normal draw leads to Nan of Inf with given parameters");
400 KOKKOS_INLINE_FUNCTION F
mean()
const
408 KOKKOS_INLINE_FUNCTION F
var()
const
415 const auto tmp = (phi_a - phi_b) / Z;
416 const auto tmp2 = (alpha * phi_a - beta * phi_b) / Z;
417 return sigma *
sigma * (1 - tmp2 - tmp * tmp);
452 template <
class DeviceType>
453 static KOKKOS_INLINE_FUNCTION F
469 template <
class DeviceType>
470 KOKKOS_INLINE_FUNCTION F
471 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
476 KOKKOS_INLINE_FUNCTION F
mean()
const
482 KOKKOS_INLINE_FUNCTION F
var()
const
504 template <
class DeviceType>
505 KOKKOS_INLINE_FUNCTION F
506 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
508 return Kokkos::exp(gen.normal(
mu,
sigma));
511 KOKKOS_INLINE_FUNCTION F
mean()
const
515 KOKKOS_INLINE_FUNCTION F
var()
const
518 return (Kokkos::exp(sigma2) - 1) * Kokkos::exp(2 *
mu + sigma2);
523 return (Kokkos::exp(sigma2) + 2) * Kokkos::sqrt(Kokkos::exp(sigma2) - 1);
541 template <
class DeviceType>
542 KOKKOS_INLINE_FUNCTION F
543 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
546 const double Z0 = gen.normal();
547 const double Z1 = gen.normal();
549 const double scale_factor = Kokkos::sqrt(1. + delta * delta);
550 const double X = (Z0 + delta * Kokkos::abs(Z1)) / scale_factor;
553 KOKKOS_INLINE_FUNCTION F
mean()
const
556 Kokkos::sqrt(2 / M_PI);
558 KOKKOS_INLINE_FUNCTION F
var()
const
561 return omega *
omega * (1 - 2 * delta * delta / M_PI);
566 return ((4 - M_PI) * Kokkos::pow(delta * std::sqrt(2 / M_PI), 3)) /
567 Kokkos::pow(1 - 2 * delta * delta / M_PI, 1.5);
571 static_assert(ProbabilityLaw<SkewNormal<float>, float,
ComputeSpace>);
586 template <
class DeviceType>
587 KOKKOS_INLINE_FUNCTION F
588 draw(Kokkos::Random_XorShift1024<DeviceType>& gen)
const
590 const float rnd = gen.frand();
594 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
mean()
const
598 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
var()
const
603 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
skewness()
const
609 static_assert(ProbabilityLaw<Exponential<float>, float,
ComputeSpace>);
Concept for probability distribution laws.
Definition prng_extension.hpp:50
KOKKOS_INLINE_FUNCTION float _ln(float x)
Definition maths.hpp:10
Kokkos compatible method to draw from specific probability distribution.
Definition prng_extension.hpp:17
KOKKOS_INLINE_FUNCTION F erfinv(F x)
Computes an approximation of the inverse error function.
Definition prng_extension.hpp:78
KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
Computes the inverse CDF (probit function) of a normal distribution.
Definition prng_extension.hpp:133
KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
Computes the standard normal probability density function (PDF).
Definition prng_extension.hpp:158
KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
Computes the standard normal cumulative distribution function (CDF).
Definition prng_extension.hpp:180
Kokkos::DefaultExecutionSpace ComputeSpace
Definition alias.hpp:21
Represents a Exponential probability distribution.
Definition prng_extension.hpp:581
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:603
static constexpr bool use_kokkos_log
Definition prng_extension.hpp:584
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:594
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:598
F lambda
Definition prng_extension.hpp:582
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:588
Represents a LogNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:500
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:506
F sigma
Definition prng_extension.hpp:502
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:511
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:520
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:515
F mu
Definition prng_extension.hpp:501
Represents a normal (Gaussian) probability distribution.
Definition prng_extension.hpp:259
F mu
Mean.
Definition prng_extension.hpp:260
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:271
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:286
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:295
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:322
KOKKOS_INLINE_FUNCTION F stddev() const
Returns the standard deviation of the distribution.
Definition prng_extension.hpp:313
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:304
F sigma
Standard deviation.
Definition prng_extension.hpp:261
F inverse_factor
Definition prng_extension.hpp:440
F scale_factor
Definition prng_extension.hpp:439
TruncatedNormal< F > dist
Definition prng_extension.hpp:441
constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
Definition prng_extension.hpp:443
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:486
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:482
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F factor, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:454
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:471
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:476
Represents a SkewNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:536
F alpha
Definition prng_extension.hpp:539
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:563
F omega
Definition prng_extension.hpp:538
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:558
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:553
F xi
Definition prng_extension.hpp:537
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:543
Represents a TruncatedNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:338
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:400
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:419
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:357
KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
Definition prng_extension.hpp:347
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:408
F upper
Definition prng_extension.hpp:343
F lower
Definition prng_extension.hpp:342
F sigma
Definition prng_extension.hpp:341
F mu
Definition prng_extension.hpp:340
TruncatedNormal()=default
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:364