当前位置:首页 >> 硬件技术 >> 【TensorFlow-windows】学习笔记六——变分自编码器,锋哲v6t

【TensorFlow-windows】学习笔记六——变分自编码器,锋哲v6t

cpugpu芯片开发光刻机 硬件技术 1
文件名:【TensorFlow-windows】学习笔记六——变分自编码器,锋哲v6t 【TensorFlow-windows】学习笔记六——变分自编码器

#前言

对理论没兴趣的直接看代码吧,理论一堆,而且还有点复杂,我自己的描述也不一定准确,但是代码就两三句话搞定了。

国际惯例,参考博文

论文:Tutorial on Variational Autoencoders

【干货】一文读懂什么是变分自编码器

CS598LAZ - Variational Autoencoders

MusicVAE: Creating a palette for musical scores with machine learning

【Learning Notes】变分自编码器(Variational Auto-Encoder,VAE)

花式解释AutoEncoder与VAE

#理论

##基础知识

似然函数(引自百度百科)

似然函数是关于统计模型中的参数的函数,表示模型参数的似然性。在给定输出xxx时,关于参数θ\thetaθ的似然函数L(θ∣x)L(\theta|x)L(θx)在数值上等于给定参数θ\thetaθ后变量XXX的概率: L(θ∣x)=P(X=x∣θ)L(\theta|x)=P(X=x|\theta) L(θx)=P(X=xθ) 有两个比较有趣的说法来区分概率与似然的关系,比如抛硬币的例子:

概率说法:对于“一枚正反对称的硬币上抛十次”这种事件,问硬币落地时十次都是正面向上的“概率”是多少似然说法:对于“一枚硬币上抛十次”,问这枚硬币正反面对称的“似然”程度是多少。 极大似然估计

(摘自西瓜书)两大学派:

频率主义学:参数是固定的,通过优化似然函数来确定参数贝叶斯学派:参数是变化的,且本身具有某种分布,先假设参数服从某个先验分布,然后基于观测到的数据来计算参数的后验分布

极大似然估计(Maximum Likelihood Estimation,MLE)源自频率主义学。

假设DDD是第ccc类样本的集合,比如所有的数字333的图片集合,假设它们是独立同分布的,则参数θ\thetaθ对于数据集DDD的似然是: P(D∣θ)=∏x∈DP(x∣θ)P(D|\theta)=\prod_{x\in D}P(x|\theta) P(Dθ)=xDP(xθ) 极大似然估计就是寻找一个θ\thetaθ使得样本xxx出现的概率最大

但是上面的连乘比较难算,这就出现了对数似然: L(θ)=log⁡P(D∣θ)=∑x∈Dlog⁡P(x∣D)L(\theta)=\log P(D|\theta)=\sum_{x\in D}\log P(x|D) L(θ)=logP(Dθ)=xDlogP(xD) 我们的目标就是求参数θ\thetaθ的极大似然估计θ^\hat{\theta}θ^ θ^=arg⁡max⁡θL(θ)\hat{\theta}=\arg \max_{\theta}L(\theta) θ^=argθmaxL(θ) 例子:在连续属性情况下,如果样本集合概率密度函数p(x∣c)∼N(μ,σ2)p(x|c)\sim N(\mu,\sigma^2)p(xc)N(μ,σ2),那么参数μ,σ2\mu,\sigma^2μ,σ2的极大似然估计就是 μ^=1∣D∣∑x∈Dxσ^2=1∣D∣∑x∈D(x−μ^)(x−μ^)T\begin{aligned} \hat{\mu}&=\frac{1}{|D|}\sum_{x\in D}x\\ \hat{\sigma}^2&=\frac{1}{|D|}\sum_{x \in D}(x-\hat{\mu})(x-\hat{\mu})^T \end{aligned} μ^σ^2=D1xDx=D1xD(xμ^)(xμ^)T 其实就是计算均值和方差了。这样想,这些样本就服从这个高斯分布,那么把高斯分布直接当做参数,一定能够大概率得到此类样本,也就是说用333的样本所服从的高斯分布作为模型参数一定能使333出现的概率P(x∣θ)P(x|\theta)P(xθ)最大。

###期望值最大化算法(EM)

这一部分简单说一下即可,详细的在我前面的博客HMM——前向后向算法中有介绍,主要有两步:

