BioCMAMC-ST
prng_extension.hpp
1#ifndef __PRNG_EXTENSION_HPP__
2#define __PRNG_EXTENSION_HPP__
3
4#include "Kokkos_Macros.hpp"
5#include "common/maths.hpp"
6#include "mc/alias.hpp"
7#include <Kokkos_Core.hpp>
8#include <Kokkos_MathematicalConstants.hpp>
9#include <Kokkos_Random.hpp>
10#include <cmath>
11#include <common/traits.hpp>
12
17{
49 template <typename T, typename F, class DeviceType>
52 requires(const T& obj, Kokkos::Random_XorShift1024<DeviceType>& gen) {
53 { obj.draw(gen) } -> std::same_as<F>;
54 { obj.mean() } -> std::same_as<F>;
55 { obj.var() } -> std::same_as<F>;
56 { obj.skewness() } -> std::same_as<F>;
57 };
58
78 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F erfinv(F x)
79 {
80
81 // Use the Winitzki’s method to calculate get an approached expression of
82 // erf(x) and inverse it
83
84 constexpr F a = 0.147;
85 constexpr F inv_a = 1. / a;
86 constexpr F tmp = (2 / (M_PI * a));
87 const double ln1mx2 = Kokkos::log((1. - x) * (1. + x));
88 const F term1 = tmp + (0.5 * ln1mx2);
89 const F term2 = inv_a * ln1mx2;
90 return Kokkos::copysign(
91 Kokkos::sqrt(Kokkos::sqrt(term1 * term1 - term2) - term1), x);
92 }
93
94 // Inverse error function approximation (Using the rational approximation)
95 // template <FloatingPointType F> KOKKOS_INLINE_FUNCTION constexpr F
96 // erf_inv(F x)
97 // {
98 // constexpr F a[4] = {0.147, 0.147, 0.147, 0.147};
99 // constexpr F b[4] = {-1.0, 0.5, -0.5, 1.0};
100 // KOKKOS_ASSERT(x <= -1.0 || x >= 1.0);
101
102 // F z = (x < 0.0) ? -x : x;
103
104 // F t = 2.0 / (Kokkos::numbers::pi * 0.147) + 0.5 * Kokkos::log(1.0 - z);
105 // F result = a[0] * Kokkos::pow(t, b[0]) +
106 // a[1] * Kokkos::pow(t, b[1]) +
107 // a[2] * Kokkos::pow(t, b[2]) +
108 // a[3] * Kokkos::pow(t, b[3]);
109 // KOKKOS_ASSERT(Kokkos::isfinite(result));
110 // return result;
111 // }
112
132 template <FloatingPointType F>
133 KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
134 {
135 constexpr F erfinv_lb = -5;
136 constexpr F erfinv_up = 5;
137 auto clamped_inverse =
138 Kokkos::clamp(erfinv(2 * p - 1), erfinv_lb, erfinv_up);
139 KOKKOS_ASSERT(
140 Kokkos::isfinite(stddev * Kokkos::numbers::sqrt2 * clamped_inverse));
141 return mean + stddev * Kokkos::numbers::sqrt2 * clamped_inverse;
142 }
143
158 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
159 {
160 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
161 constexpr double inv_sqrt_2_pi = 0.3989422804014327; // 1/sqrt(2pi)
162 return inv_sqrt_2_pi * Kokkos::exp(-0.5 * x * x);
163 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
164 }
165
180 template <FloatingPointType F> KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
181 {
182 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
183 return 0.5 * (1 + Kokkos::erf(x / Kokkos::numbers::sqrt2));
184 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
185 }
186
187 /*DISTRIBUTIONS*/
188
189 template <FloatingPointType F> struct Uniform
190 {
191 F min;
192 F max;
193
200 template <class DeviceType>
201 KOKKOS_INLINE_FUNCTION F
202 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
203 {
204 return draw_from(gen, min, max);
205 }
206
215 template <class DeviceType>
216 static KOKKOS_INLINE_FUNCTION F
217 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen, F min, F max)
218 {
219 return gen.drand(min, max);
220 }
221
226 KOKKOS_INLINE_FUNCTION F mean() const
227 {
228 return (min + max) / F(2);
229 }
230
235 KOKKOS_INLINE_FUNCTION F var() const
236 {
237 return (max - min) * (max - min) / F(12);
238 }
239
244 KOKKOS_INLINE_FUNCTION F skewness() const
245 {
246 return F(0);
247 }
248 };
249
258 template <FloatingPointType F> struct Normal
259 {
260 F mu;
262
269 template <class DeviceType>
270 KOKKOS_INLINE_FUNCTION F
271 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
272 {
273 return draw_from(gen, mu, sigma);
274 }
275
284 template <class DeviceType>
285 static KOKKOS_INLINE_FUNCTION F
286 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen, F mu, F sigma)
287 {
288 return gen.normal(mu, sigma);
289 }
290
295 KOKKOS_INLINE_FUNCTION F mean() const
296 {
297 return mu;
298 }
299
304 KOKKOS_INLINE_FUNCTION F var() const
305 {
306 return sigma * sigma;
307 }
308
313 KOKKOS_INLINE_FUNCTION F stddev() const
314 {
315 return sigma;
316 }
317
322 KOKKOS_INLINE_FUNCTION F skewness() const
323 {
324 return F(0);
325 }
326 };
327
337 template <FloatingPointType F> struct TruncatedNormal
338 {
339
340 F mu; // Mean
341 F sigma; // Standard deviation
342 F lower; // Standard deviation
343 F upper; // Standard deviation
344
345 TruncatedNormal() = default;
346
347 KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
348 : mu(m), sigma(s), lower(l), upper(u)
349 {
350 X_ASSERT(mu > lower);
351 X_ASSERT(mu < upper);
352 KOKKOS_ASSERT(mu > lower && mu < upper);
353 }
354
355 template <class DeviceType>
356 KOKKOS_INLINE_FUNCTION F
357 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
358 {
359 return draw_from(gen, mu, sigma, lower, upper);
360 }
361
362 template <class DeviceType>
363 static KOKKOS_INLINE_FUNCTION F
364 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen,
365 F mu,
366 F sigma,
367 F lower,
368 F upper)
369 {
370 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
371 const F rand = static_cast<F>(gen.drand());
372
373 // Max bounded because if sigma <<1 z -> Inf not wanted because
374 // erf/erfc/erfinv are not stable for extrem value Min bounded if
375 // |mu-bound| <<1 z -> 0 which is also not wanted for error function
376
377 F zl = Kokkos::clamp(
378 (lower - mu) / sigma, F(-5e3), F(0)); // upper-mu is by defintion <0
379 F zu = Kokkos::clamp(
380 (upper - mu) / sigma, F(0), F(5e3)); // upper-mu is by defintion >0
381
382 F pl = 0.5 * Kokkos::erfc(-zl / Kokkos::numbers::sqrt2);
383 KOKKOS_ASSERT(
384 Kokkos::isfinite(pl) &&
385 "Truncated normal draw leads to Nan of Inf with given parameters");
386 F pu = 0.5 * Kokkos::erfc(-zu / Kokkos::numbers::sqrt2);
387 KOKKOS_ASSERT(
388 Kokkos::isfinite(pu) &&
389 "Truncated normal draw leads to Nan of Inf with given parameters");
390 F p = rand * (pu - pl) + pl;
391 F x = norminv(p, mu, sigma);
392 KOKKOS_ASSERT(
393 Kokkos::isfinite(x) &&
394 "Truncated normal draw leads to Nan of Inf with given parameters");
395 return x;
396
397 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
398 }
399
400 KOKKOS_INLINE_FUNCTION F mean() const
401 {
402 const F alpha = (lower - mu) / sigma;
403 const F beta = (upper - mu) / sigma;
404 F Z = std_normal_cdf(beta) - std_normal_cdf(alpha);
405 return mu + sigma * (std_normal_pdf(alpha) - std_normal_pdf(beta)) / Z;
406 }
407
408 KOKKOS_INLINE_FUNCTION F var() const
409 {
410 const F alpha = (lower - mu) / sigma;
411 const F beta = (upper - mu) / sigma;
412 F Z = std_normal_cdf(beta) - std_normal_cdf(alpha);
413 const auto phi_b = std_normal_pdf(beta);
414 const auto phi_a = std_normal_pdf(alpha);
415 const auto tmp = (phi_a - phi_b) / Z;
416 const auto tmp2 = (alpha * phi_a - beta * phi_b) / Z;
417 return sigma * sigma * (1 - tmp2 - tmp * tmp);
418 }
419 KOKKOS_INLINE_FUNCTION F skewness() const
420 {
421 return 0.;
422 }
423 };
424
436 template <FloatingPointType F> struct ScaledTruncatedNormal
437 {
438
442
443 constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
445 dist(scale_factor * m,
446 s * scale_factor,
447 scale_factor * l,
448 scale_factor * u)
449 {
450 }
451
452 template <class DeviceType>
453 static KOKKOS_INLINE_FUNCTION F
454 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen,
455 F factor,
456 F mu,
457 F sigma,
458 F lower,
459 F upper)
460 {
461 return 1. / factor *
463 factor * mu,
464 factor * sigma,
465 factor * lower,
466 factor * upper);
467 }
468
469 template <class DeviceType>
470 KOKKOS_INLINE_FUNCTION F
471 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
472 {
473 return inverse_factor * dist.draw(gen);
474 }
475
476 KOKKOS_INLINE_FUNCTION F mean() const
477 {
478
479 return inverse_factor * dist.mean();
480 }
481
482 KOKKOS_INLINE_FUNCTION F var() const
483 {
484 return (inverse_factor * inverse_factor) * dist.var();
485 }
486 KOKKOS_INLINE_FUNCTION F skewness() const
487 {
488 return 0.;
489 }
490 };
491
499 template <FloatingPointType F> struct LogNormal
500 {
501 F mu; // Mean
502 F sigma; // Standard deviation
503
504 template <class DeviceType>
505 KOKKOS_INLINE_FUNCTION F
506 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
507 {
508 return Kokkos::exp(gen.normal(mu, sigma));
509 }
510
511 KOKKOS_INLINE_FUNCTION F mean() const
512 {
513 return Kokkos::exp(mu + sigma * sigma / 2);
514 }
515 KOKKOS_INLINE_FUNCTION F var() const
516 {
517 const auto sigma2 = sigma * sigma;
518 return (Kokkos::exp(sigma2) - 1) * Kokkos::exp(2 * mu + sigma2);
519 }
520 KOKKOS_INLINE_FUNCTION F skewness() const
521 {
522 const auto sigma2 = sigma * sigma;
523 return (Kokkos::exp(sigma2) + 2) * Kokkos::sqrt(Kokkos::exp(sigma2) - 1);
524 }
525 };
526
535 template <FloatingPointType F> struct SkewNormal
536 {
537 F xi; // Mean
538 F omega; // Standard deviation
540
541 template <class DeviceType>
542 KOKKOS_INLINE_FUNCTION F
543 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
544 {
545
546 const double Z0 = gen.normal();
547 const double Z1 = gen.normal();
548 const double delta = alpha / Kokkos::sqrt(1. + alpha * alpha);
549 const double scale_factor = Kokkos::sqrt(1. + delta * delta);
550 const double X = (Z0 + delta * Kokkos::abs(Z1)) / scale_factor;
551 return xi + omega * X;
552 }
553 KOKKOS_INLINE_FUNCTION F mean() const
554 {
555 return xi + omega * (alpha / (Kokkos::sqrt(1 + alpha * alpha))) *
556 Kokkos::sqrt(2 / M_PI);
557 }
558 KOKKOS_INLINE_FUNCTION F var() const
559 {
560 const auto delta = alpha / (Kokkos::sqrt(1 + alpha * alpha));
561 return omega * omega * (1 - 2 * delta * delta / M_PI);
562 }
563 KOKKOS_INLINE_FUNCTION F skewness() const
564 {
565 const auto delta = alpha / (Kokkos::sqrt(1 + alpha * alpha));
566 return ((4 - M_PI) * Kokkos::pow(delta * std::sqrt(2 / M_PI), 3)) /
567 Kokkos::pow(1 - 2 * delta * delta / M_PI, 1.5);
568 }
569 };
570
571 static_assert(ProbabilityLaw<SkewNormal<float>, float, ComputeSpace>);
572
580 template <FloatingPointType F> struct Exponential
581 {
583
584 static constexpr bool use_kokkos_log = true;
585
586 template <class DeviceType>
587 KOKKOS_INLINE_FUNCTION F
588 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
589 {
590 const float rnd = gen.frand();
591 return F(-1) * CommonMaths::_ln<use_kokkos_log>(rnd) / lambda;
592 }
593
594 [[nodiscard]] KOKKOS_INLINE_FUNCTION F mean() const
595 {
596 return F(1) / lambda;
597 }
598 [[nodiscard]] KOKKOS_INLINE_FUNCTION F var() const
599 {
600
601 return F(1) / (lambda * lambda);
602 }
603 [[nodiscard]] KOKKOS_INLINE_FUNCTION F skewness() const
604 {
605 return 2;
606 }
607 };
608
609 static_assert(ProbabilityLaw<Exponential<float>, float, ComputeSpace>);
610
611} // namespace MC::Distributions
612
613#endif
Definition traits.hpp:20
Concept for probability distribution laws.
Definition prng_extension.hpp:50
KOKKOS_INLINE_FUNCTION float _ln(float x)
Definition maths.hpp:10
Kokkos compatible method to draw from specific probability distribution.
Definition prng_extension.hpp:17
KOKKOS_INLINE_FUNCTION F erfinv(F x)
Computes an approximation of the inverse error function.
Definition prng_extension.hpp:78
KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
Computes the inverse CDF (probit function) of a normal distribution.
Definition prng_extension.hpp:133
KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
Computes the standard normal probability density function (PDF).
Definition prng_extension.hpp:158
KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
Computes the standard normal cumulative distribution function (CDF).
Definition prng_extension.hpp:180
Kokkos::DefaultExecutionSpace ComputeSpace
Definition alias.hpp:21
Represents a Exponential probability distribution.
Definition prng_extension.hpp:581
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:603
static constexpr bool use_kokkos_log
Definition prng_extension.hpp:584
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:594
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:598
F lambda
Definition prng_extension.hpp:582
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:588
Represents a LogNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:500
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:506
F sigma
Definition prng_extension.hpp:502
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:511
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:520
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:515
F mu
Definition prng_extension.hpp:501
Represents a normal (Gaussian) probability distribution.
Definition prng_extension.hpp:259
F mu
Mean.
Definition prng_extension.hpp:260
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:271
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:286
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:295
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:322
KOKKOS_INLINE_FUNCTION F stddev() const
Returns the standard deviation of the distribution.
Definition prng_extension.hpp:313
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:304
F sigma
Standard deviation.
Definition prng_extension.hpp:261
F inverse_factor
Definition prng_extension.hpp:440
F scale_factor
Definition prng_extension.hpp:439
TruncatedNormal< F > dist
Definition prng_extension.hpp:441
constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
Definition prng_extension.hpp:443
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:486
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:482
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F factor, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:454
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:471
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:476
Represents a SkewNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:536
F alpha
Definition prng_extension.hpp:539
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:563
F omega
Definition prng_extension.hpp:538
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:558
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:553
F xi
Definition prng_extension.hpp:537
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:543
Represents a TruncatedNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:338
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:400
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:419
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:357
KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
Definition prng_extension.hpp:347
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:408
F upper
Definition prng_extension.hpp:343
F lower
Definition prng_extension.hpp:342
F sigma
Definition prng_extension.hpp:341
F mu
Definition prng_extension.hpp:340
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:364
Definition prng_extension.hpp:190
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:226
F max
Max.
Definition prng_extension.hpp:192
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:235
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:202
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F min, F max)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:217
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:244
F min
min
Definition prng_extension.hpp:191