Keegan Hines

On distributed Gibbs sampling with Spark

Over the last few months, I'm becoming an ever-increasing fan of Apache Spark. I find it to be an easy-to-use system for distributed computing and processing large data sets. I'm also a fan of interesting Bayesian models and MCMC methods. Here, I'll describe how simple it is to express some generic Gibbs sampling components in the language of map(), reduce(), and the rest of the gang. As an example, I'll take a look at a simple Bayesian latent variable model and show how we can construct a Gibbs sampler with Spark in order to easily process terabytes of data. This model, while simple, will have many of basic components of more interesting latent variable models (such as LDA or HMM), and the patterns will be easily generalized. (For a gentler introduction to Bayesian methods and Gibbs sampling, see this primer.)

So let's think about a simple latent variable model - a mixture of Exponential distributions with a finite number of components. That is, for each datapoint \(y_i\), we imagine that \(y_i\) was drawn from one of \(K\) distinct Exponential distributions, each with distinct time scale parameter,

\[ y_i \sim w_1 \theta_1 exp(-\theta_1 y) + w_1 \theta_2 exp(-\theta_2 y) + ... + w_K \theta_K exp(-\theta_K y). \]

Without loss of generatlity, let's consider for now just a two-component model. This leaves us with the following generative model,

\[ y_i \sim w_1 \theta_1 exp(-\theta_1 y) + w_1 \theta_2 exp(-\theta_2 y). \]

Thus, for each model component, we're only concerned with estimating a single parameter, \(\theta_k \), for that component as well as the relative weight of that component, \(w_k\). This model is intentionally simple (impractically so), but has the same basic guts of many latent variable models. Thus, given some data \( \{y_1,y_2,...\} \) denoted \(y_N\), our goal is to estimate the model parameters \( \{w_1,w_2,\theta_1,\theta_2\} \). In the Bayesian approach, we're interested in the posterior distribution \(p(\theta_1, \theta_2, w_1, w_2| y_N ) \).