E步:求Q函数Q(θ,θ(i))Q(\theta,\theta^{(i)})Q(θ,θ(i)),这个θ(i)\theta^{(i)}θ(i)就是当前迭代次数iii对应的参数值,Q函数实际就是对数联合似然函数log⁡P(X,Z∣θ)\log P(X,Z|\theta)logP(X,Zθ)在分布P(Z∣X,θ(i))P(Z|X,\theta^{(i)})P(ZX,θ(i))下的期望 Q(θ,θ(i))=EZ∣X,θ(i)L(θ∣X,Z)=Ez[log⁡P(X,Z∣θ)∣X,θ(i)]=∑ZP(Z∣X,θ(i))log⁡P(X,Z∣θ)\begin{aligned} Q(\theta,\theta^{(i)})&=E_{Z|X,\theta^{(i)}}L(\theta|X,Z)\\ &=E_z\left[\log P(X,Z|\theta)|X,\theta^{(i)}\right]\\ &=\sum_Z P(Z|X,\theta^{(i)})\log P(X,Z|\theta) \end{aligned} Q(θ,θ(i))=EZX,θ(i)L(θX,Z)=Ez[logP(X,Zθ)X,θ(i)]=ZP(ZX,θ(i))logP(X,Zθ)

M步:求使得Q函数最大化的参数θ​\theta​θ,并将其作为下一步的θ(i)​\theta^{(i)}​θ(i) θ(i+1)=arg⁡max⁡θQ(θ,θ(i))\theta^{(i+1)}=\arg\max_\theta Q(\theta,\theta^{(i)}) θ(i+1)=argθmaxQ(θ,θ(i))

从西瓜书上再摘点主要内容过来:

有时候样本的一些属性可以观测到,而另一些属性观测不到,所以就定义未观测变量为隐变量,设XXX为可观测变量,ZZZ为隐变量,θ\thetaθ为模型参数,则可写出对数似然: L(θ∣X,Z)=ln⁡P(X,Z∣θ)L(\theta|X,Z)=\ln P(X,Z|\theta) L(θX,Z)=lnP(X,Zθ) 但是ZZZ又不知道,所以采用边缘化(marginal)方法消除它 L(θ∣X)=ln⁡P(X∣θ)=ln⁡∑ZP(X,Z∣θ)=∑i=1Nln⁡{∑ZP(xi,Z∣θ)}L(\theta|X)=\ln P(X|\theta)=\ln\sum_Z P(X,Z|\theta)=\sum_{i=1}^N\ln\left\{\sum_Z P(x_i,Z|\theta)\right\} L(θX)=lnP(Xθ)=lnZP(X,Zθ)=i=1Nln{ZP(xi,Zθ)} 使用EM算法求解参数的方法是:

基于θ(i)\theta^{(i)}θ(i)推断隐变量ZZZ的期望,记为ZtZ^tZt基于已观测变量XXXZtZ^tZt对参数θ\thetaθ做极大似然估计,求得θ(i+1)\theta^{(i+1)}θ(i+1)

【注】是不是感觉很像坐标下降法

变分推断

(摘自西瓜书)

变分推断是通过使用已知简单分布来逼近需推断的复杂分布,并通过限制近似分布的类型,从而得到一种局部最优,但具有确定解的近似后验分布。

继续看上面的EM算法的M步,我们得到了: θ(i+1)=arg⁡max⁡θQ(θ,θ(i))=arg⁡max⁡θ∑ZP(Z∣x,θ(i))ln⁡P(x,Z∣θ)\begin{aligned} \theta^{(i+1)}&=\arg\max_\theta Q(\theta,\theta^{(i)})\\ &=\arg\max_\theta \sum_ZP\left(Z|x,\theta^{(i)}\right)\ln P(x,Z|\theta) \end{aligned} θ(i+1)=argθmaxQ(θ,θ(i))=argθmaxZP(Zx,θ(i))lnP(x,Zθ) 还记得QQQ函数的意义吧,对数联合似然函数ln⁡P(X,Z∣θ)\ln P(X,Z|\theta)lnP(X,Zθ)在分布P(Z∣X,θ(i))P(Z|X,\theta^{(i)})P(ZX,θ(i))下的期望。当分布P(Z∣X,θ(i))P(Z|X,\theta^{(i)})P(ZX,θ(i))与变量ZZZ的真实后验分布相等的时候,QQQ函数就近似于对数似然函数,因而EM算法能够获得稳定的参数θ\thetaθ,且隐变量ZZZ的分布也能通过该参数获得。

