寫(xiě)在前面
文本分類(lèi)是NLP中一個(gè)非常重要的任務(wù),也是非常適合入坑NLP的第一個(gè)完整項(xiàng)目。
文本分類(lèi)看似簡(jiǎn)單,但實(shí)則里面有好多門(mén)道。作者水平有限,只能將平時(shí)用到的方法和trick在此做個(gè)記錄和分享,并且盡可能提供給出簡(jiǎn)潔、清晰的代碼實(shí)現(xiàn)。希望各位看官都能有所收獲。
本文主要討論文本分類(lèi)中處理樣本不均衡和提升模型魯棒性的trick。
1. 緩解樣本不均衡
樣本不均衡現(xiàn)象
假如我們要實(shí)現(xiàn)一個(gè)新聞?wù)?fù)面判斷的文本二分類(lèi)器,負(fù)面新聞的樣本比例較少,可能2W條新聞?dòng)?00條甚至更少的樣本屬于負(fù)例。這種現(xiàn)象就是樣本不均衡。
在樣本不均衡場(chǎng)景下,樣本會(huì)呈現(xiàn)一個(gè)長(zhǎng)尾分布(如圖中所示會(huì)出現(xiàn)長(zhǎng)長(zhǎng)的尾巴),頭部的標(biāo)簽包含了大量的樣本,而尾部的標(biāo)簽擁有很少的樣本,這種現(xiàn)象也叫長(zhǎng)尾現(xiàn)象。岔開(kāi)說(shuō)下,聽(tīng)過(guò)二八定律的人大多知道長(zhǎng)尾現(xiàn)象其實(shí)很普遍,比如80%的財(cái)富掌握在20%的人手中。

樣本不均衡問(wèn)題
樣本不均衡會(huì)帶來(lái)很多問(wèn)題。模型訓(xùn)練的本質(zhì)是最小化損失函數(shù),當(dāng)某個(gè)類(lèi)別的樣本數(shù)量非常龐大,損失函數(shù)的值大部分被其所影響,導(dǎo)致的結(jié)果就是模型分類(lèi)會(huì)傾向于該類(lèi)別(樣本量較大的類(lèi)別)。
咱拿上面文本分類(lèi)的例子來(lái)說(shuō)明?,F(xiàn)在有2W條用戶(hù)搜索的樣本,其中100條是負(fù)面新聞,即負(fù)樣本,那么當(dāng)模型全部將樣本預(yù)測(cè)為正例,也能得到 99.5% 的準(zhǔn)確率。但實(shí)際上這個(gè)模型跟盲猜沒(méi)什么區(qū)別,而我們的目的是讓模型能夠正確的區(qū)分正例和負(fù)例。
1.1 模型層面解決樣本不均衡
在模型層面解決樣本不均衡問(wèn)題,可以選擇加入 Focal Loss 學(xué)習(xí)難學(xué)樣本,具體原理可以參考文章《何愷明大神的「Focal Loss」,如何更好地理解?》[1]。
1.1.1 Focal Loss pytorch代碼實(shí)現(xiàn)
classFocalLoss(nn.Module): """Multi-classFocallossimplementation""" def__init__(self,gamma=2,weight=None,reduction='mean',ignore_index=-100): super(FocalLoss,self).__init__() self.gamma=gamma self.weight=weight self.ignore_index=ignore_index self.reduction=reduction defforward(self,input,target): """ input:[N,C] target:[N,] """ log_pt=torch.log_softmax(input,dim=1) pt=torch.exp(log_pt) log_pt=(1-pt)**self.gamma*log_pt loss=torch.nn.functional.nll_loss(log_pt,target,self.weight,reduction=self.reduction,ignore_index=self.ignore_index) returnloss
代碼鏈接:blog_code/nlp/focal_loss.py[2]
1.2 數(shù)據(jù)層面解決樣本不均衡
假如我們的正樣本只有100條,而負(fù)樣本可能有1W條。如果不采取任何策略,那么我們就是使用這1.01W條樣本去訓(xùn)練模型。從數(shù)據(jù)層面解決樣本不均衡的問(wèn)題核心是通過(guò)人為控制正負(fù)樣本的比例,分成欠采樣和過(guò)采樣兩種。
1.2.1 欠采樣
簡(jiǎn)單隨機(jī)
欠采樣的基本做法是這樣的,現(xiàn)在我們的正負(fù)樣本比例為1:100。如果我們想讓正負(fù)樣本比例不超過(guò)1:10,那么模型訓(xùn)練的時(shí)候數(shù)量比較少的正樣本也就是100條全部使用,而負(fù)樣本隨機(jī)挑選1000條。
通過(guò)這樣人為的方式,我們把樣本的正負(fù)比例強(qiáng)行控制在了1:10。需要注意的是,這種方式存在一個(gè)問(wèn)題:為了強(qiáng)行控制樣本比例我們生生的舍去了那9000條負(fù)樣本,這對(duì)于模型來(lái)說(shuō)是莫大的損失。
迭代預(yù)分類(lèi)
相比于簡(jiǎn)單的對(duì)負(fù)樣本隨機(jī)采樣的欠采樣方法,實(shí)際工作中更推薦使用迭代預(yù)分類(lèi)的方式來(lái)采樣負(fù)樣本。具體流程如下圖所示:

首先我們會(huì)使用全部的正樣本和從負(fù)例候選集中隨機(jī)采樣一部分負(fù)樣本(這里假如是100條)去訓(xùn)練第一輪分類(lèi)器;
然后用第一輪分類(lèi)器去預(yù)測(cè)負(fù)例候選集剩余的9900條數(shù)據(jù),把9900條負(fù)例中預(yù)測(cè)為正例的樣本(也就是預(yù)測(cè)錯(cuò)誤的樣本)再隨機(jī)采樣100條和第一輪訓(xùn)練的數(shù)據(jù)放到一起去訓(xùn)練第二輪分類(lèi)器;
同樣的方法用第二輪分類(lèi)器去預(yù)測(cè)負(fù)例候選集剩余的9800條數(shù)據(jù),直到訓(xùn)練的第N輪分類(lèi)器可以全部識(shí)別負(fù)例候選集,這就是使用迭代預(yù)分類(lèi)的方式進(jìn)行欠采樣。
相比于隨機(jī)欠采樣來(lái)說(shuō),迭代預(yù)分類(lèi)的欠采樣方式能最大限度地利用負(fù)樣本中差異性較大的負(fù)樣本,從而在控制正負(fù)樣本比例的基礎(chǔ)上采樣出了最有代表意義的負(fù)樣本。
欠采樣的方式整體來(lái)說(shuō)或多或少的會(huì)損失一些樣本,對(duì)于那些需要控制樣本量級(jí)的場(chǎng)景下比較合適。如果沒(méi)有嚴(yán)格控制樣本量級(jí)的要求那么下面的過(guò)采樣可能會(huì)更加適合你。
1.2.2 過(guò)采樣
過(guò)采樣和欠采樣比較類(lèi)似,都是人工干預(yù)控制樣本的比例,不同的是過(guò)采樣不會(huì)損失樣本。
還拿上面的例子,現(xiàn)在有正樣本100條,負(fù)樣本1W條,最簡(jiǎn)單的過(guò)采樣方式是我們使用全部的負(fù)樣本1W條。但是,為了維持正負(fù)樣本比例,我們會(huì)從正樣本中有放回的重復(fù)采樣,直到獲取了1000條正樣本,也就是說(shuō)有些正樣本可能會(huì)被重復(fù)采樣到,這樣就能保持1:10的正負(fù)樣本比例了。這是最簡(jiǎn)單的過(guò)采樣方式,這種方式可能會(huì)存在嚴(yán)重的過(guò)擬合。
實(shí)際的場(chǎng)景中會(huì)通過(guò)樣本增強(qiáng)的技術(shù)來(lái)增加正樣本。
2. 提升模型魯棒性
提升模型魯棒性的方法有很多,其中對(duì)抗訓(xùn)練、知識(shí)蒸餾、防止模型過(guò)擬合和多模型融合是常見(jiàn)的穩(wěn)定提升方式。
2.1 對(duì)抗訓(xùn)練
對(duì)抗訓(xùn)練是一種能有效提高模型魯棒性和泛化能力的訓(xùn)練手段,其基本原理是通過(guò)在原始輸入上增加對(duì)抗擾動(dòng),得到對(duì)抗樣本,再利用對(duì)抗樣本進(jìn)行訓(xùn)練,從而提高模型的表現(xiàn)。
由于自然語(yǔ)言文本是離散的,一般會(huì)把對(duì)抗擾動(dòng)添加到嵌入層上。為了最大化對(duì)抗樣本的擾動(dòng)能力,利用梯度上升的方式生成對(duì)抗樣本。為了避免擾動(dòng)過(guò)大,將梯度做了歸一化處理。