As is typical in Bayesian mixture models, we will proceed with the trick of data augmentation (a misnomer, as we'll augment the parameters and not the data). We'll add to the model a latent indicator variable, \( s_i\), for each datapoint \(y_i \). This indicator (or label) simply points to one of the components, \( s_i \in \{1,2 \}\), to specify which component likely generated that particular data point.

Our posterior now has many parameters

\[ p(\theta_1, \theta_2, w_1, w_2 s_1,s_2,...,s_N | y_N ), \]

but in the course of Gibbs sampling, we will marginalize out the \(s_i\),

\[ \int p(\theta_1, \theta_2 , w_1, w_2, s_1,s_2,...,s_N | y_N ) ds_1... ds_N = p(\theta_1, \theta_2, w_1, w_2 | y_N ) , \]

thus resulting in our original posterior. Now that we have the latent indicators, let \(A_j \) denote the set of all \(i\) where \(s_i = j \): the set \(A_j \) groups the datapoints together based on their current latent label \(s_i \). We can then arrive at a collapsed Gibbs sampling scheme where we really only need to think about sampling from the three types of parameters,

\[ p(\theta_1, \theta_2|s_i,...,s_N,w_1,w_2,y_N) = ? ,\\ p(w_1, w_2|\mu_1, \mu_2,s_1,...s_N, y_N) = ?, \\ p(s_i|\mu_1, \mu_2,w_1,w_2, y_N) = ? \]

Skipping lots of steps here, but proceeding in the typical fashion, we'll use a Gamma prior on the \( \theta_k\), \(p(\theta_k)=Ga(a,b) \), and a flat prior on the \(s_i \) so we arrive at these conditional posteriors,

\[ \theta_k|s_i,...,s_N,y_N \sim Ga \left( a + |A_k| , b + \sum_{i \in A_j} y_i \right ) ,\\ w_1, w_2|... \sim Dir(|A_1|, |A_2| ), \\ s_i|... = Cat(p_1,p_2). \]

The latent indicators \(s_i\) are sampled from a Categorical distribution with a parameter vector whose elements are the likelihood of observing \(y_i\) under each component. For example, the first component would yield a likliehood of \(L_1 = \theta_1 w_1 exp(- \theta_1 y_i) \) and the second would yield \(L_2 = \theta_2 w_2 exp(- \theta_2 y_i) \). The likelihoods are then normalized the produce the elements of the probability vector, \(p_1 = \frac{L_1}{L_1 + L_2} \) and \(p_1 = \frac{L_2}{L_1 + L_2} \).

The \(w_k\) and the \(s_i \) are closely related through the conjugacy of the Dirichlet distribution and the Categorical distribution, so for simplicity we will ignore the \(w_k \). Thus, we arrive at this simplified two-part collapsed Gibbs sampler,

\[ \theta_k \sim Ga \left( a + |A_k| , b + \sum_{i \in A_k} y_i \right) ,\\ s_i \sim Cat(p_1,p_2). \]

All this is to say that a Gibbs sampler for our latent variable model can proceed in two basic parts - we resample the parameters of component ( \( \theta_k \)) given the current data labels, then we resample the data labels (\(s_i\)) given the current model parameters. This pattern will hold true in much more complex latent variable models as well. So our whole Gibbs sampler might be implemented like this pseudocode,

  for loop:
    resample the parameters given the labels
    resample the labels given the parameters

More interestingly, if we look at the requirements for each piece of that for loop, we'll notice that each components can be calculated fairly easily in a distributed computing context.

In particular, let's look first at the resampling for the data labels \(s_i \). \[ s_i \sim Cat(p_1,p_2), \] where \(p_1\) is relative probably of observing \(y_i\) as drawn from component 1 and \(p_2\) is relative probably of observing \(y_i\) as drawn from component 2. Calculating these can proceed simply from the current estimate of \(\theta_1\) and \(\theta_2\) - the only information we need are the values of \(\theta_1, \theta_2, y_i\). Then we can resample \(s_i\) for the \(y_i\). This makes clear that the resampling of \(s_i\) is independent of all other datapoints besides \(y_i\). When viewed in this way, we see that step of the Gibbs sampler is embarassingly parallel and can be computed with a simple map() operation.

Next, let's look at the resampling for the model parameters \(\theta_k \), \[ \theta_k \sim Ga \left( a + |A_k| , b + \sum_{i \in A_k} y_i \right) . \] Aside from some normalizing constants, notice that this posterior depends on a sum over a subset of the data, \( \sum_{i \in A_k} y_i \). So for each \(k\), to estimate \(\theta_k \) we only need information from the \(y_i\) where \(i \in A_k\). And from that set of \(y_i\), we can easily extract the sufficient statistics required for the computation - just their sum and the size of the set in this case. In the Spark mindset, you might notice two things. First, breaking up the data into groups according to their label \(s_i\) can be done with a groupBy(). Second, within each group, the sufficient statistics can be computed with a simple sum or count, both of which form a commutative monoid, and thus can easily be parallelized and implemented with a reduce(). Most succinctly, both of these actions can be acheived with Spark's reduceByKey(). Then, given the sufficient statistics for each, we can easily draw a new posterior sample, and update our estimates of the parameters \( w_k\) and \(\theta_k\) for each group.

Now we have everything we need to do proper Gibbs sampling with Spark's map(), reduceByKey() and so on, and can employ this kind of Bayesian latent variable model on large data sets with billions of observations. This particular example is indeed simple, but demonstrates the basic structure of latent variable models that allows the parameter sampling to be done in a distributed way. To finish off this Exponential Mixture Model example, below is a PySpark application that computes just what I've described.

# Distributed Gibbs Sampling with Apache Spark
# Keegan Hines 07/23/15
# dependencies 
import numpy as np
import numpy.random as r
from collections import namedtuple
import itertools
from pyspark import SparkConf, SparkContext

# closures
def componentSample(datapoint, params):
    def likelihood(datapoint, theta):
        return theta * np.exp(-(datapoint * theta))
    # compute likelihood of the datapoint under each component    
    for i in range(len(params)):
        L.append(likelihood(datapoint, params[i]))
    # normalize
    L = L/sum(L)
    draw = r.choice(range(len(params)),1,p=L)
    return (draw[0], datapoint)

def ExponentialMixture(theRDD, k, iterations=10):
    # create a simple holding class for the parameters of the exponential distributions
    Parameters = namedtuple("Parameters", tuple(["theta_"+str(i) for i in range(k)]))

    # main Gibbs sampling loop
    for i in range(iterations):
        # estimate the thetas given the latent labels
        # do the reduceByKey to  get the sufficient statistics
        reduction = (k,v): (k, (v,1))).reduceByKey(lambda a,b: (a[0] + b[0], a[1] + b[1])).collect()
        # from the sufficient statistics, draw new posterior samples
        thetas = map(lambda x: r.gamma(x[1][1], 1/x[1][0]), reduction)
        #broadcast the new parameter estimates
        params = sc.broadcast(Parameters(*thetas))
        # estimate the latent labels given the thetas
        theRDD = s: componentSample(s[1], params))

if __name__ == "__main__":
    conf = SparkConf().setAppName("Distributed Gibbs Sampler")
    sc = SparkContext(conf=conf)
    fakeData = np.concatenate([r.exponential(500**-1,500), r.exponential(5**-1,300)])
    anRDD = sc.parallelize(zip(r.choice(2,len(fakeData))), fakeData)
    result = ExponentialMixture(anRDD, 2, iterations = 20)
    print "################"
    print result
    print "################"