但是通常情况下,P(Z∣X,θ(i))P(Z|X,\theta^{(i)})P(ZX,θ(i))只是隐变量ZZZ所服从的真实分布的近似,若用Q(Z)Q(Z)Q(Z)表示,则 ln⁡P(X)=L(Q)+KL(Q∣∣P)\ln P(X)=L(Q)+KL(Q||P) lnP(X)=L(Q)+KL(QP) 其中 L(Q)=∫Q(Z)ln⁡{P(X,Z)Q(Z)}dZKL(Q∣∣P)=−∫Q(Z)ln⁡P(Z∣X)Q(Z)dZL(Q)=\int Q(Z)\ln\left\{\frac{P(X,Z)}{Q(Z)}\right\}dZ\\ KL(Q||P)=-\int Q(Z)\ln \frac{P(Z|X)}{Q(Z)}dZ L(Q)=Q(Z)ln{Q(Z)P(X,Z)}dZKL(QP)=Q(Z)lnQ(Z)P(ZX)dZ 但是,这个ZZZ模型可能很复杂,导致E步的P(Z∣X,θ)P(Z|X,\theta)P(ZX,θ)比较难推断,这时候就借用变分推断了,假设ZZZ服从分布 Q(Z)=∏i=1MQi(Zi)Q(Z)=\prod_{i=1}^MQ_i(Z_i) Q(Z)=i=1MQi(Zi) 也就是说多变量ZZZ可拆解为一系列相互独立的多变量ZiZ_iZi,可以另QiQ_iQi是非常简单的分布。

【PS】浅尝辄止了,经过层层理论已经引出了变分自编码的主要思想,变分推断,使用简单分布逼近复杂分布,实际上,变分自编码所使用的简单分布就是高斯分布,用多个高斯分布来逼近隐变量,随后利用服从这些分布的隐变量重构我们想要的数据。

变分自编码

先看优化目标ELBO(Evidence Lower Bound): ELBO=log⁡p(x)−KL[q(z∣x)∣∣p(z∣x)]ELBO=\log p(x)-KL\left[q(z|x)||p(z|x)\right] ELBO=logp(x)KL[q(zx)p(zx)] 其中qqq是假设分布,ppp是真实分布,我们希望最大化第一项而最小化KL距离,所以整个规则就是最大化ELBOELBOELBO,但是这里面有个p(z∣x)p(z|x)p(zx)代表隐变量的真实分布,这个是无法求解的,所以需要简化:

简化结果是: log⁡p(x)−KL(q(z)∣∣p(z∣x))=Ez∼q[log⁡P(x∣z)]−KL(q(z)∣∣p(z))\log p(x)-KL(q(z)||p(z|x))=E_{z\sim q}\left[\log P(x|z)\right]-KL(q(z)||p(z)) logp(x)KL(q(z)p(zx))=Ezq[logP(xz)]KL(q(z)p(z)) 证明:

假设KL距离为KL(q(z)∣∣p(z∣x))=Ez∼q[log⁡q(z)−log⁡p(z∣x)]KL(q(z)||p(z|x))=E_{z\sim q}\left[\log q(z)-\log p(z|x)\right]KL(q(z)p(zx))=Ezq[logq(z)logp(zx)]

那么直接使用贝叶斯准则:

p(z∣x)=p(x∣z)p(z)p(x)p(z|x)=\frac{p(x|z)p(z)}{p(x)}p(zx)=p(x)p(xz)p(z)log⁡p(z∣x)=log⁡p(x∣z)+log⁡p(z)−log⁡p(x)\log p(z|x)=\log p(x|z)+\log p(z)-\log p(x)logp(zx)=logp(xz)+logp(z)logp(x)p(x)p(x)p(x)不依赖于z

可以得到: KL(q(z)∣∣p(z∣x))=Ez∼q(log⁡q(z)−log⁡p(x∣z)−log⁡p(z))+log⁡p(x)KL(q(z)||p(z|x))=E_{z\sim q}(\log q(z)-\log p(x|z)-\log p(z))+\log p(x) KL(q(z)p(zx))=Ezq(logq(z)logp(xz)logp(z))+logp(x) 其中Ez∼q(log⁡q(z)−log⁡p(z))=KL(q(z)∣∣p(z))E_{z\sim q}(\log q(z)-\log p(z))=KL(q(z)||p(z))Ezq(logq(z)logp(z))=KL(q(z)p(z))

