BioCMAMC-ST
prng_extension.hpp
1#ifndef __PRNG_EXTENSION_HPP__
2#define __PRNG_EXTENSION_HPP__
3
4#include "Kokkos_Macros.hpp"
5#include "common/maths.hpp"
6#include "mc/alias.hpp"
7#include <Kokkos_Core.hpp>
8#include <Kokkos_MathematicalConstants.hpp>
9#include <Kokkos_Random.hpp>
10#include <cmath>
11#include <common/traits.hpp>
12
18{
50 template <typename T, typename F, class DeviceType>
53 && requires(const T& obj,
54 Kokkos::Random_XorShift1024<DeviceType>& gen) {
55 { obj.draw(gen) } -> std::same_as<F>;
56 { obj.mean() } -> std::same_as<F>;
57 { obj.var() } -> std::same_as<F>;
58 { obj.skewness() } -> std::same_as<F>;
59 };
60
80 template <FloatingPointType F>
81 KOKKOS_INLINE_FUNCTION F
82 erfinv(F x)
83 {
84
85 // Use the Winitzki’s method to calculate get an approached expression of
86 // erf(x) and inverse it
87
88 constexpr F a = 0.147;
89 constexpr F inv_a = 1. / a;
90 constexpr F tmp = (2 / (M_PI * a));
91 const double ln1mx2 = Kokkos::log((1. - x) * (1. + x));
92 const F term1 = tmp + (0.5 * ln1mx2);
93 const F term2 = inv_a * ln1mx2;
94 return Kokkos::copysign(
95 Kokkos::sqrt(Kokkos::sqrt(term1 * term1 - term2) - term1), x);
96 }
97
98 // Inverse error function approximation (Using the rational approximation)
99 // template <FloatingPointType F> KOKKOS_INLINE_FUNCTION constexpr F
100 // erf_inv(F x)
101 // {
102 // constexpr F a[4] = {0.147, 0.147, 0.147, 0.147};
103 // constexpr F b[4] = {-1.0, 0.5, -0.5, 1.0};
104 // KOKKOS_ASSERT(x <= -1.0 || x >= 1.0);
105
106 // F z = (x < 0.0) ? -x : x;
107
108 // F t = 2.0 / (Kokkos::numbers::pi * 0.147) + 0.5 * Kokkos::log(1.0 - z);
109 // F result = a[0] * Kokkos::pow(t, b[0]) +
110 // a[1] * Kokkos::pow(t, b[1]) +
111 // a[2] * Kokkos::pow(t, b[2]) +
112 // a[3] * Kokkos::pow(t, b[3]);
113 // KOKKOS_ASSERT(Kokkos::isfinite(result));
114 // return result;
115 // }
116
136 template <FloatingPointType F>
137 KOKKOS_INLINE_FUNCTION F
138 norminv(F p, F mean, F stddev)
139 {
140 constexpr F erfinv_lb = -5;
141 constexpr F erfinv_up = 5;
142 auto clamped_inverse
143 = Kokkos::clamp(erfinv(2 * p - 1), erfinv_lb, erfinv_up);
144 KOKKOS_ASSERT(
145 Kokkos::isfinite(stddev * Kokkos::numbers::sqrt2 * clamped_inverse));
146 return mean + stddev * Kokkos::numbers::sqrt2 * clamped_inverse;
147 }
148
163 template <FloatingPointType F>
164 KOKKOS_INLINE_FUNCTION F
166 {
167 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
168 constexpr double inv_sqrt_2_pi = 0.3989422804014327; // 1/sqrt(2pi)
169 return inv_sqrt_2_pi * Kokkos::exp(-0.5 * x * x);
170 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
171 }
172
187 template <FloatingPointType F>
188 KOKKOS_INLINE_FUNCTION F
190 {
191 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
192 return 0.5 * (1 + Kokkos::erf(x / Kokkos::numbers::sqrt2));
193 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
194 }
195
196 /*DISTRIBUTIONS*/
197
198 template <FloatingPointType F> struct Uniform
199 {
200 F min;
201 F max;
202
209 template <class DeviceType>
210 KOKKOS_INLINE_FUNCTION F
211 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
212 {
213 return draw_from(gen, min, max);
214 }
215
224 template <class DeviceType>
225 static KOKKOS_INLINE_FUNCTION F
226 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen, F min, F max)
227 {
228 return gen.drand(min, max);
229 }
230
235 KOKKOS_INLINE_FUNCTION F
236 mean() const
237 {
238 return (min + max) / F(2);
239 }
240
245 KOKKOS_INLINE_FUNCTION F
246 var() const
247 {
248 return (max - min) * (max - min) / F(12);
249 }
250
255 KOKKOS_INLINE_FUNCTION F
256 skewness() const
257 {
258 return F(0);
259 }
260 };
261
270 template <FloatingPointType F> struct Normal
271 {
272 F mu;
274
281 template <class DeviceType>
282 KOKKOS_INLINE_FUNCTION F
283 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
284 {
285 return draw_from(gen, mu, sigma);
286 }
287
296 template <class DeviceType>
297 static KOKKOS_INLINE_FUNCTION F
298 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen, F mu, F sigma)
299 {
300 return gen.normal(mu, sigma);
301 }
302
307 KOKKOS_INLINE_FUNCTION F
308 mean() const
309 {
310 return mu;
311 }
312
317 KOKKOS_INLINE_FUNCTION F
318 var() const
319 {
320 return sigma * sigma;
321 }
322
327 KOKKOS_INLINE_FUNCTION F
328 stddev() const
329 {
330 return sigma;
331 }
332
337 KOKKOS_INLINE_FUNCTION F
338 skewness() const
339 {
340 return F(0);
341 }
342 };
343
353 template <FloatingPointType F> struct TruncatedNormal
354 {
355
356 F mu; // Mean
357 F sigma; // Standard deviation
358 F lower; // Standard deviation
359 F upper; // Standard deviation
360
361 TruncatedNormal() = default;
362
363 KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
364 : mu(m), sigma(s), lower(l), upper(u)
365 {
366 X_ASSERT(mu > lower);
367 X_ASSERT(mu < upper);
368 KOKKOS_ASSERT(mu > lower && mu < upper);
369 }
370
371 template <class DeviceType>
372 KOKKOS_INLINE_FUNCTION F
373 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
374 {
375 return draw_from(gen, mu, sigma, lower, upper);
376 }
377
378 template <class DeviceType>
379 static KOKKOS_INLINE_FUNCTION F
380 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen,
381 F mu,
382 F sigma,
383 F lower,
384 F upper)
385 {
386 // NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
387 const F rand = static_cast<F>(gen.drand());
388
389 // Max bounded because if sigma <<1 z -> Inf not wanted because
390 // erf/erfc/erfinv are not stable for extrem value Min bounded if
391 // |mu-bound| <<1 z -> 0 which is also not wanted for error function
392
393 F zl = Kokkos::clamp((lower - mu) / sigma,
394 F(-5e3),
395 F(0)); // upper-mu is by defintion <0
396 F zu = Kokkos::clamp((upper - mu) / sigma,
397 F(0),
398 F(5e3)); // upper-mu is by defintion >0
399
400 F pl = 0.5 * Kokkos::erfc(-zl / Kokkos::numbers::sqrt2);
401 KOKKOS_ASSERT(
402 Kokkos::isfinite(pl)
403 && "Truncated normal draw leads to Nan of Inf with given parameters");
404 F pu = 0.5 * Kokkos::erfc(-zu / Kokkos::numbers::sqrt2);
405 KOKKOS_ASSERT(
406 Kokkos::isfinite(pu)
407 && "Truncated normal draw leads to Nan of Inf with given parameters");
408 F p = rand * (pu - pl) + pl;
409 F x = norminv(p, mu, sigma);
410 KOKKOS_ASSERT(
411 Kokkos::isfinite(x)
412 && "Truncated normal draw leads to Nan of Inf with given parameters");
413 return x;
414
415 // NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
416 }
417
418 KOKKOS_INLINE_FUNCTION F
419 mean() const
420 {
421 const F alpha = (lower - mu) / sigma;
422 const F beta = (upper - mu) / sigma;
423 F Z = std_normal_cdf(beta) - std_normal_cdf(alpha);
424 return mu + sigma * (std_normal_pdf(alpha) - std_normal_pdf(beta)) / Z;
425 }
426
427 KOKKOS_INLINE_FUNCTION F
428 var() const
429 {
430 const F alpha = (lower - mu) / sigma;
431 const F beta = (upper - mu) / sigma;
432 F Z = std_normal_cdf(beta) - std_normal_cdf(alpha);
433 const auto phi_b = std_normal_pdf(beta);
434 const auto phi_a = std_normal_pdf(alpha);
435 const auto tmp = (phi_a - phi_b) / Z;
436 const auto tmp2 = (alpha * phi_a - beta * phi_b) / Z;
437 return sigma * sigma * (1 - tmp2 - tmp * tmp);
438 }
439 KOKKOS_INLINE_FUNCTION F
440 skewness() const
441 {
442 return 0.;
443 }
444 };
445
458 template <FloatingPointType F> struct ScaledTruncatedNormal
459 {
460
464
465 constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
467 dist(scale_factor * m,
468 s * scale_factor,
469 scale_factor * l,
470 scale_factor * u)
471 {
472 }
473
474 template <class DeviceType>
475 static KOKKOS_INLINE_FUNCTION F
476 draw_from(Kokkos::Random_XorShift1024<DeviceType>& gen,
477 F factor,
478 F mu,
479 F sigma,
480 F lower,
481 F upper)
482 {
483 return 1. / factor
485 factor * mu,
486 factor * sigma,
487 factor * lower,
488 factor * upper);
489 }
490
491 template <class DeviceType>
492 KOKKOS_INLINE_FUNCTION F
493 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
494 {
495 return inverse_factor * dist.draw(gen);
496 }
497
498 KOKKOS_INLINE_FUNCTION F
499 mean() const
500 {
501
502 return inverse_factor * dist.mean();
503 }
504
505 KOKKOS_INLINE_FUNCTION F
506 var() const
507 {
508 return (inverse_factor * inverse_factor) * dist.var();
509 }
510 KOKKOS_INLINE_FUNCTION F
511 skewness() const
512 {
513 return 0.;
514 }
515 };
516
524 template <FloatingPointType F> struct LogNormal
525 {
526 F mu; // Mean
527 F sigma; // Standard deviation
528
529 template <class DeviceType>
530 KOKKOS_INLINE_FUNCTION F
531 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
532 {
533 return Kokkos::exp(gen.normal(mu, sigma));
534 }
535
536 KOKKOS_INLINE_FUNCTION F
537 mean() const
538 {
539 return Kokkos::exp(mu + sigma * sigma / 2.);
540 }
541 KOKKOS_INLINE_FUNCTION F
542 var() const
543 {
544 const auto sigma2 = sigma * sigma;
545 return (Kokkos::exp(sigma2) - 1.) * Kokkos::exp(2. * mu + sigma2);
546 }
547 KOKKOS_INLINE_FUNCTION F
548 skewness() const
549 {
550 const auto sigma2 = sigma * sigma;
551 return (Kokkos::exp(sigma2) + 2.)
552 * Kokkos::sqrt(Kokkos::exp(sigma2) - 1.);
553 }
554 };
555
564 template <FloatingPointType F> struct SkewNormal
565 {
566 F xi; // Mean
567 F omega; // Standard deviation
569
570 template <class DeviceType>
571 KOKKOS_INLINE_FUNCTION F
572 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
573 {
574
575 const double Z0 = gen.normal();
576 const double Z1 = gen.normal();
577 const double delta = alpha / Kokkos::sqrt(1. + alpha * alpha);
578 const double scale_factor = Kokkos::sqrt(1. + delta * delta);
579 const double X = (Z0 + delta * Kokkos::abs(Z1)) / scale_factor;
580 return xi + omega * X;
581 }
582 KOKKOS_INLINE_FUNCTION F
583 mean() const
584 {
585 return xi
586 + omega * (alpha / (Kokkos::sqrt(1 + alpha * alpha)))
587 * Kokkos::sqrt(2 / M_PI);
588 }
589 KOKKOS_INLINE_FUNCTION F
590 var() const
591 {
592 const auto delta = alpha / (Kokkos::sqrt(1 + alpha * alpha));
593 return omega * omega * (1 - 2 * delta * delta / M_PI);
594 }
595 KOKKOS_INLINE_FUNCTION F
596 skewness() const
597 {
598 const auto delta = alpha / (Kokkos::sqrt(1 + alpha * alpha));
599 return ((4 - M_PI) * Kokkos::pow(delta * std::sqrt(2 / M_PI), 3))
600 / Kokkos::pow(1 - 2 * delta * delta / M_PI, 1.5);
601 }
602 };
603
604 static_assert(ProbabilityLaw<SkewNormal<float>, float, ComputeSpace>);
605
613 template <FloatingPointType F> struct Exponential
614 {
616
617 static constexpr bool use_kokkos_log = true;
618
619 template <class DeviceType>
620 KOKKOS_INLINE_FUNCTION F
621 draw(Kokkos::Random_XorShift1024<DeviceType>& gen) const
622 {
623 const float rnd = gen.frand();
624 return F(-1) * CommonMaths::_ln<use_kokkos_log>(rnd) / lambda;
625 }
626
627 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
628 mean() const
629 {
630 return F(1) / lambda;
631 }
632 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
633 var() const
634 {
635
636 return F(1) / (lambda * lambda);
637 }
638 [[nodiscard]] KOKKOS_INLINE_FUNCTION F
639 skewness() const
640 {
641 return 2;
642 }
643 };
644
645 static_assert(ProbabilityLaw<Exponential<float>, float, ComputeSpace>);
646
647} // namespace MC::Distributions
648
649#endif
Definition traits.hpp:20
Concept for probability distribution laws.
Definition prng_extension.hpp:52
KOKKOS_INLINE_FUNCTION float _ln(float x)
Definition maths.hpp:11
Kokkos compatible method to draw from specific probability distribution.
Definition prng_extension.hpp:18
KOKKOS_INLINE_FUNCTION F erfinv(F x)
Computes an approximation of the inverse error function.
Definition prng_extension.hpp:82
KOKKOS_INLINE_FUNCTION F norminv(F p, F mean, F stddev)
Computes the inverse CDF (probit function) of a normal distribution.
Definition prng_extension.hpp:138
KOKKOS_INLINE_FUNCTION F std_normal_pdf(F x)
Computes the standard normal probability density function (PDF).
Definition prng_extension.hpp:165
KOKKOS_INLINE_FUNCTION F std_normal_cdf(F x)
Computes the standard normal cumulative distribution function (CDF).
Definition prng_extension.hpp:189
Kokkos::DefaultExecutionSpace ComputeSpace
Definition alias.hpp:22
Represents a Exponential probability distribution.
Definition prng_extension.hpp:614
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:639
static constexpr bool use_kokkos_log
Definition prng_extension.hpp:617
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:628
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:633
F lambda
Definition prng_extension.hpp:615
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:621
Represents a LogNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:525
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:531
F sigma
Definition prng_extension.hpp:527
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:537
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:548
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:542
F mu
Definition prng_extension.hpp:526
Represents a normal (Gaussian) probability distribution.
Definition prng_extension.hpp:271
F mu
Mean.
Definition prng_extension.hpp:272
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:283
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:298
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:308
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:338
KOKKOS_INLINE_FUNCTION F stddev() const
Returns the standard deviation of the distribution.
Definition prng_extension.hpp:328
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:318
F sigma
Standard deviation.
Definition prng_extension.hpp:273
F inverse_factor
Definition prng_extension.hpp:462
F scale_factor
Definition prng_extension.hpp:461
TruncatedNormal< F > dist
Definition prng_extension.hpp:463
constexpr ScaledTruncatedNormal(F factor, F m, F s, F l, F u)
Definition prng_extension.hpp:465
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:511
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:506
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F factor, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:476
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:493
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:499
Represents a SkewNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:565
F alpha
Definition prng_extension.hpp:568
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:596
F omega
Definition prng_extension.hpp:567
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:590
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:583
F xi
Definition prng_extension.hpp:566
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:572
Represents a TruncatedNormal (Gaussian) probability distribution.
Definition prng_extension.hpp:354
KOKKOS_INLINE_FUNCTION F mean() const
Definition prng_extension.hpp:419
KOKKOS_INLINE_FUNCTION F skewness() const
Definition prng_extension.hpp:440
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Definition prng_extension.hpp:373
KOKKOS_INLINE_FUNCTION constexpr TruncatedNormal(F m, F s, F l, F u)
Definition prng_extension.hpp:363
KOKKOS_INLINE_FUNCTION F var() const
Definition prng_extension.hpp:428
F upper
Definition prng_extension.hpp:359
F lower
Definition prng_extension.hpp:358
F sigma
Definition prng_extension.hpp:357
F mu
Definition prng_extension.hpp:356
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F mu, F sigma, F lower, F upper)
Definition prng_extension.hpp:380
Definition prng_extension.hpp:199
KOKKOS_INLINE_FUNCTION F mean() const
Returns the mean of the distribution.
Definition prng_extension.hpp:236
F max
Max.
Definition prng_extension.hpp:201
KOKKOS_INLINE_FUNCTION F var() const
Returns the variance of the distribution.
Definition prng_extension.hpp:246
KOKKOS_INLINE_FUNCTION F draw(Kokkos::Random_XorShift1024< DeviceType > &gen) const
Draws a random sample from the distribution.
Definition prng_extension.hpp:211
static KOKKOS_INLINE_FUNCTION F draw_from(Kokkos::Random_XorShift1024< DeviceType > &gen, F min, F max)
Static method to draw a sample from N(μ, σ).
Definition prng_extension.hpp:226
KOKKOS_INLINE_FUNCTION F skewness() const
Returns the skewness of the distribution.
Definition prng_extension.hpp:256
F min
min
Definition prng_extension.hpp:200