鄭宗新
(重慶師范大學(xué)計(jì)算機(jī)與信息科學(xué)學(xué)院,重慶 401331)
隨著深度學(xué)習(xí)的發(fā)展,深度學(xué)習(xí)在生活中的應(yīng)用越來(lái)越廣泛。面對(duì)復(fù)雜的任務(wù)場(chǎng)景,深度學(xué)習(xí)的運(yùn)算量也隨之增大?,F(xiàn)有關(guān)于神經(jīng)網(wǎng)絡(luò)的分布式研究多是關(guān)于訓(xùn)練階段的,關(guān)于神經(jīng)網(wǎng)絡(luò)推理階段的分布式研究較少。推理階段運(yùn)算量較大的解決方法一方面是通過(guò)優(yōu)化網(wǎng)絡(luò)結(jié)構(gòu)[1-4],設(shè)計(jì)高效簡(jiǎn)潔的模型來(lái)減少運(yùn)算量;另一方面通過(guò)分布式架構(gòu)使得多設(shè)備協(xié)同工作增加運(yùn)算力。在推理階段的分布式研究主要是通過(guò)不同的設(shè)備間(邊緣節(jié)點(diǎn)和云服務(wù)器、邊緣節(jié)點(diǎn)和邊緣節(jié)點(diǎn))的相互協(xié)助,加快邊緣節(jié)點(diǎn)推理速度。
基于動(dòng)態(tài)卸載的分布式[5-7]網(wǎng)絡(luò)設(shè)計(jì)通過(guò)分析神經(jīng)網(wǎng)絡(luò)的層的運(yùn)算量,把神經(jīng)網(wǎng)絡(luò)模型縱向劃分為兩部分,運(yùn)算量較大的卷積層部分放在云服務(wù)器運(yùn)算,利用云服務(wù)器計(jì)算速度快的特點(diǎn),在云服務(wù)器計(jì)算后將計(jì)算數(shù)據(jù)發(fā)送到邊緣設(shè)備繼續(xù)另一部分運(yùn)算量較小的計(jì)算,將計(jì)算延遲和通信延遲進(jìn)行平衡,得到最優(yōu)化的計(jì)算速度。另一種方法是通過(guò)將特征圖進(jìn)行區(qū)域的劃分[8-9],將輸入數(shù)據(jù)橫向劃分為不同的區(qū)塊發(fā)送給不同的設(shè)備進(jìn)行運(yùn)算,最后一層時(shí)進(jìn)行拼接。每個(gè)設(shè)備的運(yùn)算量都較普通運(yùn)算減少了,因此運(yùn)算速度獲得提升。國(guó)內(nèi)的相關(guān)的研究主要是通過(guò)動(dòng)態(tài)卸載,將不同階段的運(yùn)算放置于不同的設(shè)備,從而并行計(jì)算[10]。
以往的分布式推理研究主要在提升運(yùn)行速度方面,容錯(cuò)率較低,一旦通信中斷便無(wú)法完成推理。在無(wú)人機(jī)、自動(dòng)駕駛等方面是無(wú)法接受的。對(duì)此本文提出一種基于知識(shí)蒸餾的神經(jīng)網(wǎng)絡(luò)設(shè)計(jì)方法,與其余設(shè)備協(xié)同運(yùn)算時(shí)具有較高的準(zhǔn)確率,當(dāng)通信不穩(wěn)定時(shí)可離線運(yùn)行,有著可以接受的準(zhǔn)確率。
知識(shí)蒸餾(Knowledge Distillation,KD)[11]是通過(guò)將訓(xùn)練數(shù)據(jù)輸入一個(gè)訓(xùn)練好的、高準(zhǔn)確率的教師模型,得到教師模型的輸出結(jié)果,學(xué)生模型根據(jù)輸出結(jié)果進(jìn)行學(xué)習(xí)。教師模型輸出為軟標(biāo)簽(soft-target),其中包含了教師模型本身的信息,相比于訓(xùn)練集原有的硬標(biāo)簽(hard-target)信息量更大,因此訓(xùn)練時(shí)效率更高。
表1 硬標(biāo)簽和軟標(biāo)簽
網(wǎng)絡(luò)剪枝(Network Pruning)通過(guò)去除重要性較低的連接,降低神經(jīng)網(wǎng)絡(luò)模型的運(yùn)算量。網(wǎng)絡(luò)剪枝對(duì)于一個(gè)連接的重要程度的評(píng)價(jià),一般是通過(guò)這個(gè)連接的參數(shù)絕對(duì)值的大小[12]、濾波器中位數(shù)[13]等信息來(lái)判斷?,F(xiàn)有網(wǎng)絡(luò)剪枝方法多是依據(jù)參數(shù)自身信息進(jìn)行判別[14],而忽略了其他信息。因此在裁剪較大的時(shí)候,準(zhǔn)確率下降嚴(yán)重。如圖1 所示。
圖1 刪除不同比例的連接后的準(zhǔn)確率
一個(gè)連接的參數(shù)絕對(duì)值越大,一般來(lái)說(shuō)對(duì)準(zhǔn)確率的影響就越大。若一個(gè)模型中參數(shù)值的分布較為均勻,每個(gè)連接都對(duì)準(zhǔn)確率的影響差距不大,刪除小部分連接會(huì)導(dǎo)致準(zhǔn)確率大幅下降。對(duì)于這個(gè)問(wèn)題,本文提出一種促進(jìn)參數(shù)中較大值的訓(xùn)練算法(Promote Maxi?mum Weight SGD,PMW-SGD),通過(guò)在反向傳播時(shí),根據(jù)參數(shù)的絕對(duì)值進(jìn)行排序,根據(jù)相對(duì)大小來(lái)對(duì)應(yīng)不同的學(xué)習(xí)率。公式如下:
其中w為模型參數(shù),Δw為更新的梯度,p為與參數(shù)絕對(duì)值大小相關(guān)的量。
通過(guò)將模型中參數(shù)值較大的一部分變得更大,使得這小部分連接對(duì)準(zhǔn)確率的貢獻(xiàn)較大,在刪除大部分連接后模型仍然有較高的準(zhǔn)確率。
使用網(wǎng)絡(luò)剪枝刪除部分全連接層參數(shù),通過(guò)刪除不同比例的參數(shù)從而得到不同的子模型;不同子模型參數(shù)數(shù)量不相同,一般參數(shù)越多的子模型準(zhǔn)確率越高,如圖2 所示。在本文中,使用上一小節(jié)中經(jīng)過(guò)PMWSGD 訓(xùn)練后的模型,按照參數(shù)的權(quán)重絕對(duì)值進(jìn)行排序,剪枝掉大部分權(quán)重絕對(duì)值較小的連接,根據(jù)剪枝的比例不同,得到不同準(zhǔn)確率的子模型。
圖2 完整模型分解為三個(gè)子模型
本文通過(guò)PyTorch 框架,在ResNet18 和LeNet 模型及CIFAR10 數(shù)據(jù)集上進(jìn)行算法有效性驗(yàn)證。
首先將訓(xùn)練好的ResNet18 模型作為教師模型,LeNet 作為學(xué)生模型,進(jìn)行知識(shí)蒸餾,先采用minibatch SGD 梯度下降算法訓(xùn)練。在初步經(jīng)過(guò)50 次迭代訓(xùn)練后采用PMW-SGD 梯度下降算法對(duì)全連接層的參數(shù)進(jìn)行知識(shí)蒸餾的訓(xùn)練,參數(shù)分布如圖3 所示。
圖3 mini-batch SGD和PMW-SGD訓(xùn)練算法訓(xùn)練后的參數(shù)分布
在使用PMW-SGD 算法后全連接層中的參數(shù)絕對(duì)值較大的一部分變得更加大,對(duì)應(yīng)節(jié)點(diǎn)的重要性變高,對(duì)于準(zhǔn)確率的貢獻(xiàn)因此變大。在刪除部分全連接層的參數(shù)時(shí),保留的節(jié)點(diǎn)主要為權(quán)重絕對(duì)值較大的,因此準(zhǔn)確率較mini-batch SGD 算法高。如圖4 所示。
圖4 mini-batch SGD與PMW-SGD訓(xùn)練后的模型刪除不同比例參數(shù)后的準(zhǔn)確率
首先通過(guò)網(wǎng)絡(luò)剪枝將上小節(jié)中訓(xùn)練好的ResNet18模型全連接層參數(shù)進(jìn)行剪枝,按照參數(shù)權(quán)重的絕對(duì)值進(jìn)行排序,從小到大將全連接層剪枝95%得到子模型A;剪枝85%得到子模型B。從而將ResNet18 分解為兩個(gè)子模型A 和B;其中A 模型中節(jié)點(diǎn)較少,因此準(zhǔn)確率相對(duì)較低;B 模型節(jié)點(diǎn)較多,準(zhǔn)確率較高。詳細(xì)信息如表2 所示。
表2 兩個(gè)子模型的信息
基于LeNet 構(gòu)造兩個(gè)模型,分別為L(zhǎng)eNetA 和LeNetB;其全連接層節(jié)點(diǎn)數(shù)分別25 和50 個(gè)。使用知識(shí)蒸餾讓LeNetA 模型全連接層節(jié)點(diǎn)學(xué)習(xí)子模型A 中全連接層節(jié)點(diǎn)的輸出;LeNetB 模型全連接層節(jié)點(diǎn)學(xué)習(xí)子模型B 中去掉子模型A 中的25 個(gè)節(jié)點(diǎn)后的全連接層節(jié)點(diǎn)的輸出;最后將兩個(gè)模型作為一個(gè)整體進(jìn)行微調(diào)訓(xùn)練。
然后將上述方法中的LeNetB 模型換成更加復(fù)雜的EfficientNet 模型,在模型中添加節(jié)點(diǎn)總數(shù)為50 的全連接層。進(jìn)行與上述相同的訓(xùn)練過(guò)程。
普通數(shù)據(jù)集訓(xùn)練LeNetA 模型、知識(shí)蒸餾訓(xùn)練LeNetA 模型和本文方法訓(xùn)練后結(jié)果如圖5。
圖5 不同訓(xùn)練方法下的準(zhǔn)確率變化
其中普通訓(xùn)練和知識(shí)蒸餾訓(xùn)練LeNetA 模型的準(zhǔn)確率分別為:74.4%和74.3%。在本文訓(xùn)練方法中第一階段訓(xùn)練LeNetA 模型的準(zhǔn)確率為69.3%,在第二階段LeNetB 模型加入訓(xùn)練后準(zhǔn)確率為77.8%;在最后一階段整體微調(diào)后,準(zhǔn)確率達(dá)到78.1%。用EfficientNet 模型替換LeNetB 模型后準(zhǔn)確率為83.4%,微調(diào)后準(zhǔn)確率的84.9%。結(jié)果如表3 所示。
表3 不同訓(xùn)練方法下的準(zhǔn)確率
可以看出,通過(guò)本文方法設(shè)計(jì)的分布式神經(jīng)網(wǎng)絡(luò)與多個(gè)設(shè)備協(xié)同計(jì)算時(shí),使用更加復(fù)雜的神經(jīng)網(wǎng)絡(luò)模型進(jìn)行協(xié)同運(yùn)算時(shí)可達(dá)到的準(zhǔn)確率較高,對(duì)此適用于通信條件良好時(shí)通過(guò)與云服務(wù)器協(xié)同運(yùn)算達(dá)到較高的準(zhǔn)確率;通信情況一般時(shí)通過(guò)與附近的邊緣設(shè)備協(xié)同運(yùn)算,有良好的準(zhǔn)確率。協(xié)同運(yùn)算的準(zhǔn)確率都比原始模型較高;在出現(xiàn)干擾等情況無(wú)法與其他設(shè)備協(xié)同計(jì)算時(shí),單機(jī)運(yùn)算的準(zhǔn)確率較原始模型稍低,仍在可接受范圍內(nèi)。