其中, 為嵌入向量。在實(shí)際訓(xùn)練過(guò)程中,我們會(huì)在訓(xùn)練完一個(gè)batch的原始輸入數(shù)據(jù)時(shí),保存當(dāng)前batch對(duì)輸入詞向量的梯度,得到對(duì)抗樣本后,再使用對(duì)抗樣本進(jìn)行對(duì)抗訓(xùn)練。
2.1.1 對(duì)抗訓(xùn)練pytorch代碼實(shí)現(xiàn)
class FGM():
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1., emb_name='emb'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0:
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='emb'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
訓(xùn)練中加入幾行代碼
# 初始化 fgm = FGM(model) for batch_input, batch_label in data: # 正常訓(xùn)練 loss = model(batch_input, batch_label) loss.backward() # 對(duì)抗訓(xùn)練 fgm.attack() # 修改embedding # optimizer.zero_grad() # 梯度累加,不累加去掉注釋 loss_sum = model(batch_input, batch_label) loss_sum.backward() # 累加對(duì)抗訓(xùn)練的梯度 fgm.restore() # 恢復(fù)Embedding的參數(shù) optimizer.step() optimizer.zero_grad()
代碼鏈接:blog_code/nlp/at.py [3]
2.2 知識(shí)蒸餾
與對(duì)抗訓(xùn)練類(lèi)似,知識(shí)蒸餾也是一種常用的提高模型泛化能力的訓(xùn)練方法。
知識(shí)蒸餾這個(gè)概念最早由Hinton在2015年提出。一開(kāi)始,知識(shí)蒸餾通往往應(yīng)用在模型壓縮方面,利用訓(xùn)練好的復(fù)雜模型(teacher model)輸出作為監(jiān)督信號(hào)去訓(xùn)練另一個(gè)簡(jiǎn)單模型(student model),從而將teacher學(xué)習(xí)到的知識(shí)遷移到student。
Tommaso在18年提出,若student和teacher的模型完全相同,蒸餾后則會(huì)對(duì)模型的表現(xiàn)有一定程度上的提升。
2.3 防止模型過(guò)擬合
2.3.1 正則化
L1和L2正則化
L1正則化可以得到稀疏解,L2正則化可以得到平滑解,原因參考文章《為什么L1稀疏,L2平滑?》[4]。
2.3.2 Dropout
Dropout是指在深度學(xué)習(xí)網(wǎng)絡(luò)的訓(xùn)練過(guò)程中,對(duì)于神經(jīng)網(wǎng)絡(luò)單元,按照一定的概率將其暫時(shí)從網(wǎng)絡(luò)中丟棄。
Dropout為什么能防止過(guò)擬合,可以通過(guò)以下幾個(gè)方面來(lái)解釋?zhuān)?/p>
它強(qiáng)迫一個(gè)神經(jīng)單元,和隨機(jī)挑選出來(lái)的其他神經(jīng)單元共同工作,達(dá)到好的效果。消除減弱了神經(jīng)元節(jié)點(diǎn)間的聯(lián)合適應(yīng)性,增強(qiáng)了泛化能力;
類(lèi)似于bagging的集成效果;
對(duì)于每一個(gè)dropout后的網(wǎng)絡(luò),進(jìn)行訓(xùn)練時(shí),相當(dāng)于做了Data Augmentation,因?yàn)?,總可以找到一個(gè)樣本,使得在原始的網(wǎng)絡(luò)上也能達(dá)到dropout單元后的效果。比如,對(duì)于某一層,dropout一些單元后,形成的結(jié)果是(1.5,0,2.5,0,1,2,0),其中0是被drop的單元,那么總能找到一個(gè)樣本,使得結(jié)果也是如此。這樣,每一次dropout其實(shí)都相當(dāng)于增加了樣本。
Dropout在測(cè)試時(shí),并不會(huì)隨機(jī)丟棄神經(jīng)元,而是使用全部所有的神經(jīng)元,同時(shí),所有的權(quán)重值都乘上1-p,p代表的是隨機(jī)失活率。
2.3.3 數(shù)據(jù)增強(qiáng)
數(shù)據(jù)增強(qiáng)即需要得到更多的符合要求的數(shù)據(jù),即和已有的數(shù)據(jù)是獨(dú)立同分布的,或者近似獨(dú)立同分布的。一般有以下方法:
1)從數(shù)據(jù)源頭采集更多數(shù)據(jù);
2)復(fù)制原有數(shù)據(jù)并加上隨機(jī)噪聲;
3)重采樣;
4)根據(jù)當(dāng)前數(shù)據(jù)集估計(jì)數(shù)據(jù)分布參數(shù),使用該分布產(chǎn)生更多數(shù)據(jù)等。
2.3.4 Early stopping
在模型對(duì)訓(xùn)練數(shù)據(jù)集迭代收斂之前停止迭代來(lái)防止過(guò)擬合。因?yàn)樵诔跏蓟W(wǎng)絡(luò)的時(shí)候一般都是初始為較小的權(quán)值,訓(xùn)練時(shí)間越長(zhǎng),部分網(wǎng)絡(luò)權(quán)值可能越大。如果我們?cè)诤线m時(shí)間停止訓(xùn)練,就可以將網(wǎng)絡(luò)的能力限制在一定范圍內(nèi)。
2.3.5 交叉驗(yàn)證
交叉驗(yàn)證的基本思想就是將原始數(shù)據(jù)進(jìn)行分組,一部分做為訓(xùn)練集來(lái)訓(xùn)練模型,另一部分做為測(cè)試集來(lái)評(píng)價(jià)模型。我們常用的交叉驗(yàn)證方法有簡(jiǎn)單交叉驗(yàn)證、S折交叉驗(yàn)證和留一交叉驗(yàn)證。
2.3.6 Batch Normalization
一種非常有用的正則化方法,可以讓大型的卷積網(wǎng)絡(luò)訓(xùn)練速度加快很多倍,同時(shí)收斂后分類(lèi)的準(zhǔn)確率也可以大幅度的提高。
BN在訓(xùn)練某層時(shí),會(huì)對(duì)每一個(gè)mini-batch數(shù)據(jù)進(jìn)行標(biāo)準(zhǔn)化(normalization)處理,使輸出規(guī)范到 的正態(tài)分布,減少了Internal convariate shift(內(nèi)部神經(jīng)元分布的改變),傳統(tǒng)的深度神經(jīng)網(wǎng)絡(luò)在訓(xùn)練是每一層的輸入的分布都在改變,因此訓(xùn)練困難,只能選擇用一個(gè)很小的學(xué)習(xí)速率,但是每一層用了BN后,可以有效的解決這個(gè)問(wèn)題,學(xué)習(xí)速率可以增大很多倍。
2.3.7 選擇合適的網(wǎng)絡(luò)結(jié)構(gòu)
通過(guò)減少網(wǎng)絡(luò)層數(shù)、神經(jīng)元個(gè)數(shù)、全連接層數(shù)等降低網(wǎng)絡(luò)容量。
3.多模型融合
Baggging &Boosting,將弱分類(lèi)器融合之后形成一個(gè)強(qiáng)分類(lèi)器,而且融合之后的效果會(huì)比最好的弱分類(lèi)器更好,三個(gè)臭皮匠頂一個(gè)諸葛亮。
-
模型
+關(guān)注
關(guān)注
1文章
3753瀏覽量
52116 -
代碼
+關(guān)注
關(guān)注
30文章
4968瀏覽量
73998 -
nlp
+關(guān)注
關(guān)注
1文章
491瀏覽量
23282
原文標(biāo)題:2. 提升模型魯棒性
文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語(yǔ)言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
pyhanlp文本分類(lèi)與情感分析
NLPIR平臺(tái)在文本分類(lèi)方面的技術(shù)解析
不均衡數(shù)據(jù)集上基于子域?qū)W習(xí)的復(fù)合分類(lèi)模型
結(jié)合BERT模型的中文文本分類(lèi)算法
融合文本分類(lèi)和摘要的多任務(wù)學(xué)習(xí)摘要模型
基于不同神經(jīng)網(wǎng)絡(luò)的文本分類(lèi)方法研究對(duì)比
膠囊網(wǎng)絡(luò)在小樣本做文本分類(lèi)中的應(yīng)用(下)
基于主題分布優(yōu)化的模糊文本分類(lèi)方法
如何解決樣本不均的問(wèn)題?
文本分類(lèi)中處理樣本不均衡和提升模型魯棒性的trick
評(píng)論