所以就能继续简化那个KL(q(z)∣∣p(z∣x))KL(q(z)||p(z|x))KL(q(z)p(zx))了: log⁡p(x)−KL(q(z)∣p(z∣x))=Ez∼q[log⁡p(x∣z)]−KL(q(z)∣∣p(z))\log p(x)-KL(q(z)|p(z|x))=E_{z\sim q}[\log p(x|z)]-KL(q(z)||p(z)) logp(x)KL(q(z)p(zx))=Ezq[logp(xz)]KL(q(z)p(z)) 证毕

但是我们的优化目标是 ELBO=log⁡p(x)−KL[q(z∣x)∣∣p(z∣x)]ELBO=\log p(x)-KL\left[q(z|x)||p(z|x)\right] ELBO=logp(x)KL[q(zx)p(zx)] 发现一个是q(z)q(z)q(z)一个是q(z∣x)q(z|x)q(zx),怎么办呢?看论文第8页有这样一句话:

Note that X is fixed, and Q can be any distribution, not just a distribution which does a good job mapping X to the z’s that can produce X. Since we’re interested in inferring P(X), it makes sense to construct a Q which does depend on X, and in particular, one which makes D [Q(z)k|P(z|X)] small

翻译一下意思就是:XXX是固定的(因为它是样本集),Q也可是任意分布,并非仅是能够生成XXX的分布,因为我们想推断P(X)P(X)P(X),那么构建一个依赖于XXXQQQ分布是可行的,还能让KL(Q(z)∣∣P(z∣X))KL(Q(z)||P(z|X))KL(Q(z)P(zX))较小: log⁡p(x)−KL(q(z∣x)∣p(z∣x))=Ez∼q[log⁡p(x∣z)]−KL(q(z∣x)∣∣p(z))\log p(x)-KL(q(z|x)|p(z|x))=E_{z\sim q}[\log p(x|z)]-KL(q(z|x)||p(z)) logp(x)KL(q(zx)p(zx))=Ezq[logp(xz)]KL(q(zx)p(z)) 这个式子就是变分自编码的核心了

这样我们就知道了优化目标(等号右边的时候),我们看看变换后的式子为什么能够计算?

首先没了p(z∣x)p(z|x)p(zx),其次每一项都能计算,我们挨个来看:

如何计算q(z∣x)q(z|x)q(zx)? 我们可以使用神经网络逼近q(z∣x)q(z|x)q(zx),假设q(z∣x)q(z|x)q(zx)服从高斯分布N(μ,σ)N(\mu,\sigma)N(μ,σ)

神经网络的输出就是均值μ\muμ和方差σ\sigmaσ输入是图片,输出是分布

计算q(z∣x)q(z|x)q(zx)就是编码过程了

如果计算p(x∣z)?用一个神经网络去逼近p(x|z)? 用一个神经网络去逼近p(xz)p(x|z),假设神经网络输出是,假设神经网络输出是f(z)$ 假设p(x∣z)p(x|z)p(xz)服从另一种高斯分布

x=f(z)+ηx=f(z)+\etax=f(z)+η,其中η∼N(0,I)\eta\sim N(0,I)ηN(0,I)简化成l2l_2l2损失:∣∣X−f(z)∣∣2||X-f(z)||^2Xf(z)2

计算p(x∣z)p(x|z)p(xz)就是解码过程了

最终损失就是 L=∣∣X−f(z)∣∣2−λ⋅KL(q(z∣x)∣∣p(z))L=||X-f(z)||^2-\lambda\cdot KL(q(z|x)||p(z)) L=Xf(z)2λKL(q(zx)p(z)) 在这里,我们先不看这个最终损失的式子,我们去瞅瞅未经过l2l_2l2简化的的优化目标 ELBO=Ez∼q[log⁡p(x∣z)]−KL(q(z∣x)∣∣p(z))ELBO=E_{z\sim q}[\log p(x|z)]-KL(q(z|x)||p(z)) ELBO=Ezq[logp(xz)]KL(q(zx)p(z))

