BioCMAMC-ST
prng_extension.hpp
1#ifndef __PRNG_EXTENSION_HPP__
2#define __PRNG_EXTENSION_HPP__
3
4#include "Kokkos_Macros.hpp"
5#include "mc/alias.hpp"
6#include <Kokkos_Core.hpp>
7#include <Kokkos_MathematicalConstants.hpp>
8#include <Kokkos_Random.hpp>
9#include <cmath>
10#include <common/traits.hpp>
11
16{
47 template <typename T, typename F, class DeviceType>
49 FloatingPointType<F> && requires(const T& obj, Kokkos::Random_XorShift1024<DeviceType>& gen) {
50 { obj.draw(gen) } -> std::same_as<F>;
51 { obj.mean() } -> std::same_as<F>;
52 { obj.var() } -> std::same_as<F>;
53 { obj.skewness() } -> std::same_as<F>;
54 };
55
74 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F erfinv(F x)
75 {
76
77 // Use the Winitzki’s method to calculate get an approached expression of erf(x) and inverse it
78
79 constexpr F a = 0.147;
80 constexpr F inv_a = 1. / a;
81 constexpr F tmp = (2 / (M_PI * a));
82 const double ln1mx2 = Kokkos::log((1. - x) * (1. + x));
83 const F term1 = tmp + (0.5 * ln1mx2);
84 const F term2 = inv_a * ln1mx2;
85 return Kokkos::copysign(Kokkos::sqrt(Kokkos::sqrt(term1 * term1 - term2) - term1), x);
86 }
87
88 // Inverse error function approximation (Using the rational approximation)
89 // template <FloatingPointType F> KOKKOS_INLINE_FUNCTION constexpr F erf_inv(F x)
90 // {
91 // constexpr F a[4] = {0.147, 0.147, 0.147, 0.147};
92 // constexpr F b[4] = {-1.0, 0.5, -0.5, 1.0};
93 // KOKKOS_ASSERT(x <= -1.0 || x >= 1.0);
94
95 // F z = (x < 0.0) ? -x : x;
96
97 // F t = 2.0 / (Kokkos::numbers::pi * 0.147) + 0.5 * Kokkos::log(1.0 - z);
98 // F result = a[0] * Kokkos::pow(t, b[0]) +
99 // a[1] * Kokkos::pow(t, b[1]) +
100 // a[2] * Kokkos::pow(t, b[2]) +
101 // a[3] * Kokkos::pow(t, b[3]);
102 // KOKKOS_ASSERT(Kokkos::isfinite(result));
103 // return result;
104 // }
105
123 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
124 {
125 constexpr F erfinv_lb = -5;
126 constexpr F erfinv_up = 5;
127 auto clamped_inverse = Kokkos::clamp(erfinv(2 * p - 1), erfinv_lb, erfinv_up);
128 KOKKOS_ASSERT(Kokkos::isfinite(stddev * Kokkos::numbers::sqrt2 * clamped_inverse));
129 return mean + stddev * Kokkos::numbers::sqrt2 * clamped_inverse;
130 }
131
146 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
147 {
148 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
149 constexpr double inv_sqrt_2_pi = 0.3989422804014327; // 1/sqrt(2pi)
150 return inv_sqrt_2_pi * Kokkos::exp(-0.5 * x * x);
151 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
152 }
153
166 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
167 {
168 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
169 return 0.5 * (1 + Kokkos::erf(x / Kokkos::numbers::sqrt2));
170 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
171 }
172
173 /*DISTRIBUTIONS*/
174
183 template <FloatingPointType F> struct Normal
184 {
185 F mu;
187
194 template <class DeviceType>
195 KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
196 {
197 return draw_from(gen, mu, sigma);
198 }
199
208 template <class DeviceType>
209 static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen,
210 F mu,
211 F sigma)
212 {
213 return gen.normal(mu, sigma);
214 }
215
220 KOKKOS_INLINE_FUNCTION F mean() const
221 {
222 return mu;
223 }
224
229 KOKKOS_INLINE_FUNCTION F var() const
230 {
231 return sigma * sigma;
232 }
233
238 KOKKOS_INLINE_FUNCTION F stddev() const
239 {
240 return sigma;
241 }
242
247 KOKKOS_INLINE_FUNCTION F skewness() const
248 {
249 return F(0);
250 }
251 };
252
253
262 template <FloatingPointType F> struct TruncatedNormal
263 {
264
265 F mu; // Mean
266 F sigma; // Standard deviation
267 F lower; // Standard deviation
268 F upper; // Standard deviation
269
270
271 KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
272 : mu(m), sigma(s), lower(l), upper(u)
273 {
274 X_ASSERT(mu > lower);
275 X_ASSERT(mu < upper);
276 KOKKOS_ASSERT(mu > lower && mu < upper);
277 }
278
279 template <class DeviceType>
280 KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
281 {
282 return draw_from(gen, mu, sigma, lower, upper);
283 }
284
285 template <class DeviceType>
286 static KOKKOS_INLINE_FUNCTION F
287 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen, F mu, F sigma, F lower, F upper)
288 {
289 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
290 const F rand = static_cast<F>(gen.drand());
291
292 // Max bounded because if sigma <<1 z -> Inf not wanted because erf/erfc/erfinv are not stable
293 // for extrem value Min bounded if |mu-bound| <<1 z -> 0 which is also not wanted for error
294 // function
295
296 F zl = Kokkos::clamp((lower - mu) / sigma, F(-5e3), F(0)); // upper-mu is by defintion <0
297 F zu = Kokkos::clamp((upper - mu) / sigma, F(0), F(5e3)); // upper-mu is by defintion >0
298
299 F pl = 0.5 * Kokkos::erfc(-zl / Kokkos::numbers::sqrt2);
300 KOKKOS_ASSERT(Kokkos::isfinite(pl) &&
301 "Truncated normal draw leads to Nan of Inf with given parameters");
302 F pu = 0.5 * Kokkos::erfc(-zu / Kokkos::numbers::sqrt2);
303 KOKKOS_ASSERT(Kokkos::isfinite(pu) &&
304 "Truncated normal draw leads to Nan of Inf with given parameters");
305 F p = rand * (pu - pl) + pl;
306 F x = norminv(p, mu, sigma);
307 KOKKOS_ASSERT(Kokkos::isfinite(x) &&
308 "Truncated normal draw leads to Nan of Inf with given parameters");
309 return x;
310
311 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
312 }
313
314 KOKKOS_INLINE_FUNCTION F mean() const
315 {
316 const F alpha = (lower - mu) / sigma;
317 const F beta = (upper - mu) / sigma;
318 F Z = std_normal_cdf(beta) - std_normal_cdf(alpha);
319 return mu + sigma * (std_normal_pdf(alpha) - std_normal_pdf(beta)) / Z;
320 }
321
322 KOKKOS_INLINE_FUNCTION F var() const
323 {
324 const F alpha = (lower - mu) / sigma;
325 const F beta = (upper - mu) / sigma;
326 F Z = std_normal_cdf(beta) - std_normal_cdf(alpha);
327 const auto phi_b = std_normal_pdf(beta);
328 const auto phi_a = std_normal_pdf(alpha);
329 const auto tmp = (phi_a - phi_b) / Z;
330 const auto tmp2 = (alpha * phi_a - beta * phi_b) / Z;
331 return sigma * sigma * (1 - tmp2 - tmp * tmp);
332 }
333 KOKKOS_INLINE_FUNCTION F skewness() const
334 {
335 return 0.;
336 }
337 };
338
339
350 template <FloatingPointType F> struct ScaledTruncatedNormal
351 {
352
356
357 constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
360 {
361 }
362
363 template <class DeviceType>
364 static KOKKOS_INLINE_FUNCTION F draw_from(
365 Kokkos::Random_XorShift1024<DeviceType>& gen, F factor, F mu, F sigma, F lower, F upper)
366 {
367 return 1. / factor *
369 gen, factor * mu, factor * sigma, factor * lower, factor * upper);
370 }
371
372 template <class DeviceType>
373 KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
374 {
375 return inverse_factor * dist.draw(gen);
376 }
377
378 KOKKOS_INLINE_FUNCTION F mean() const
379 {
380
381 return inverse_factor * dist.mean();
382 }
383
384 KOKKOS_INLINE_FUNCTION F var() const
385 {
386 return (inverse_factor * inverse_factor) * dist.var();
387 }
388 KOKKOS_INLINE_FUNCTION F skewness() const
389 {
390 return 0.;
391 }
392 };
393
401 template <FloatingPointType F> struct LogNormal
402 {
403 F mu; // Mean
404 F sigma; // Standard deviation
405
406 template <class DeviceType>
407 KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
408 {
409 return Kokkos::exp(gen.normal(mu, sigma));
410 }
411
412 KOKKOS_INLINE_FUNCTION F mean() const
413 {
414 return Kokkos::exp(mu + sigma * sigma / 2);
415 }
416 KOKKOS_INLINE_FUNCTION F var() const
417 {
418 const auto sigma2 = sigma * sigma;
419 return (Kokkos::exp(sigma2) - 1) * Kokkos::exp(2 * mu + sigma2);
420 }
421 KOKKOS_INLINE_FUNCTION F skewness() const
422 {
423 const auto sigma2 = sigma * sigma;
424 return (Kokkos::exp(sigma2) + 2) * Kokkos::sqrt(Kokkos::exp(sigma2) - 1);
425 }
426 };
427
435 template <FloatingPointType F> struct SkewNormal
436 {
437 F xi; // Mean
438 F omega; // Standard deviation
440
441 template <class DeviceType>
442 KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
443 {
444
445 const double Z0 = gen.normal();
446 const double Z1 = gen.normal();
447 const double delta = alpha / Kokkos::sqrt(1. + alpha * alpha);
448 const double scale_factor = Kokkos::sqrt(1. + delta * delta);
449 const double X = (Z0 + delta * Kokkos::abs(Z1)) / scale_factor;
450 return xi + omega * X;
451 }
452 KOKKOS_INLINE_FUNCTION F mean() const
453 {
454 return xi + omega * (alpha / (Kokkos::sqrt(1 + alpha * alpha))) * Kokkos::sqrt(2 / M_PI);
455 }
456 KOKKOS_INLINE_FUNCTION F var() const
457 {
458 const auto delta = alpha / (Kokkos::sqrt(1 + alpha * alpha));
459 return omega * omega * (1 - 2 * delta * delta / M_PI);
460 }
461 KOKKOS_INLINE_FUNCTION F skewness() const
462 {
463 const auto delta = alpha / (Kokkos::sqrt(1 + alpha * alpha));
464 return ((4 - M_PI) * Kokkos::pow(delta * std::sqrt(2 / M_PI), 3)) /
465 Kokkos::pow(1 - 2 * delta * delta / M_PI, 1.5);
466 }
467 };
468
469 static_assert(ProbabilityLaw<SkewNormal<float>, float, ComputeSpace>);
470
471} // namespace MC::Distributions
472
473#endif
Definition traits.hpp:20
Concept for probability distribution laws.
Definition prng_extension.hpp:48
Kokkos compatible method to draw from specific probability distribution.
Definition prng_extension.hpp:16
KOKKOS_INLINE_FUNCTION F erfinv(F x)
Computes an approximation of the inverse error function.
Definition prng_extension.hpp:74
KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
Computes the inverse CDF (probit function) of a normal distribution.
Definition prng_extension.hpp:123
KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
Computes the standard normal probability density function (PDF).
Definition prng_extension.hpp:146
KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
Computes the standard normal cumulative distribution function (CDF).
Definition prng_extension.hpp:166
Represents a LogNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:402
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:407
F sigma
Definition prng_extension.hpp:404
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:412
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:421
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:416
F mu
Definition prng_extension.hpp:403
Represents a normal (Gaussian) probability distribution.
Definition prng_extension.hpp:184
F mu
Mean.
Definition prng_extension.hpp:185
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:195
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:209
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:220
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:247
KOKKOS_INLINE_FUNCTION F stddev() const
Returns the standard deviation of the distribution.
Definition prng_extension.hpp:238
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:229
F sigma
Standard deviation.
Definition prng_extension.hpp:186
Represents a Scaled TruncatedNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:351
F inverse_factor
Definition prng_extension.hpp:354
F scale_factor
Definition prng_extension.hpp:353
TruncatedNormal< F > dist
Definition prng_extension.hpp:355
constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
Definition prng_extension.hpp:357
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:388
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:384
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F factor, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:364
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:373
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:378
Represents a SkewNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:436
F alpha
Definition prng_extension.hpp:439
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:461
F omega
Definition prng_extension.hpp:438
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:456
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:452
F xi
Definition prng_extension.hpp:437
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:442
Represents a TruncatedNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:263
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:314
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:333
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:280
KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
Definition prng_extension.hpp:271
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:322
F upper
Definition prng_extension.hpp:268
F lower
Definition prng_extension.hpp:267
F sigma
Definition prng_extension.hpp:266
F mu
Definition prng_extension.hpp:265
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:287