# Requirements
require(rstan)
require(tidyverse)
require(bayesplot)
require(ggplot2)
require(loo)
require(stats)
require(posterior)
require(ggpubr)
<- 999 # rng seed for MCMC
rng_seed_r <- 123 # rng seed for data simulation
rng_seed_stan set.seed(rng_seed_r)
# A function to get a data frame with param draws and chains ids
<- function(fit, pars) {
get_draws <- rstan::extract(fit, pars = pars, permute = FALSE)
draws_3d <- dim(draws_3d)[1]
n_draws <- dim(draws_3d)[2]
n_chains <- c()
draws_2d <- rep(1:n_chains, each = n_draws)
chain_id for (j in seq_len(n_chains)) {
<- rbind(draws_2d, draws_3d[,j,])
draws_2d
}<- cbind(as.factor(chain_id), draws_2d)
draws_2d colnames(draws_2d) <- c("chain_id", pars)
data.frame(draws_2d)
}
# Colored pairs plot of MCMC draws
<- function(df, x, y, color_by) {
scatter_colored <- aes_string(x = x, y = y, color = color_by)
aest <- ggplot(df, aest) + geom_point() + scale_color_viridis_c()
a return(a)
}
The required posterior
package is not in CRAN but can be installed via install.packages("posterior", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))
.
An ordinary differential equation (ODE)
\[\begin{equation} \frac{d \textbf{x}(t)}{dt} = f_{\theta}(\textbf{x}(t), t) \end{equation}\]
accompanied by an initial value \(\textbf{x}(t_0) = \textbf{x}_0\) is an example of an implicit function definition. If certain smoothness requirements for \(f_{\theta}\) are satisfied, there exists a unique solution for \(\textbf{x}(t)\), but it usually has no closed form. If \(\textbf{x}(t)\) needs to be evaluated at some time point \(t \neq t_0\), a numerical solver is needed.
It is common for Bayesian nonlinear ODE models that evaluating the likelihood requires solving \(\textbf{x}(t)\) numerically at several time points. When performing Bayesian inference for the parameters \(\theta\) (and possible other model parameters) using Stan, the system needs to be solved numerically on each log posterior probability evaluation. Furthermore, Stan needs to compute gradients for these solutions. Try as we might, these computations are often expensive and frequently become the limiting factors in whatever model they are involved in. The same problem can occur also with any other type of model that requires numerically solving an implicitly defined function or variable. Partial differential equation (PDE) models are an example.
The numerical methods for ODEs and PDEs usually involve some discretization in space/time, which affects the solver accuracy. Denser discretization means more accuracy but also more computation. Alternatively, methods can estimate their error and adapt their step size so that a given tolerance is achieved. The latter is what the built-in ODE solvers in Stan do.
The simplest things we might do to speed up our calculations are lowering the timestep, coarsening the discretization, or increasing the tolerance of the solvers. That immediately leaves us with the question, is this okay? Has changing the numerical method affected our parameter inference results? Was our original method giving correct inference results to begin with? Are the default tolerances in Stan suitable for the problem at hand?
The solution provided by a numerical method is always an approximation to the true solution \(\textbf{x}(t)\). This is why our posterior probability density evaluations are also approximations and the whole MCMC inference can be thought to be biased to some degree. However, we can think that the inference results are correct if making the numerical method more accurate does not affect the statistical properties of the posterior draws.
How can something like this be checked? The first problem is that it might not be computationally very attractive to run sampling repeatedly, gradually increasing the solver accuracy. If the model parameters are fixed, we can verify that the solution at all points in space/time is appropriately close to a more accurate reference solution. That isn’t so much of a problem in and of itself, but we are doing statistics, and so we need to know that the solution is accurate enough across all relevant parts of parameter space. Additionally, it is not known beforehand where the “relevant parts of parameter space” are!
The problem of validating the use of a numerical method for a Bayesian model is therefore significantly more complicated than in the classical numerical analysis world. The point of this case study is to show how by adding one additional tool, namely Pareto-Smoothed Importance Sampling (PSIS) (Yao et al. 2018; Vehtari et al. 2019), we can solve this problem.
Let \(M\) be the model for which we would like to perform inference, but which we cannot evaluate since the likelihood is defined implicitly through an ODE or PDE system that is not analytically tractable. MCMC inference for \(M\) can be seen actually as inference for another model \(M_{high}\), which is the same model as \(M\) but using a numerical solver, and can therefore be evaluated.
Our workflow addresses the problem of defining the high-precision numerical method in \(M_{high}\) so that \(M_{high}\) can trusted to have essentially the same posterior as \(M\). We define a way to perform inference for \(M_{high}\) without needing to compute gradients or HMC trajectories for it. This involves another model \(M_{low}\), which is again the same model, except that \(M_{low}\) uses a cheaper and less accurate numerical methods (or just looser tolerances and/or coarser discretization) to compute the required ODE or PDE solutions, and is therefore faster to fit. The posterior densities are denoted \(p_{low}\) and \(p_{high}\), respectively.
To understand how PSIS comes into play, we must first discuss importance sampling. If we want to compute expectations with the high precision model, we can take draws from the low precision models and reweight these according to the importance weights \(\frac{p_{high}}{p_{low}}\). If these models are too different, then the reweighting will produce noisy estimates that are not useful. PSIS and particularly the Pareto \(k\)-diagnostic (denoted \(\hat{k}\)), is the tool that tells us when we can or cannot rely on the importance weights. If \(\hat{k} < 0.5\) we are safe to do the importance sampling, if \(\hat{k} < 0.7\) the importance sampling will start to converge more slowly, and if \(\hat{k} > 0.7\) the importance sampling estimates are unreliable. For simplicity we will only consider the \(\hat{k} < 0.5\) threshold.
Ideally, \(M_{high}\) would involve a numerical method that we can trust completely in all parts of the parameter space so that, as long as \(\hat{k} < 0.5\), importance weights can be used to reweight the low precision approximation \(p_{low}\) to the high precision approximation \(p_{high}\). We can think of \(M_{high}\) as a reference model, because it is the baseline to which we compare. It is difficult in practice to know if a given model is a good reference model in all parts of parameter space, due to curse of dimensionality and the fact that analysed system can have different properties in different parts of the parameter space. For example, ODEs can qualitatively change their behaviour as a function of parameters (bifurcation), or become stiff or chaotic in some parameter regions. Accuracy can be checked at a given set of parameters fairly easily, but not over a high dimensional parameter space. Under these conditions it is necessary to compromise to develop a reference model that works only over a range of parameter space, but even then it is hard to know a priori what range that is.
We propose the following workflow:
The next two sections of this case study outline how to apply this workflow to do fast but reliable inference for
The importance sampling diagnostics are handled with the loo
package and the resampling is handled with the posterior
package.
Here we study a classic Susceptible-Infected-Recovered (SIR) model of disease spread. The code is adapted from that in the Stan Case Study (Grinsztajn et al. 2020), which provides an introduction to disease transmission modeling in Stan in general.
For the purposes of this case study, the goal is to use a very low precision ODE solver to do inference and then check it afterwards against a high precision solver. This is useful in practice if sampling with the high precision solver itself would take an inordinate amount of time.
The states of the ODE are amounts of susceptible (S), infected (I) and recovered (R) people. The dynamics are given by the ODE system:
\[\begin{align} \frac{dS}{dt} &= -\beta \cdot I \cdot \frac{S}{N_{pop}} \\ \frac{dI}{dt} &= \beta \cdot I \cdot \frac{S}{N_{pop}} - \gamma \cdot I \\ \frac{dR}{dt} &= \gamma \cdot I, \end{align}\] where \(N_{pop}\) is the population size. The parameters \(\beta\) and \(\gamma\) will be estimated from time series observations of the number of infected people (I).
<- stan_model("sir.stan") model
We print the entire Stan code for our model here.
cat(model@model_code)
## // sir.stan
##
## functions {
## // SIR system right-hand side
## real[] stan_sir(real t, real[] y, real[] theta, real[] x_r, int[] x_i) {
## real S = y[1];
## real I = y[2];
## real R = y[3];
## real M = x_i[1];
## real beta = theta[1];
## real gamma = theta[2];
## real dS_dt = -beta * I * S / M;
## real dI_dt = beta * I * S / M - gamma * I;
## real dR_dt = gamma * I;
## return { dS_dt, dI_dt, dR_dt };
## }
##
## // Solve the SIR system
## vector stan_solve_sir(data real[] ts, real[] theta,
## data real[] x_r,
## data real rtol, data real atol, data int max_num_steps) {
## int N = num_elements(ts);
## int M = 1000; // population size
## int I0 = 20; // number of infected on day 0
## int x_i[1] = { M }; // population size
## real y0[3] = { M - I0, I0, 0.0 }; // S, I, R on day 0
## real f[N, 3] = integrate_ode_rk45(stan_sir, y0, 0.0, ts, theta,
## x_r, x_i, rtol, atol, max_num_steps);
## return(to_vector(f[, 2]));
## }
## }
##
## data {
## int<lower=1> N; // Number of observations
## real t_data[N]; // Observation times
## int y_data[N]; // Counts of infected people
## real<lower=0.0> rtol;
## real<lower=0.0> atol;
## int<lower=1> max_num_steps;
## }
##
## transformed data {
## real x_r[0];
## }
##
## parameters {
## real<lower=0> beta;
## real<lower=0> gamma;
## real<lower=0> phi;
## }
##
## transformed parameters{
## vector[N] mu = stan_solve_sir(t_data, { beta, gamma },
## x_r, rtol, atol, max_num_steps);
## }
##
## model {
## beta ~ normal(2, 1);
## gamma ~ normal(0.4, 0.5);
## phi ~ lognormal(1, 1);
##
## // Add small positive number to solution to avoid negative numbers
## y_data ~ neg_binomial_2(mu + 2.0 * atol, phi);
## }
We will import a function from sir.stan
for solving the ODE. It uses the integrate_ode_rk45
function that is built into Stan. The actual function exposed from the Stan model (stan_solve_sir
) is a bit awkward so we rewrap it here in a way that is easier to use.
expose_stan_functions(model)
## Trying to compile a simple C file
## Running /Library/Frameworks/R.framework/Resources/bin/R CMD SHLIB foo.c
## clang -mmacosx-version-min=10.13 -I"/Library/Frameworks/R.framework/Resources/include" -DNDEBUG -I"/Library/Frameworks/R.framework/Versions/4.0/Resources/library/Rcpp/include/" -I"/Library/Frameworks/R.framework/Versions/4.0/Resources/library/RcppEigen/include/" -I"/Library/Frameworks/R.framework/Versions/4.0/Resources/library/RcppEigen/include/unsupported" -I"/Library/Frameworks/R.framework/Versions/4.0/Resources/library/BH/include" -I"/Library/Frameworks/R.framework/Versions/4.0/Resources/library/StanHeaders/include/src/" -I"/Library/Frameworks/R.framework/Versions/4.0/Resources/library/StanHeaders/include/" -I"/Library/Frameworks/R.framework/Versions/4.0/Resources/library/RcppParallel/include/" -I"/Library/Frameworks/R.framework/Versions/4.0/Resources/library/rstan/include" -DEIGEN_NO_DEBUG -DBOOST_DISABLE_ASSERTS -DBOOST_PENDING_INTEGER_LOG2_HPP -DSTAN_THREADS -DBOOST_NO_AUTO_PTR -include '/Library/Frameworks/R.framework/Versions/4.0/Resources/library/StanHeaders/include/stan/math/prim/mat/fun/Eigen.hpp' -D_REENTRANT -DRCPP_PARALLEL_USE_TBB=1 -I/usr/local/include -fPIC -Wall -g -O2 -c foo.c -o foo.o
## In file included from <built-in>:1:
## In file included from /Library/Frameworks/R.framework/Versions/4.0/Resources/library/StanHeaders/include/stan/math/prim/mat/fun/Eigen.hpp:13:
## In file included from /Library/Frameworks/R.framework/Versions/4.0/Resources/library/RcppEigen/include/Eigen/Dense:1:
## In file included from /Library/Frameworks/R.framework/Versions/4.0/Resources/library/RcppEigen/include/Eigen/Core:88:
## /Library/Frameworks/R.framework/Versions/4.0/Resources/library/RcppEigen/include/Eigen/src/Core/util/Macros.h:628:1: error: unknown type name 'namespace'
## namespace Eigen {
## ^
## /Library/Frameworks/R.framework/Versions/4.0/Resources/library/RcppEigen/include/Eigen/src/Core/util/Macros.h:628:16: error: expected ';' after top level declarator
## namespace Eigen {
## ^
## ;
## In file included from <built-in>:1:
## In file included from /Library/Frameworks/R.framework/Versions/4.0/Resources/library/StanHeaders/include/stan/math/prim/mat/fun/Eigen.hpp:13:
## In file included from /Library/Frameworks/R.framework/Versions/4.0/Resources/library/RcppEigen/include/Eigen/Dense:1:
## /Library/Frameworks/R.framework/Versions/4.0/Resources/library/RcppEigen/include/Eigen/Core:96:10: fatal error: 'complex' file not found
## #include <complex>
## ^~~~~~~~~
## 3 errors generated.
## make: *** [foo.o] Error 1
# Solve the SIR system
# - theta = c(beta, gamma), parameters
# - opts = c(rtol, atol, max_num_steps), solver options
# - ts = vector of output time points
<- function(ts, theta, opts) {
solve_sir stan_solve_sir(ts, theta, c(0), opts[1], opts[2], opts[3])
}
The RK45 solver in Stan is an adaptive time step solver, which estimates the local error of the solution and adapts its step size so that the local error estimates are less than atol + rtol * abs(y)
, where y
is the ODE solution, and atol
and rtol
are called absolute and relative tolerance, respectively. These tolerances need to be given, and affect both the accuracy and computational cost of the solution. In general, rtol
is the tolerance on the relative error the solver can make when y
is far from zero. When abs(y)
is small (of the order of atol
or smaller), there is no need to achieve the relative tolerance.
A third control parameter, max_num_steps
, determines the maximum number of steps that can be taken to achieve the tolerance. In practice, we have observed that setting this to a much smaller value than the default can lower the warmup times of some chains by several orders of magnitude. This can be because it possibly helps in rejecting or quickly getting out of initial low-probability parameter regions, where the ODE solutions blow up and achieving the tolerances would require a much larger number of steps than in the good parameter region.
We can just pick some options and quickly run and plot a solution just to get a feel for what the system looks like:
<- function(t, y) {
plot_sir %>%
y as_tibble() %>%
setNames(c("I")) %>%
mutate(Day = t) %>%
ggplot(aes(Day, I)) +
geom_line() +
geom_point() +
ylab("Infected people")
}
<- seq(0.1, 16, by = 0.1)
ts <- c(1, 0.2) # true parameter values
theta_true <- c(1e-4, 1e-4, 100)
opts <- solve_sir(ts, theta_true, opts)
ys plot_sir(ts, ys)
To test fitting our model we will create noisy measurements of the number of infected people (I) at each day. If we’re going to generate data from our model we better have an accurate ODE solver, otherwise we’re just generating data from some weird approximate model.
The simplest way to check that an atol
and rtol
are suitable is to do a solve at one tolerance level, repeat the solve at a much smaller (more precise) tolerance, and then look at the maximum absolute error at any output point. We will create a function to do this automatically:
<- function(theta, opts) {
check_reliability_sir <- seq(0.1, 16, by = 0.1)
ts <- solve_sir(ts, theta, opts)
y_hat <- c(opts[1] / 10, opts[2] / 10, opts[3])
opts_strict <- solve_sir(ts, theta, opts_strict)
y_hat_strict <- max(abs(y_hat - y_hat_strict))
max_abs_err return(max_abs_err)
}
We can study the maximum absolute error compared to a solution with 10 times smaller tolerances, as a function tol = atol = rtol
. Value of max_num_steps
is kept constant, but if a solver should fail to compute a solution in those steps, an error is thrown and it needs to be increased.
<- c()
mae_true <- 10^(-c(1:12))
tols for (tol in tols) {
<- c(tol, tol, 1e7)
opts <- c(mae_true, check_reliability_sir(theta_true, opts))
mae_true
}
qplot(tols, mae_true, geom = c("point", "line")) +
scale_x_log10() +
scale_y_log10() + ylab("Max. absolute error") + xlab("Tolerance")
From this and our prior knowledge of infectious diseases, we assert that \(10^{-6}\) is a good enough value to use for atol
and rtol
during simulation. Certainly we do not expect have a count of the infected population accurate to \(10^{-4}\) people.
We generate the observed number of infected people (cases) at each time point \(t_i\), from a negative binomial distribution with mean equal to the solution of \(S(t_i)\) from the ODE, and dispersion parameter \(\phi = 5\).
<- 1e-6
atol <- 1e-6
rtol <- c(atol, rtol, 1e7)
opts
<- 16 # number of data points
N <- seq(1, N)
t_data <- 5 # noise parameter for negative binomial
dispersion <- solve_sir(t_data, theta_true, opts)
mu <- rnbinom(length(t_data), mu = mu, size = dispersion)
y_data
tibble(t = t_data, mu = mu, y = y_data) %>%
ggplot() +
geom_line(aes(t, mu), col = "firebrick") +
geom_point(aes(t, y)) +
xlab("Day") +
ylab("Infected people") +
ggtitle("Simulated data as points \nUnderlying solution as lines")
We also define an R function that computes the likelihood given the data, parameter values and solver options.
# Likelihood function for the SIR model
# - t_data = vector of measurement times
# - y_data = vector of measurements of number of infected people
# - params = parameter vector c(beta, gamma, phi)
# - opts = c(rtol, atol, max_num_steps), solver options
<- function(t_data, y_data, params, opts) {
log_likelihood_sir <- params[1:2]
theta <- params[3]
phi <- solve_sir(t_data, theta, opts)
y_hat <- sum(dnbinom(y_data, size = phi, mu = y_hat, log = TRUE))
log_lh return(log_lh)
}
As a reminder, our mission in fitting this ODE is to use a low precision solver. It is always tempting to use low precision solvers when working with ODEs because they (usually) run faster. The difficulty becomes how to deal with the coarser approximation. Does the lower precision cause an effect that matters? If so, can it be corrected and how? These are the questions the workflow here will allow us to answer.
The first step in the workflow is to take any low precision approximation (\(M_{low}\)) and fit the data. Remember, all our numerical methods are approximations, and so we refer to this model specifically as a low precision model. We will check it against a higher precision model later. In this case, we will use rtol = 1e-4
, atol = 1e-3
.
<- c(1e-4, 1e-3, 100)
opts_low <- list(
stan_data N = length(t_data),
t_data = t_data,
y_data = y_data,
rtol = opts_low[1],
atol = opts_low[2],
max_num_steps = opts_low[3]
)
<- rstan::sampling(model,
fit1
stan_data,seed = rng_seed_stan,
cores = 4
)
<- c("beta", "gamma", "phi")
pars <- get_draws(fit1, pars)
draws1 print(fit1, pars = c("beta", "gamma", "phi"))
## Inference for Stan model: sir.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## beta 1.24 0.00 0.17 1.00 1.13 1.21 1.32 1.68 1686 1
## gamma 0.21 0.00 0.02 0.16 0.19 0.21 0.22 0.26 1888 1
## phi 5.37 0.04 2.01 2.29 3.95 5.10 6.50 10.04 2284 1
##
## Samples were drawn using NUTS(diag_e) at Fri Feb 26 22:26:57 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
Before we can check if the importance sampling correction is possible, we need to have a reference model (\(M_{high}\)) to compare against. That means we need a version of the model with tolerances such that it is suitably accurate across all the posterior draws generated from the low precision model.
In this case, rtol = atol = 1e-6
was accurate enough generating the data, so let’s check if it is accurate in all the draws in this posterior.
# Check reliability at each draw and add MAEs to data frame
<- function(df, opts) {
compute_errors_sir <- nrow(df)
num_draws <- c()
mae for (i in 1:num_draws) {
<- as.numeric(df[i, c("beta", "gamma")])
theta <- check_reliability_sir(theta, opts)
mae_i <- c(mae, mae_i)
mae
}$mae <- mae
dfreturn(df)
}
<- c(1e-6, 1e-6, 1e8)
opts_high <- c("beta", "gamma")
ode_pars <- compute_errors_sir(draws1, opts_high)
draws1
<- scatter_colored(draws1, "beta", "gamma", "mae")
p1 <- scatter_colored(draws1, "beta", "phi", "mae")
p2 <- scatter_colored(draws1, "gamma", "phi", "mae")
p3 ggarrange(p1, p2, p3)
We can plot this as a distribution and see that rtol = atol = 1e-6
keeps us under an absolute error of one milliperson. This seems accurate enough.
qplot(draws1$mae, geom = "histogram")
With the reference model in place, it is time to compute the importance weights \(\frac{p_{high}}{p_{low}}\) for each post-warmup draw. This is simple: just compute the log density of the reference model and the log density of the low precision model and take the difference (we work with the log of the importance ratios \(\log p_{high} - \log p_{low}\) for numeric stability).
The hidden downside is that it might take some time to compute the log densities of the reference model for each draw. It should still be way faster than sampling with the reference model itself, since we don’t need to compute gradients, HMC trajectories, evaluate proposals that can be rejected, and skip the whole warmup period. Therefore we likely won’t have to try to do accurate ODE solves in the ill-behaved parameter regions, where a huge number of steps would be needed to achieve the tolerances. Another upside is that the calculations could be done in parallel for each draw.
If the priors are kept the same between the reference and low precision model, then those can be left out of this calculation (they will cancel).
# Compute log likelihood ratio for each draw and add to data frame
<- function(df) {
log_ratios_sir <- nrow(df)
num_draws <- rep(0, num_draws)
log_lh_low <- rep(0, num_draws)
log_lh_high for (i in seq_len(num_draws)) {
<- as.numeric(df[i, c("beta", "gamma", "phi")])
params_i <- log_likelihood_sir(t_data, y_data, params_i, opts_low)
log_lh_low[i] <- log_likelihood_sir(t_data, y_data, params_i, opts_high)
log_lh_high[i]
}$log_ratio <- log_lh_high - log_lh_low
dfreturn(df)
}
<- log_ratios_sir(draws1)
draws1 <- scatter_colored(draws1, "beta", "gamma", "log_ratio")
p1 <- scatter_colored(draws1, "beta", "phi", "log_ratio")
p2 <- scatter_colored(draws1, "gamma", "phi", "log_ratio")
p3 ggarrange(p1, p2, p3)
We can plot the log importance ratios and see they are all close to zero (which means out approximation was not too bad).
qplot(draws1$log_ratios, geom = "histogram")
With the importance ratios calculated, we can check if they are usable or not with the PSIS \(\hat{k}\) diagnostic.
<- loo::relative_eff(x = exp(-draws1$log_ratio), draws1$chain_id)
r_eff1 <- loo::psis(draws1$log_ratio, r_eff = r_eff1)
psis1 print(psis1$diagnostics)
## $pareto_k
## [1] -0.01929457
##
## $n_eff
## [1] 2810.404
print(psis1)
## Computed from 4000 by 1 log-weights matrix
##
## All Pareto k estimates are good (k < 0.5).
## See help('pareto-k-diagnostic') for details.
\(\hat{k} < 0.5\), and so importance sampling should be reliable.
At this point we have a weighted set of posterior draws. It is usually easier to work with a set of draws than a set of weighted draws, so we can resample our weighted draws to become a set of unweighted draws using posterior::resample_draws
. The effective sample size will be slightly lowered by such a resampling, but unweighted draws are really more convenient to work with.
Just because it is possible to do an importance sampling correction on a set of draws does not mean that unweighted statistics on these draws are safe to use. In this case, the results are not much different, but it should not be forgotten:
<- exp(draws1$log_ratio)
w1 <- rstan::extract(fit1, c("beta", "gamma", "phi"))
draws1_list <- posterior::as_draws_df(draws1_list)
draws_df1 <- posterior::resample_draws(draws_df1, weights = w1)
resampled_df1
print(draws_df1 %>% posterior::summarize_draws())
## # A tibble: 3 x 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 beta 1.24 1.21 0.170 0.138 1.03 1.54 1.00 4144. 3698.
## 2 gamma 0.206 0.206 0.0241 0.0234 0.169 0.247 1.00 4246. 3692.
## 3 phi 5.37 5.10 2.01 1.87 2.58 8.98 1.00 3677. 3602.
print(resampled_df1 %>% posterior::summarize_draws())
## # A tibble: 3 x 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 beta 1.24 1.21 0.170 0.138 1.03 1.54 1.00 3972. 3560.
## 2 gamma 0.206 0.206 0.0241 0.0233 0.169 0.247 1.00 4088. 3619.
## 3 phi 5.38 5.10 2.02 1.87 2.58 9.00 1.00 3574. 3406.
In this example we consider the diffusion of heat (\(u(t, x)\)) in a rod (\(x \in [0, L]\)).
For the purposes of this case study, the goal is to use a PDE solver in a model that has no automatic error control, only fixed discretization controls. This problem comes up from time to time when problems demand a custom solver be written – it is not always easy to tack on error control algorithms.
In this hypothetical experiment, the rod is cooled to room temperature and then heated from the left side. After some time the temperature profile of the rod is measured and from this the thermal diffusivity \(K\) will be estimated.
The dynamics are governed by the 1D heat equation:
\[\begin{align} \frac{\partial u}{\partial t} &= K \cdot \frac{\partial^2 u}{\partial x^2} \\ u(0, x) &= 0 \\ u(t, 0) &= 1 \\ u(t, L) &= 0 \end{align}\]
All of the computations in this example are going to be done with a method of lines discretization of this problem and a backwards Euler integrator. The appropriate math is described in the online lecture notes ATM 623: Climate Modeling by Brian E. J. Rose, though any introductory PDE reference should suffice.
For convenience we have defined a Stan function that solves equations above and computes the measured temperatures in the system given a timestep, a spatial discretization, a hypothetical diffusivity, a measurement time, and a list of measurement points.
<- stan_model("diffusion.stan")
model expose_stan_functions(model)
cat(model@model_code)
## // diffusion.stan
##
## functions {
## // Solve a symmetric tridiagonal linear system Ax = d with constant
## // secondary diagonals
## //
## // a = the constant value of both secondary diagonals of matrix A
## // b = diagonal of matrix A
## vector stan_solve_tridiag_be(real a, vector b, vector d){
## int n = num_elements(b);
## vector[n] x = rep_vector(0.0, n);
## real w;
## int idx;
## vector[n] bb = b;
## vector[n] dd = d;
##
## // Forward sweep
## for (i in 2:n) {
## w = a / bb[i - 1];
## bb[i] = bb[i] - w*a;
## dd[i] = dd[i] - w*dd[i - 1];
## }
##
## // Back substitution
## x[n] = dd[n]/bb[n];
## idx = n - 1;
## while(idx > 0) {
## x[idx] = (dd[idx] - a*x[idx + 1]) / bb[idx];
## idx = idx - 1;
## }
## return(x);
## }
##
## // Backward Euler method for solving the 1D diffusion problem
## //
## // u_init, initial conditions
## // dt, timestep
## // dx, spatial discretization
## // T_max, max time
## // K, diffusion constant
## // ul, left boundary condition
## // ur, right boundary condition
## vector solve_pde(real dt, int Nx, real K, real T_meas, vector x_meas) {
## real L = 1.0; // length of rod
## real ul = 1.0; // left boundary condition
## real ur = 0.0; // right boundary condition
##
## real dx = L / (Nx + 1);
## real K_star = K * dt / (dx^2);
## real t = 0.0;
##
## vector[Nx] u = rep_vector(0, Nx);
## vector[Nx] u_prev = u;
##
## vector[rows(x_meas)] solution;
##
## // Create the diagonal of the tridiagonal matrix A
## vector[Nx] A_diag = rep_vector(1.0 + 2.0 * K_star, Nx);
##
## // Iterate time step
## while(t < T_meas) {
## vector[Nx] b = u;
## u_prev = u;
##
## b[1] += ul * K_star;
## b[Nx] += ur * K_star;
## // Update u and t
## u = stan_solve_tridiag_be(-K_star, A_diag, b);
## t = t + dt;
## }
##
## // Use linear interpolation to get solution at T_max not a multiple of dt
## if(T_meas < t) {
## real alpha = (t - T_meas) / dt;
## u = alpha * u_prev + (1.0 - alpha) * u;
## }
##
## // Use linear interpolation to get solution at measurement points
## {
## int i = 1;
## int j = 0;
## while(i <= rows(x_meas)) {
## if(x_meas[i] < 0.0) {
## solution[i] = ul;
## i += 1;
## } else if(x_meas[i] >= L) {
## solution[i] = ur;
## i += 1;
## } else if(j == 0) {
## if(x_meas[i] < dx) {
## real alpha = (dx - x_meas[i]) / dx;
## solution[i] = alpha * ul + (1.0 - alpha) * u[1];
## i += 1;
## } else {
## j += 1;
## }
## } else if(j + 1 == Nx + 1) {
## if(x_meas[i] < L) {
## real alpha = (L - x_meas[i]) / dx;
## solution[i] = alpha * u[Nx] + (1.0 - alpha) * ur;
## i += 1;
## } else {
## j += 1;
## }
## } else {
## if(x_meas[i] >= j * dx && x_meas[i] < (j + 1) * dx) {
## real alpha = ((j + 1) * dx - x_meas[i]) / dx;
## solution[i] = alpha * u[j] + (1.0 - alpha) * u[j + 1];
## i += 1;
## } else {
## j += 1;
## }
## }
## }
## }
##
## return(solution);
## }
## }
##
## data {
## real dt;
## int Nx;
## int N_meas;
## real T_meas;
## vector[N_meas] x_meas;
##
## vector[N_meas] y;
## }
##
## parameters {
## real<lower = 0.0> K; // diffusion constant
## real<lower = 0.0> sigma; // noise magnitude
## }
##
## transformed parameters {
## vector[N_meas] mu = solve_pde(dt, Nx, K, T_meas, x_meas);
## }
##
## model {
## sigma ~ normal(0, 1.0);
## K ~ normal(0, 1.0);
## y ~ normal(mu, sigma);
## }
<- 1.0
dt <- 10
Nx <- 1e-1
K <- 0.1
T_meas <- c(-1.0, 0.01, 0.5, 0.99, 1.0, 2.0)
x_meas
solve_pde(dt, Nx, K, T_meas, x_meas)
## [1] 1.0000000000 0.8982480469 0.0200404568 0.0002739004 0.0000000000
## [6] 0.0000000000
The function has the signature:
vector solve_pde(dt, Nx, K, T_meas, x_meas)
with arguments:
dt
- TimestepNx
- Number of interior points in spatial discretizationK
- Thermal diffusivityT_meas
- Measurement timex_meas
- Measurement pointsAssume a true thermal diffusivity \(K_{true} = 0.05\) and that we measure the temperature in the rod at Nx
points evenly spaced on the rod. We will generate data under these conditions and try to recover the diffusivity later.
First, let’s set up constants and plot a possible solution with measurement points:
<- 1e-1
dt <- 5
Nx <- 1.0
L
<- seq(0.0, L, length = 7)[2:6]
x_meas <- 1.0
T_meas <- 0.015
K_true
# For these calculations pretend we are measuring everywhere so we can
# see the whole solution
<- seq(-0.1, 1.1, length = 100)
x <- c(rep(1.0, sum(x <= 0.0)), rep(0.0, sum(x > 0.0)))
u0 <- solve_pde(dt, Nx, K_true, T_meas, x)
uT
# Solve at only the measurement points
<- solve_pde(dt, Nx, K_true, T_meas, x_meas)
mu
# Plot
tibble(x = x, `u(t = 0)` = u0, `u(t = T)` = uT) %>%
gather(Legend, u, -x) %>%
ggplot(aes(x, u)) +
geom_line(aes(color = Legend, group = Legend)) +
geom_point(data = tibble(x = x_meas, u = mu)) +
ggtitle("Measurement points indicated in black\nDashed lines indicate boundary of sample") +
geom_vline(aes(xintercept = 0.0), linetype = "dashed") +
geom_vline(aes(xintercept = L), linetype = "dashed")
The red line shows the initial conditions. Because the solution is actually discretized, to only five points on the rod (seven including the boundaries) we do linear interpolation to get the values in the intermediate points (which makes the boundary look a bit strange).
The teal points show the distribution of heat in the rod at time t = T
, where we plan to take measurements (indicated by the black dots) and make an inference about the unknown thermal diffusivity of the rod.
Now that we can compute solutions to this problem, our first question will be is a given solution accurate enough? The simple way to check this is by computing the solution again at a higher space/time resolution and checking the difference.
We can define a convenience function that for a given discretization and experimental configuration computes a solution and also another solution at higher precision and returns the maximum absolute error.
# Function to help determine if dt and Nx are small enough at given K
<- function(dt, Nx, K, T_meas, x_meas) {
check_reliability <- solve_pde(dt, Nx, K, T_meas, x_meas)
mu <- solve_pde(dt / 2.0, 2 * Nx, K, T_meas, x_meas)
mu_more_accurate <- max(abs(mu_more_accurate - mu))
max_abs_err return(max_abs_err)
}
# Check at K = K_true
check_reliability(dt, Nx, K_true, T_meas, x_meas)
## [1] 0.007415297
Is that error good or is that bad? That is something that will need to be determined in the context of the application. In this case we are going assume a measurement noise of \(0.1\), and so we should get our numerical error quite a bit below that.
<- 0.01
dt <- 40
Nx check_reliability(dt, Nx, K_true, T_meas, x_meas)
## [1] 0.001039788
This seems good enough for now, but you might further refine your solution. Now to simulate data:
<- 0.1
sigma <- rnorm(length(x_meas), 0, sigma)
noise <- solve_pde(dt, Nx, K_true, T_meas, x_meas) + noise y
Now that we have simulated data, it is time to do inference. The first step, similarly as for the ODE, is to fit an initial approximate model to our data. Again, all our calculations are approximations, and so we refer to this model as a low precision model because we will check it against a higher precision model later.
Assume we are very impatient and want this computation to finish quickly so that we use only one timestep and one spatial point of resolution in our discretization:
<- 1.0
dt_low <- 1
Nx_low <- list(
stan_data dt = dt_low,
Nx = Nx_low,
N_meas = length(x_meas),
T_meas = T_meas,
x_meas = x_meas,
y = y
)<- rstan::sampling(model,
fit
stan_data,control = list(adapt_delta = 0.95),
cores = 4,
seed = rng_seed_stan
)
Let us look at our results:
print(fit, pars = c("K", "sigma"))
## Inference for Stan model: diffusion.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## K 0.52 0.02 0.55 0.01 0.08 0.32 0.79 1.93 1309 1
## sigma 0.47 0.01 0.22 0.20 0.33 0.42 0.56 1.05 1042 1
##
## Samples were drawn using NUTS(diag_e) at Fri Feb 26 22:27:43 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
We remember from earlier that \(K_{true} = 0.1\), \(\sigma_{true} = 0.1\), so something is off. We will diagnose this using our approximation tools.
Again, to check if the importance sampling correction can be done, we need a reference model that works for all the posterior draws we got from the low precision model. We can develop the reference model using the same technique we did previously (guess a high precision, and check the maximum absolute error between that high precision model and one of even higher precision).
<- 0.01
dt_high <- 100
Nx_high
<- get_draws(fit, c("K", "sigma"))
draws <- draws$K
K_draws <- draws$sigma
sigma_draws <- nrow(draws)
num_draws
# Compute differences
<- c()
mae for (i in 1:num_draws) {
<- check_reliability(dt_high, Nx_high, K_draws[i], T_meas, x_meas)
mae_i <- c(mae, mae_i)
mae }
With a simple one parameter model we can plot our approximate errors as a function of K (so we know the solution is suitable everywhere).
<- function(K, mae) {
error_plot <-
ylab <- data.frame(K, mae)
df ggplot(df, aes_string(x = "K", y = "mae")) +
geom_point(col = "#1864de", alpha = 0.5) +
xlab("K") +
ylab("Max. absolute error")
}error_plot(K_draws, mae)
The errors here seem low enough.
The importance weights \(\frac{p_{high}}{p_{low}}\) are computed on the log scale. The priors cancel out so we only need to work with log likelihoods.
Again, this step looks simple in this example, but in practice it might be more complicated. It is possible that the reference calculation is done with an entirely different piece of software. For instance, with a PDE perhaps a the reference solution is computed with a well-tested FEM solver in a different software environment entirely.
# Define a function
<- function(
compute_log_ratios
dt_low, Nx_low, dt_high, Nx_high,
K_draws, T_meas, x_meas, y_meas) {<- rep(0, num_draws)
log_lh_low <- rep(0, num_draws)
log_lh_high for (i in seq_len(num_draws)) {
<- solve_pde(dt_low, Nx_low, K_draws[i], T_meas, x_meas)
mu_low <- solve_pde(dt_high, Nx_high, K_draws[i], T_meas, x_meas)
mu_high <- sum(dnorm(y_meas, mu_low, sigma_draws[i], log = TRUE))
log_lh_low[i] <- sum(dnorm(y_meas, mu_high, sigma_draws[i], log = TRUE))
log_lh_high[i]
}<- log_lh_high - log_lh_low
log_ratios return(log_ratios)
}
# Apply function
<- compute_log_ratios(
log_ratios
dt_low, Nx_low, dt_high, Nx_high,
K_draws, T_meas, x_meas, y )
If the \(\hat{k}\) diagnostic is not low enough, it is not possible to do the importance sampling correction and we need to recompute our posterior with a higher resolution model. The loo
package computes the \(\hat{k}\) diagnostic for us:
<- loo::relative_eff(x = exp(-log_ratios), chain_id = draws$chain_id)
r_eff <- loo::psis(log_ratios, r_eff = r_eff) psis2
## Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
print(psis2$diagnostics)
## $pareto_k
## [1] 1.149851
##
## $n_eff
## [1] 21.66079
print(psis2)
## Computed from 4000 by 1 log-weights matrix
## Pareto k diagnostic values:
## Count Pct. Min. n_eff
## (-Inf, 0.5] (good) 0 0.0% <NA>
## (0.5, 0.7] (ok) 0 0.0% <NA>
## (0.7, 1] (bad) 0 0.0% <NA>
## (1, Inf) (very bad) 1 100.0% 22
## See help('pareto-k-diagnostic') for details.
Oh no! \(\hat{k} > 0.5\), and it turns out modeling this process with one timestep and one spatial point was not a good idea. This means we need to up the precision in the low resolution model and go back to Step 1.
<- 0.1
dt_low <- 10
Nx_low <- list(
stan_data dt = dt_low,
Nx = Nx_low,
N_meas = length(x_meas),
T_meas = T_meas,
x_meas = x_meas,
y = y
)<- rstan::sampling(model,
fit
stan_data,control = list(adapt_delta = 0.95),
cores = 4,
seed = rng_seed_stan
)
Again, we can check our regular diagnostics:
print(fit, pars = c("K", "sigma"))
## Inference for Stan model: diffusion.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## K 0.07 0.01 0.23 0.00 0.01 0.01 0.02 0.83 423 1.01
## sigma 0.18 0.01 0.14 0.06 0.10 0.13 0.20 0.59 378 1.00
##
## Samples were drawn using NUTS(diag_e) at Fri Feb 26 22:27:51 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
Again, we verify our reference solution:
<- get_draws(fit, c("K", "sigma"))
draws <- draws$K
K_draws <- draws$sigma
sigma_draws <- nrow(draws)
num_draws
# Compute differences
<- c()
errors for (i in 1:num_draws) {
<- check_reliability(dt_high, Nx_high, K_draws[i], T_meas, x_meas)
mae <- c(errors, mae)
errors
}
# Plot
error_plot(K_draws, errors)
And again we can compute the importance ratios and run the PSIS diagnostics on them:
<- compute_log_ratios(
log_ratios
dt_low, Nx_low, dt_high, Nx_high,
K_draws, T_meas, x_meas, y
)
<- loo::relative_eff(x = exp(-log_ratios), chain_id = draws$chain_id)
r_eff <- loo::psis(log_ratios, r_eff = r_eff)
psis3 print(psis3$diagnostics)
## $pareto_k
## [1] -1.256378
##
## $n_eff
## [1] 1722.142
print(psis3)
## Computed from 4000 by 1 log-weights matrix
##
## All Pareto k estimates are good (k < 0.5).
## See help('pareto-k-diagnostic') for details.
And this time \(\hat{k} < 0.5\), so we are good enough!
At this point we have a weighted set of posterior draws. Again it is usually easier to work with a set of draws than a set of weighted draws, so we resample our weighted draws using posterior::resample_draws
.
<- exp(log_ratios)
w <- rstan::extract(fit, c("K", "sigma"))
draws_list <- posterior::as_draws_df(draws_list)
draws_df <- posterior::resample_draws(draws_df,
resampled_df weights = w
)
print(draws_df %>% posterior::summarize_draws())
## # A tibble: 2 x 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 K 0.0680 0.0135 0.230 0.00948 0.00341 0.340 1.00 3107. 3754.
## 2 sigma 0.179 0.132 0.141 0.0635 0.0688 0.474 1.00 4206. 3788.
print(resampled_df %>% posterior::summarize_draws())
## # A tibble: 2 x 10
## variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 K 0.0688 0.0135 0.232 0.00963 0.00349 0.343 1.00 2573. 3115.
## 2 sigma 0.179 0.132 0.141 0.0642 0.0691 0.472 1.00 2817. 2875.
And that is that! Happy approximating!
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS 10.16
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] ggpubr_0.4.0 posterior_0.1.3 loo_2.4.1
## [4] bayesplot_1.8.0 forcats_0.5.0 stringr_1.4.0
## [7] dplyr_1.0.3 purrr_0.3.4 readr_1.4.0
## [10] tidyr_1.1.2 tibble_3.0.5 tidyverse_1.3.0
## [13] rstan_2.21.2 ggplot2_3.3.3 StanHeaders_2.21.0-7
##
## loaded via a namespace (and not attached):
## [1] matrixStats_0.57.0 fs_1.5.0 lubridate_1.7.9.2
## [4] httr_1.4.2 tools_4.0.2 backports_1.2.1
## [7] utf8_1.1.4 R6_2.5.0 DBI_1.1.1
## [10] colorspace_2.0-0 withr_2.4.1 tidyselect_1.1.0
## [13] gridExtra_2.3 prettyunits_1.1.1 processx_3.4.5
## [16] curl_4.3 compiler_4.0.2 cli_2.2.0
## [19] rvest_0.3.6 xml2_1.3.2 labeling_0.4.2
## [22] scales_1.1.1 checkmate_2.0.0 ggridges_0.5.3
## [25] callr_3.5.1 digest_0.6.27 foreign_0.8-81
## [28] rmarkdown_2.6 rio_0.5.16 pkgconfig_2.0.3
## [31] htmltools_0.5.1 dbplyr_2.0.0 rlang_0.4.10
## [34] readxl_1.3.1 rstudioapi_0.13 farver_2.0.3
## [37] generics_0.1.0 jsonlite_1.7.2 zip_2.1.1
## [40] car_3.0-10 inline_0.3.17 magrittr_2.0.1
## [43] Matrix_1.3-2 Rcpp_1.0.6 munsell_0.5.0
## [46] fansi_0.4.2 abind_1.4-5 lifecycle_0.2.0
## [49] stringi_1.5.3 yaml_2.2.1 carData_3.0-4
## [52] pkgbuild_1.2.0 plyr_1.8.6 grid_4.0.2
## [55] parallel_4.0.2 crayon_1.3.4 lattice_0.20-41
## [58] cowplot_1.1.1 haven_2.3.1 BH_1.75.0-0
## [61] hms_1.0.0 knitr_1.30 ps_1.5.0
## [64] pillar_1.4.7 ggsignif_0.6.0 codetools_0.2-18
## [67] stats4_4.0.2 reprex_0.3.0 glue_1.4.2
## [70] evaluate_0.14 V8_3.4.0 data.table_1.13.6
## [73] RcppParallel_5.0.2 modelr_0.1.8 vctrs_0.3.6
## [76] cellranger_1.1.0 gtable_0.3.0 assertthat_0.2.1
## [79] xfun_0.20 openxlsx_4.2.3 broom_0.7.3
## [82] RcppEigen_0.3.3.9.1 rstatix_0.6.0 viridisLite_0.3.0
## [85] ellipsis_0.3.1
## Initial rng seed was 999
## Stan rng seed was 123
## In psis1, pareto_k was -0.01929
## In psis2, pareto_k was 1.14985
## In psis3, pareto_k was -1.25638