计算第二项的KL散度 我们经常选择q(z∣x)=N(z∣μ(x;θ),Σ(x;θ))q(z|x)=N(z|\mu(x;\theta),\Sigma(x;\theta))q(zx)=N(zμ(x;θ),Σ(x;θ)),这里面μ,Σ\mu,\Sigmaμ,Σ通常是任意确定的函数,且其参数θ\thetaθ能够从数据中学习。通常通过神经网络获取,并且Σ\SigmaΣ被限制为一个对角阵。这样选择的好处是便于计算,仅此而已,那么右边的KL(q(z∣x)∣∣p(z))KL(q(z|x)||p(z))KL(q(zx)p(z))就编程了两个多元高斯分布的KL距离,有闭式解为: D(N(μ0,Σ0)∣∣N(μ1,Σ1))=12(tr(Σ1−1Σ0)+(μ1−μ0)TΣ1−1(μ1−μ0)−k+log⁡(det⁡Σ1det⁡Σ0)D(N(\mu_0,\Sigma_0)||N(\mu_1,\Sigma_1))=\\ \frac{1}{2}\left(tr(\Sigma^{-1}_1\Sigma_0\right)+(\mu_1-\mu_0)^T\Sigma_1^{-1}(\mu_1-\mu_0)-k+\log(\frac{\det \Sigma_1}{\det \Sigma_0}) D(N(μ0,Σ0)N(μ1,Σ1))=21(tr(Σ11Σ0)+(μ1μ0)TΣ11(μ1μ0)k+log(detΣ0detΣ1) 其中kkk是分布的维数,而在变分推断中,经常又被简化成 D(N(μ(x),Σ(x))∣∣N(0,I))=12(tr(Σ(x))+(μ(x))T(μ(x))−k+log⁡det⁡(Σ(x))D(N(\mu(x),\Sigma(x))||N(0,I))=\\ \frac{1}{2}\left(tr(\Sigma(x)\right)+(\mu(x))^T(\mu(x))-k+\log\det(\Sigma(x)) D(N(μ(x),Σ(x))N(0,I))=21(tr(Σ(x))+(μ(x))T(μ(x))k+logdet(Σ(x))

计算第一项Ez∼qE_{z\sim q}Ezq 论文中说这一项的计算有点小技巧(tricky),本来是可以通过采样的方法估计Ez∼q(log⁡p(x∣z))E_{z\sim q}(\log p(x|z))Ezq(logp(xz)),但是只有将很多的zzz通过fff式子(解码部分)输出以后才能得到较好的估计结果,这个计算量很大,因此想到了随机梯度下降,我们可以拿一个样本zzz,将p(x∣z)p(x|z)p(xz)作为Ez∼q(log⁡p(x∣z))E_{z\sim q}(\log p(x|z))Ezq(logp(xz))的估计,所以式子又变成了: Ex∼D(log⁡p(x)−KL(q(z∣x)∣∣p(z∣x)))=Ex∼D[Ez∼q[log⁡p(x∣z)]]−KL(q(z∣x)∣∣p(z))E_{x\sim D}(\log p(x)-KL(q(z|x)||p(z|x)))=\\ E_{x\sim D}\left[E_{z\sim q}\left[\log p(x|z)\right]\right]-KL(q(z|x)||p(z)) ExD(logp(x)KL(q(zx)p(zx)))=ExD[Ezq[logp(xz)]]KL(q(zx)p(z)) 意思就是我们从样本集合DDD中取一个样本xxx来计算,所以对于单个,可以计算下式梯度: log⁡p(x∣z)−KL(q(z∣x)∣∣p(z))\log p(x|z)-KL(q(z|x)||p(z)) logp(xz)KL(q(zx)p(z)) 这样消除了Ez∼qE_{z\sim q}Ezq中对qqq的依赖。

论文中有个图很好

其实log⁡p(x∣z)−KL(q(z∣x)∣∣p(z))\log p(x|z)-KL(q(z|x)||p(z))logp(xz)KL(q(zx)p(z))刚好就是左图,主要就是反传的时候没法计算梯度,看左图红框部分,这一部分是随机采样,是无法计算梯度的,那么文中就说了一个技巧:重新参数化(reparameterization trick),给定了μ(x),Σ(x)\mu(x),\Sigma(x)μ(x),Σ(x)也就是Q(z∣x)Q(z|x)Q(zx)的均值和方差,我们先从N(0,I)N(0,I)N(0,I)中采样,然后计算z=μ(x)+Σ12∗ϵz=\mu(x)+\Sigma^{\frac{1}{2}}*\epsilonz=μ(x)+Σ21ϵ,所以我们又可以计算下式的梯度了: Ex∼D[Eϵ∼N(0,I)[log⁡p(x∣z=μ(x)+Σ1/2(x)∗ϵ)]−KL(q(z∣x)∣∣p(z))]E_{x\sim D}\left[E_{\epsilon\sim N(0,I)}\left[\log p(x|z=\mu(x)+\Sigma^{1/2}(x)*\epsilon)\right]-KL(q(z|x)||p(z)) \right] ExD[EϵN(0,I)[logp(xz=μ(x)+Σ1/2(x)ϵ)]KL(q(zx)p(z))] 这就完成了从左图到右图的转变。

代码实现-模型训练及保存

理论很复杂,但是我们看着右图就能实现,无需看理论,理论只是让我们知道为什么会有右图这种网络结构。按照标准流程来书写代码:

读数据初始化相关参数定义数据接收接口以便测试使用初始化权重和偏置定义基本模块:编码器、采样器、解码器构建模型定义预测函数、损失函数、优化器训练

整个代码很简单,我就只贴部分重点的:

初始化权重偏置 #初始化权重、偏置def glorot_init(shape):return tf.random_normal(shape=shape,stddev=1./tf.sqrt(shape[0]/2.0))#权重weights={'encoder_h1':tf.Variable(glorot_init([num_input,hidden_dim])),'z_mean':tf.Variable(glorot_init([hidden_dim,latent_dim])),'z_std':tf.Variable(glorot_init([hidden_dim,latent_dim])),'decoder_h1':tf.Variable(glorot_init([latent_dim,hidden_dim])),'decoder_out':tf.Variable(glorot_init([hidden_dim,num_input]))}#偏置biases={'encoder_b1':tf.Variable(glorot_init([hidden_dim])),'z_mean':tf.Variable(glorot_init([latent_dim])),'z_std':tf.Variable(glorot_init([latent_dim])),'decoder_b1':tf.Variable(glorot_init([hidden_dim])),'decoder_out':tf.Variable(glorot_init([num_input]))}

注意这里使用了另一种初始化方法,说是Xavier初始化方法,因为直接使用上一篇博客的方法

tf.Variable(tf.random_normal([num_input,num_hidden1])),

训练时候一直给我弹出loss:nan,我也是醉了,以后还是用之前学theano时候采用的fan_in-fan_out方法初始化权重算了。

定义基本模块

注意需要定义编码器、采样器、解码器

#定义编码器def encoder(x):encoder=tf.matmul(x,weights['encoder_h1'])+biases['encoder_b1']encoder=tf.nn.tanh(encoder)z_mean=tf.matmul(encoder,weights['z_mean'])+biases['z_mean']z_std=tf.matmul(encoder,weights['z_std'])+biases['z_std']return z_mean,z_std#定义采样器def sampler(z_mean,z_std):eps=tf.random_normal(tf.shape(z_std),dtype=tf.float32,mean=0,stddev=1.0,name='epsilon')z=z_mean+tf.exp(z_std/2)*epsreturn z#定义解码器def decoder(x):decoder=tf.matmul(x,weights['decoder_h1'])+biases['decoder_b1']decoder=tf.nn.tanh(decoder)decoder=tf.matmul(decoder,weights['decoder_out'])+biases['decoder_out']decoder=tf.nn.sigmoid(decoder)return decoder 构建模型 #构建模型[z_mean,z_std]=encoder(X)#计算均值方差sample_latent=sampler(z_mean,z_std)#采样隐空间decoder_op=decoder(sample_latent)#重构输出 预测函数和损失 #预测函数y_pred=decoder_opy_true=Xtf.add_to_collection('recon',y_pred)#定义损失函数和优化器def vae_loss(x_reconstructed,x_true,z_mean,zstd):#重构损失encode_decode_loss=x_true*tf.log(1e-10+x_reconstructed)\+(1-x_true)*tf.log(1e-10+1-x_reconstructed)encode_decode_loss=-tf.reduce_sum(encode_decode_loss,1)#KL损失kl_div_loss=1+z_std-tf.square(z_mean)-tf.exp(z_std)kl_div_loss=-0.5*tf.reduce_sum(kl_div_loss,1)return tf.reduce_mean(encode_decode_loss+kl_div_loss)loss_op=vae_loss(decoder_op,y_true,z_mean,z_std)optimizer=tf.train.RMSPropOptimizer(learning_rate=learning_rate)train_op=optimizer.minimize(loss_op)

注意这里损失的第一项类似于交叉熵损失:y×log⁡(y^)+(1−y)×(1−log⁡y^)y\times \log(\hat{y})+(1-y)\times(1-\log \hat y)y×log(y^)+(1y)×(1logy^)

关于交叉熵损失和均方差损失的区别,可以看我前面的博客:损失函数梯度对比-均方差和交叉熵

训练和保存模型 #参数初始化init=tf.global_variables_initializer()input_image,input_label=read_images('./mnist/train_labels.txt',batch_size)#训练和保存模型saver=tf.train.Saver()with tf.Session() as sess:sess.run(init)coord=tf.train.Coordinator()tf.train.start_queue_runners(sess=sess,coord=coord)for step in range(1,num_steps):batch_x,batch_y=sess.run([input_image,tf.one_hot(input_label,1,0)])sess.run(train_op,feed_dict={X:batch_x})if step%disp_step==0 or step==1:loss=sess.run(loss_op,feed_dict={X:batch_x})print('step '+str(step)+' ,loss '+'{:.4f}'.format(loss))coord.request_stop()coord.join()print('optimization finished')saver.save(sess,'./VAE_mnist_model/VAE_mnist')

常规的保存方法,没什么说的,训练日志:

step 1 ,loss 616.3002step 1000 ,loss 169.6044step 2000 ,loss 163.5006step 3000 ,loss 166.1648step 4000 ,loss 161.2366step 5000 ,loss 155.1714step 6000 ,loss 153.2840step 7000 ,loss 161.5571step 8000 ,loss 152.0021step 9000 ,loss 159.5550step 10000 ,loss 154.6315step 11000 ,loss 153.8298step 12000 ,loss 141.5825step 13000 ,loss 149.7792step 14000 ,loss 150.9575step 15000 ,loss 151.2249step 16000 ,loss 159.3878step 17000 ,loss 148.7136step 18000 ,loss 148.5801step 19000 ,loss 150.6678step 20000 ,loss 146.3471step 21000 ,loss 156.4142step 22000 ,loss 148.7607step 23000 ,loss 145.4101step 24000 ,loss 153.3523step 25000 ,loss 157.8997step 26000 ,loss 136.9668step 27000 ,loss 155.7835step 28000 ,loss 137.7291step 29000 ,loss 153.1723optimization finished

【很尴尬的事情】偷偷说一句,上面保存错东东了,别打我,但是也不必重新训练,看接下来的蛇皮操作。

#代码实现-模型加载及测试

老样子,先载入模型

sess=tf.Session()new_saver=tf.train.import_meta_graph('./VAE_mnist_model/VAE_mnist.meta')new_saver.restore(sess,'./VAE_mnist_model/VAE_mnist')

获取计算图:

graph=tf.get_default_graph()

看看保存了啥

print (graph.get_all_collection_keys())#['queue_runners', 'recon', 'summaries', 'train_op', 'trainable_variables', 'variables']

准备调用recon函数重构数据。

等等,“重构”?搞错了,这里应该是依据噪声来生成数据的,不是输入一个数据然后重构它,这是AE的做法,我们在VAE中应该称之为生成了,然而不幸的是,我们保存的recon函数接收的是图片输入,无法指定decoder部分所需的分布参数,也就是均值方差,怎么办?手动选择性载入模型,这篇博客有介绍怎么在测试阶段定义网络权重和载入训练好的权重,但是我好像没成功,懒得试了,按照我自己的想法来做。

眼尖的童鞋会发现,我们之前一直只关注recon函数了,忽视了其它keys,很快发现最后两个trainable_variables和variables貌似与我们想要的模型参数有关,我们来输出一下这两个东东里面都保存了啥:

第一个trainable_variables

for i in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):print(i)

输出

<tf.Variable 'Variable:0' shape=(784, 512) dtype=float32_ref><tf.Variable 'Variable_1:0' shape=(512, 2) dtype=float32_ref><tf.Variable 'Variable_2:0' shape=(512, 2) dtype=float32_ref><tf.Variable 'Variable_3:0' shape=(2, 512) dtype=float32_ref><tf.Variable 'Variable_4:0' shape=(512, 784) dtype=float32_ref><tf.Variable 'Variable_5:0' shape=(512,) dtype=float32_ref><tf.Variable 'Variable_6:0' shape=(2,) dtype=float32_ref><tf.Variable 'Variable_7:0' shape=(2,) dtype=float32_ref><tf.Variable 'Variable_8:0' shape=(512,) dtype=float32_ref><tf.Variable 'Variable_9:0' shape=(784,) dtype=float32_ref>

第二个:GLOBAL_VARIABLES

for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):print(i) <tf.Variable 'Variable:0' shape=(784, 512) dtype=float32_ref><tf.Variable 'Variable_1:0' shape=(512, 2) dtype=float32_ref><tf.Variable 'Variable_2:0' shape=(512, 2) dtype=float32_ref><tf.Variable 'Variable_3:0' shape=(2, 512) dtype=float32_ref><tf.Variable 'Variable_4:0' shape=(512, 784) dtype=float32_ref><tf.Variable 'Variable_5:0' shape=(512,) dtype=float32_ref><tf.Variable 'Variable_6:0' shape=(2,) dtype=float32_ref><tf.Variable 'Variable_7:0' shape=(2,) dtype=float32_ref><tf.Variable 'Variable_8:0' shape=(512,) dtype=float32_ref><tf.Variable 'Variable_9:0' shape=(784,) dtype=float32_ref><tf.Variable 'Variable/RMSProp:0' shape=(784, 512) dtype=float32_ref><tf.Variable 'Variable/RMSProp_1:0' shape=(784, 512) dtype=float32_ref><tf.Variable 'Variable_1/RMSProp:0' shape=(512, 2) dtype=float32_ref><tf.Variable 'Variable_1/RMSProp_1:0' shape=(512, 2) dtype=float32_ref><tf.Variable 'Variable_2/RMSProp:0' shape=(512, 2) dtype=float32_ref><tf.Variable 'Variable_2/RMSProp_1:0' shape=(512, 2) dtype=float32_ref><tf.Variable 'Variable_3/RMSProp:0' shape=(2, 512) dtype=float32_ref><tf.Variable 'Variable_3/RMSProp_1:0' shape=(2, 512) dtype=float32_ref><tf.Variable 'Variable_4/RMSProp:0' shape=(512, 784) dtype=float32_ref><tf.Variable 'Variable_4/RMSProp_1:0' shape=(512, 784) dtype=float32_ref><tf.Variable 'Variable_5/RMSProp:0' shape=(512,) dtype=float32_ref><tf.Variable 'Variable_5/RMSProp_1:0' shape=(512,) dtype=float32_ref><tf.Variable 'Variable_6/RMSProp:0' shape=(2,) dtype=float32_ref><tf.Variable 'Variable_6/RMSProp_1:0' shape=(2,) dtype=float32_ref><tf.Variable 'Variable_7/RMSProp:0' shape=(2,) dtype=float32_ref><tf.Variable 'Variable_7/RMSProp_1:0' shape=(2,) dtype=float32_ref><tf.Variable 'Variable_8/RMSProp:0' shape=(512,) dtype=float32_ref><tf.Variable 'Variable_8/RMSProp_1:0' shape=(512,) dtype=float32_ref><tf.Variable 'Variable_9/RMSProp:0' shape=(784,) dtype=float32_ref><tf.Variable 'Variable_9/RMSProp_1:0' shape=(784,) dtype=float32_ref>

很容易发现我们只需要从可训练的参数集中获取权重,还记得之前说过的么,我们啥都往sess.run中丢试试,看看能不能取出来值:

a=sess.run(graph.get_collection('trainable_variables'))for i in a:print(i.shape) (784, 512)(512, 2)(512, 2)(2, 512)(512, 784)(512,)(2,)(2,)(512,)(784,)

可以看出参数可以从a中通过索引取出来了。

接下来简单,重新定义一下模型的decoder计算:

latent_dim=2noise_input=tf.placeholder(tf.float32,shape=[None,latent_dim])decoder=tf.matmul(noise_input,a[3])+a[8]decoder=tf.nn.tanh(decoder)decoder=tf.matmul(decoder,a[4])+a[9]decoder=tf.nn.sigmoid(decoder)

然后就可以尝试丢进去一个均值和方差去预测了:

generate=sess.run(decoder,feed_dict={noise_input:[[3,3]]})

可视化

generate=generate*255.0gen_img=generate.reshape(28,28)plt.imshow(gen_img)plt.show()

效果还不错,再测试几个,丢[0,0][0,0][0,0]试试:

[0,5][0,5][0,5]试试:

好了,不玩了,随机性太大了,我都不知道啥噪声能输出啥数字,只有输出的时候才知道。

后记

好玩是好玩,但是我不知道哪个数字对应哪种噪声输入,还是有点郁闷,下一篇我们就去看看有名的搞基网GAN

本文训练代码:链接:https://pan.baidu.com/s/19QSNfT7fgWrU68CV7lXSZA 密码:6c7l

本文测试代码:链接:https://pan.baidu.com/s/1CAxPGnmCTg-OT8TqnXQdVw 密码:tmep

协助本站SEO优化一下,谢谢!
关键词不能为空
同类推荐
«    2025年12月    »
1234567
891011121314
15161718192021
22232425262728
293031
控制面板
您好,欢迎到访网站!
  查看权限
网站分类
搜索
最新留言
文章归档
网站收藏
友情链接