張目飛,李 廷,蘇 鵬
(1 浪潮云信息技術股份公司 服務研發(fā)部,濟南 250000;2 山東浪潮新基建科技有限公司,濟南 250000)
隨著個人智能設備和圖像相關應用的普及,會產(chǎn)生大量的圖像數(shù)據(jù),如何高效、合理地對這些圖像數(shù)據(jù)進行合理的分類是一項技術難題。在過去的幾年中,深度神經(jīng)網(wǎng)絡(DNN)在計算機視覺和模式識別任務中,如:圖像分類、語義分割、對象檢測應用廣泛。卷積神經(jīng)網(wǎng)絡中的卷積層能夠捕獲圖像的局部特征,以獲得與輸入維度相似的空間表示,使用全連接層和softmax 分類層生成概率表示,來達到分類效果[1]。He 等[2]提出了深度殘差網(wǎng)絡ResNet34,引入了殘差結構,可以更好地學習殘差信息,并在后續(xù)層中使用這些殘差信息,提高了圖像分類的性能,為深度學習領域帶來了新的思路和方法。
許多基于深度神經(jīng)網(wǎng)絡,在網(wǎng)絡學習過程中添加注意力機制來獲得圖像中感興趣區(qū)域,通過選擇給定輸入的特征通道、區(qū)域來自動提取相關特征[3]。Woo 等[4]將注意力機 制模塊集成 到CNN中,提高網(wǎng)絡的特征表達能力,從而提高了圖像分類的準確率;Wang[5]提出了殘差注意網(wǎng)絡,殘差結構可以使網(wǎng)絡更好地學習圖像中的特征,通過添加注意力模塊來學習圖像中的局部區(qū)域特征;Park 等[6]提出了一種新的注意力機制,可以在空間和通道維度上同時進行特征加權,更加準確地捕捉到圖像中的重要信息;Xi 等[7]提出用殘差注意模塊進行特征提取,以增強分類任務中的關鍵特征,抑制無用的特征;Liang[8]提出將自下而上和自上而下的前饋注意力殘差模塊用于圖像分類。以上工作說明殘差結構和注意力機制都可以幫助模型更好地學習圖像特征,提高圖像分類的準確性。
隨著數(shù)據(jù)集規(guī)模的增大和類別的增多,訓練一個高準確率的分類模型變得越來越困難。傳統(tǒng)的數(shù)據(jù)增強方法對原始圖像進行幾何變換或者對圖像進行隨機擾動,雖然可以增加數(shù)據(jù)集的樣本量,提高分類模型的準確率,但是這些方法無法生成新的數(shù)據(jù)分布。而生成網(wǎng)絡是一種可以學習數(shù)據(jù)分布的生成模型,可以生成新的樣本,從而擴大數(shù)據(jù)集并且增加數(shù)據(jù)多樣性,從而可以提高分類模型的泛化性[9]。因此,本文提出一個深度殘差注意力生成網(wǎng)絡來生成圖像數(shù)據(jù),對數(shù)據(jù)進行必要的數(shù)據(jù)增強,利用ResNet34 網(wǎng)絡進行圖像分類。
本文提出了一個深度殘差注意力生成網(wǎng)絡模型用于圖像數(shù)據(jù)增強,主要結構包括生成器、判別器和殘差注意力模塊。生成器包含4 個反卷積層(DConv)和3 個殘差注意力模型(SPAM),殘差注意力模型能夠對圖像的重點區(qū)域進行特別關注,以生成高質量的圖像,在生成器的最后一層使用Tanh 函數(shù)將數(shù)據(jù)映射到[-1,1]的區(qū)間內(nèi);判別器包括4 個卷積層(Conv),能夠提取圖像細節(jié)特征。深度殘差注意力生成網(wǎng)絡模型結構如圖1 所示。
圖1 深度殘差注意力生成網(wǎng)絡模型結構Fig.1 Deep residual attention generation network model
生成網(wǎng)絡由生成器和判別器組成。生成器將隨機向量Z作為輸入,學習真實數(shù)據(jù)分布p(x)從而合成逼真的圖像;判別器區(qū)分生成的圖像與真實的圖像,其輸出表示從真實分布p(x)提取樣本y的概率。生成網(wǎng)絡的最終目標是讓生成器生成和真實圖像相同的數(shù)據(jù)分布,而判別器無法判定圖像為真實圖像還是生成圖像,達到一個納什平衡。在生成器和判別器相互博弈的過程中,生成網(wǎng)絡的目標函數(shù)定義為公式(1):
其中,p(x)表示真實數(shù)據(jù)分布;p(z)表示生成數(shù)據(jù)分布;D(x)表示判別器運算;G(z)表示生成器運算。
本文隨機選取Z=100 維的隨機數(shù)據(jù)作為生成器的輸入,經(jīng)過生成器生成圖像;判別器網(wǎng)絡的輸入為生成圖像和真實圖像,判別器網(wǎng)絡指導生成器合成圖像,鼓勵生成器捕捉更為精細的特征細節(jié),使得生成器生成的圖像和真實圖像難以區(qū)分。
殘差注意力模型使具有相似特征的區(qū)域相互增強,以突出全局視野中的感興趣區(qū)域,殘差注意力模型如圖2 所示。通過sigmoid 函數(shù)可以得到一個[0,1]的系數(shù),給每個通道或空間分配不同的權重,可以給每個特征圖分配不同的重要程度。
圖2 殘差注意力模型Fig.2 Residual attention model
本文設C × H × W為殘差注意力模型的輸入,C為特征圖的數(shù)量,H和W分別表示為圖像的高度和寬度;通過卷積和批量歸一化運算對輸入的特征進行處理,利用Sigmoid函數(shù)得到空間注意系數(shù)S;將輸入的特征圖和通過注意力模型得到的特征圖利用殘差結構進行融合,得到最終的殘差空間注意力特征表示,公式(2)和公式(3):
其中,X表示空間注意模型的輸入,Conv 表示卷積運算。
首先,對輸入圖像進行數(shù)據(jù)預處理,主要包括:將圖像裁剪為28×28 的大小,并進行隨機旋轉和對比度增強;其次,將預處理的數(shù)據(jù)送入到深度殘差注意力生成網(wǎng)絡中進行數(shù)據(jù)增強。深度殘差注意力生成網(wǎng)絡通過學習圖像不變性特征,合成高質量的數(shù)據(jù),注意力機制對圖像的感興趣區(qū)域進行重點關注;生成器通過學習隨機數(shù)據(jù)來生成感興趣的圖像分布,判別器學習真實樣本的分布,辨別生成器生成的圖像;同時訓練生成器和判別器,促使兩者競爭,在理想情況下,生成器可以生成近似于真實的圖像數(shù)據(jù),而判別器不能將真實圖像與生成圖像區(qū)分,從而達到納什平衡,達到數(shù)據(jù)增強的目的;最后,利用ResNet34 網(wǎng)絡對增強的圖像數(shù)據(jù)進行分類。
本文使用PyTorch 深度學習框架來訓練模型,GPU 為NVIDIA Tesla V100,顯存為32 GB。采用Adam 算法優(yōu)化損失函數(shù),采用小批量樣本的方式訓練深度學習模型,batch_size 設置為64,在訓練的過程中采用固定步長策略調(diào)整學習率,初始學習率設置為0.000 1,gamma 值為0.85,L2 正則化系數(shù)設置為0.000 1,迭代次數(shù)為50 000 次。
本文采用的數(shù)據(jù)集為MNIST 數(shù)據(jù)集和cirfar10數(shù)據(jù)集。MNIST 數(shù)據(jù)集一共有70 000張圖片,其中60 000 張作為訓練集,10 000 張作為測試集,每張圖片由28×28 的0~9 的手寫數(shù)字圖片組成;cirfar10數(shù)據(jù)集由60 000 張32×32 的彩色圖片組成,一共有十個類別,每個類別有6 000 張圖片,其中50 000 張圖片作為訓練集,10 000 張圖片作為測試集。
使用深度殘差注意力生成網(wǎng)絡分別對MNIST和cirfar10 數(shù)據(jù)集中的圖像進行圖像增強,使得圖像的特征更加多樣,對MNIST 數(shù)據(jù)集進行數(shù)據(jù)增強的效果如圖3 所示,對cirfar10 數(shù)據(jù)進行數(shù)據(jù)增強的效果如圖4 所示。
圖3 MNIST 數(shù)據(jù)集數(shù)據(jù)增強的效果Fig.3 Effect of data enhancement of MNIST dataset
圖4 cirfar10 數(shù)據(jù)集數(shù)據(jù)增強的效果Fig.4 Effect of data enhancement on the cirfar10 dataset
從圖3 和圖4 可以看出,使用深度殘差注意力生成網(wǎng)絡對MNIST 和cirfar10 數(shù)據(jù)集進行數(shù)據(jù)增強,具有很強的視覺可讀性,同時也具有較清晰的紋理特征,實現(xiàn)了數(shù)據(jù)增強,擴充了數(shù)據(jù)集。
為了驗證本文模型數(shù)據(jù)增強后的MNIST 以及cirfar10 數(shù)據(jù)在分類方面的效果,選擇 CNN、ResNet18、ResNet34、ResNet50 和ResNet101 作為分類網(wǎng)絡做對比實驗。第一組測試增強數(shù)據(jù)的分類準確率;第二組,測試原始數(shù)據(jù)的分類準確率;第三組,將增強數(shù)據(jù)和原始數(shù)據(jù)各拿出50%組成新的數(shù)據(jù)集進行測試,實驗結果見表1 和表2。
表1 MNIST 數(shù)據(jù)集分類準確率實驗結果(%)Tab.1 Experimental results of classification accuracy of MNIST dataset(%)
通過表1 和表2 可以看出,使用深度殘差注意力生成網(wǎng)絡進行數(shù)據(jù)增強能夠提高數(shù)據(jù)集的分類效果,證明本文提出的模型是切實有效的。利用本文模型進行數(shù)據(jù)增強的數(shù)據(jù)和原始數(shù)據(jù)相結合,在MNIST 數(shù)據(jù)集上達到了98.95% 的準確率,在cirfar10 數(shù)據(jù)集上達到了92.68%的準確率。
表2 cirfar10 數(shù)據(jù)集分類準確率實驗結果(%)Tab.2 Experimental results of classification accuracy(%)for the cirfar10 dataset
本文提出了一種深度殘差注意力生成網(wǎng)絡用于數(shù)據(jù)增強,從而提高分類的準確率。實驗結果證明,該模型在MNIST 數(shù)據(jù)集上獲得了98.95%的準確率,準確率提高了0.93 個百分點;在cirfar10 數(shù)據(jù)集上獲得了92.68%的準確率,準確率提高了0.65 個百分點。本文模型的提出,為數(shù)據(jù)增強提供了一種解決思路和方式。