Animating MCMC in Julia
Having fun with base plots and and MCMC.
One of my favorite aspects about Julia is how simple it is to build really high quality graphs and plots.
I have used base plotting and {ggplot2}
in R as well as matplotlib
in python for several years. Those packages are great, and I think Plots.jl
has a place among them.
It has several cool features out of the box that are super easy to use, and boasts a “time to first (beautiful) plot” that is competitive.
As the Julia data science community continues to grow, several other plotting libraries have emerged such as makie.jl and Algebra of Graphics.
I have yet to dive into those, but as with everything in Julia, they look super cool!
In this post I will demonstrate the power of Plots.jl
by making an animation of an MCMC sampler exploring the posterior distribuiton of a simple Normal-Normal model.
For this, we will need to load in several packages for stats, plotting, and performing the MCMC sampling.
We will use the Turing.jl
library to define our Bayesian model.
We need some data to fit our model. For this, we will simulate 100 observations from a normal distribution with mean 4 and standard deviation of 2.
The next thing we need is a model. In Turing.jl
this is extremely simple to do.
We simply define our priors and the likelihood distribution and we are good to go!
For this model, we are using a (considerably) uniformative Normal(0,5) prior for μ
and an Exponential(1) prior for σ
.
Next, we use the sample
method to perform inference. We will use the NUTS
implementation of MCMC over 2 chains of 100 draws each.
It is important to visualize the trace plots to ensure that the sample didn’t have issues exploring the posterior distribuiton. We are looking for furry caterpillar plots on the left, and Normal-ish plots on the right.
All looks good, now lets dive into a custom visualization of these chains.
Since our model only has two parameters, μ
and σ
, we can visualize the full posterior distribution on a standard two-dimensional plot.
I really like how Richard McElreath visualizes Hamiltonian Monte Carlo in his blog.
Let’s try and build a plot that looks similar to what he uses to display the different MCMC algorithms.
First, we will define a plot of the posterior domain for our model’s parameters. Ignore the legend for now, we will add lines for our chains later on.
Next we will build a 2D kernel density and add that to our plot. We can do this with the kde
function from the KernelDensity.jl
library.
All we have to do is extract the parameter traces from our chain object and feed them into that function and add it to our plot.
With an eye toward plotting sequential draws in a gif, lets first define the draws that we are going to plot. For this example, let’s plot the estimated posterior density on the 2nd draw from each chain.
The last element to add to the posterior plot are some lines indicating the path that the sampler is taking through the posterior space.
Our param_trace
object has three dimensions, draw, parameter, and chain. We can plot individual chains by indexing the third dimension of that object.
Lets add the path from the first to second draw for each chain to our plot.
Now that we have our posterior plot looking good, let’s add marginal density plots to the plot object. we can use the kde
function on each of the parameters from the trace.
Here I demonstrate how the julia pipe operator |>
can be used to iteratively transform your data. We take the paramater trace and transform it to a vector, then pass that through to the kde
function.
Note, because we want the sigma density to be rotated 90 degrees and displayed on the right side of our posterior plot, we need to pass the density to the x axis and the domain to the y axis in our plot.
Now we need to add these marginal plots to our posterior plot that we made earlier. To do this, we need to define a custom layout using the @layout
macro.
We want the μ_density_plot
on the top of the post_plot
, with on the σ_density_plot
on the right.
We can use the layout defined below to achieve this, along with using the orientation
keyword to indicate the the last plot (the sigma density) needs to oriented vertically.
Passing the series of plots into the plot
function along with the keyword arguments for layout, orientation, size, etc. produces our final plot.
Now that all of the pieces are constructed to build the full plot, we need to define a function that puts it all together. The arguments for our function will include the chain object, as well as the draw number and step size to plot.
Defining the animation is super simple. All we have have to do is define a loop that builds the plot frames and add the @animate
macro before it’s definition.
To build the gif, we simply pass in our animation object along with the desired frames per second and we are good to go!
And thats all there is to it. In the future I look forward to exploring some of the other cool plotting libraries in Julia.
Having fun with base plots and and MCMC.
An exploration of GAMs in Julia ‘from scratch’.
A review of some of the methods.