方 芹,繆寧杰,董仲星,鄭樹松,王佳敏, 羅文東,周 霖
(1.國網(wǎng)浙江省電力有限公司雙創(chuàng)中心,浙江 杭州 310051) (2.國網(wǎng)浙江省電力有限公司杭州供電公司,浙江 杭州 310009) (3.浙江光珀智能科技有限公司,浙江 杭州 311100) (4.杭州致成電子科技有限公司,浙江 杭州 310009) (5.北京大道合創(chuàng)科技有限責(zé)任公司,北京 100085)
一般來說,處理高度非線性的任務(wù)需要深層次的神經(jīng)網(wǎng)絡(luò),因為深層次的網(wǎng)絡(luò)模型能夠擬合更為復(fù)雜的輸入與輸出之間的關(guān)系。密集預(yù)測獲益于各種深度卷積神經(jīng)網(wǎng)絡(luò)的快速發(fā)展[1-2],骨干網(wǎng)絡(luò)提取的特征越好,在后續(xù)密集預(yù)測時效果也越好。出于這個原因,許多學(xué)者通過不斷加深骨干網(wǎng)絡(luò)來獲取輸入圖片高層次特征,然而這會導(dǎo)致模型推理效率低下,需要數(shù)十個浮點運算來計算每幅圖像。另外,許多學(xué)者由于設(shè)備資源限制,無法訓(xùn)練這種深層次網(wǎng)絡(luò)。
知識蒸餾的目的是研究不同神經(jīng)網(wǎng)絡(luò)之間的信息傳遞。Hinton等[3]首先提出了知識蒸餾的概念,通過訓(xùn)練一個大型網(wǎng)絡(luò)(教師網(wǎng)絡(luò))來幫助小型網(wǎng)絡(luò)(學(xué)生網(wǎng)絡(luò))訓(xùn)練。其基本原理是首先訓(xùn)練一個深層次的大型神經(jīng)網(wǎng)絡(luò),然后使用教師網(wǎng)絡(luò)的預(yù)測概率分布[3]、中間層的特征表示[4]或者網(wǎng)絡(luò)的結(jié)構(gòu)信息[5],作為學(xué)生網(wǎng)絡(luò)的額外監(jiān)督,以輔助學(xué)生網(wǎng)絡(luò)完成自身的訓(xùn)練過程。這一原理最近也被應(yīng)用于大規(guī)模分布式模型的訓(xùn)練過程[6-7],用于多層間或多個訓(xùn)練狀態(tài)之間的知識傳遞。此外,知識蒸餾還被用來將容易訓(xùn)練的大網(wǎng)絡(luò)提煉成更難訓(xùn)練的小網(wǎng)絡(luò)[8]。
人體姿態(tài)估計是密集預(yù)測中的一項基本任務(wù),其目的是在一幅圖像中定位人的所有關(guān)鍵點(如手腕、手肘等),應(yīng)用領(lǐng)域十分廣泛,可應(yīng)用于虛擬現(xiàn)實、人機交互、動作檢測和自動駕駛等[9-11]。目前的人體姿態(tài)估計網(wǎng)絡(luò)可以分為自頂向下和自下向上兩類。
自頂向下:自頂向下的姿態(tài)估計網(wǎng)絡(luò)分為兩個階段。首先用目標(biāo)估計網(wǎng)絡(luò)檢測出圖片中的人,并用包圍盒把人框出來。然后對每個包圍盒里的人用姿態(tài)估計網(wǎng)絡(luò)估計出對應(yīng)的姿態(tài)。文獻(xiàn)[12]提出了深度高分辨網(wǎng)絡(luò)HRNet,該網(wǎng)絡(luò)在整個訓(xùn)練過程中保持特征圖的分辨率,并在姿態(tài)估計任務(wù)中得到了較好的結(jié)果。文獻(xiàn)[13]建議網(wǎng)絡(luò)同時預(yù)測關(guān)節(jié)點熱圖和每個關(guān)節(jié)點與標(biāo)簽的偏差,然后利用偏差校正預(yù)測熱圖得到最終的預(yù)測結(jié)果。文獻(xiàn)[14]用堆疊的沙漏網(wǎng)絡(luò)與跳躍連接來提高整體性能。文獻(xiàn)[15]使用金字塔殘差模塊來獲取多尺度信息。文獻(xiàn)[16]提出了一個簡單的姿態(tài)估計網(wǎng)絡(luò),使用轉(zhuǎn)置卷積來得到高分辨率熱圖。
自下而上:自下而上的網(wǎng)絡(luò)直接預(yù)測圖中的所有關(guān)節(jié)點,然后用算法將關(guān)節(jié)點組裝成不同的人。文獻(xiàn)[16]提出了兩個分支多階段的網(wǎng)絡(luò),一個用于關(guān)節(jié)熱圖預(yù)測,一個用于組合關(guān)節(jié)點。文獻(xiàn)[17]使用空洞殘差網(wǎng)絡(luò)直接學(xué)習(xí)每個關(guān)節(jié)點的二維偏移向量來對關(guān)節(jié)點進(jìn)行分組。文獻(xiàn)[18]使用一個局部強度場來定位關(guān)節(jié)點,使用一個部件關(guān)聯(lián)場來將身體的各個部件組合起來。文獻(xiàn)[19]在HRNet的基礎(chǔ)上提出了HigherHRNet,通過多分辨率監(jiān)督的方式訓(xùn)練網(wǎng)絡(luò),然后使用文獻(xiàn)[20]的網(wǎng)絡(luò)對檢測到的關(guān)節(jié)點進(jìn)行分組。
盡管HRNet、HigherHRNet等網(wǎng)絡(luò)在姿態(tài)估計的任務(wù)中得到了較高的精度,但它們的參數(shù)量十分龐大,以至于訓(xùn)練這些網(wǎng)絡(luò)需要消耗很大的計算資源。由于知識蒸餾可以把大型網(wǎng)絡(luò)的知識轉(zhuǎn)移到小型網(wǎng)絡(luò)中,并且不需要很多的計算資源,因此本文提出了一種基于知識蒸餾的輕量級人體姿態(tài)估計網(wǎng)絡(luò),以HigherHRNet作為教師網(wǎng)絡(luò)來指導(dǎo)監(jiān)督網(wǎng)絡(luò)。
本文提出的基于知識蒸餾的輕量級人體姿態(tài)估計網(wǎng)絡(luò)框架如圖1所示,該框架主體由兩個HigherHRNet構(gòu)成:一個預(yù)訓(xùn)練好的HigherHRNet作為教師網(wǎng)絡(luò);一個簡化版的HigherHRNet作為學(xué)生網(wǎng)絡(luò),學(xué)習(xí)教師網(wǎng)絡(luò)中的結(jié)構(gòu)知識和標(biāo)簽信息。
圖1 基于知識蒸餾的輕量級人體姿態(tài)估計的網(wǎng)絡(luò)框架流圖
HigherHRNet是目前最先進(jìn)的姿態(tài)估計網(wǎng)絡(luò)[21],該網(wǎng)絡(luò)具有訓(xùn)練時多分辨率監(jiān)督、推理時多分辨率融合預(yù)測的特點,能夠較好地解決自下而上多人姿態(tài)估計中尺度變換的問題,并且能夠精確定位出關(guān)節(jié)點。
教師網(wǎng)絡(luò)的結(jié)構(gòu)如圖1的上半部分所示。首先,輸入一張圖片,以數(shù)字1表示圖片完整的分辨率,經(jīng)過Stem,圖片的分辨率變?yōu)樵瓐D的1/4,Stem由兩個卷積塊和4個殘差卷積模塊構(gòu)成。然后,以該分辨率的特征圖作為網(wǎng)絡(luò)的第一分支,從高分辨率到低分辨率,生成多個不同分辨率的分支(圖1中有3個分支),并將這些分支并行地連接起來。通過反復(fù)地進(jìn)行多尺度融合,從并行的分辨率特征圖中可以學(xué)到知識,從而得到魯棒性強的、豐富的高分辨率。
在得到圖片的高分辨率表示之后(分辨率為1/4),HigherHRNet進(jìn)行了第一階段的預(yù)測,得到預(yù)測熱圖和分組熱圖。然后,將預(yù)測結(jié)果和上一步的特征圖串聯(lián),通過1個轉(zhuǎn)置卷積模塊和多個殘差卷積塊得到第二個預(yù)測熱圖(分辨率為1/2)。最后,使用不同分辨率的關(guān)節(jié)熱圖標(biāo)簽來監(jiān)督訓(xùn)練網(wǎng)絡(luò)。
人體姿態(tài)估計網(wǎng)絡(luò)通常由多個具有相同結(jié)構(gòu)的塊組成,如Hourglass和HigherHRNet。由于在整體結(jié)構(gòu)中部署了大量重復(fù)的塊,因此現(xiàn)有的設(shè)計并不具有成本效益,從而導(dǎo)致了表達(dá)能力和計算成本之間的次優(yōu)權(quán)衡。例如:Hourglass由8個沙漏結(jié)構(gòu)堆疊而成,每個階段結(jié)構(gòu)都有9個殘差塊;HigherHRNet的每個分支由多個重復(fù)的殘差塊組成。
本文的學(xué)生網(wǎng)絡(luò)采用簡化版的教師網(wǎng)絡(luò),即簡化版的HigherHRNet。學(xué)生網(wǎng)絡(luò)中的殘差卷積模塊只有教師網(wǎng)絡(luò)中的一半,因此訓(xùn)練只需要較少的計算資源。
學(xué)生網(wǎng)絡(luò)使用Pytorch進(jìn)行訓(xùn)練,教師網(wǎng)絡(luò)使用官網(wǎng)提供的預(yù)訓(xùn)練模型[22]。網(wǎng)絡(luò)使用ADAM優(yōu)化器,基礎(chǔ)學(xué)習(xí)率為0.001,并分別在200和260個訓(xùn)練周期時降低學(xué)習(xí)率,一共訓(xùn)練300個周期,批量大小為12。
在圖像推理階段,使用與文獻(xiàn)[19]一樣的網(wǎng)絡(luò),通過多熱圖聯(lián)合預(yù)測的方式來預(yù)測人體的姿態(tài)。學(xué)生網(wǎng)絡(luò)預(yù)測了兩個階段的關(guān)節(jié)點熱圖,由于兩個階段預(yù)測熱圖的分辨率不一致,因此需要先對第一階段的熱圖進(jìn)行采樣,然后把它與第二階段的預(yù)測熱圖融合得到最終的人體姿態(tài)預(yù)測結(jié)果。
假定網(wǎng)絡(luò)的輸入圖片為X,X∈3×H×W,其中H和W分別代表輸入圖片的高和寬。教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)經(jīng)過多分支多分辨率融合模塊后,分別得到第一階段的預(yù)測結(jié)果MT1和通道數(shù)34由前17張關(guān)節(jié)點熱圖和后17張分組熱圖組成。Loss1只使用教師網(wǎng)絡(luò)的預(yù)測關(guān)節(jié)熱圖作為學(xué)生網(wǎng)絡(luò)的額外監(jiān)督,所以定義Loss1為:
(1)
在得到學(xué)生網(wǎng)絡(luò)第一階段的預(yù)測結(jié)果MS1后,使用對應(yīng)的關(guān)節(jié)標(biāo)簽監(jiān)督預(yù)測結(jié)果MS1的前17張預(yù)測關(guān)節(jié)熱圖,因此定義Loss2為:
(2)
(3)
由此,學(xué)生網(wǎng)絡(luò)的最終聯(lián)合損失Loss定義為:
Loss=α·Loss1+β·Loss2+γ·Loss3+Lg
(4)
式中:α,β,γ分別為對應(yīng)損失的權(quán)重,本文中α和γ設(shè)置為1/4,β設(shè)置為3/4;Lg為三元組損失,通常取1。
COCO數(shù)據(jù)集是在復(fù)雜的環(huán)境干擾下收集得到的,因此要求網(wǎng)絡(luò)能夠在復(fù)雜的條件下估計定位出圖片中所有人的關(guān)節(jié)點[23]。該數(shù)據(jù)集總共包含超過200 000張圖像,250 000個帶有17個關(guān)鍵點的人。該數(shù)據(jù)集被分為57 000個訓(xùn)練集、5 000個驗證集和20 000個測試集。學(xué)生網(wǎng)絡(luò)在訓(xùn)練集上進(jìn)行訓(xùn)練,并報告了在驗證集上的實驗結(jié)果。
COCO關(guān)鍵點相似度(object keypoint similarity,OKS),與目標(biāo)檢測中的IoU類似,OKS可以表示預(yù)測出來的關(guān)節(jié)點和標(biāo)簽圖片中的關(guān)節(jié)點的重合程度,其值越接近1越好。
(5)
式中:exp()為指數(shù)函數(shù);n為關(guān)節(jié)點的序號,dn為標(biāo)注關(guān)節(jié)點和預(yù)測關(guān)節(jié)點之間的歐氏距離;s為所占面積;kn為第n個關(guān)節(jié)點的歸一化因子,可通過對數(shù)據(jù)集進(jìn)行標(biāo)準(zhǔn)差得到,反映了當(dāng)前關(guān)節(jié)點對與整體的影響程度。
首先在COCO數(shù)據(jù)集上對知識蒸餾出來的學(xué)生網(wǎng)絡(luò)進(jìn)行驗證,實驗結(jié)果見表1。表中:AP0.5為所有圖像中人物預(yù)測的關(guān)鍵點位置和真實位置的相似性在0.5以上的平均準(zhǔn)確率,AP0.75為所有圖像中人物預(yù)測的關(guān)鍵點位置和真實位置的相似性在0.75以上的平均準(zhǔn)確率,AP為AP0.5,AP0.55,AP0.6,AP0.65,AP0.7,AP0.75,AP0.8,AP0.85,AP0.9,AP0.95的平均準(zhǔn)確率,APM表示像素面積在[32×32,96×96]的人物預(yù)測準(zhǔn)確度,APL表示像素面積大于96×96的人物預(yù)測準(zhǔn)確度。教師網(wǎng)絡(luò)是一個大型的網(wǎng)絡(luò),所以它能夠達(dá)到較高的精度。未蒸餾的學(xué)生網(wǎng)絡(luò)是指直接使用標(biāo)簽數(shù)據(jù)進(jìn)行訓(xùn)練,沒有額外使用教師網(wǎng)絡(luò)的預(yù)測特征圖監(jiān)督?;谡麴s的學(xué)生網(wǎng)絡(luò)即本文所設(shè)計的網(wǎng)絡(luò),使用標(biāo)簽和教師網(wǎng)絡(luò)預(yù)測的特征圖聯(lián)合監(jiān)督訓(xùn)練學(xué)生網(wǎng)絡(luò)??梢钥吹?,基于蒸餾的網(wǎng)絡(luò)比未蒸餾的網(wǎng)絡(luò)提高了1.3%,這說明教師網(wǎng)絡(luò)的監(jiān)督是有作用的。值得注意的是,雖然學(xué)生網(wǎng)絡(luò)的精度比教師網(wǎng)絡(luò)低了許多,但本文的目的是訓(xùn)練一個簡單姿態(tài)估計網(wǎng)絡(luò),給訓(xùn)練資源不足的學(xué)者提供一個有效的蒸餾訓(xùn)練網(wǎng)絡(luò),該網(wǎng)絡(luò)比直接訓(xùn)練學(xué)生網(wǎng)絡(luò)具有更高的精度。另一方面,深層次的神經(jīng)網(wǎng)絡(luò)(教師網(wǎng)絡(luò))能夠較好地處理姿態(tài)估計任務(wù),而簡化的網(wǎng)絡(luò)(學(xué)生網(wǎng)絡(luò))并不能達(dá)到教師網(wǎng)絡(luò)的精度。這也說明了姿態(tài)估計是一個高度非線性的任務(wù),使用淺層網(wǎng)絡(luò)并不能準(zhǔn)確地對人體姿態(tài)進(jìn)行預(yù)測。
表1 COCO數(shù)據(jù)集上不同網(wǎng)絡(luò)精度比較
除了定量分析,本文還進(jìn)行了定性分析,結(jié)果如圖2所示。從圖中可以看出,教師網(wǎng)絡(luò)預(yù)測的結(jié)果最好,未蒸餾的學(xué)生網(wǎng)絡(luò)最差。
圖2 預(yù)測結(jié)果可視化
圖3為教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)預(yù)測的關(guān)節(jié)點,第一列為原始圖片,第二至第十八列分別預(yù)測鼻子、左眼、右眼等??梢钥吹?,深層次的教師網(wǎng)絡(luò)的預(yù)測結(jié)果接近標(biāo)簽,而淺層次的學(xué)生網(wǎng)絡(luò)僅能預(yù)測圖片中一部分關(guān)節(jié)點。
圖3 預(yù)測熱圖可視化
網(wǎng)絡(luò)模型的參數(shù)量是一個十分重要的參數(shù),表2中報告了教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的模型參數(shù)量(Params)和網(wǎng)絡(luò)需要計算的浮點運算數(shù)(giga floating-point operations per second, GFLOPs)。從表中可以看到,由于教師網(wǎng)絡(luò)是深層次網(wǎng)絡(luò),所以它的模型參數(shù)量、浮點運算數(shù)和推理時間(Inference)都大于學(xué)生網(wǎng)絡(luò)。因此,本文能夠在計算資源不足的情況下訓(xùn)練學(xué)生網(wǎng)絡(luò)。
表2 模型參數(shù)量、浮點運算數(shù)和推理時間
本文提出了一個基于知識蒸餾的輕量級姿態(tài)估計網(wǎng)絡(luò),該網(wǎng)絡(luò)由標(biāo)簽和教師網(wǎng)絡(luò)預(yù)測熱圖聯(lián)合監(jiān)督訓(xùn)練得到。通過知識蒸餾的方式訓(xùn)練的學(xué)生網(wǎng)絡(luò)能夠比直接訓(xùn)練得到的學(xué)生網(wǎng)絡(luò)得到更高的人體姿態(tài)估計精度。此外,本文設(shè)計的學(xué)生網(wǎng)絡(luò)是一個較為簡單、常見的姿態(tài)估計網(wǎng)絡(luò),能夠幫助學(xué)者在計算資源不足的情況下得到較好的姿態(tài)估計精度。研究結(jié)果表明,使用知識蒸餾得到的學(xué)生網(wǎng)絡(luò)能夠較為有效地估計出人體關(guān)節(jié)點。