如何简单易懂地理解变分推断(variational inference)?

  正在学,把网上优质文章整理了一下。

  我们经常利用贝叶斯公式求posterior distribution P ( Z X ) P(Z | X)

P ( Z X ) = p ( X , Z ) z p ( X , Z = z ) d z P(Z | X)=\frac{p(X, Z)}{\int_{z} p(X, Z=z) d z}

  但posterior distribution P ( Z X ) P(Z | X) 求解用贝叶斯的方法是比较困难的,因为我们需要去计算 z p ( X = x , Z = z ) d z \int_{z} p(X=x, Z=z) d z ,而 Z Z 通常会是一个高维的随机变量,这个积分计算起来就非常困难。在贝叶斯统计中,所有的对于未知量的推断(inference)问题可以看做是对后验概率(posterior)的计算。因此提出了Variational Inference来计算posterior distribution

  那Variational Inference怎么做的呢?其核心思想主要包括两步:

  1. 假设一个分布 q ( z ; λ ) q(z ; \lambda) (这个分布是我们搞得定的,搞不定的就没意义了)
  2. 通过改变分布的参数 λ \lambda ,使 q ( z ; λ ) q(z ; \lambda) 靠近 p ( z x ) p(z|x)

  总结称一句话就是,为真实的后验分布引入了一个参数话的模型。 即:用一个简单的分布 q ( z ; λ ) q(z ; \lambda) 拟合复杂的分布 p ( z x ) p(z|x)

  这种策略将计算 p ( z x ) p(z|x) 的问题转化成优化问题了

λ = arg min λ divergence ( p ( z x ) , q ( z ; λ ) ) \lambda^{*}=\arg \min _{\lambda} \operatorname{divergence}(p(z | x), q(z ; \lambda))

  收敛后,就可以用 q ( z ; λ ) q(z;\lambda) 来代替 p ( z x ) p(z|x) 了。

KL散度

  而用一个分布去拟合另一个分布通常需要衡量这两个分布之间的相似性,通常采用KL散度,当然还有其他的一些方法,像JS散度这种。下面介绍KL散度

  机器学习中比较重要的一个概念—相对熵(relative entropy)。相对熵又被称为KL散度(Kullback–Leibler divergence) 或信息散度 (information divergence),是两个概率分布间差异的非对称性度量 。在信息论中,相对熵等价于两个概率分布的信息熵的差值,若其中一个概率分布为真实分布,另一个为理论(拟合)分布,则此时相对熵等于交叉熵与真实分布的信息熵之差,表示使用理论分布拟合真实分布时产生的信息损耗 。其公式如下:

D K L ( p q ) = i = 1 N [ p ( x i ) log p ( x i ) p ( x i ) log q ( x i ) ] D_{K L}(p \| q)=\sum_{i=1}^{N}\left[p\left(x_{i}\right) \log p\left(x_{i}\right)-p\left(x_{i}\right) \log q\left(x_{i}\right)\right]

  合并之后表示为:

D K L ( p q ) = i = 1 N p ( x i ) log ( p ( x i ) q ( x i ) ) D_{K L}(p \| q)=\sum_{i=1}^{N} p\left(x_{i}\right) \log \left(\frac{p\left(x_{i}\right)}{q\left(x_{i}\right)}\right)

  假设理论拟合出来的事件概率分布 q ( x ) q(x) 跟真实的分布 p ( x ) p(x) 一模一样,即 p ( x ) = q ( x ) p(x)=q(x) ,那么 p ( x i ) log q ( x i ) p\left(x_{i}\right) \log q\left(x_{i}\right) 就等于真实事件的信息熵,这一点显而易见。在理论拟合出来的事件概率分布跟真实的一模一样的时候,相对熵等于0。而拟合出来不太一样的时候,相对熵大于0。其证明如下:

i = 1 N p ( x i ) log q ( x i ) p ( x i ) i = 1 N p ( x i ) ( q ( x i ) p ( x i ) 1 ) = i = 1 N [ p ( x i ) q ( x i ) ] = 0 \sum_{i=1}^{N} p\left(x_{i}\right) \log \frac{q\left(x_{i}\right)}{p\left(x_{i}\right)} \leq \sum_{i=1}^{N} p\left(x_{i}\right)\left(\frac{q\left(x_{i}\right)}{p\left(x_{i}\right)}-1\right)=\sum_{i=1}^{N}\left[p\left(x_{i}\right)-q\left(x_{i}\right)\right]=0

  其中第一个不等式是由 l n ( x ) x 1 ln(x) \leq x -1 推导出来的,只在 p ( x i ) = q ( x i ) p(x_{i})=q(x_{i}) 时取到等号。

  这个性质很关键,因为它正是深度学习梯度下降法需要的特性。假设神经网络拟合完美了,那么它就不再梯度下降,而不完美则因为它大于0而继续下降。

  但它有不好的地方,就是它是不对称的。也就是用 P P 来拟合 Q Q 和用 Q Q 来拟合 P P 的相对熵居然不一样,而他们的距离是一样的。这也就是说,相对熵的大小并不跟距离有一一对应的关系。

