Single-Parameter Models | Pyro vs. STAN

Photo by Joey Csunyo on Unsplash

Modeling U.S. cancer-death rates with two Bayesian approaches: MCMC in STAN and SVI in Pyro.

Single parameter models are an excellent way to get started with the topic of probabilistic modeling. These models comprise of one parameter that influences our observation and which we can infer from the given data. In this article we look at the performance and compare two well established frameworks – the Statistical Language STAN and the Pyro Probabilistic Programming Language.

Kidney Cancer Data

One old and established dataset is the cases of kidney cancer in the U.S. from 1980-1989, which is available here (see [1]). Given are U.S. counties, their total population and the cases of reported cancer-deaths.
Our task is to infer the rate of death from the given data in a Bayesian way.
An elaborate walk-through of the task can be found in section 2.8 of “Bayesian Data Analysis 3” [1].
Our dataframe looks like this:

We have the total number of datapoints N (all rows in the dataframe) and per county we have the observed death-count (dc) which we refer to as y and the population (pop), which we call n later.
Given that we have epidemiological data, we think that the Poisson distribution gives a good basis for estimating the rate that we want to compute. Therefore our observations y are sampled from a Poisson distribution. The interesting parameter is the \lambda for our Poisson, which we call rate. This death-rate comes from a Gamma distribution (see explanations in Chapter 2 [1]). A very short explanation why we use a Gamma here is that it is a conjugate prior with respect to Poisson.
Though there is more to the special relationship that the distributions have.

Putting it together we arrive at the formulation for the observations of death-cases:

and for the wanted single parameter rate:

Now that we have seen the tasks lets fire up our tools and get to work.

STAN – the classical statistical models

The statistician’s well established work-horse is the stats language STAN.
You write your problem as model-code and STAN will compile an efficient C++ model under the hood. From a compiled model you can sample and perform inference.
The Framework offers different interfaces for example in Python (PyStan), R (RStan), Julia and others. For this article we will use PyStan so that we can have both models neatly side by side in a notebook.
We start to define the data that we have given in STAN. This means the integer N for the total size of the data set, y the observed deaths and n the population per county. So that we have each N times.

