BioCMAMC-ST
prng_extension.hpp
1#ifndef __PRNG_EXTENSION_HPP__
2#define __PRNG_EXTENSION_HPP__
3
4#include "Kokkos_Macros.hpp"
5#include "common/maths.hpp"
6#include "mc/alias.hpp"
7#include <Kokkos_Core.hpp>
8#include <Kokkos_MathematicalConstants.hpp>
9#include <Kokkos_Random.hpp>
10#include <cmath>
11#include <common/traits.hpp>
12
18{
50 template <typename T, typename F, class DeviceType>
53 && requires(const T& obj, Kokkos::Random_XorShift1024<DeviceType>& gen)
54 {
55 {
56 obj.draw(gen)
57 } -> std::same_as<F>;
58 {
59 obj.mean()
60 } -> std::same_as<F>;
61 {
62 obj.var()
63 } -> std::same_as<F>;
64 {
65 obj.skewness()
66 } -> std::same_as<F>;
67 };
68
88 template <FloatingPointType F>
89 KOKKOS_INLINE_FUNCTION F
90 erfinv(F x)
91 {
92
93 // Use the Winitzki’s method to calculate get an approached expression of
94 // erf(x) and inverse it
95
96 constexpr F a = 0.147;
97 constexpr F inv_a = 1. / a;
98 constexpr F tmp = (2 / (M_PI * a));
99 const double ln1mx2 = Kokkos::log((1. - x) * (1. + x));
100 const F term1 = tmp + (0.5 * ln1mx2);
101 const F term2 = inv_a * ln1mx2;
102 return Kokkos::copysign(
103 Kokkos::sqrt(Kokkos::sqrt(term1 * term1 - term2) - term1), x);
104 }
105
106 // Inverse error function approximation (Using the rational approximation)
107 // template <FloatingPointType F> KOKKOS_INLINE_FUNCTION constexpr F
108 // erf_inv(F x)
109 // {
110 // constexpr F a[4] = {0.147, 0.147, 0.147, 0.147};
111 // constexpr F b[4] = {-1.0, 0.5, -0.5, 1.0};
112 // KOKKOS_ASSERT(x <= -1.0 || x >= 1.0);
113
114 // F z = (x < 0.0) ? -x : x;
115
116 // F t = 2.0 / (Kokkos::numbers::pi * 0.147) + 0.5 * Kokkos::log(1.0 - z);
117 // F result = a[0] * Kokkos::pow(t, b[0]) +
118 // a[1] * Kokkos::pow(t, b[1]) +
119 // a[2] * Kokkos::pow(t, b[2]) +
120 // a[3] * Kokkos::pow(t, b[3]);
121 // KOKKOS_ASSERT(Kokkos::isfinite(result));
122 // return result;
123 // }
124
144 template <FloatingPointType F>
145 KOKKOS_INLINE_FUNCTION F
146 norminv(F p, F mean, F stddev)
147 {
148 constexpr F erfinv_lb = -5;
149 constexpr F erfinv_up = 5;
150 auto clamped_inverse
151 = Kokkos::clamp(erfinv(2 * p - 1), erfinv_lb, erfinv_up);
152 KOKKOS_ASSERT(
153 Kokkos::isfinite(stddev * Kokkos::numbers::sqrt2 * clamped_inverse));
154 return mean + stddev * Kokkos::numbers::sqrt2 * clamped_inverse;
155 }
156
171 template <FloatingPointType F>
172 KOKKOS_INLINE_FUNCTION F
174 {
175 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
176 constexpr double inv_sqrt_2_pi = 0.3989422804014327; // 1/sqrt(2pi)
177 return inv_sqrt_2_pi * Kokkos::exp(-0.5 * x * x);
178 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
179 }
180
195 template <FloatingPointType F>
196 KOKKOS_INLINE_FUNCTION F
198 {
199 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
200 return 0.5 * (1 + Kokkos::erf(x / Kokkos::numbers::sqrt2));
201 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
202 }
203
204 /*DISTRIBUTIONS*/
205
206 template <FloatingPointType F> struct Uniform
207 {
208 F min;
209 F max;
210
217 template <class DeviceType>
218 KOKKOS_INLINE_FUNCTION F
219 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
220 {
221 return draw_from(gen, min, max);
222 }
223
232 template <class DeviceType>
233 static KOKKOS_INLINE_FUNCTION F
234 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen, F min, F max)
235 {
236 return gen.drand(min, max);
237 }
238
243 KOKKOS_INLINE_FUNCTION F
244 mean() const
245 {
246 return (min + max) / F(2);
247 }
248
253 KOKKOS_INLINE_FUNCTION F
254 var() const
255 {
256 return (max - min) * (max - min) / F(12);
257 }
258
263 KOKKOS_INLINE_FUNCTION F
264 skewness() const
265 {
266 return F(0);
267 }
268 };
269
278 template <FloatingPointType F> struct Normal
279 {
280 F mu;
282
289 template <class DeviceType>
290 KOKKOS_INLINE_FUNCTION F
291 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
292 {
293 return draw_from(gen, mu, sigma);
294 }
295
304 template <class DeviceType>
305 static KOKKOS_INLINE_FUNCTION F
306 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen, F mu, F sigma)
307 {
308 return gen.normal(mu, sigma);
309 }
310
315 KOKKOS_INLINE_FUNCTION F
316 mean() const
317 {
318 return mu;
319 }
320
325 KOKKOS_INLINE_FUNCTION F
326 var() const
327 {
328 return sigma * sigma;
329 }
330
335 KOKKOS_INLINE_FUNCTION F
336 stddev() const
337 {
338 return sigma;
339 }
340
345 KOKKOS_INLINE_FUNCTION F
346 skewness() const
347 {
348 return F(0);
349 }
350 };
351
361 template <FloatingPointType F> struct TruncatedNormal
362 {
363
364 F mu; // Mean
365 F sigma; // Standard deviation
366 F lower; // Standard deviation
367 F upper; // Standard deviation
368
369 TruncatedNormal() = default;
370
371 KOKKOS_INLINE_FUNCTION constexpr
372 TruncatedNormal(F m, F s, F l, F u)
373 : mu(m), sigma(s), lower(l), upper(u)
374 {
375 X_ASSERT(mu > lower);
376 X_ASSERT(mu < upper);
377 KOKKOS_ASSERT(mu > lower && mu < upper);
378 }
379
380 template <class DeviceType>
381 KOKKOS_INLINE_FUNCTION F
382 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
383 {
384 return draw_from(gen, mu, sigma, lower, upper);
385 }
386
387 template <class DeviceType>
388 static KOKKOS_INLINE_FUNCTION F
389 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen,
390 F mu,
391 F sigma,
392 F lower,
393 F upper)
394 {
395 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
396 const F rand = static_cast<F>(gen.drand());
397
398 // Max bounded because if sigma <<1 z -> Inf not wanted because
399 // erf/erfc/erfinv are not stable for extrem value Min bounded if
400 // |mu-bound| <<1 z -> 0 which is also not wanted for error function
401
402 F zl = Kokkos::clamp((lower - mu) / sigma,
403 F(-5e3),
404 F(0)); // upper-mu is by defintion <0
405 F zu = Kokkos::clamp((upper - mu) / sigma,
406 F(0),
407 F(5e3)); // upper-mu is by defintion >0
408
409 F pl = 0.5 * Kokkos::erfc(-zl / Kokkos::numbers::sqrt2);
410 KOKKOS_ASSERT(
411 Kokkos::isfinite(pl)
412 && "Truncated normal draw leads to Nan of Inf with given parameters");
413 F pu = 0.5 * Kokkos::erfc(-zu / Kokkos::numbers::sqrt2);
414 KOKKOS_ASSERT(
415 Kokkos::isfinite(pu)
416 && "Truncated normal draw leads to Nan of Inf with given parameters");
417 F p = rand * (pu - pl) + pl;
418 F x = norminv(p, mu, sigma);
419 KOKKOS_ASSERT(
420 Kokkos::isfinite(x)
421 && "Truncated normal draw leads to Nan of Inf with given parameters");
422 return x;
423
424 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
425 }
426
427 KOKKOS_INLINE_FUNCTION F
428 mean() const
429 {
430 const F alpha = (lower - mu) / sigma;
431 const F beta = (upper - mu) / sigma;
432 F Z = std_normal_cdf(beta) - std_normal_cdf(alpha);
433 return mu + sigma * (std_normal_pdf(alpha) - std_normal_pdf(beta)) / Z;
434 }
435
436 KOKKOS_INLINE_FUNCTION F
437 var() const
438 {
439 const F alpha = (lower - mu) / sigma;
440 const F beta = (upper - mu) / sigma;
441 F Z = std_normal_cdf(beta) - std_normal_cdf(alpha);
442 const auto phi_b = std_normal_pdf(beta);
443 const auto phi_a = std_normal_pdf(alpha);
444 const auto tmp = (phi_a - phi_b) / Z;
445 const auto tmp2 = (alpha * phi_a - beta * phi_b) / Z;
446 return sigma * sigma * (1 - tmp2 - tmp * tmp);
447 }
448 KOKKOS_INLINE_FUNCTION F
449 skewness() const
450 {
451 return 0.;
452 }
453 };
454
467 template <FloatingPointType F> struct ScaledTruncatedNormal
468 {
469
473
474 constexpr
475 ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
477 dist(scale_factor * m,
478 s * scale_factor,
479 scale_factor * l,
480 scale_factor * u)
481 {
482 }
483
484 template <class DeviceType>
485 static KOKKOS_INLINE_FUNCTION F
486 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen,
487 F factor,
488 F mu,
489 F sigma,
490 F lower,
491 F upper)
492 {
493 return 1. / factor
495 factor * mu,
496 factor * sigma,
497 factor * lower,
498 factor * upper);
499 }
500
501 template <class DeviceType>
502 KOKKOS_INLINE_FUNCTION F
503 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
504 {
505 return inverse_factor * dist.draw(gen);
506 }
507
508 KOKKOS_INLINE_FUNCTION F
509 mean() const
510 {
511
512 return inverse_factor * dist.mean();
513 }
514
515 KOKKOS_INLINE_FUNCTION F
516 var() const
517 {
518 return (inverse_factor * inverse_factor) * dist.var();
519 }
520 KOKKOS_INLINE_FUNCTION F
521 skewness() const
522 {
523 return 0.;
524 }
525 };
526
534 template <FloatingPointType F> struct LogNormal
535 {
536 F mu; // Mean
537 F sigma; // Standard deviation
538
539 template <class DeviceType>
540 KOKKOS_INLINE_FUNCTION F
541 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
542 {
543 return Kokkos::exp(gen.normal(mu, sigma));
544 }
545
546 KOKKOS_INLINE_FUNCTION F
547 mean() const
548 {
549 return Kokkos::exp(mu + sigma * sigma / 2.);
550 }
551 KOKKOS_INLINE_FUNCTION F
552 var() const
553 {
554 const auto sigma2 = sigma * sigma;
555 return (Kokkos::exp(sigma2) - 1.) * Kokkos::exp(2. * mu + sigma2);
556 }
557 KOKKOS_INLINE_FUNCTION F
558 skewness() const
559 {
560 const auto sigma2 = sigma * sigma;
561 return (Kokkos::exp(sigma2) + 2.)
562 * Kokkos::sqrt(Kokkos::exp(sigma2) - 1.);
563 }
564 };
565
574 template <FloatingPointType F> struct SkewNormal
575 {
576 F xi; // Mean
577 F omega; // Standard deviation
579
580 template <class DeviceType>
581 KOKKOS_INLINE_FUNCTION F
582 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
583 {
584
585 const double Z0 = gen.normal();
586 const double Z1 = gen.normal();
587 const double delta = alpha / Kokkos::sqrt(1. + alpha * alpha);
588 const double scale_factor = Kokkos::sqrt(1. + delta * delta);
589 const double X = (Z0 + delta * Kokkos::abs(Z1)) / scale_factor;
590 return xi + omega * X;
591 }
592 KOKKOS_INLINE_FUNCTION F
593 mean() const
594 {
595 return xi
596 + omega * (alpha / (Kokkos::sqrt(1 + alpha * alpha)))
597 * Kokkos::sqrt(2 / M_PI);
598 }
599 KOKKOS_INLINE_FUNCTION F
600 var() const
601 {
602 const auto delta = alpha / (Kokkos::sqrt(1 + alpha * alpha));
603 return omega * omega * (1 - 2 * delta * delta / M_PI);
604 }
605 KOKKOS_INLINE_FUNCTION F
606 skewness() const
607 {
608 const auto delta = alpha / (Kokkos::sqrt(1 + alpha * alpha));
609 return ((4 - M_PI) * Kokkos::pow(delta * std::sqrt(2 / M_PI), 3))
610 / Kokkos::pow(1 - 2 * delta * delta / M_PI, 1.5);
611 }
612 };
613
614 static_assert(ProbabilityLaw<SkewNormal<float>, float, ComputeSpace>);
615
623 template <FloatingPointType F> struct Exponential
624 {
626
627 static constexpr bool use_kokkos_log = true;
628
629 template <class DeviceType>
630 KOKKOS_INLINE_FUNCTION F
631 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
632 {
633 const float rnd = gen.frand();
634 return F(-1) * CommonMaths::_ln<use_kokkos_log>(rnd) / lambda;
635 }
636
637 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
638 mean() const
639 {
640 return F(1) / lambda;
641 }
642 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
643 var() const
644 {
645
646 return F(1) / (lambda * lambda);
647 }
648 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
649 skewness() const
650 {
651 return 2;
652 }
653 };
654
655 static_assert(ProbabilityLaw<Exponential<float>, float, ComputeSpace>);
656
657} // namespace MC::Distributions
658
659#endif
Definition traits.hpp:20
Concept for probability distribution laws.
Definition prng_extension.hpp:52
KOKKOS_INLINE_FUNCTION float _ln(float x)
Definition maths.hpp:11
Kokkos compatible method to draw from specific probability distribution.
Definition prng_extension.hpp:18
KOKKOS_INLINE_FUNCTION F erfinv(F x)
Computes an approximation of the inverse error function.
Definition prng_extension.hpp:90
KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
Computes the inverse CDF (probit function) of a normal distribution.
Definition prng_extension.hpp:146
KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
Computes the standard normal probability density function (PDF).
Definition prng_extension.hpp:173
KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
Computes the standard normal cumulative distribution function (CDF).
Definition prng_extension.hpp:197
Kokkos::DefaultExecutionSpace ComputeSpace
Definition alias.hpp:22
Represents a Exponential probability distribution.
Definition prng_extension.hpp:624
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:649
static constexpr bool use_kokkos_log
Definition prng_extension.hpp:627
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:638
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:643
F lambda
Definition prng_extension.hpp:625
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:631
Represents a LogNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:535
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:541
F sigma
Definition prng_extension.hpp:537
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:547
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:558
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:552
F mu
Definition prng_extension.hpp:536
Represents a normal (Gaussian) probability distribution.
Definition prng_extension.hpp:279
F mu
Mean.
Definition prng_extension.hpp:280
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:291
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:306
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:316
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:346
KOKKOS_INLINE_FUNCTION F stddev() const
Returns the standard deviation of the distribution.
Definition prng_extension.hpp:336
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:326
F sigma
Standard deviation.
Definition prng_extension.hpp:281
F inverse_factor
Definition prng_extension.hpp:471
F scale_factor
Definition prng_extension.hpp:470
TruncatedNormal< F > dist
Definition prng_extension.hpp:472
constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
Definition prng_extension.hpp:475
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:521
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:516
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F factor, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:486
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:503
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:509
Represents a SkewNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:575
F alpha
Definition prng_extension.hpp:578
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:606
F omega
Definition prng_extension.hpp:577
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:600
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:593
F xi
Definition prng_extension.hpp:576
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:582
Represents a TruncatedNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:362
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:428
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:449
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:382
KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
Definition prng_extension.hpp:372
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:437
F upper
Definition prng_extension.hpp:367
F lower
Definition prng_extension.hpp:366
F sigma
Definition prng_extension.hpp:365
F mu
Definition prng_extension.hpp:364
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:389
Definition prng_extension.hpp:207
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:244
F max
Max.
Definition prng_extension.hpp:209
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:254
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:219
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F min, F max)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:234
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:264
F min
min
Definition prng_extension.hpp:208