徐 哲,耿 杰,蔣 雯,張 卓,曾慶捷
(西北工業(yè)大學電子信息學院,西安710072)
圖像分類作為計算機視覺領(lǐng)域最基礎(chǔ)的任務之一,主要通過提取原始圖像的特征并根據(jù)特征學習進行分類[1]。傳統(tǒng)的特征提取方法主要是對圖像的顏色、紋理、局部特征等圖像表層特征進行處理實現(xiàn)的,例如尺度不變特征變換法[2],方向梯度法[3]以及局部二值法[4]等。但是這些特征都是人工設(shè)計的特征,很大程度上靠人類對識別目標的先驗知識進行設(shè)計,具有一定的局限性。隨著大數(shù)據(jù)時代的到來,基于深度學習的圖像分類方法具有對大量復雜數(shù)據(jù)進行處理和表征的能力,能夠有效學習目標的特征信息,從而提高圖像分類的精度[5-8]。
深度學習以大數(shù)據(jù)驅(qū)動的方式進行訓練,對標簽數(shù)據(jù)依賴性較強,而在現(xiàn)實應用中往往難以獲取大量的標簽數(shù)據(jù)。當樣本數(shù)量不足時,深度卷積網(wǎng)絡(luò)模型容易過擬合,導致分類性能較差。生成對抗網(wǎng)絡(luò)(Generative Adversarial Networks,GAN)[9]具有強大的數(shù)據(jù)生成能力,采用博弈對抗的方式,既訓練正常的樣本,也能對抗學習達到納什均衡,從而完成網(wǎng)絡(luò)的訓練。這樣GAN 在訓練時既能夠生成樣本,又能夠提高特征提取能力,可以用來解決小樣本條件下網(wǎng)絡(luò)過擬合的問題。但GAN 網(wǎng)絡(luò)還存在穩(wěn)定性差和依賴標簽數(shù)據(jù)的問題,不能直接應用于分類任務中。
針對GAN 存在的問題,有不少學者從網(wǎng)絡(luò)框架和理論模型兩個角度對GAN 進行了改進。從網(wǎng)絡(luò)框架角度,Radford 等[10]提出了深度卷積生成對抗網(wǎng)絡(luò)(Deep Convolutional GAN,DCGAN),將卷積神經(jīng)網(wǎng)絡(luò)應用到生成對抗網(wǎng)絡(luò)中,提高了GAN 訓練的穩(wěn)定性;Shaham 等[11]提出了單圖像生成對抗網(wǎng)絡(luò)(SinGAN),運用一個多尺度金字塔結(jié)構(gòu)的全卷積網(wǎng)絡(luò),能夠?qū)W習到不同尺度的圖像塊分布;Karnewar 等[12]提出了多尺度梯度生成對抗網(wǎng)絡(luò)(Multi-Scale Gradients GAN,MSG-GAN),通過從判別器到生成器的梯度流向多個尺度來解決訓練不穩(wěn)定的問題。從理論模型角度,Arjovsky 等[13]提出Wasserstein 生成對抗網(wǎng)絡(luò)(Wasserstein GAN,WGAN),使用Earth-Mover 距離代替JS(Jensen-Shannon)散度來計算生成樣本分布與真實樣本分布之間的距離,緩解了GAN 訓練不穩(wěn)定和梯度消失的問題;Goodfellow 等[14]提出了自注意力生成對抗網(wǎng)絡(luò)(Self-Attention Generative Adversarial Networks,SAGAN),通過引入自注意力機制來增大深度卷積網(wǎng)絡(luò)的感受野,從而更好地獲取圖像的全局信息。
為進一步提高圖像分類的準確率,解決GAN 訓練穩(wěn)定性差的問題,本文提出一種聯(lián)合訓練生成對抗網(wǎng)絡(luò)(Co-Training Generative Adversarial Networks,CT-GAN)的半監(jiān)督分類方法,設(shè)計兩個判別器進行聯(lián)合訓練,以消除單個判別器存在的分布誤差問題,同時利用大量無標簽數(shù)據(jù)和少量標簽數(shù)據(jù)進行半監(jiān)督學習,設(shè)計新的監(jiān)督損失和無監(jiān)督損失以優(yōu)化網(wǎng)絡(luò)模型,能夠?qū)W習到泛化能力較強、性能更好的模型,在一定程度上減小網(wǎng)絡(luò)對標簽數(shù)據(jù)的依賴,提高網(wǎng)絡(luò)的分類準確率。
生成對抗網(wǎng)絡(luò)是由Goodfellow 等[9]在2014年提出的無監(jiān)督生成模型,由一個生成器(Generator)和一個判別器(Discriminator)構(gòu)成。生成器依據(jù)樣本的數(shù)據(jù)分布來生成盡可能逼真的偽數(shù)據(jù),判別器用于判別輸入數(shù)據(jù)是真實數(shù)據(jù)還是生成器生成的偽數(shù)據(jù),生成器和判別器經(jīng)過博弈對抗達到納什均衡,此時生成的數(shù)據(jù)能夠擬合真實樣本的數(shù)據(jù)分布。GAN 的網(wǎng)絡(luò)結(jié)構(gòu)如圖1所示。
圖1 GAN 的網(wǎng)絡(luò)結(jié)構(gòu)Fig.1 Structure of GAN
生成器G和判別器D通常可以由卷積神經(jīng)網(wǎng)絡(luò)或者函數(shù)表示,G輸入隨機噪聲z用于生成偽數(shù)據(jù)G(z),D對輸入的真實數(shù)據(jù)x和偽數(shù)據(jù)G(z)判別真?zhèn)?,輸出其屬于真實樣本的概率。生成器G和判別器D通過損失函數(shù)相互博弈對抗進行訓練,其優(yōu)化過程是極大極小博弈的過程,目標函數(shù)為:
其中:x表示真實數(shù)據(jù),Pdata為x的數(shù)據(jù)分布,z表示服從標準正態(tài)分布的隨機噪聲,Pz為z的數(shù)據(jù)分布,G(z)表示生成器生成的偽數(shù)據(jù),D(·)表示判別器判別輸入樣本來自真實樣本的概率。對于判別器D而言,其希望判別的準確率越高,即希望D(x) 越接近1,D(G(z)) 越接近0,此時V(D,G)取極大值。對于生成器G而言,生成的能力越強,生成的數(shù)據(jù)分布越接近真實的數(shù)據(jù)分布,即希望D(G(z))越接近1 越好,此時V(D,G)取極小值。
當V(D,G)取到極大極小值時,生成對抗網(wǎng)絡(luò)達到納什均衡,此時生成的數(shù)據(jù)能夠擬合真實數(shù)據(jù)分布。
半監(jiān)督生成對抗網(wǎng)絡(luò)(Semi-Supervised Learning with Generative Adversarial Networks,SGAN)[15]是由Odena 提出的半監(jiān)督生成模型,其對原始GAN 網(wǎng)絡(luò)進行改進,引入半監(jiān)督學習,將標簽數(shù)據(jù)和無標簽數(shù)據(jù)共同輸入到判別器中進行訓練,并輸出K+1 維帶有類別信息的分類結(jié)果。SGAN 的網(wǎng)絡(luò)結(jié)構(gòu)如圖2 所示。
在SGAN 中,隨機噪聲z通過生成器生成的偽數(shù)據(jù)G(z)與K類標簽數(shù)據(jù)xl和無標簽數(shù)據(jù)xu共同輸入到判別器中進行訓練,在判別器的最后一層使用softmax 非線性分類器,最終輸出K+1維分類結(jié)果{l1,l2,…,lK+1},其中前K維輸出代表對應類的置信度,第K+1 維代表判定為“偽”的置信度。
圖2 SGAN 的網(wǎng)絡(luò)結(jié)構(gòu)Fig.2 Structure of SGAN
SGAN 采用了半監(jiān)督訓練方式,利用少量標簽數(shù)據(jù)和大量無標簽數(shù)據(jù)同時進行網(wǎng)絡(luò)訓練,從而提高半監(jiān)督分類的準確率。但有研究表明,SGAN 仍存在訓練不穩(wěn)定的問題[16],主要表現(xiàn)在訓練過程中可能出現(xiàn)梯度消失,導致網(wǎng)絡(luò)不收斂的問題。這一問題的原因是SGAN 在訓練過程中,單個判別器可能存在較大的分布誤差,從而造成梯度消失,判別器網(wǎng)絡(luò)不收斂。其中,分布誤差是指判別器對樣本類別預測時的概率分布誤差。一般情況下,判別器預測樣本類別的分布誤差都可以通過訓練迭代,逐漸消除其對網(wǎng)絡(luò)訓練的影響。但當出現(xiàn)較大的分布誤差時,判別器網(wǎng)絡(luò)會對樣本產(chǎn)生較大的誤判,造成梯度消失,使得判別器網(wǎng)絡(luò)不收斂,影響其分類性能。
為進一步提高圖像分類的準確率,解決SGAN 訓練不穩(wěn)定的問題,本文提出一種聯(lián)合訓練生成對抗網(wǎng)絡(luò)(Co-training GAN,CT-GAN)的半監(jiān)督分類方法,CT-GAN 的網(wǎng)絡(luò)結(jié)構(gòu)如圖3所示。
圖3 CT-GAN 的網(wǎng)絡(luò)結(jié)構(gòu)Fig.3 Structure of CT-GAN
在CT-GAN 中,采用了兩個判別器D1,D2進行聯(lián)合訓練,能夠有效提升網(wǎng)絡(luò)訓練穩(wěn)定性的同時提高圖像分類的準確率。判別器D1,D2共享同一個生成器G,同時兩個判別器的網(wǎng)絡(luò)結(jié)構(gòu)和初始參數(shù)設(shè)為相同。不同的是,將標簽數(shù)據(jù)和無標簽數(shù)據(jù)的順序打亂后分別輸入到判別器D1,D2中,即保證在訓練過程中兩個判別器是動態(tài)變化的。CT-GAN 采用兩個判別器進行聯(lián)合訓練,在訓練過程計算損失函數(shù)時,取兩個判別器損失的平均值,以消除單個判別器存在的分布誤差。同時在訓練過程中,兩個判別器不僅僅輸出K+1維分類結(jié)果,還設(shè)置了一個置信度閾值,如果生成數(shù)據(jù)的置信度高于該閾值,則賦予其偽標簽并加入到初始標簽數(shù)據(jù)集中,在訓練過程中就能夠擴充數(shù)據(jù)集,加快網(wǎng)絡(luò)收斂。
對于CT-GAN 的生成器G而言,G的能力越強,生成的圖像越接近真實圖像,即希望D(G(z))越接近1 越好,此時V(D,G)取極小值。由此可得到生成器的損失為:
同時為了讓生成器生成的數(shù)據(jù)分布更接近真實數(shù)據(jù)的統(tǒng)計分布,采用特征匹配[17]的方法對生成器的損失進行約束,定義特征匹配損失為:
其中:fj(·)表示判別器Dj在全連接層前的最后一層輸出的特征值。這樣,CT-GAN 生成器的總損失為:
對于CT-GAN 的判別器損失函數(shù),采取監(jiān)督損失和無監(jiān)督損失相結(jié)合的方式給出。對于判別器的監(jiān)督損失,需要加入標簽信息,因此以交叉熵的形式定義如下:
其中:yi表示第i維標簽,Dj(xi)表示判別器Dj判別標簽數(shù)據(jù)的標簽結(jié)果為第i維的概率。
對于無監(jiān)督損失,CT-GAN 需要判別無標簽數(shù)據(jù)的類別標簽??紤]到兩個判別器聯(lián)合訓練的情況,CT-GAN 判別器的無監(jiān)督損失定義如下:
其中:yi′表示判別器前一次迭代時判別無標簽數(shù)據(jù)的類別為第i維,Dj(xi)表示判別器Dj判別標簽數(shù)據(jù)的標簽結(jié)果為第i維的概率。
由式(5)和式(6)可得CT-GAN 判別器的總損失函數(shù)為:
由CT-GAN 生成器總損失函數(shù)和判別器總損失函數(shù)相加,可以得到CT-GAN 整體的損失函數(shù)如下:
CT-GAN 網(wǎng)絡(luò)的聯(lián)合訓練示意如圖4 所示。對于生成器生成的偽數(shù)據(jù)而言,判別器只需判斷其真?zhèn)?,不判別其類別,所以在此聯(lián)合訓練中暫不考慮偽數(shù)據(jù)的輸入,只考慮標簽數(shù)據(jù)和無標簽數(shù)據(jù)的輸入。
圖4 網(wǎng)絡(luò)聯(lián)合訓練示意圖Fig.4 Schematic of co-training method
在CT-GAN 中,為保證判別器D1,D2訓練時是動態(tài)變化的,首先將標簽數(shù)據(jù)和無標簽數(shù)據(jù)的順序打亂得到標簽樣本L1,L2和無標簽樣本U1,U2,分別輸入到判別器D1,D2中進行聯(lián)合訓練。以判別器D1為例,訓練過程按照以下步驟進行訓練:
(1)利用標簽樣本L1訓練判別器D1。標簽樣本L1輸入到判別器D1中,輸出L1分類結(jié)果,計算判別器的監(jiān)督損失以訓練判別器D1;
(2)利用判別器D1來預測無標簽樣本U1的標簽。判別器D1將前一次迭代得到的U1分類結(jié)果轉(zhuǎn)化為獨熱向量并認為是當前無標簽樣本U1的標簽,與當前得到的U1分類結(jié)果共同計算判別器的無監(jiān)督損失,從而不斷優(yōu)化預測無標簽樣本U1的標簽;
(3)利用無標簽樣本U1擴充標簽樣本L2。設(shè)置一個置信度閾值,對每次迭代得到的無標簽樣本U1的分類結(jié)果進行置信度判斷,如果大于該置信度閾值,則賦予其偽標簽并加入到對應的標簽樣本L2中繼續(xù)訓練,這樣在訓練過程中就可以擴充數(shù)據(jù)集,加快網(wǎng)絡(luò)收斂。
CT-GAN 模型通過判別器D1,D2的聯(lián)合訓練,一方面可以消除單個判別器存在的分布誤差,提高判別器訓練的穩(wěn)定性;另一方面,利用無標簽數(shù)據(jù)在訓練時擴充標簽數(shù)據(jù)集,能夠加快網(wǎng)絡(luò)收斂。因此,CT-GAN 模型能夠充分利用少量標簽數(shù)據(jù)的標簽信息和大量無標簽數(shù)據(jù)的分布信息來獲取整個樣本的特征分布,從而進一步提高網(wǎng)絡(luò)識別的精度。
本文實驗所使用的數(shù)據(jù)集為CIFAR-10 和SVHN 數(shù)據(jù)集,其中CIFAR10 數(shù)據(jù)集是一個包含10 個類別32×32 的彩色圖像數(shù)據(jù)集,共計60 000 張圖像,其中40 000 張作為訓練集,20 000張作為測試集,即每個類別有4 000 張訓練樣本和2 000 張測試樣本。SVHN 數(shù)據(jù)集是一個真實街景數(shù)字數(shù)據(jù)集,包含10 個類別32×32 的彩色圖像,共計99 289 張圖片,其中73 257 張作為訓練集,26 032 張作為測試集。
數(shù)據(jù)集中的每張圖像均包含一個類別信息,即均為標簽數(shù)據(jù)。為滿足本文實驗要求,對訓練集中圖像進行預處理,按一定的比例隨機去除部分標簽數(shù)據(jù)的類別信息,得到無標簽數(shù)據(jù),CIFAR-10 數(shù)據(jù)集和SVHN 數(shù)據(jù)集的預處理方案如表1 所示。其中CIFAR-10 數(shù)據(jù)集在各類別標簽數(shù)量分別為10,100,250,500,1 000 和2 000 時分別進行實驗,SVHN 數(shù)據(jù)集在各類別標簽數(shù)量分別為100 和1 000 時分別進行實驗,以研究不同數(shù)量標簽數(shù)據(jù)對網(wǎng)絡(luò)的影響。
表1 CIFAR10數(shù)據(jù)集各類別標簽數(shù)據(jù)的數(shù)量及所占比例Tab.1 Amount and proportion of labeled data in each category of the CIFAR10 data
本實驗采用一個RTX 2080Ti 的GPU 進行訓練,共訓練200 個epochs,且設(shè)置batch size 為128,即每個epoch 迭代313 次。設(shè)置初始學習率為0.000 2,并在迭代50 000 次和90 000 次時分別衰減為原來的1/10。采用Adam 優(yōu)化算法對網(wǎng)絡(luò)進行優(yōu)化,其中一階動量設(shè)為0.5,二階動量設(shè)為0.999。模型采用基于PyTorch 的深度學習框架實現(xiàn)。
在CIFAR-10 數(shù)據(jù)集上,CT-GAN 模型的生成器框架和判別器框架分別如圖5(a)和(b)所示。生成器的輸入為(128,100)的隨機噪聲,首先通過(100,8 192)的全連接層得到(128,8 192)的張量,經(jīng)過維度轉(zhuǎn)換得到維度為(128,128,8,8)的圖像,經(jīng)過兩次上采樣操作和三次步長為1的3×3 卷積核的卷積操作后得到維度為(128,3,32,32)的圖像,其中每次完成卷積操作后都使用批歸一化(Batch Normalization)操作并加入ReLU 激活函數(shù)。最后一層通過Tanh 激活函數(shù)輸出生成數(shù)據(jù)G(z)。
判別器的輸入為128 張大小為32×32 的3通道RGB 彩色圖像,其維度為(128,3,32,32),經(jīng)過四次步長為2 的3×3 卷積核的卷積操作,最終輸出圖像維度為(128,128,2,2),其中每次完成卷積操作后都加入LeakyReLU 激活函數(shù)和Dropout 操作以防止過擬合,而除了首次卷積不使用批歸一化外,其余卷積操作后都使用批歸一化。將卷積輸出圖像進行維度轉(zhuǎn)換得到維度為(128,512)的張量,通過(512,10)的全連接層和softmax 分類器得到分類結(jié)果,同時通過(512,1)的全連接層和Sigmoid 分類器判別真?zhèn)巍?/p>
圖5 CT-GAN 模型的生成器和判別器框架Fig.5 Structure of generator and discriminator in CTGAN
在CIFAR-10 數(shù)據(jù)集上的實驗首先按照4.1節(jié)中的數(shù)據(jù)集預處理方案,對數(shù)據(jù)集中的圖像按一定的比例去除部分標簽數(shù)據(jù)的標簽信息,構(gòu)成無標簽數(shù)據(jù)。在各類別標簽數(shù)量分別為10,100,250,500,1 000 和2 000 時分別進行實驗,以研究不同數(shù)量的標簽數(shù)據(jù)下CT-GAN 模型的性能。如圖6 和圖7 給出了在各類標簽數(shù)據(jù)數(shù)量分別為10,100,250,500,1 000 和2 000 時的CT-GAN 判別器和生成器損失變化曲線。
分析圖6 可知,在不同數(shù)量的標簽數(shù)據(jù)下,CT-GAN 的判別器損失在一定迭代次數(shù)后都達到了穩(wěn)定,標簽數(shù)據(jù)越少,損失趨于穩(wěn)定需要迭代的次數(shù)也越少。這是因為當標簽數(shù)量越少時,整個數(shù)據(jù)所含的類別信息也就越少,判別器可以學習的信息也相應減少,導致?lián)p失收斂速度加快。雖然不同標簽數(shù)量下?lián)p失收斂所需的迭代次數(shù)不同,但是其損失收斂值大致相同。這說明標簽數(shù)量對CT-GAN 的判別器的訓練影響很小,在一定程度上CT-GAN 模型能夠減小對標簽數(shù)據(jù)的依賴。分析圖7 可知,在不同數(shù)量的標簽數(shù)據(jù)下,CT-GAN 的生成器損失值逐漸減小并收斂到較低水平。
為了驗證本文方法的有效性,利用CIFAR-10 數(shù)據(jù)集對比了不同數(shù)量標簽數(shù)據(jù)下CT-GAN模型與相關(guān)的深度網(wǎng)絡(luò)模型的分類效果,其分類準確率如表2 所示。實驗在不同條件下分別進行了20 次重復實驗,計算平均精度和方差。
圖6 CT-GAN 判別器損失變化曲線Fig.6 Discriminator loss of CT-GAN
圖7 CT-GAN 生成器損失變化曲線Fig.7 Generator loss of CT-GAN
分析表2 的分類準確率可知,本文提出的CT-GAN 模型在CIFAR-10 數(shù)據(jù)集上的分類精度更高,在不同數(shù)量的標簽數(shù)據(jù)下的分類精度都有不同程度的提升,在標簽數(shù)據(jù)數(shù)量僅為10 時,就可以達到47.6%的分類精度,相比SGAN 模型提高了6.5%,這說明CT-GAN 模型能夠有效提升在標簽數(shù)據(jù)極少情況下的分類準確率,在一定程度上解決了GAN 網(wǎng)絡(luò)在小樣本條件下的過擬合問題。
表2 CIFAR-10 數(shù)據(jù)集上不同數(shù)量標簽樣本的半監(jiān)督分類精度Tab.2 Using different number of labeled data when semi-supervised training on CIFAR-10(%)
為更好地說明本文所提算法的有效性,在SVHN 數(shù)據(jù)集上進行實驗。按照4.1 節(jié)中的數(shù)據(jù)集預處理方案,在各類別標簽數(shù)量分別為100和1 000 時分別進行實驗,以研究不同數(shù)量的標簽數(shù)據(jù)下CT-GAN 模型的性能。表3 為SVHN數(shù)據(jù)集上不同數(shù)量標簽樣本的半監(jiān)督分類精度。實驗在不同條件下分別進行了20 次重復實驗,計算平均精度和方差。
分析表3 可知,本文所提方法CT-GAN 模型在SVHN 數(shù)據(jù)集上的分類性能優(yōu)異,在不同數(shù)量的標簽數(shù)據(jù)下的分類精度都達到了較高水平,特別是當標簽樣本數(shù)量僅為100 時,即少量標簽樣本的情況下,達到了77.7%,相較于其他算法分別高38.33%,21.40%,6.34%和13.85%,進一步說明CT-GAN 模型能夠在少量標簽樣本條件下有效提升網(wǎng)絡(luò)的分類精度。同時,CT-GAN 在不同標簽樣本數(shù)量下的分類精度誤差都在0.1%左右,相較于其他對比方法,本文所提模型訓練更加穩(wěn)定。
表3 SVHN 數(shù)據(jù)集上不同數(shù)量標簽樣本的半監(jiān)督分類精度Tab.3 Classification accuracy of different number of labeled data on SVHN(%)
本文提出了一種基于聯(lián)合訓練生成對抗網(wǎng)絡(luò)(CT-GAN)的半監(jiān)督分類方法,通過兩個判別器的聯(lián)合訓練來消除單個判別器存在的分布誤差,同時利用無標簽數(shù)據(jù)來擴充標簽數(shù)據(jù)集,可以有效提升半監(jiān)督分類的精度。實驗結(jié)果表明,在少量標簽樣本條件下,CT-GAN 模型能夠有效提升圖像分類精度,在一定程度上降低了GAN 網(wǎng)絡(luò)對標簽數(shù)據(jù)的依賴。此外,在不同數(shù)量的標簽數(shù)據(jù)下,CT-GAN 模型都取得了較好的分類效果,多種情況下的分類準確率相比其他方法都有一定程度提升,說明了本文模型的有效性。