求解

  中间引入了KL散度,但是我们本文的目的还是来求这个变分推理,不要走偏了。下面涉及一些公式等价转换:

log P ( x ) = log P ( x , z ) log P ( z x ) = log P ( x , z ) Q ( z ; λ ) log P ( z x ) Q ( z ; λ ) \begin{aligned} \log P(x) &=\log P(x, z)-\log P(z | x) \\ &=\log \frac{P(x, z)}{Q(z ; \lambda)}-\log \frac{P(z | x)}{Q(z ; \lambda)} \end{aligned}

  等式两边同时对 Q ( z ) Q(z) 求期望,得:

E q ( z ; λ ) log P ( x ) = E q ( z ; λ ) log P ( x , z ) E q ( z ; λ ) log P ( z x ) log P ( x ) = E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) E q ( z ; λ ) log p ( z x ) q ( z ; λ ) = K L ( q ( z ; λ ) p ( z x ) ) + E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) log P ( x ) = K L ( q ( z ; λ ) p ( z x ) ) + E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) \begin{aligned} \mathbb{E}_{q(z ; \lambda)} \log P(x) &=\mathbb{E}_{q(z ; \lambda)} \log P(x, z)-\mathbb{E}_{q(z ; \lambda)} \log P(z | x) \\ \log P(x) &=\mathbb{E}_{q(z ; \lambda)} \log \frac{p(x, z)}{q(z ; \lambda)}-\mathbb{E}_{q(z ; \lambda)} \log \frac{p(z | x)}{q(z ; \lambda)} \\ &=K L(q(z ; \lambda) \| p(z | x))+\mathbb{E}_{q(z ; \lambda)} \log \frac{p(x, z)}{q(z ; \lambda)} \\ \log P(x) &=K L(q(z ; \lambda) \| p(z | x))+\mathbb{E}_{q(z ; \lambda)} \log \frac{p(x, z)}{q(z ; \lambda)} \end{aligned}

  到这里我们需要回顾一下我们的问题,我们的目标是使 q ( z ; λ ) q(z;\lambda) 靠近 p ( z x ) p(z|x) ,就是求解:

min λ K L ( q ( z ; λ ) p ( z x ) ) \min_\lambda KL(q(z;\lambda)||p(z|x))

  而由于 K L ( q ( z ; λ ) p ( z x ) ) KL(q(z;\lambda)||p(z|x)) 中包含 p ( z x ) p(z|x) ,这项非常难求。借助上述公示的推导变形得到的结论:

log P ( x ) = K L ( q ( z ; λ ) p ( z x ) ) + E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) \log P(x) =K L(q(z ; \lambda) \| p(z | x))+\mathbb{E}_{q(z ; \lambda)} \log \frac{p(x, z)}{q(z ; \lambda)}

  将 λ \lambda 看做变量时, log P ( x ) \text{log}P(x) 为常量,所以, min λ K L ( q ( z ; λ ) p ( z x ) ) \min_\lambda KL(q(z;\lambda)||p(z|x)) 等价于 :

max λ E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) \max_\lambda \mathbb E_{q(z;\lambda)}\text{log}\frac{p(x,z)}{q(z;\lambda)}

  现在,variational inference的目标变成:

max λ E q ( z ; λ ) [ log p ( x , z ) log q ( z ; λ ) ] \max_{\lambda}\mathbb E_{q(z;\lambda)}[\text{log}p(x,z)-\text{log}q(z;\lambda)]

   E q ( z ; λ ) [ log p ( x , z ) log q ( z ; λ ) ] \mathbb E_{q(z;\lambda)}[\text{log}p(x,z)-\text{log}q(z;\lambda)] 称为Evidence Lower Bound(ELBO) p ( x ) p(x) 一般被称之为evidence,又因为 K L ( q p ) > = 0 KL(q||p)>=0 , 所以 p ( x ) > = E q ( z ; λ ) [ log p ( x , z ) log q ( z ; λ ) ] p(x)>=E_{q(z;\lambda)}[\text{log}p(x,z)-\text{log}q(z;\lambda)] , 这就是为什么被称为ELBO

