qnorm(p, mean, sd, lower.tail, log.p)
Using R’s C Functions in Stan Models
Motivation
The R programming language provides robust, precise, and efficient implementations of many mathematical functions and distributions. While these can be used in Stan models by interfacing with an R session, this involves significant overhead and can result in slow sampling.
Given that a majority of R’s core functions and distributions are implemented in C, we can use Stan’s external C++ framework to call them directly, avoiding the need and overhead of an R session entirely.
R Function: Quantile Functions - log inputs
When evaluating quantile functions we often prefer to provide the input probability on the log-scale to avoid numerical issues with under- or overflow. For the present example, we will use R’s qnorm()
and qt()
functions to calculate the quantile function for a standard-normal distribution and a standard student-t distribution where the input probability is provided on the log scale:
The R signature for the qnorm()
function is:
Which maps directly to the underlying C implementation:
double qnorm(double, double, double, int, int);
Gradients
We will also need to calculate gradients for inputs if we are aiming to use the function with parameters
in a Stan model. The gradients for quantile function, with respect to the probability parameter, are given by the reciprocal of the density, and the adjustment for inputs on the log-scale are easily given by the chain rule:
\[ \frac{\text{d}}{\text{d}p}F^{-1}(e^p) \\ = \frac{\text{d}}{\text{d}x}F^{-1}(x) \cdot \frac{\text{d}}{\text{d}p}e^p \\ = \frac{e^p}{f(F^{-1}(e^p))} \]
This means that we can also use R’s corresponding density functions, dnorm()
and dt()
, to define the gradients
Stan - External C++
To use the density and quantile functions in our external C++, we simply include R’s math header and then define our functions as usual:
#include <stan/math.hpp>
#include <Rmath.h>
double qnorm_logp(double p, std::ostream* pstream__) {
return qnorm(p, 0, 1, 1, 1);
}
double qt_logp(double p, double df, std::ostream* pstream__) {
return qt(p, df, 1, 1);
}
::math::var qt_logp(stan::math::var p, double df, std::ostream* pstream__) {
stanreturn stan::math::make_callback_var(
(p.val(), df, pstream__),
qt_logp[p, df](auto& vi) mutable {
// Calculate gradient on log-scale for numerical stability
.adj() += vi.adj() * exp(p.val() - dt(vi.val(), df, 1));
p}
);
}
::math::var qnorm_logp(stan::math::var p, std::ostream* pstream__) {
stanreturn stan::math::make_callback_var(
(p.val(), pstream__),
qnorm_logp[p](auto& vi) mutable {
// Calculate gradient on log-scale for numerical stability
.adj() += vi.adj() * exp(p.val() - dnorm(vi.val(), 0, 1, 1));
p}
);
}
Stan - Stan Model
We will use the following (nonsensical) Stan model to test the values and gradients of the implementation:
functions {
real qnorm_logp(real logp);
real qt_logp(real logp, data real df);
}
data {
int use_normal;
}
parameters {
real<upper=0> log_p;
}
transformed parameters {
real qnorm_test = qnorm_logp(log_p);
real qt_test = qt_logp(log_p, 3);
}
model {
target += use_normal ? qnorm_test : qt_test;
}
Stan - Compilation & Linking
In order for the Stan model to be able to use the C functions from R, we need to provide additional flags to the compilation & linking of the model. Thankfully, R has built-in functions which return these flags:
= list(
cpp_options paste0("CPPFLAGS += -I", shQuote(R.home("include"))),
paste0("LDLIBS += -L", shQuote(R.home("lib")), " -lR")
)
We can then pass these directly to cmdstanr
, along with our model and external C++, for compilation:
<- cmdstanr::cmdstan_model("rmath.stan",
mod user_header = "ext_header.hpp",
stanc_options = list("allow-undefined"),
cpp_options = cpp_options,
force_recompile = TRUE)
Stan - Validation
To test our implementation, we can fit the model for a small number of iterations and check that the calculated quantiles are consistent with those returned by qnorm()
in R directly, which shows that the values match completely.
<- mod$sample(data = list(use_normal = 0), chains = 1,
fit iter_warmup = 50, iter_sampling = 50,
show_messages = FALSE,
show_exceptions = FALSE)
Warning: 11 of 50 (22.0%) transitions ended with a divergence.
See https://mc-stan.org/misc/warnings for details.
Warning: 1 of 1 chains had an E-BFMI less than 0.3.
See https://mc-stan.org/misc/warnings for details.
$draws(variables = c("log_p", "qnorm_test", "qt_test")) |>
fit::mutate_variables(qnorm_true = qnorm(log_p, log.p=TRUE),
posteriorqt_true = qt(log_p, df = 3, log.p=TRUE)) |>
::summarise_draws() posterior
# A tibble: 5 × 10
variable mean median sd mad q5 q95 rhat ess_bulk
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 log_p -1.16e-7 -1.09e-7 1.84e-8 0 -1.66e-7 -1.09e-7 1.12 5.48
2 qnorm_test 5.17e+0 5.18e+0 2.56e-2 0 5.10e+0 5.18e+0 1.12 5.48
3 qt_test 2.13e+2 2.16e+2 9.15e+0 0 1.88e+2 2.16e+2 1.12 5.48
4 qnorm_true 5.17e+0 5.18e+0 2.56e-2 0 5.10e+0 5.18e+0 1.12 5.48
5 qt_true 2.13e+2 2.16e+2 9.15e+0 0 1.88e+2 2.16e+2 1.12 5.48
# ℹ 1 more variable: ess_tail <dbl>
Next, to validate the specification of gradients, we can use cmdstanr
’s $diagnose()
method to check that our gradient calculations match those from finite-differencing, which also shows a great match:
$diagnose(data = list(use_normal = 1))$gradients() mod
param_idx value model finite_diff error
1 0 -0.99888 0.275729 0.275729 -3.99911e-11
$diagnose(data = list(use_normal = 0))$gradients() mod
param_idx value model finite_diff error
1 0 -0.833686 0.143222 0.143222 6.39204e-11