張 佳,張麗紅
(山西大學(xué) 物理電子工程學(xué)院,山西 太原 030006)
文本生成圖像是將融合噪聲向量的文本描述信息輸入到生成對(duì)抗網(wǎng)絡(luò),生成相應(yīng)圖像。傳統(tǒng)方法主要基于基本生成對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Networks, GAN)[1],由生成器和判別器兩部分組成,通常采用淺層卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Networks, CNN)構(gòu)成,雖然可以基本完成圖像生成任務(wù),但生成圖像質(zhì)量較低。
目前,國(guó)內(nèi)外對(duì)文本生成圖像的研究主要通過(guò)增加網(wǎng)絡(luò)深度和改進(jìn)生成器網(wǎng)絡(luò)提高圖像生成質(zhì)量。Scott Reed等[2]首次提出將循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network,RNN)作為文本編碼器,以CNN組成生成對(duì)抗網(wǎng)絡(luò),作為主干網(wǎng)絡(luò)的網(wǎng)絡(luò)模型,能夠基本實(shí)現(xiàn)文本生成圖像;Han Zhang等[3]提出StackGAN模型,其主干網(wǎng)絡(luò)是將生成對(duì)抗網(wǎng)絡(luò)基本模塊進(jìn)行兩次堆疊,使圖像生成過(guò)程分為兩階段進(jìn)行,第一階段生成低分辨率圖像,第二階段以第一階段為基礎(chǔ)生成高分辨率圖像。Dunlu Peng等[4]提出新的文本視覺(jué)特征融合方法,分別將詞語(yǔ)特征、語(yǔ)句特征與視覺(jué)特征進(jìn)行融合,通過(guò)3次堆疊生成對(duì)抗網(wǎng)絡(luò)模塊進(jìn)行圖像生成。
由于上述方法生成的圖像分辨率較低,分階段生成圖像訓(xùn)練過(guò)程繁瑣,計(jì)算量過(guò)大。因此,受文獻(xiàn)[3]中條件增強(qiáng)模塊和文獻(xiàn)[5]中卷積注意力機(jī)制的啟發(fā),本文采用注意力機(jī)制和條件增強(qiáng)模塊改進(jìn)生成器網(wǎng)絡(luò),在文本特征和視覺(jué)特征的融合過(guò)程中,加強(qiáng)生成圖像和給定文本描述之間的語(yǔ)義一致性,不需要進(jìn)行多次文本視覺(jué)特征融合,訓(xùn)練過(guò)程簡(jiǎn)潔,在相關(guān)數(shù)據(jù)集上得到優(yōu)良的實(shí)驗(yàn)結(jié)果。
基于條件增強(qiáng)和注意力機(jī)制的深度融合生成對(duì)抗網(wǎng)絡(luò)的整體網(wǎng)絡(luò)架構(gòu)如圖1 所示,整個(gè)網(wǎng)絡(luò)模型由文本處理網(wǎng)絡(luò)和生成對(duì)抗網(wǎng)絡(luò)組成。
圖1 基于條件增強(qiáng)和注意力機(jī)制的深度融合生成對(duì)抗網(wǎng)絡(luò)結(jié)構(gòu)
文本處理網(wǎng)絡(luò)由文本編碼器和條件增強(qiáng)模塊組成。文本編碼器采用雙向長(zhǎng)短期記憶網(wǎng)絡(luò)(Bidirectional Long Short-Term Memory,BiLSTM)[5]對(duì)文本進(jìn)行特征提取,條件增強(qiáng)模塊(Conditioning Augmentation,CA)[6]進(jìn)一步豐富文本語(yǔ)義信息。
生成對(duì)抗網(wǎng)絡(luò)由生成器G和判別器D組成。生成器由上采樣殘差塊、注意力機(jī)制和卷積層組成。該網(wǎng)絡(luò)有兩個(gè)輸入,一是文本處理網(wǎng)絡(luò)所得文本特征,二是服從高斯分布的隨機(jī)噪聲向量Z~N(0,1)。兩者在上采樣殘差塊中逐步融合得到高分辨率圖像特征,輸入注意力機(jī)制中對(duì)文本描述的關(guān)鍵信息進(jìn)行處理,再通過(guò)卷積層生成圖像。判別器由卷積層和下采樣殘差塊組成,將生成器所得生成圖像通過(guò)卷積層進(jìn)行特征提取,下采樣殘差塊對(duì)特征進(jìn)行下采樣并融合文本特征,判別生成圖像與真實(shí)圖像,結(jié)合MA-GP損失設(shè)計(jì)對(duì)抗損失函數(shù)進(jìn)行網(wǎng)絡(luò)評(píng)估,反饋更新生成器參數(shù)生成更高質(zhì)量的圖像。
2.1.1 BiLSTM網(wǎng)絡(luò)
文本處理網(wǎng)絡(luò)采用雙向長(zhǎng)短期記憶網(wǎng)絡(luò)作為文本編碼器,其目標(biāo)是從語(yǔ)句中學(xué)習(xí)文本特征表示,將語(yǔ)句中的單詞依次輸入BiLSTM網(wǎng)絡(luò)。該網(wǎng)絡(luò)由若干長(zhǎng)短期記憶模塊(Long short-term memory,LSTM)組成,該模塊能夠捕捉雙向語(yǔ)義信息,丟棄需要遺忘的信息并記憶新的信息,使有效信息得以傳遞。LSTM網(wǎng)絡(luò)結(jié)構(gòu)如圖2 所示,主要由遺忘門ft、記憶門it、輸出門ot組成。
圖2 長(zhǎng)短期記憶網(wǎng)絡(luò)模型結(jié)構(gòu)
遺忘門
ft=σ(wf×[ht-1,xt]+bf).
(1)
記憶門
it=σ(wi×[ht-1,xt]+bi).
(2)
記憶內(nèi)容更新單元
(3)
當(dāng)前時(shí)刻記憶單元
(4)
輸出門
ot=σ(wo×[ht-1,xt]+bo).
(5)
當(dāng)前記憶單元輸出
ht=ot×tanh(ct),
(6)
LSTM模塊數(shù)量由單詞數(shù)決定,如圖3 所示,以3個(gè)單詞為例,由3個(gè)前向LSTM模塊和3個(gè)后向LSTM模塊構(gòu)成BiLSTM模型。由LSTML得到前向隱向量特征{hL0,hL1,hL2},由LSTMR得到后向隱向量特征{hR0,hR1,hR2},兩者拼接為{[hL0,hR0],[hL1,hR1],[hL2,hR2]},即文本特征φt={h0,h1,h2}。
圖3 雙向長(zhǎng)短期記憶網(wǎng)絡(luò)模型結(jié)構(gòu)
2.1.2 條件增強(qiáng)模塊
圖4 條件增強(qiáng)模塊
(7)
式中:?表示按元素相乘;ε表示服從數(shù)學(xué)期望為μ;方差為σ2的正態(tài)分布,記為ε~N(μ,σ2)。取μ=0,σ=1時(shí),該分布為標(biāo)準(zhǔn)正態(tài)分布ε~N(0,1),并且此隨機(jī)噪聲維度為100×1×1。
生成對(duì)抗網(wǎng)絡(luò)由生成器和判別器組成,二者交替訓(xùn)練以相互競(jìng)爭(zhēng)。生成器不斷優(yōu)化生成判別器難以區(qū)分的圖像,盡可能再現(xiàn)真實(shí)數(shù)據(jù)分布,同時(shí)促使判別器不斷優(yōu)化以區(qū)分真實(shí)圖像和生成圖像??傮w而言,訓(xùn)練過(guò)程類似于二者交替進(jìn)行最小、最大博弈。
2.2.1 生成器G
生成器有兩個(gè)輸入。一是文本處理網(wǎng)絡(luò)所得文本特征,二是從與ε相同的標(biāo)準(zhǔn)正態(tài)分布中取樣,得到維度同為100×1×1的隨機(jī)噪聲向量Z~N(0,1),通過(guò)全連接層輸出尺寸為(ngf*8)×4×4的特征張量,其中ngf為生成器的特征數(shù)目,初始大小為64。生成器由上采樣殘差塊、卷積注意力模塊(Convolutional Block Attention Module,CBAM)和卷積層(卷積核大小為3×3)組成。
上采樣殘差塊設(shè)置為7層,在該模塊中逐步加深文本特征與視覺(jué)特征融合以得到高分辨率圖像特征。上采樣殘差塊由上采樣、殘差網(wǎng)絡(luò)和深度融合塊(Deep Fusion Blocks,DFBlocks)組成,模塊結(jié)構(gòu)如圖5(a)所示。上采樣采用緊鄰插值法,放大倍率設(shè)置為2,逐步將分辨率從4×4上采樣至256×256。使用殘差網(wǎng)絡(luò)可以緩解由于網(wǎng)絡(luò)層數(shù)加深而導(dǎo)致的梯度消失問(wèn)題。DFBlock的模型結(jié)構(gòu)如圖5(b)所示。該模型由兩個(gè)仿射層和ReLU層依次堆疊組成,文本特征作為條件作用于仿射層,使該模塊可以更好地發(fā)揮作用,有助于充分利用文本信息,實(shí)現(xiàn)更有效的特征融合。
圖5 上采樣殘差塊
圖6 仿射層原理圖
(8)
(9)
式中:AFF表示仿射變換;xi表示視覺(jué)特征通道數(shù)。
將經(jīng)過(guò)上采樣殘差塊的高分辨率圖像特征輸入到CBAM注意力模塊,該模塊是一種有效而簡(jiǎn)單的前饋卷積神經(jīng)網(wǎng)絡(luò)注意力機(jī)制。給定輸入特征,通過(guò)卷積運(yùn)算可以使得該特征在通道和空間兩個(gè)維度上進(jìn)行特征細(xì)化。CBAM總體模型如圖7 所示,該模塊有兩個(gè)順序子模塊:通道注意力機(jī)制和空間注意力機(jī)制,通過(guò)順序連接方式將二者結(jié)合。實(shí)驗(yàn)結(jié)果表明,將通道注意力模塊置于空間注意力模塊之前效果更好。將輸入特征F∈RC×H×W輸入CBAM模塊,依次輸出一維通道注意特征F′∈RC×1×1和二維空間注意特征F″∈R1×H×W??傔^(guò)程如式(10)
圖7 CBAM模塊
F′=MC(F)?F,F″=MS(F′)?F′,
(10)
式中: ?表示按元素相乘;F′是經(jīng)過(guò)通道注意力機(jī)制MC的中間特征;F″是經(jīng)過(guò)空間注意力機(jī)制MS的細(xì)化特征。
圖8 通道注意力機(jī)制模塊
Mc(F)=σ(MLP(AvgPool(F))+
(11)
圖9 空間注意力機(jī)制模塊
MS(F″)=
σ(f7×7([AvgPool(F′);MaxPool(F′)]))=
(12)
式中:σ表示sigmoid激活函數(shù);f7×7表示濾波器大小為7×7的卷積運(yùn)算;[;]表示特征融合。
2.2.2 判別器D
判別器的輸入為生成器的生成圖像。判別器由下采樣殘差塊和二維卷積層(3×3)組成。下采樣殘差塊將下采樣嵌入到殘差網(wǎng)絡(luò)中。圖像通過(guò)卷積層(3×3)進(jìn)行特征提取,所得圖像特征分辨率為128×128,下采樣殘差塊設(shè)置為6層,在殘差網(wǎng)絡(luò)中通過(guò)一個(gè)二維的平均池化下采樣操作得到分辨率為4×4的圖像特征,在此基礎(chǔ)上融合文本特征,輸入判別器中得到對(duì)抗損失,并融合匹配感知零中心梯度損失(Matching-Aware zero-centered Gradient Penalty,MA-GP)對(duì)網(wǎng)絡(luò)進(jìn)行評(píng)估,判定文本圖像語(yǔ)義一致性,調(diào)整生成器參數(shù)得到更高質(zhì)量的圖像。
在二維數(shù)據(jù)空間中有4種數(shù)據(jù)對(duì):文本匹配的生成圖像、文本不匹配的生成圖像、文本匹配的真實(shí)圖像、文本不匹配的真實(shí)圖像。為從給定文本描述中生成文本匹配的真實(shí)圖像,判別器應(yīng)將該數(shù)據(jù)點(diǎn)放在損失函數(shù)最小點(diǎn),并將其他數(shù)據(jù)點(diǎn)置于高點(diǎn)。為更好地區(qū)分生成圖像和真實(shí)圖像,在原本損失函數(shù)的基礎(chǔ)上,引入MA-GP損失。該損失函數(shù)應(yīng)用的數(shù)據(jù)點(diǎn)為文本匹配的真實(shí)圖像,即
(13)
該損失是一種基于判別器的正則化策略。實(shí)驗(yàn)證明,在對(duì)抗損失的基礎(chǔ)上融合MA-GP損失,可以使得判別器提高圖像判別能力。與其他模型方法相比,MA-GP沒(méi)有引入額外網(wǎng)絡(luò)計(jì)算文本圖像的語(yǔ)義一致性,因此,不會(huì)增加文本生成圖像過(guò)程中的網(wǎng)絡(luò)復(fù)雜度和訓(xùn)練參數(shù)。
使用Adam[7]優(yōu)化網(wǎng)絡(luò),β1=0.0,β2=0.9,生成器的學(xué)習(xí)率設(shè)置為0.000 1,判別器的學(xué)習(xí)率設(shè)置為0.000 4。對(duì)CUB birds 200數(shù)據(jù)集進(jìn)行500個(gè)輪次的訓(xùn)練,批量處理數(shù)量為24,對(duì)MSCOCO數(shù)據(jù)集進(jìn)行300個(gè)輪次的訓(xùn)練,批量處理數(shù)量為12。
該網(wǎng)絡(luò)損失函數(shù)為
LD=L文本匹配真實(shí)圖像+L文本匹配生成圖像+L文本不匹配損失+
(14)
實(shí)驗(yàn)在MSCOCO和CUB birds 200兩個(gè)數(shù)據(jù)集進(jìn)行。CUB birds 200數(shù)據(jù)集包含200種鳥(niǎo)類,11 788幅圖像,每幅圖像有10種語(yǔ)言描述,將150種鳥(niǎo)類圖像共8 855幅圖像作為訓(xùn)練集,50種鳥(niǎo)類圖像共2 933幅圖像作為測(cè)試集。MSCOCO 2017數(shù)據(jù)集包含91種類別,每幅圖像有5種語(yǔ)言描述,其中117 266 幅圖像作為訓(xùn)練集,40 670 幅圖像作為測(cè)試集。
采用圖像分?jǐn)?shù)(Image Score,IS)和弗雷歇初始距離(Fréchet Inception Distance,F(xiàn)ID)評(píng)估文本生成圖像模型的性能。
IS計(jì)算生成圖像的清晰度,更高的IS意味著生成圖像質(zhì)量更高。
IS=exp(Ex~PgDKL((y|x)‖p(y))),
(15)
式中:x表示從生成圖像數(shù)據(jù)分布中采樣的數(shù)據(jù);y是預(yù)訓(xùn)練網(wǎng)絡(luò)預(yù)測(cè)的圖像標(biāo)簽;IS計(jì)算條件分布p(y|x)和邊緣分布p(y)之間的KL散度。如果模型能夠生成多樣化和真實(shí)的圖像,那么兩個(gè)分布之間的KL差異將很大。
FID計(jì)算生成圖像的特征向量和真實(shí)圖像的特征向量之間的距離。該距離越近,表明模型的多樣性越好。
(16)
對(duì)基于條件增強(qiáng)和注意力機(jī)制的深度融合生成網(wǎng)絡(luò)進(jìn)行訓(xùn)練和測(cè)試。部分實(shí)驗(yàn)結(jié)果如圖10、表1~表3 所示。
圖10 中(a),(b)為本文網(wǎng)絡(luò)在CUB birds 200數(shù)據(jù)集的實(shí)驗(yàn)結(jié)果圖,圖(c),(d)為在MSCOCO數(shù)據(jù)集的實(shí)驗(yàn)結(jié)果圖,(e),(f),(g),(h)分別為基礎(chǔ)生成對(duì)抗網(wǎng)絡(luò)、stackGAN,DM-GAN,AttnGAN對(duì)于鳥(niǎo)類數(shù)據(jù)集的生成效果,可以觀察到本文網(wǎng)絡(luò)模型生成圖像細(xì)節(jié)效果最優(yōu)。
圖10 深度融合網(wǎng)絡(luò)運(yùn)行結(jié)果圖
表1 為其他方法與本文方法在鳥(niǎo)類數(shù)據(jù)集上的IS指標(biāo),通過(guò)對(duì)模型的改進(jìn),IS指標(biāo)提升了0.6,體現(xiàn)出本文模型的優(yōu)勢(shì)。
表1 本文方法與其他模型的評(píng)價(jià)指標(biāo)比較
表2 為其他方法與本文方法在鳥(niǎo)類數(shù)據(jù)集和COCO數(shù)據(jù)集的FID指標(biāo),通過(guò)對(duì)模型的改進(jìn),對(duì)于CUB birds 200數(shù)據(jù)集FID指標(biāo)提升了0.59,對(duì)于COCO數(shù)據(jù)集FID指標(biāo)提升了2.05,體現(xiàn)出本文模型的優(yōu)勢(shì)。
表2 本文方法與其他模型的評(píng)價(jià)指標(biāo)比較
表3 為MA-GP損失的消融研究結(jié)果對(duì)比。Baseline表示不引入該損失的網(wǎng)絡(luò)模型,Baseline+MA-GP表示引入該損失的網(wǎng)絡(luò)模型。實(shí)驗(yàn)結(jié)果表明,該損失函數(shù)的引入有助于提升圖像生成效果。
表3 MA-GP損失消融結(jié)果對(duì)比
本文網(wǎng)絡(luò)在CUB birds 200數(shù)據(jù)集上生成器和判別器的損失函數(shù)曲線如圖11、圖12 所示。其中生成器的損失函數(shù)值趨向于2,判別器的損失函數(shù)值趨向于1.1,網(wǎng)絡(luò)收斂。
圖11 生成器損失函數(shù)曲線
圖12 判別器損失函數(shù)曲線
本文設(shè)計(jì)了一種基于條件增強(qiáng)和注意力機(jī)制的深度融合生成對(duì)抗網(wǎng)絡(luò)用于文本生成圖像的方法。對(duì)于文本處理網(wǎng)絡(luò),通過(guò)BiLSTM網(wǎng)絡(luò)進(jìn)行文本特征提取,然后,使用條件增強(qiáng)模塊豐富文本語(yǔ)義信息。對(duì)于生成器,將所得文本特征融合噪聲向量輸入上采樣殘差塊得到高分辨率圖像特征。使用注意力機(jī)制對(duì)特征進(jìn)行調(diào)整,之后通過(guò)卷積層得到生成圖像。對(duì)于判別器,對(duì)生成圖像進(jìn)行特征提取,通過(guò)下采樣殘差塊降低特征分辨率,將對(duì)抗損失與MA-GP損失相結(jié)合,對(duì)模型進(jìn)行優(yōu)化。實(shí)驗(yàn)結(jié)果表明,該網(wǎng)絡(luò)模型的IS和FID指標(biāo)均優(yōu)于其他網(wǎng)絡(luò)模型。