data {
   int N; // total observations
   int y[N]; // observed death-count
   vector[N] n; // population

We already know that a rate parameter like the death-rate \theta is bounded in [0, 1]. That means you cannot have a rate less than 0 or higher than 1.
This \theta parameter is then basis for the actual rate – that accounts for a counties population. Thus, since we transform our parameter we put it into the transformed parameters bracket.:

parameters {
   vector<lower=0,upper=1>[N] theta; // bounded deathrate estimate

transformed parameters {
   vector[N] rate=n .* theta;

Now to the magical inference part, which is what the posterior is all about. We sample our \theta from a Gamma distribution. The \alpha (=20) and \beta (=430000) are given in [1], but one could easily compute them from the underlying dataset. Notice, that the model actually takes the transformed parameter and not just the \theta. This will become apparent in the output later.

model {
   theta ~ gamma(20,430000);
   y ~ poisson(rate);

Now that we have a STAN Model as a string, we need the data in the right format. That means arrays of integers from the data, like so:

dc = data["dc"].to_numpy().astype(int)
pop = data['pop'].to_numpy().astype(int)
kidney_cancer_dat = {'N': N,
                    'y': dc[0:N],
                    'n': pop[0:N]}

Since everything is set-up we can use PyStan to convert our Model to C++ code and use sampling to perform inference. The inference procedure is Markov Chain Monte Carlo or MCMC. The necessary parameters are the amount of samples that we take (the iterations) and the amount of chains over which we draw the samples. Each chain is an ordered sequence of draws.

sm = pystan.StanModel(model_code=kidney_cancer_code)
fit = sm.sampling(data=kidney_cancer_dat, iter=5000, chains=4)

After the fitting is done, the first thing to check is if the model has converged. That means that \hat{R} is >= 1 or we can just ask the diagnostics tools:


### the below lines are output:
n_eff / iter looks reasonable for all parameters
Rhat looks reasonable for all parameters
0.0 of 10000 iterations ended with a divergence (0.0%)
0 of 10000 iterations saturated the maximum tree depth of 10 (0.0%)
E-BFMI indicated no pathological behavior

We can also see that the chains have converged (right plots) from the traces of the performed inference:

Fig. 1 – Traceplot of the \theta posterior computation. The chains on the right side display convergence. Each chain has a density function for theta (left side), which display agreement over the values.

The chains (Fig. 1 right) should look like “caterpillars madly in love” and your diagnostics look good. So we know that our model has converged. We can now see the inferred theta values and can look at the individual densities:

az.plot_density(fit, var_names=["theta"])
Fig. 2 – Posterior densities on the fitted single parameter \theta after inference.
Warning: Do not use the complete dataset as input! This will lead to errors,
we went with N=25 samples for testing purposes. Feel free to test how many samples it takes to break PyStan (or STAN).

For doing Bayesian regression task there are also different modules that let you skip the “write-an-elaborate-statistical-model” part. BRMS in R is one of those excellent tools.

Pyro – the probabilistic programming approach

Going to a completely different paradigm is Pyro. Instead of performing MCMC through sampling we treat our task as an optimization problem.
For this we formulate a model, which computes our posterior (p). Additionally we also formulate a so-called guide which gives us a parameterized distribution (q), which is used for the model-fitting procedure.
In a previous SVI article we had already looked at how this optimization works and recommend a short recap.

The highly unintuitive part to this is that instead of simply fitting one parameter we comprise it of four Pyro parameters.:

  1. Firstly, \alpha and \beta are constants that give our Gamma distribution its shape like we did in the STAN part,
  2. We add two boring trainable parameters p1 and p2, which get allow for optimization in the SVI steps. Both parameters are positive – hence constraint=constraints.positive ,
  3. When going through the data, the observations themselves are independent events.
    This is done through Pyro’s plate modeling, which also supports subsampling from our dataset. A nice introduction to this can be found in [2].
ϵ = 10e-3
def model(population, deathcount):
    α = pyro.param('α', torch.tensor(20.))
    β = pyro.param('β', torch.tensor(430000.))
    p1= pyro.param('p1', torch.ones(data.shape[0]), constraint=constraints.positive)
    p2 = pyro.param('p2', torch.ones(data.shape[0]), constraint=constraints.positive)
    with pyro.plate('data', data.shape[0], subsample_size=32) as idx:
        n_j = population[idx]
        y_j = deathcount[idx]
        α_j = α + p1[idx]*y_j
        β_j = β + p2[idx]*n_j
        θ_j = pyro.sample("θ", Gamma(α_j, β_j))
        λ_j = 10*n_j*θ_j + ϵ
        pyro.sample('obs', Poisson(λ_j), obs=y_j)

The model performs the sampling of the observations given the Poisson distribution. This is comparable to what STAN was doing, with the difference that the parameters which make up \lambda and the underlying distributions are now trainable Pyro parameters.
In order for our model to not go up in flames we have to add a small number \epsilon to the rate otherwise Poisson will be unstable:

λ_j = 10*n_j*θ_j + ϵ
This is not the best way to model this task, but it is closest to the STAN model. One can find a model and guide that do not rely on y and n for computing alpha and \beta.

Now the guide gives us a parameterized distribution q with which we can perform optimization on minimizing the Evidence Lower bound or ELBO. The parameters are trainable Pyro parameters, which we have already seen in the model.

def guide(population, deathcount):
    α = pyro.param('α', torch.tensor(20.))
    β = pyro.param('β', torch.tensor(430000.))
    p1 = pyro.param('p1', torch.ones(data.shape[0]), constraint=constraints.positive)
    p2 = pyro.param('p2', torch.ones(data.shape[0]), constraint=constraints.positive)
    with pyro.plate('data', data.shape[0], subsample_size=32) as idx:
        n_j = population[idx]
        y_j = deathcount[idx]
        α_j = α + p1[idx]*y_j
        β_j = β + p2[idx]*n_j
        θ_j = pyro.sample("θ", Gamma(α_j, β_j))

Now we can run our SVI like so:

svi = SVI(model, guide, Adam({'lr': 0.025}), JitTrace_ELBO(3))
for i in tqdm(range(1000)):
    loss = svi.step(population, deathcount)
    print(f"Loss: {loss}")

When we plot the loss, we can see that the model improves over time. The lower the bound the better. The loss is the ELBO trace over the iterations.

Fig. 3 – SVI loss over iteration. Loss is the ELBO returned from the SVI step over 1000 iterations.

Now for the final check we can see that we have fitted appropriate parameters.
The \alpha and \beta variables make up for a good underlying Gamma distribution for our posterior inference.

Conclusion – What approach is best?

STAN has a straight-forward way to reason about the model. If the mathematical formulation of a posterior is known, it can be very straight forward to implement a model.
However STAN and its MCMC sampling have their limitations. Under default configuration it was not possible to run our model over all of the data.
Pyro does an excellent job at handling larger datasets efficiently and performing Variational Inference. As a Probabilistic Programming Language it can be written just like any other Python code. The model plus guide are also informative and encapsulate the problem well. Through this we transform the posterior computation into an optimization task and get sensible outputs.

Now that you are familiar with single-parameter models have fun working with bigger, more complex tasks. The Pyro Examples are a good place to start.

Happy inferring!


[1] A. Gelman, J.B. Carlin, et. al., Bayesian Data Analysis . Third Edition.
[2] Pyro Documentation – SVI part II