ELBO

  ELBO公式表达为:

E q ( z ; λ ) [ log p ( x , z ) log q ( z ; λ ) ] \mathbb E_{q(z;\lambda)}[\text{log}p(x,z)-\text{log}q(z;\lambda)]

  原公式可表示为:

log P ( x ) = K L ( q ( z ; λ ) p ( z x ) ) + E q ( z ; λ ) log p ( x , z ) q ( z ; λ ) \log P(x) =K L(q(z ; \lambda) \| p(z | x))+\mathbb{E}_{q(z ; \lambda)} \log \frac{p(x, z)}{q(z ; \lambda)}

  引入ELBO表示为:

log P ( x ) = K L ( q ( z ; λ ) p ( z x ) ) + E L B O \log P(x) =K L(q(z ; \lambda) \| p(z | x))+ELBO

  实际上EM算法(Expectation-Maximization)就是利用了这一特征,它分为交替进行的两步:E step假设模型参数不变, q ( z ) = p ( z x ) q(z)=p(z|x) ,计算对数似然率,在M step再做ELBO相对于模型参数的优化。与变分法比较,EM算法假设了当模型参数固定时, p ( z x ) p(z|x) 是易计算的形式,而变分法并无这一限制,对于条件概率难于计算的情况,变分法仍然有效。

  那如何来求解上述公式呢?下面介绍平均场(mean-field)、蒙特卡洛、和黑盒变分推断 (Black Box Variational Inference) 的方法。

平均场变分族(mean-field variational family)

  之前我们说我们选择一族合适的近似概率分布 q ( Z ; λ ) q(Z;\lambda) ,那么实际问题中,我们可以选择什么形式的 q ( Z ; λ ) q(Z;\lambda) 呢?

  一个简单而有效的变分族为平均场变分族(mean-field variational family)。它假设了隐藏变量间是相互独立的:

q ( Z ; λ ) = k = 1 K q k ( Z k ; λ k ) q(Z;\lambda) = \prod_{k=1}^{K}q_k(Z_k;\lambda_k)

  这个假设看起来似乎比较强,但实际应用范围还是比较广泛,我们可以将其延展为将有实际相互关联的隐藏变量分组,而化为各组联合分布的乘积形式即可。

  利用ELBO和平均场假设,我们就可以利用coordinate ascent variational inference(简称CAVI)方法来处理:

  • 利用条件概率分布的链式法则有

p ( z 1 : m , x 1 : n ) = p ( x 1 : n ) j = 1 m p ( z j z 1 : ( j 1 ) , x 1 : n ) p\left(z_{1: m}, x_{1: n}\right)=p\left(x_{1: n}\right) \prod_{j=1}^{m} p\left(z_{j} | z_{1:(j-1)}, x_{1: n}\right)

  • 变分分布的期望为

E [ log q ( z 1 : m ) ] = j = 1 m E j [ log q ( z j ) ] E\left[\log q\left(z_{1: m}\right)\right]=\sum_{j=1}^{m} E_{j}\left[\log q\left(z_{j}\right)\right]

  将其代入ELBO的定义得到:

E L B O = logp ( x 1 : n ) + j = 1 m E [ log p ( z j z 1 : ( j 1 ) , x 1 : n ) ] E j [ log q ( z j ) ] E L B O=\operatorname{logp}\left(x_{1: n}\right)+\sum_{j=1}^{m} E\left[\log p\left(z_{j} | z_{1:(j-1)}, x_{1: n}\right)\right]-E_{j}\left[\log q\left(z_{j}\right)\right]

  将其对 z k z_{k} 求导并令导数为零有:

d E L B O d q ( z k ) = E k [ log p ( z k z k , x ) ] log q ( z k ) 1 = 0 \frac{d E L B O}{d q\left(z_{k}\right)}=E_{-k}\left[\log p\left(z_{k} | z_{-k}, x\right)\right]-\log q\left(z_{k}\right)-1=0

  由此得到coordinate ascent 的更新法则为:

q ( z k ) exp E k [ log p ( z k , z k , x ) ] q^{*}\left(z_{k}\right) \propto \exp E_{-k}\left[\log p\left(z_{k}, z_{-k}, x\right)\right]

  我们可以利用这一法则不断的固定其他的 z z 的坐标来更新当前的坐标对应的 z z 值,这与Gibbs Sampling过程类似,不过Gibbs Sampling是不断的从条件概率中采样,而CAVI算法中是不断的用如下形式更新:

q ( z k ) exp E [ log ( conditional ) ] q^{*}\left(z_{k}\right) \propto \exp E[\log (\text {conditional})]

  其完整算法如下所示:

