肖振遠,王逸涵,羅建橋,熊 鷹,李柏林
(西南交通大學機械工程學院,成都 610031)
目標檢測作為計算機視覺研究的熱點,在圖像識別和目標追蹤等領域得到了廣泛的應用。近年來隨著卷積神經(jīng)網(wǎng)絡(Convolutional Neural Network,CNN)的發(fā)展,涌現(xiàn)出了諸多優(yōu)秀的目標檢測網(wǎng)絡算法:基于錨框的檢測算法[1-5]和基于無錨框的檢測算法[6,7]。基于錨框的算法又可分為基于兩階的目標檢測算法(如區(qū)域卷積網(wǎng)絡(Regions with Convolutional Neural Network,R-CNN)[1]、快速區(qū)域卷積神經(jīng)網(wǎng)絡(Faster Regions with Convolutional Neural Network,F(xiàn)aster-RCNN)[2]等)和基于一階的目標檢測性算法(如單階多框檢測器(Single Shot multibox Detector,SSD)[3]、單階改進目標檢測器(Single-Shot Refinement Neural Network for Object Detection,RefineDet)[4]以 及YOLOv3(You Only Look Once version 3)[5]等)?;跓o錨框的算法又可分為基于角點的目標檢測算法[6]和基于中心點的目標檢測算法[7]。在模型訓練過程中,這些目標檢測算法通常會遇到一個共同的問題:類間樣本不平衡[8]。類間樣本不平衡即某些類別樣本數(shù)目遠大于其他類別的情況,包括前景和背景間的不平衡以及前景類間樣本不平衡。不平衡問題如不解決,會導致檢測器對小類樣本檢測準確度低下,最終降低模型的性能。
針對類間樣本的不平衡問題,許多學者進行了研究[9-16]。在基于錨框的檢測器中,兩階檢測器通常利用二階級聯(lián)和啟發(fā)式抽樣方法來解決類間樣本不平衡:在第一階段通過生成特定候選目標的方式過濾大量冗余的背景樣本,如學習分割候選目標[9,10]、選擇性搜索[11]、基于邊緣的目標建議[12]等;在第二階段采用啟發(fā)式抽樣方法平衡前景和背景之間的樣本數(shù)量,如固定前景和背景的比例1∶3[3]、在線難樣本挖掘[13]等;單階檢測器通常采用啟發(fā)式抽樣或難樣本挖掘[14]方法從稠密的錨框中有規(guī)律地進行抽樣,以平衡前景和背景之間的樣本數(shù)量。以上方法有效緩解了前景和背景之間類的不平衡,但未考慮前景類間的不平衡問題。此外,焦點損失(Focal Loss)[15]對單階檢測器中的交叉熵損失函數(shù)進行改進,通過控制正負樣本的權重和難易樣本的權重,增大難學習的小類樣本在損失函數(shù)中所占比重,使算法更偏重難學習的小類樣本;但它并未考慮背景樣本的影響。基于無錨框的檢測器的檢測準確率高于基于錨框的檢測器,主要原因在于學習過程中產(chǎn)生了與基于錨框檢測器不同的類間樣本數(shù)量。自適應樣本選擇方法(Adaptive Training Sample Selection,ATSS)[16]對基于錨框的檢測器和基于無錨框的檢測器進行了詳細的分析,通過對樣本進行統(tǒng)計,自適應選擇正負樣本,控制類間樣本數(shù)量,有效緩解了類間樣本的不平衡,縮小了無錨框的檢測器和有錨框的檢測器的性能差異??傊?,上述解決不平衡問題的方法,只考慮了前景類與背景類間的不平衡或者前景類間的不平衡,并沒有綜合考慮前景類與背景類不平衡和前景類間樣本不平衡。
目前流行的目標檢測網(wǎng)絡一般采用分類損失和位置/尺寸回歸損失來完成多任務學習(如Faset-RCNN、SSD)。RefineDet 沿用了經(jīng)典的分類和回歸損失,是一種有代表性的目標檢測方法,而且它在性能上超過了大部分其他目標檢測方法[1-5],具有先進性;但RefineDet 在使用損失函數(shù)進行權重更新時,存在上述提到的前景類與背景類不平衡和前景類間樣本不平衡問題,并且由于是多任務學習,RefineDet 還存在多任務間的不平衡(分類任務和回歸任務間的不平衡)[17]。因此,RefineDet的性能仍然具有提升的潛力。
本文針對RefineDet 損失函數(shù),提出了一種改進的部分加權損失函數(shù)(Subsection Weighted Loss,SWLoss),以緩解類間不平衡數(shù)據(jù)集中小樣本類別檢測性能低的問題,它的主要組成如圖1所示。
圖1 SWLoss組成Fig.1 Composition of SWLoss
SWLoss主要有以下內(nèi)容:1)在RefineDet目標檢測損失函數(shù)中引入類間樣本平衡因子,并以每個訓練批量中不同類別樣本數(shù)量的倒數(shù)作為啟發(fā)式的平衡因子,對分類損失中的不同類別進行加權,從而提高對小樣本類別學習的關注程度;2)在分類損失和回歸損失中引入多任務平衡因子,對分類損失進行加權量化,縮小兩個任務學習速率的差異。
RefineDet 在訓練時,損失函數(shù)主要由三個模塊構成:1)錨框調(diào)整模塊(Anchor Refinement Module,ARM),將最初生成的錨框分為前景和背景,并對錨框的位置進行粗略的調(diào)整;2)目標檢測模塊(Object Detection Module,ODM),對ARM 模塊粗調(diào)后的錨框進行多分類,并對錨框的位置進行精確的調(diào)整;3)連接模塊(Transfer Connection Block,TCB),將ARM 模塊輸出的特征圖進行特征金字塔(Feature Pyramid Network,F(xiàn)PN)[18]操作,變換成ODM 需要的特征圖。其中ARM 和ODM提供了整個模型權重更新所需的損失。RefineDet 進行了兩次分類和回歸調(diào)整,在訓練時采用難樣本挖掘策略使得前景-背景不平衡問題得到了改善,同時兼顧了兩階檢測器的準確率和一階檢測器的速率,因此在目標檢測任務中獲得了較高的檢測性能。
RefineDet 損失函數(shù)Ltotal如式(1),主要由兩部分構成:1)錨框調(diào)整模塊ARM 的損失Larm,包括前期二分類的損失Lb和先驗框粗調(diào)整的損失Lr,如式(2);2)目標檢測模塊ODM 的損失Lodm,包括多分類損失Lm和準確回歸目標位置的損失Lr,如式(3)。
其中:P為錨框i在ARM 中對應的預測置信度,P={pi|i∈錨框索引值};X為錨框i在ARM 中粗調(diào)后的位置信息X={xi|i∈錨框索引值};C為錨框i在ODM 中對應的預測置信度,C={ci|i∈錨框索引值};T為錨框i在ODM 中精調(diào)后的位置信息,T={ti|i∈錨框索引值};Larm和Lodm如式(2)、(3)。
式中:i為錨框的標簽索引;Narm為ARM 中對應的正樣本的數(shù)量,Nodm為ODM 中對應的正樣本的數(shù)量;Lb、Lm、Lr分別代表二分類損失、多分類損失和回歸損失;l*i表示錨框i對應的真實框的類別標簽,l*=;l為二分類損失對前景的編碼值,l=或l=為錨框i對應的真實框的位置信息,g*=
盡管RefineDet 已經(jīng)改善了前景和背景的類別不平衡,但它仍沒有考慮前景類間樣本的不平衡問題。在ODM 中,多分類損失是所有樣本的平均損失,并以此來更新梯度,這種全局損失會使得樣本數(shù)目越少的小類的關注度越低,進而導致網(wǎng)絡模型對小類的檢測準確率越低。此外,ODM 中存在著多任務不平衡問題,會導致分類任務與回歸任務的權重更新速率不同,也會影響檢測準確率。本文針對RefineDet 存在的上述缺陷,提出了部分加權損失函數(shù)SWLoss,以提高網(wǎng)絡的檢測性能。
在對模型訓練時,由式(3)可以看出,在Lodm階段,損失函數(shù)由分類損失Lcla和回歸損失Lreg組成,如式(4):
從式(4)中可得,分類損失是所有樣本分類損失的平均值,在反向傳播過程中,模型參數(shù)通過以下公式進行調(diào)整:
其中:η為學習率lNodm為第Nodm個樣本的損失;wm是第m次更新的權重。
如式(4)所示,每個樣本在權值調(diào)整過程中貢獻相同,這會導致前景中樣本數(shù)較大的類別在權值更新過程中占主導作用,從而使模型權重更新的速率偏向于該類,導致小類樣本識別率降低。這種前景類間不平衡現(xiàn)象在工業(yè)檢測應用中經(jīng)常發(fā)生,例如缺陷檢測[19],正常樣本占絕大多數(shù),而缺陷樣本卻非常少。
為緩解類間樣本不平衡問題(前景類與背景類不平衡和前景類間樣本不平衡),本文針對目標檢測模塊ODM 的損失Lodm提出一種部分加權損失函數(shù)SWLoss。SWLoss首先在Lodm中引入類間樣本平衡因子,增加小類樣本在損失函數(shù)中所占的比重,提高小類樣本的檢測效果,如式(7)所示:
其中:n為每批量訓練樣本類別的總數(shù);j為樣本的類別,j=0時表示背景類;Lj對應每批量訓練樣本中每種類別的總樣本損失;1/βj為第j類引入的類間樣本平衡因子,是一個具有啟發(fā)性的代表類間樣本不平衡的值,βj為每批量訓練樣本中每種類別的樣本數(shù),即以每批量訓練樣本中每種類別的樣本數(shù)作為各類的懲罰因子,其中背景類選為設置正負樣本比例倍的平衡因子。如式(7)所示,總損失SWLoss 為每種類別損失的加權和,當βj選為每批量訓練樣本中每種類別j的樣本數(shù)時,每種類別的損失相當于取該類的平均損失,間接使得小類總樣本損失在總損失中所占比重增加,大類損失總樣本損失在總損失中所占比重減少,從而平衡了類間樣本不均衡的問題。
此外,Lodm誤差函數(shù)并沒有考慮多任務間的不平衡問題。不同任務之間的難度、損失大小各不相同,最優(yōu)化損失函數(shù)時,不同任務之間的最佳區(qū)間不相容,會導致不同任務權重更新速率不同。為解決上述問題,SWLoss 引入多任務平衡因子對不同任務進行加權量化,使兩者更新速率盡量同步。最終的損失函數(shù)SWLoss如式(8)所示:
其中:κ為動態(tài)多任務平衡因子,決定著分類損失在整個損失的比重。如式(9)所示,κ越大意味著分類損失更新得越快,κ越小意味著分類損失更新得越慢。取κ=n,n為每批量訓練樣本類別的總數(shù)。
在反向傳播中,SWLoss 的權重更新如式(9)所示。相比式(5)和(6)中?Lodm(wm),由于每個類別的損失懲罰因子不同,使得每個類在權值調(diào)整過程中貢獻不同,權重更新速率也會不同,也就意味著每個類擁有不同的學習率。因此,該損失保持了原始類間樣本數(shù)量不平衡的同時,提高了網(wǎng)絡對小類的關注度。此外,多任務平衡因子κ的引入,使分類和回歸任務之間的更新速率變得可調(diào)節(jié)。
為分析SWLoss在不平衡數(shù)據(jù)集上的目標檢測性能,在兩個有代表性的數(shù)據(jù)集上進行實驗,并與其他損失函數(shù)的目標檢測效果進行對比。
分別在公開數(shù)據(jù)集(Pattern Analysis,Statical Modeling and Computational Learning,Visual Object Classes Challenge 2007,Pascal VOC2007)和人工采集的包裝盒點陣字符數(shù)據(jù)集上進行實驗。VOC2007 數(shù)據(jù)集覆蓋了20 個目標類別的9 963張圖片。按9∶1 比例在VOC2007 上劃分訓練和測試集,具體數(shù)量信息如圖2 所示,圖中橫坐標縮寫了部分目標類別名稱,包 括:Aerop(Aero Plane),Bicyc(Bicycle),DinT(Dining Table),Motor(Motor Bike),Potted(Potted Plant),TVmon(Tv Monitor)。圖2中的直方圖展示了訓練集中各類目標的數(shù)量??梢钥吹?,Person、Car 等類別的樣本數(shù)量大幅超過Sheep、Chair 等類別的樣本數(shù)量,因此,VOC2007 中不同目標的樣本數(shù)量嚴重不平衡。點陣字符數(shù)據(jù)集記錄了食品包裝盒上的生產(chǎn)日期,包括12個目標類別,如圖3所示的類別名稱。為體現(xiàn)SWLoss 損失函數(shù)的適應性,在不同時間段分別收集訓練和測試樣本共500 張圖片,9 000 個點陣字符。訓練集來自時間跨度為2018-07-02T12:11:00—12:15:00 的200 張有效樣本,訓練集中的字符標注信息通過圖像處理軟件Halcon[20]自動獲得。人工剔除標注不準確的失效樣本后,訓練集中各類字符數(shù)量如圖3 所示。由于字符6 和9 僅出現(xiàn)在訓練集時間跨度的秒單位上,因此圖3 中字符6 和9 的數(shù)量相對其他字符較少,各類字符樣本數(shù)量嚴重不平衡。測試集來自時間跨度為2018-07-02T13:55:00—13:59:59 的300 張有效樣本,該時間跨度內(nèi)字符6 和9 較多,因而可以驗證網(wǎng)絡對6 和9 的檢測能力。同樣,采用Halcon標注測試集字符,用于計算檢測精度。
網(wǎng)絡優(yōu)化器采用隨機梯度下降法,動量和權重衰減分別設置為0.9、0.000 5。以32批量在VOC2007上訓練,前80 000次迭代學習率為10-3,后20 000 次學習率為10-4,最后20 000次學習率為10-5。由于點陣字符圖像尺寸較大(1 296×966),以8批量在點陣字符上迭代訓練7 500次,初始學習率為10-3,每迭代10次,學習率衰減0.05。
選用其他損失函數(shù)替換SWLoss 進行實驗,對比方法包括:1)RefineDet中原有的損失函數(shù),記為Loss0;2)按類別數(shù)量比例對輸出概率進行加權的概率期望損失[21],記為PELoss;3)基于曲線下面積(Area Under Curves,AUC)優(yōu)化的加權成對損失[22],記為WPLoss;4)同時考慮類間不平衡和難樣本挖掘的Focal Loss,記為Floss。所有算法參數(shù)均采用對應文獻中推薦的配置。
由于Focal Loss同時考慮類間不平衡和難樣本挖掘,可采用SWLoss 損失函數(shù)中的類別不平衡因子1/βi替換Focal Loss中的類別系數(shù)?;赟WLoss的實驗包括:1)本文提出的部分加權損失函數(shù)SWLoss,記為SWLoss1;2)忽略不平衡因子的損失函數(shù),即κ=1 時,記為SWLoss0;3)Focal Loss 中的類別系數(shù)替換為1/β,記為SWLoss0+Floss;4)Focal Loss 中的類別系數(shù)替換為1/β,同時考慮不平衡因子,記為SWLoss1+Floss。
實驗結果采用平均精度均值(mean Average Precision,mAP)進行評價,如式(13)所示,其中:TP(True Positives)為預測框與標簽框正確的匹配;FP(False Positives)為預測框將背景預測為目標;FN(False Negatives)指需要模型檢測出的物體,沒有檢測出;P(Precision)為測試精度;R(Recall)為召回率。
圖2 對比了SWLoss 與RefineDet 原始損失Loss0 在VOC2007上的檢測結果。由圖2可知,相比原始損失函數(shù),在SWLoss 損失函數(shù)中,樣本嚴重不平衡的小類Sheep、Cow 的準確度得到了提升,大類Person的準確度略微下降,整體的準確度由表1 可知,提升了1.01 個百分點,原因是SWLoss 損失函數(shù)引入了類間樣本平衡因子,增加了小類樣本在整體損失函數(shù)中的權重,降低了大類樣本的權重,從而在梯度更新中,使得小類樣本的更新速度得到提升。這說明本文引入的類間樣本平衡因子能夠提高網(wǎng)絡關注小類樣本的能力,改善大類樣本對小類樣本權重更新的覆蓋,緩解類間樣本的不平衡。
圖2 VOC2007數(shù)據(jù)集上各類樣本數(shù)量及測試精度Fig.2 Numbers and test accuracies of different classes of samples in VOC2007 dataset
表1還列出了SWLoss與其他對比方法的mAP。由表1可知:1)SWLoss0 明顯超過Loss0,SWLoss1 進一步提高了性能,說明提升主要來自分類損失,平衡多任務能進一步提高性能。2)FLoss 超過SWLoss1,說明考慮難樣本挖掘能夠提高性能。3)Floss+SWLoss0 和Floss+SWLoss1 均超過Floss,原因是Focal loss 類別系數(shù)是全局類別比例,SWLoss 損失函數(shù)基于批量中的類別比例,能更準確保證每次更新時類間樣本的平衡,因此SWLoss 損失函數(shù)在處理類間樣本不平衡問題上超過Floss。4)SWLoss1 超過PEloss 和WPloss,原因是PEloss 通過加權系數(shù)直接調(diào)整輸出概率,小類樣本損失占比依然很??;WPloss 通過AUC 損失更多關注難分類樣本,對小類樣本的關注程度不足,而SWLoss損失函數(shù)直接提高小類樣本損失的占比,因此SWLoss 損失函數(shù)在處理類別不平衡問題時具有優(yōu)勢。
表1 采用不同損失函數(shù)時的mA 單位:%Tab.1 mAP when using different loss functions unit:%
圖3 和表1 展示了包裝盒點陣字符上的實驗結果。由圖3 可以看出:小樣本字符類別6 和9 的精度大幅提高,原因是SWLoss 損失函數(shù)提高了6 和9 的損失在整體損失的占比,在梯度更新時使得網(wǎng)絡權重的更新偏向于6和9的相關權重,因此對于該種存在較大不平衡的類間不平衡,SWLoss 損失函數(shù)的改善效果較明顯。由表1 可以看出:點陣字符數(shù)據(jù)集的檢測結果展現(xiàn)了與VOC2007 數(shù)據(jù)集類似的對比趨勢,說明了SWLoss 損失函數(shù)對不同類間不平衡數(shù)據(jù)集的適應性。具體來說:首先,大幅提高了原始RefineDet 性能;然后,優(yōu)于對比損失函數(shù)PEloss、WPLoss;最后,結合Focal Loss獲得的最好結果,表明基于批量類別比例的有效性。
圖3 點陣字符數(shù)據(jù)集上各類樣本數(shù)量及測試精度Fig.3 Numbers and test accuracies of different classes of samples in dot-matrix character dataset
為直觀分析所提算法效果,圖4 展示了點陣字符上的檢測效果??梢钥吹?,原始網(wǎng)絡RefineDet 在檢測時產(chǎn)生了部分多余的檢測框如圖4(a)和(b),并且對數(shù)字6 和9 產(chǎn)生了誤檢和漏檢現(xiàn)象;圖4(c)和(d)為SWLoss損失函數(shù)的檢測結果,它在獲得高準確率的同時沒有產(chǎn)生多余的檢測框,且檢測框的位置相對準確,這是由于SWLoss損失函數(shù)在對網(wǎng)絡反向傳播更新權重時引入了多任務平衡因子權重進行加權量化,縮小了分類任務和回歸任務的差距。
圖4 字符檢測結果Fig.4 Character detection results
本文針對RefineDet 目標檢測網(wǎng)絡對類間不平衡數(shù)據(jù)集檢測時,存在小樣本類別檢測性能低的問題,提出SWLoss 損失函數(shù)。該損失函數(shù)在目標檢測模塊的損失函數(shù)中引入了類間樣本平衡因子和多任務平衡因子,分別解決類間樣本不平衡問題以及多任務間不平衡問題。實驗表明SWLoss 損失函數(shù)有效地緩解了RefineDet 檢測網(wǎng)絡中的不平衡問題,能夠明顯提高小類樣本的檢測精度。
本文雖然關注了類間數(shù)據(jù)不平衡問題,但未考慮同類樣本的數(shù)據(jù)變化,即如何區(qū)分簡單樣本和復雜樣本。合理的損失函數(shù)應該更加側重難分類的復雜樣本,因此,今后將研究能夠關注復雜樣本的損失函數(shù),進一步提高網(wǎng)絡特征學習的魯棒性。