Automatic Differentiation with RTMB
Andrew Johnson
2026-01-25
RTMB-Gradients.rmdIntroduction
Stan’s algorithms, including MCMC sampling (NUTS), optimization (L-BFGS), variational inference, Pathfinder, and Laplace approximation, are gradient-based methods. This means they require not only the log-probability function but also its gradient (the vector of partial derivatives with respect to each parameter) to work efficiently.
StanEstimators provides three ways to compute
gradients:
- Finite differences (default): Automatic but slow. Approximates gradients by evaluating the function at slightly perturbed parameter values.
- Analytical gradients: Fast and accurate, but requires you to manually derive and code the gradient function.
-
RTMB automatic differentiation: Fast and automatic.
Uses the
RTMBpackage to compute exact gradients via automatic differentiation (AD).
Installing RTMB
To use RTMB with StanEstimators, you need to install the
RTMB, withr, and future
packages:
install.packages(c("RTMB", "withr", "future"))Once installed, simply set grad_fun = "RTMB" in any
StanEstimators function to enable automatic
differentiation.
For basic usage of StanEstimators, see the Getting Started vignette.
Poisson Regression
Next, we’ll examine a generalized linear model (GLM) for count data. Poisson regression uses a log-link function: , where is the expected count.
Performance Comparison
inits_pois <- rep(0, 3)
# Finite differences
timing_pois_fd <- system.time({
fit_pois_fd <- stan_sample(poisson_loglik, inits_pois,
additional_args = list(y = y_pois, X = X),
num_chains = 1, seed = 1234)
})
# RTMB
timing_pois_rtmb <- system.time({
fit_pois_rtmb <- stan_sample(poisson_loglik, inits_pois,
grad_fun = "RTMB",
additional_args = list(y = y_pois, X = X),
num_chains = 1, seed = 1234)
})Results
timing_results_pois <- data.frame(
Method = c("Finite Differences", "RTMB"),
Time_seconds = c(timing_pois_fd[3], timing_pois_rtmb[3]),
Speedup = c(1, timing_pois_fd[3] / timing_pois_rtmb[3])
)
knitr::kable(timing_results_pois, digits = 2,
caption = "Performance comparison for Poisson regression")| Method | Time_seconds | Speedup | |
|---|---|---|---|
| Finite Differences | 10.05 | 1.0 | |
| elapsed | RTMB | 2.39 | 4.2 |
summary(fit_pois_rtmb)
#> # A tibble: 4 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -315. -314. 1.22 1.09 -317. -313. 1.00 453.
#> 2 pars[1] 0.543 0.544 0.0609 0.0609 0.436 0.637 1.01 482.
#> 3 pars[2] 1.15 1.15 0.0472 0.0488 1.08 1.23 1.01 515.
#> 4 pars[3] -0.734 -0.734 0.0515 0.0521 -0.817 -0.651 1.00 580.
#> # ℹ 1 more variable: ess_tail <dbl>RTMB handles the matrix operations and log-link function automatically, providing improved performance (typically 8-10x speedup) while correctly recovering the true parameter values.
Logistic Regression
Logistic regression models binary outcomes using a logit link: , where is the probability of success.
Performance Comparison
inits_logit <- rep(0, 3)
# Finite differences
timing_logit_fd <- system.time({
fit_logit_fd <- stan_sample(logistic_loglik, inits_logit,
additional_args = list(y = y_binom, X = X_logit),
num_chains = 1, seed = 1234)
})
# RTMB
timing_logit_rtmb <- system.time({
fit_logit_rtmb <- stan_sample(logistic_loglik, inits_logit,
grad_fun = "RTMB",
additional_args = list(y = y_binom, X = X_logit),
num_chains = 1, seed = 1234)
})Results
timing_results_logit <- data.frame(
Method = c("Finite Differences", "RTMB"),
Time_seconds = c(timing_logit_fd[3], timing_logit_rtmb[3]),
Speedup = c(1, timing_logit_fd[3] / timing_logit_rtmb[3])
)
knitr::kable(timing_results_logit, digits = 2,
caption = "Performance comparison for Logistic regression")| Method | Time_seconds | Speedup | |
|---|---|---|---|
| Finite Differences | 5.40 | 1.0 | |
| elapsed | RTMB | 0.91 | 5.9 |
summary(fit_logit_rtmb)
#> # A tibble: 4 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -152. -152. 1.31 1.09 -155. -151. 1.00 546.
#> 2 pars[1] 0.224 0.225 0.151 0.150 -0.0283 0.462 1.00 736.
#> 3 pars[2] 1.56 1.55 0.206 0.207 1.24 1.89 1.00 740.
#> 4 pars[3] -0.915 -0.920 0.176 0.174 -1.20 -0.614 1.00 600.
#> # ℹ 1 more variable: ess_tail <dbl>Gaussian Mixture Model
Mixture models represent complex latent structure and demonstrate RTMB’s benefits for challenging models. We’ll fit a two-component Gaussian mixture.
The model is:
Defining the Log-Likelihood
mixture_loglik <- function(pars, y) {
# Transform parameters to satisfy constraints
pi <- pars[1] # mixing proportion in [0,1]
mu1 <- pars[2]
mu2 <- pars[3]
sigma1 <- pars[4] # positive
sigma2 <- pars[5] # positive
# Log-likelihood for each component
log_lik1 <- dnorm(y, mu1, sigma1, log = TRUE) + log(pi)
log_lik2 <- dnorm(y, mu2, sigma2, log = TRUE) + log(1 - pi)
sum(log(exp(log_lik1) + exp(log_lik2)))
}Performance Comparison
# Initialize near true values (mixture models can have multimodality)
inits_mix <- c(0.3, -2, 3, 1, 1.5)
# Finite differences
timing_mix_fd <- system.time({
fit_mix_fd <- stan_sample(mixture_loglik, inits_mix,
lower = c(0, -Inf, -Inf, 0, 0),
upper = c(1, Inf, Inf, Inf, Inf),
additional_args = list(y = y_mix),
num_chains = 1, seed = 1234)
})
# RTMB
timing_mix_rtmb <- system.time({
fit_mix_rtmb <- stan_sample(mixture_loglik, inits_mix,
lower = c(0, -Inf, -Inf, 0, 0),
upper = c(1, Inf, Inf, Inf, Inf),
grad_fun = "RTMB",
additional_args = list(y = y_mix),
num_chains = 1, seed = 1234)
})Results
timing_results_mix <- data.frame(
Method = c("Finite Differences", "RTMB"),
Time_seconds = c(timing_mix_fd[3], timing_mix_rtmb[3]),
Speedup = c(1, timing_mix_fd[3] / timing_mix_rtmb[3])
)
knitr::kable(timing_results_mix, digits = 2,
caption = "Performance comparison for Gaussian Mixture")| Method | Time_seconds | Speedup | |
|---|---|---|---|
| Finite Differences | 14.73 | 1.00 | |
| elapsed | RTMB | 1.75 | 8.43 |
summary(fit_mix_rtmb)
#> # A tibble: 6 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -909. -909. 1.75 1.60 -913. -907. 1.00 441.
#> 2 pars[1] 0.294 0.293 0.0264 0.0248 0.251 0.338 1.01 776.
#> 3 pars[2] -2.05 -2.05 0.134 0.123 -2.27 -1.82 1.01 623.
#> 4 pars[3] 2.95 2.95 0.110 0.110 2.77 3.14 1.00 878.
#> 5 pars[4] 1.08 1.07 0.105 0.104 0.930 1.26 1.01 917.
#> 6 pars[5] 1.51 1.51 0.0872 0.0842 1.37 1.66 1.00 725.
#> # ℹ 1 more variable: ess_tail <dbl>Time Series: AR(1) Model
An autoregressive model of order 1 (AR(1)) captures temporal dependence: , where for stationarity.
Defining the Log-Likelihood
We use tanh() to constrain φ to (-1, 1).
ar1_loglik <- function(pars, y) {
phi <- pars[1] # constrain to (-1, 1)
sigma <-pars[2] # positive
n <- length(y)
# First observation from stationary distribution
ll <- dnorm(y[1], 0, sigma / sqrt(1 - phi^2), log = TRUE)
# Subsequent observations
for (t in 2:n) {
ll <- ll + dnorm(y[t], phi * y[t-1], sigma, log = TRUE)
}
ll
}Performance Comparison
inits_ar <- c(0.5, 1)
# Finite differences
timing_ar_fd <- system.time({
fit_ar_fd <- stan_sample(ar1_loglik, inits_ar,
lower = c(-1, 0),
upper = c(0, Inf),
additional_args = list(y = y_ar),
num_chains = 1, seed = 1234)
})
# RTMB
timing_ar_rtmb <- system.time({
fit_ar_rtmb <- stan_sample(ar1_loglik, inits_ar,
lower = c(-1, 0),
upper = c(0, Inf),
grad_fun = "RTMB",
additional_args = list(y = y_ar),
num_chains = 1, seed = 1234)
})Results
timing_results_ar <- data.frame(
Method = c("Finite Differences", "RTMB"),
Time_seconds = c(timing_ar_fd[3], timing_ar_rtmb[3]),
Speedup = c(1, timing_ar_fd[3] / timing_ar_rtmb[3])
)
knitr::kable(timing_results_ar, digits = 2,
caption = "Performance comparison for AR(1) model")| Method | Time_seconds | Speedup | |
|---|---|---|---|
| Finite Differences | 41.77 | 1.00 | |
| elapsed | RTMB | 1.03 | 40.51 |
summary(fit_ar_rtmb)
#> # A tibble: 3 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -373. -3.73e+2 0.990 0.767 -3.75e+2 -3.72e+2 0.999 453.
#> 2 pars[1] -0.00660 -4.55e-3 0.00660 0.00475 -1.97e-2 -3.52e-4 1.01 473.
#> 3 pars[2] 1.53 1.53e+0 0.0748 0.0743 1.41e+0 1.66e+0 1.00 540.
#> # ℹ 1 more variable: ess_tail <dbl>Quick Approximation with Pathfinder
RTMB also works with Pathfinder, Stan’s fast variational inference method:
fit_ar_path <- stan_pathfinder(ar1_loglik, inits_ar,
grad_fun = "RTMB",
additional_args = list(y = y_ar))
summary(fit_ar_path)
#> # A tibble: 5 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp_approx__ 3.20 3.50 1.01 0.781 1.18 4.19 1.00 739.
#> 2 lp__ -284. -284. 0.926 0.787 -286. -283. 1.00 758.
#> 3 path__ 2.46 2 1.14 1.48 1 4 2.70 1.19
#> 4 pars[1] 0.750 0.748 0.0467 0.0477 0.674 0.831 1.00 786.
#> 5 pars[2] 1.00 1.00 0.0506 0.0493 0.925 1.09 1.000 552.
#> # ℹ 1 more variable: ess_tail <dbl>