CAVI算法流程

MCMC

  MCMC方法是利用马尔科夫链取样来近似后验概率,变分法是利用优化结果来近似后验概率,那么我们什么时候用MCMC,什么时候用变分法呢?

  首先,MCMC相较于变分法计算上消耗更大,但是它可以保证取得与目标分布相同的样本,而变分法没有这个保证:它只能寻找到近似于目标分布一个密度分布,但同时变分法计算上更快,由于我们将其转化为了优化问题,所以可以利用诸如随机优化(stochastic optimization)或分布优化(distributed optimization)等方法快速的得到结果。所以当数据量较小时,我们可以用MCMC方法消耗更多的计算力但得到更精确的样本。当数据量较大时,我们用变分法处理比较合适。

  另一方面,后验概率的分布形式也影响着我们的选择。比如对于有多个峰值的混合模型,MCMC可能只注重其中的一个峰而不能很好的描述其他峰值,而变分法对于此类问题即使样本量较小也可能优于MCMC方法。

黑盒变分推断(BBVI)

  ELBO公式表达为:

E q ( z ; λ ) [ log p ( x , z ) log q ( z ; λ ) ] \mathbb E_{q(z;\lambda)}[\text{log}p(x,z)-\text{log}q(z;\lambda)]

  对用参数 θ \theta 替代 λ \lambda ,并对其求导:

θ ELBO ( θ ) = θ E q ( log p ( x , z ) log q θ ( z ) ) \nabla_{\theta} \operatorname{ELBO}(\theta)=\nabla_{\theta} \mathbb{E}_{q}\left(\log p(x, z)-\log q_{\theta}(z)\right)

  直接展开计算如下:

θ q θ ( z ) ( log p ( x , z ) log q θ ( z ) ) d z = θ [ q θ ( z ) ( log p ( x , z ) log q θ ( z ) ) ] d z = θ ( q θ ( z ) log p ( x , z ) ) θ ( q θ ( z ) log q θ ( z ) ) d z = q θ ( z ) θ log p ( x , z ) q θ ( z ) θ log q θ ( z ) q θ ( z ) θ d z \begin{aligned} & \frac{\partial}{\partial \theta} \int q_{\theta}(z)\left(\log p(x, z)-\log q_{\theta}(z)\right) d z \\ =& \int \frac{\partial}{\partial \theta}\left[q_{\theta}(z)\left(\log p(x, z)-\log q_{\theta}(z)\right)\right] d z \\ =& \int \frac{\partial}{\partial \theta}\left(q_{\theta}(z) \log p(x, z)\right)-\frac{\partial}{\partial \theta}\left(q_{\theta}(z) \log q_{\theta}(z)\right) d z \\ =& \int \frac{\partial q_{\theta}(z)}{\partial \theta} \log p(x, z)-\frac{\partial q_{\theta}(z)}{\partial \theta} \log q_{\theta}(z)-\frac{\partial q_{\theta}(z)}{\partial \theta} d z \end{aligned}

  由于:

q θ ( z ) θ d z = θ q θ ( z ) d z = θ 1 = 0 \int \frac{\partial q_{\theta}(z)}{\partial \theta} d z=\frac{\partial}{\partial \theta} \int q_{\theta}(z) d z=\frac{\partial}{\partial \theta} 1=0

  因此:

θ ELBO ( θ ) = q θ ( z ) θ ( log p ( x , z ) log q θ ( z ) ) d z = q θ ( z ) log q θ ( z ) θ ( log p ( x , z ) log q θ ( z ) ) d z = q θ ( z ) θ log q θ ( z ) ( log p ( x , z ) log q θ ( z ) ) d z = E q [ θ log q θ ( z ) ( log p ( x , z ) log q θ ( z ) ) ] \begin{aligned} \nabla_{\theta} \operatorname{ELBO}(\theta) &=\int \frac{\partial q_{\theta}(z)}{\partial \theta}\left(\log p(x, z)-\log q_{\theta}(z)\right) d z \\ &=\int q_{\theta}(z) \frac{\partial \log q_{\theta}(z)}{\partial \theta}\left(\log p(x, z)-\log q_{\theta}(z)\right) d z \\ &=\int q_{\theta}(z) \nabla_{\theta} \log q_{\theta}(z)\left(\log p(x, z)-\log q_{\theta}(z)\right) d z \\ &=\mathbb{E}_{q}\left[\nabla_{\theta} \log q_{\theta}(z)\left(\log p(x, z)-\log q_{\theta}(z)\right)\right] \end{aligned}

相关文章
相关标签/搜索