We have an intractable posterior distribution \(p(x|D)\) that we wish to approximate with \(q(x)\), chosen from a given family of tractable distribution (eg, a gaussian).
We define \(\widetilde{p}(x) = p(x|D)p(D) = p(x,D)\) which is easier to compute pointwise, since we don’t need to compute the expensive \(p(D)\).
The goal is to approximate \(q\) with a cost function defined using the KL-divergence from Information Theory:
\[J(q) = KL(q\|\widetilde{p}) = \sum_x q(x) \log \frac{q(x)}{\widetilde{p}(x)}\]
or using an integral for continuous distributions.
We can implement KL divergence in R:
# KL for continuous functions
KL <- function(q, p_tilde, lower, upper, ...) {
f <- function(x) q(x, ...) * log(q(x, ...)/p_tilde(x))
integrate(f, lower, upper)$value
}
# an eg:
p_tilde <- function(x) dgamma(x, shape=3.0, scale=0.25)
q1 <- function(x) dlnorm(x, 0, 1)
q2 <- function(x) dlnorm(x, 0, .45)
curve(p_tilde, 0, 6, lwd=2, ylab="")
curve(q1, 0, 6, lwd=2, col="red", add=T)
curve(q2, 0, 6, lwd=2, col="green", add=T) # this one 'seems' closer...
KL(q1,p_tilde, lower=1e-3, upper=100)
## [1] 1.709245
KL(q2,p_tilde, lower=1e-3, upper=100) # ...and in fact, KL gives a smaller number
## [1] 0.3400462
Notice that this is not exactly a DL divergence, since \(\widetilde{p}\) is a non-normalized ‘distribution’.
To see that cost function \(J\) works as desired, let’s develop the equation
\[ \begin{array}{lcll} J(q) & = & \sum_x q(x) \log \frac{q(x)}{\widetilde{p}(x)} & \\ & = & \sum_x q(x) \log \frac{q(x)}{p(x|D)p(D)} & \\ & = & \sum_x q(x) \left( \log \frac{q(x)}{p(x|D)} - \log p(D) \right) & \\ & = & \sum_x q(x) \log \frac{q(x)}{p(x|D)} - \log p(D) & \color{blue}{ \sum_x q(x) = 1} \\ & = & KL(q\|p(x|D)) - \log p(D) & \end{array} \]
Since \(p(D)\) is a constant, it means that minimizing \(J(q)\) is minimizing \(KL(q\|p(x|D))\), and so \(q(x)\) will approach \(p(x|D)\).
optim
Here’s an eg where we want to approximate a gamma using log-normals:
variational_lnorm <- function(p_tilde, lower, upper) {
q <- dlnorm # in this eg, q is a log-normal
J <- function(params) {
KL(q, p_tilde, lower=lower, upper=upper, meanlog=params[1], sdlog=params[2])
}
optim(par=c(0, 1), fn=J)$par
}
# an eg:
p_tilde <- function(x) dgamma(x, shape=3.0, scale=0.25)
approximation_params <- variational_lnorm(p_tilde, lower=1e-3, upper=100)
# get the resulting approximation:
q <- function(x) dlnorm(x, approximation_params[1], approximation_params[2])
KL(q,p_tilde,1e-3,10) # compute their distance
## [1] 0.02765858
curve(p_tilde, 0, 6, lwd=2, ylab="", ylim=c(0,1.25))
curve(q, 0, 6, lwd=2, col="red", add=T)
We can simulate Laplace approximation (ie, gaussian approximation) of, say, a given beta:
variational_norm <- function(p_tilde, lower, upper) {
q <- dnorm # in this eg, q is a normal
J <- function(params) {
KL(q, p_tilde, lower=lower, upper=upper, mean=params[1], sd=params[2])
}
optim(par=c(0.5, 0.2), fn=J)$par # initial values are tricky, not very stable
}
p_tilde <- function(x) dbeta(x,11,9)
approximation_params <- variational_norm(p_tilde, lower=1e-3, upper=1-1e-3)
# get the resulting approximation:
q <- function(x) dnorm(x, mean=approximation_params[1], sd=approximation_params[2])
KL(q,p_tilde,0,1) # compute their distance
## [1] 0.004112233
curve(p_tilde, 0, 1, lwd=2, ylab="", ylim=c(0,4))
curve(q, 0, 1, lwd=2, col="red", add=T)
The next eg approximates beta distributions with Kumaraswamy distributions, which has simpler PDF function
\[f_\text{kumar}(x|a,b) = abx^{a-1}(1-x^a)^{b-1}\]
and has a closed CDF expression:
\[F_\text{kumar}(x|a,b) = 1-(1-x^a)^b\]
In this post, John Cook refers that, since the CDF is easy to invert, it’s simple to generate random samples from \(K(a,b)\) by generating \(u \sim U(0,1)\) and return
\[F^{-1} = (1-(1-u)^{1/b})^{1/a}\]
library(extraDistr)
n <- 1e4
u <- runif(n)
a <- 1/2
b <- 1/2
ku <- (1 - (1-u)^(1/b))^(1/a)
hist(ku, breaks=50, prob=T)
q <- function(x) dkumar(x,a,b)
curve(q, 0, 1, lwd=2, col="red", add=T)
Let’s do the approximation:
variational_kumaraswamy <- function(p_tilde, lower, upper) {
q <- dkumar
J <- function(params) {
# TODO: somehow the KL-divergence function outputs negative values (?)
# included abs() in order to work
abs(KL(q, p_tilde, lower=lower, upper=upper, a=params[1], b=params[2]))
}
optim(par=c(0.5, 0.5), fn=J, method="BFGS")$par
}
p_tilde <- function(x) dbeta(x,1/2,1/2)
approximation_params <- variational_kumaraswamy(p_tilde, lower=0.01, upper=0.99)
# get the resulting approximation:
q <- function(x) dkumar(x, approximation_params[1], approximation_params[2])
KL(q,p_tilde,0.01,0.99) # compute their distance
## [1] 1.240327e-16
curve(p_tilde, 0, 1, lwd=2, ylab="", ylim=c(0,4))
curve(q, 0, 1, lwd=2, col="red", add=T)
Another eg:
p_tilde <- function(x) dbeta(x,3,3)
approximation_params <- variational_kumaraswamy(p_tilde, lower=0.01, upper=0.99)
# get the resulting approximation:
q <- function(x) dkumar(x, approximation_params[1], approximation_params[2])
#q <- function(x) dkumar(x, 5, 251/40)
KL(q,p_tilde,0.01,0.99) # compute their distance
## [1] 0.001386222
curve(p_tilde, 0, 1, lwd=2, ylab="", ylim=c(0,4))
curve(q, 0, 1, lwd=2, col="red", add=T)
But this method, for these two distributions, is very unstable. It diverges with many parameter values…