91欧美超碰AV自拍|国产成年人性爱视频免费看|亚洲 日韩 欧美一厂二区入|人人看人人爽人人操aV|丝袜美腿视频一区二区在线看|人人操人人爽人人爱|婷婷五月天超碰|97色色欧美亚州A√|另类A√无码精品一级av|欧美特级日韩特级

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

如何進(jìn)行MLM訓(xùn)練

深度學(xué)習(xí)自然語言處理 ? 來源:CSDN ? 作者:常鴻宇 ? 2022-08-13 10:54 ? 次閱讀
加入交流群
微信小助手二維碼

掃碼添加小助手

加入工程師交流群

1. 關(guān)于MLM

1.1 背景

作為 Bert 預(yù)訓(xùn)練的兩大任務(wù)之一,MLMNSP 大家應(yīng)該并不陌生。其中,NSP 任務(wù)在后續(xù)的一些預(yù)訓(xùn)練任務(wù)中經(jīng)常被嫌棄,例如 Roberta 中將 NSP 任務(wù)直接放棄,Albert 中將 NSP 替換成了句子順序預(yù)測(cè)。

這主要是因?yàn)?NSP 作為一個(gè)分類任務(wù)過于簡單,對(duì)模型的學(xué)習(xí)并沒有太大的幫助,而 MLM 則被多數(shù)預(yù)訓(xùn)練模型保留下來。由 Roberta的實(shí)驗(yàn)結(jié)果也可以證明,Bert 的主要能力應(yīng)該是來自于 MLM 任務(wù)的訓(xùn)練。

Bert為代表的預(yù)訓(xùn)練語言模型是在大規(guī)模語料的基礎(chǔ)上訓(xùn)練以獲得的基礎(chǔ)的學(xué)習(xí)能力,而實(shí)際應(yīng)用時(shí),我們所面臨的語料或許具有某些特殊性,這就使得重新進(jìn)行 MLM 訓(xùn)練具有了必要性。

1.2 如何進(jìn)行MLM訓(xùn)練

1.2.1 什么是MLM

MLM 的訓(xùn)練,在不同的預(yù)訓(xùn)練模型中其實(shí)是有所不同的。今天介紹的內(nèi)容以最基礎(chǔ)的 Bert 為例。

Bert的MLM是靜態(tài)mask,而在后續(xù)的其他預(yù)訓(xùn)練模型中,這一策略通常被替換成了動(dòng)態(tài)mask。除此之外還有 whole word mask 的模型,這些都不在今天的討論范圍內(nèi)。

所謂 mask language model 的任務(wù),通俗來講,就是將句子中的一部分token替換掉,然后根據(jù)句子的剩余部分,試圖去還原這部分被mask的token

1.2.2 如何Mask

mask 的比例一般是15%,這一比例也被后續(xù)的多數(shù)模型所繼承,而在最初BERT 的論文中,沒有對(duì)這一比例的界定給出具體的說明。在我的印象中,似乎是知道后來同樣是Google提出的 T5 模型的論文中,對(duì)此進(jìn)行了解釋,對(duì) mask 的比例進(jìn)行了實(shí)驗(yàn),最終得出結(jié)論,15%的比例是最合理的(如果我記錯(cuò)了,還請(qǐng)指正)。

15%的token選出之后,并不是所有的都替換成[mask]標(biāo)記符。實(shí)際操作是:

  • 從這15%選出的部分中,將其中的80%替換成[mask];
  • 10%替換成一個(gè)隨機(jī)的token;
  • 剩下的10%保留原來的token。

這樣做可以提高模型的魯棒性。這個(gè)比例也可以自己控制。

到這里可能有同學(xué)要問了,既然有10%保留不變的話,為什么不干脆只選擇15%*90% = 13.5%的token呢?如果看完后面的代碼,就會(huì)很清楚地理解這個(gè)問題了。

先說結(jié)論:因?yàn)?MLM 的任務(wù)是將選出的這15%的token全部進(jìn)行預(yù)測(cè),不管這個(gè)token是否被替換成了[mask],也就是說,即使它被保留了原樣,也還是需要被預(yù)測(cè)的

2. 代碼部分

2.1 背景

介紹完了基礎(chǔ)內(nèi)容之后,接下來的內(nèi)容,我將基于 transformers 模塊,介紹如何進(jìn)行 mask language model 的訓(xùn)練。

其實(shí) transformers 模塊中,本身是提供了 MLM 訓(xùn)練任務(wù)的,模型都寫好了,只需要調(diào)用它內(nèi)置的 trainerdatasets模塊即可。感興趣的同學(xué)可以去 huggingface 的官網(wǎng)搜索相關(guān)教程。

然而我覺得 datasets 每次調(diào)用的時(shí)候都要去寫數(shù)據(jù)集的py文件,對(duì)arrow的數(shù)據(jù)格式不熟悉的話還很容易出錯(cuò),而且 trainer 我覺得也不是很好用,任何一點(diǎn)小小的修改都挺費(fèi)勁(就是它以為它寫的很完備,考慮了用戶的所有需求,但是實(shí)際上有一些冗余的部分)。

所以我就參考它的實(shí)現(xiàn)方式,把它的代碼拆解,又按照自己的方式重新組織了一下。

2.2 準(zhǔn)備工作

首先在寫核心代碼之前,先做好準(zhǔn)備工作。
import 所有需要的模塊:

import os
import json
import copy
from tqdm.notebook import tqdm

import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import BertForMaskedLM, BertTokenizerFast

然后寫一個(gè)config類,將所有參數(shù)集中起來:

class Config:
    def __init__(self):
        pass
    
    def mlm_config(
        self, 
        mlm_probability=0.15, 
        special_tokens_mask=None,
        prob_replace_mask=0.8,
        prob_replace_rand=0.1,
        prob_keep_ori=0.1,
    ):
        """
        :param mlm_probability: 被mask的token總數(shù)
        :param special_token_mask: 特殊token
        :param prob_replace_mask: 被替換成[MASK]的token比率
        :param prob_replace_rand: 被隨機(jī)替換成其他token比率
        :param prob_keep_ori: 保留原token的比率
        """
        assert sum([prob_replace_mask, prob_replace_rand, prob_keep_ori]) == 1,                 ValueError("Sum of the probs must equal to 1.")
        self.mlm_probability = mlm_probability
        self.special_tokens_mask = special_tokens_mask
        self.prob_replace_mask = prob_replace_mask
        self.prob_replace_rand = prob_replace_rand
        self.prob_keep_ori = prob_keep_ori
        
    def training_config(
        self,
        batch_size,
        epochs,
        learning_rate,
        weight_decay,
        device,
    ):
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.device = device
        
    def io_config(
        self,
        from_path,
        save_path,
    ):
        self.from_path = from_path
        self.save_path = save_path

接著就是設(shè)置各種配置:

config = Config()
config.mlm_config()
config.training_config(batch_size=4, epochs=10, learning_rate=1e-5, weight_decay=0, device='cuda:0')
config.io_config(from_path='/data/BERTmodels/huggingface/chinese_wwm/', 
                 save_path='./finetune_embedding_model/mlm/')

最后創(chuàng)建BERT模型。注意,這里的 tokenizer 就是一個(gè)普通的 tokenizer,而BERT模型則是帶了下游任務(wù)的 BertForMaskedLM,它是 transformers 中寫好的一個(gè)類,

bert_tokenizer = BertTokenizerFast.from_pretrained(config.from_path)
bert_mlm_model = BertForMaskedLM.from_pretrained(config.from_path)

2.3 數(shù)據(jù)集

因?yàn)樯釛壛?code style="font-size:14px;padding:2px 4px;margin:0 2px;color:#1e6bb8;background-color:rgba(27,31,35,.05);font-family:'Operator Mono', Consolas, Monaco, Menlo, monospace;">datasets這個(gè)包,所以我們現(xiàn)在需要自己實(shí)現(xiàn)數(shù)據(jù)的輸入了。方案就是使用 torchDataset 類。這個(gè)類一般在構(gòu)建 DataLoader 的時(shí)候,會(huì)與一個(gè)聚合函數(shù)一起使用,以實(shí)現(xiàn)對(duì)batch的組織。而我這里偷個(gè)懶,就沒有寫聚合函數(shù),batch的組織方法放在dataset中進(jìn)行。

在這個(gè)類中,有一個(gè) mask tokens 的方法,作用是從數(shù)據(jù)中選擇出所有需要mask 的token,并且采用三種mask方式中的一個(gè)。這個(gè)方法是從transformers 中拿出來的,將其從類方法轉(zhuǎn)為靜態(tài)方法測(cè)試之后,再將其放在自己的這個(gè)類中為我們所用。仔細(xì)閱讀這一段代碼,也就可以回答1.2.2 中提出的那個(gè)問題了。

取batch的原理很簡單,一開始我們將原始數(shù)據(jù)deepcopy備份一下,然后每次從中截取一個(gè)batch的大小,這個(gè)時(shí)候的當(dāng)前數(shù)據(jù)就少了一個(gè)batch,我們定義這個(gè)類的長度為當(dāng)前長度除以batch size向下取整,所以當(dāng)類的長度變?yōu)?的時(shí)候,就說明這一個(gè)epoch的所有step都已經(jīng)執(zhí)行結(jié)束,要進(jìn)行下一個(gè)epoch的訓(xùn)練,此時(shí),再將當(dāng)前數(shù)據(jù)變?yōu)樵紨?shù)據(jù),就可以實(shí)現(xiàn)對(duì)epoch的循環(huán)了。

class TrainDataset(Dataset):
    """
    注意:由于沒有使用data_collator,batch放在dataset里邊做,
    因而在dataloader出來的結(jié)果會(huì)多套一層batch維度,傳入模型時(shí)注意squeeze掉
    """
    def __init__(self, input_texts, tokenizer, config):
        self.input_texts = input_texts
        self.tokenizer = tokenizer
        self.config = config
        self.ori_inputs = copy.deepcopy(input_texts)
        
    def __len__(self):
        return len(self.input_texts) // self.config.batch_size
    
    def __getitem__(self, idx):
        batch_text = self.input_texts[: self.config.batch_size]
        features = self.tokenizer(batch_text, max_length=512, truncation=True, padding=True, return_tensors='pt')
        inputs, labels = self.mask_tokens(features['input_ids'])
        batch = {"inputs": inputs, "labels": labels}
        self.input_texts = self.input_texts[self.config.batch_size: ]
        if not len(self):
            self.input_texts = self.ori_inputs
        
        return batch
        
    def mask_tokens(self, inputs):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.config.mlm_probability)
        if self.config.special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = self.config.special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, self.config.prob_replace_mask)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        current_prob = self.config.prob_replace_rand / (1 - self.config.prob_replace_mask)
        indices_random = torch.bernoulli(torch.full(labels.shape, current_prob)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

然后取一些用于訓(xùn)練的語料,格式很簡單,就是把所有文本放在一個(gè)list里邊,注意長度不要超過512個(gè)token,不然多出來的部分就浪費(fèi)掉了。可以做適當(dāng)?shù)念A(yù)處理。

[
"這是一條文本",
"這是另一條文本",
...,
]

然后構(gòu)建dataloader:

train_dataset = TrainDataset(training_texts, bert_tokenizer, config)
train_dataloader = DataLoader(train_dataset)

2.4 訓(xùn)練

構(gòu)建一個(gè)訓(xùn)練方法,輸入?yún)?shù)分別是我們實(shí)例化好的待訓(xùn)練模型,數(shù)據(jù)集,還有config:

def train(model, train_dataloader, config):
    """
    訓(xùn)練
    :param model: nn.Module
    :param train_dataloader: DataLoader
    :param config: Config
    ---------------
    ver: 2021-11-08
    by: changhongyu
    """
    assert config.device.startswith('cuda') or config.device == 'cpu', ValueError("Invalid device.")
    device = torch.device(config.device)
    
    model.to(device)
    
    if not len(train_dataloader):
        raise EOFError("Empty train_dataloader.")
        
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01},
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}]
    
    optimizer = AdamW(params=optimizer_grouped_parameters, lr=config.learning_rate, weight_decay=config.weight_decay)
    
    for cur_epc in tqdm(range(int(config.epochs)), desc="Epoch"):
        training_loss = 0
        print("Epoch: {}".format(cur_epc+1))
        model.train()
        for step, batch in enumerate(tqdm(train_dataloader, desc='Step')):
            input_ids = batch['inputs'].squeeze(0).to(device)
            labels = batch['labels'].squeeze(0).to(device)
            loss = model(input_ids=input_ids, labels=labels).loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.zero_grad()
            training_loss += loss.item()
        print("Training loss: ", training_loss)

調(diào)用它訓(xùn)練幾輪:

train(model=bert_mlm_model, train_dataloader=train_dataloader, config=config)

2.5 保存和加載

使用過預(yù)訓(xùn)練模型的同學(xué)應(yīng)該都了解,普通的bert有兩項(xiàng)輸出,分別是:

  • 每一個(gè)token對(duì)應(yīng)的768維編碼結(jié)果;
  • 以及用于表征整個(gè)句子的句子特征。

其中,這個(gè)句子特征是由模型中的一個(gè) Pooler 模塊對(duì)原句池化得來的??墒沁@個(gè)Pooler的訓(xùn)練,并不是由 MLM 任務(wù)來的,而是由 NSP任務(wù)中來的。

由于沒有 NSP 任務(wù),所以無法對(duì) Pooler 進(jìn)行訓(xùn)練,故而沒有必要在模型中加入 Pooler。所以在保存的時(shí)候需要分別保存 embedding和 encoder, 加載的時(shí)候也需要分別讀取 embedding 和 encoder,這樣訓(xùn)練出來的模型拿不到 CLS 層的句子表征。如果需要的話,可以手動(dòng)pooling 。

torch.save(bert_mlm_model.bert.embeddings.state_dict(), os.path.join(config.save_path, 'bert_mlm_ep_{}_eb.bin'.format(config.epochs)))
torch.save(bert_mlm_model.bert.encoder.state_dict(), os.path.join(config.save_path, 'bert_mlm_ep_{}_ec.bin'.format(config.epochs)))

加載的話,也是實(shí)例化完bert模型之后,用bert的 embedding 組件和 encoder 組件分別讀取這兩個(gè)權(quán)重文件即可。

到這里,本期內(nèi)容就全部結(jié)束了,希望看完這篇博客的同學(xué),能夠?qū)?Bert 的基礎(chǔ)原理有更深入的了解。

審核編輯 :李倩


聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 模型
    +關(guān)注

    關(guān)注

    1

    文章

    3757

    瀏覽量

    52130
  • 語言模型
    +關(guān)注

    關(guān)注

    0

    文章

    572

    瀏覽量

    11323
  • mask
    +關(guān)注

    關(guān)注

    0

    文章

    10

    瀏覽量

    3227

原文標(biāo)題:2. 代碼部分

文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏
加入交流群
微信小助手二維碼

掃碼添加小助手

加入工程師交流群

    評(píng)論

    相關(guān)推薦
    熱點(diǎn)推薦

    對(duì)于設(shè)備上的舊固件如何進(jìn)行備份和恢復(fù)?

    對(duì)于設(shè)備上的舊固件,如何進(jìn)行備份和恢復(fù)?
    發(fā)表于 12-12 08:23

    在使用CW32L083系列微控制器時(shí)如何進(jìn)行系統(tǒng)復(fù)位和看門狗定時(shí)器配置?

    在使用CW32L083系列微控制器時(shí),如何進(jìn)行系統(tǒng)復(fù)位和看門狗定時(shí)器配置?
    發(fā)表于 12-10 06:46

    單片機(jī)如何進(jìn)行加解密鑰操作,一般使用哪種形式,具體流程是什么樣子的?

    目前單片機(jī)如何進(jìn)行加解密鑰操作,一般使用哪種形式,具體流程是什么樣子的?
    發(fā)表于 12-04 06:09

    L083最低功耗是多少,應(yīng)該如何進(jìn)行低功耗設(shè)計(jì)?有哪些注意事項(xiàng)?

    L083最低功耗是多少,應(yīng)該如何進(jìn)行低功耗設(shè)計(jì)?有哪些注意事項(xiàng)?
    發(fā)表于 11-12 07:29

    在Ubuntu20.04系統(tǒng)中訓(xùn)練神經(jīng)網(wǎng)絡(luò)模型的一些經(jīng)驗(yàn)

    模型。 我們使用MNIST數(shù)據(jù)集,訓(xùn)練一個(gè)卷積神經(jīng)網(wǎng)絡(luò)(CNN)模型,用于手寫數(shù)字識(shí)別。一旦模型被訓(xùn)練并保存,就可以用于對(duì)新圖像進(jìn)行推理和預(yù)測(cè)。要使用生成的模型進(jìn)行推理,可以按照以下步
    發(fā)表于 10-22 07:03

    何進(jìn)行聲音定位?

    文章主要介紹了如何利用一種簡單的TDOA算法進(jìn)行聲音點(diǎn)位,并使用數(shù)據(jù)采集卡進(jìn)行聲音定位的實(shí)驗(yàn)。
    的頭像 發(fā)表于 09-23 15:47 ?1843次閱讀
    如<b class='flag-5'>何進(jìn)行</b>聲音定位?

    2KW逆變側(cè)功率管的損耗如何進(jìn)行計(jì)算詳細(xì)公式免費(fèi)下載

    本文檔的主要內(nèi)容詳細(xì)介紹的是2KW逆變側(cè)功率管的損耗如何進(jìn)行計(jì)算詳細(xì)公式免費(fèi)下載。
    發(fā)表于 08-29 16:18 ?34次下載

    何進(jìn)行YOLO模型轉(zhuǎn)換?

    我目前使用的轉(zhuǎn)模型代碼如下 from ultralytics import YOLOimport cv2import timeimport nncaseimport# 加載預(yù)訓(xùn)練的YOLO模型
    發(fā)表于 08-14 06:03

    在對(duì)廬山派K230的SD卡data文件夾進(jìn)行刪除和新件文件夾時(shí)無法操作,且訓(xùn)練時(shí)線程異常,怎么解決?

    1.我的SD卡可以正常啟動(dòng),也可以收到來自于廬山派攝像頭拍攝的照片,SD卡有data和sdcard分區(qū) 2.但是一旦進(jìn)行刪除或?qū)懭耄蜁?huì)斷開,我一開始在data/data/images這個(gè)目錄
    發(fā)表于 08-01 08:03

    使用 ai cude 里面自帶的案例訓(xùn)練UI顯示異常的原因?怎么解決?

    案例的配置是默認(rèn)的,顯示訓(xùn)練ui更改顯示異常
    發(fā)表于 06-23 06:21

    k210在線訓(xùn)練的算法是yolo5嗎?

    k210在線訓(xùn)練的算法是yolo5嗎
    發(fā)表于 06-16 08:25

    OCR識(shí)別訓(xùn)練完成后給的是空壓縮包,為什么?

    OCR識(shí)別 一共弄了26張圖片,都標(biāo)注好了,點(diǎn)擊開始訓(xùn)練,顯示訓(xùn)練成功了,也將壓縮包發(fā)到郵箱了,下載下來后,壓縮包里面是空的 OCR圖片20幾張圖太少了。麻煩您多添加點(diǎn),參考我們的ocr識(shí)別訓(xùn)練數(shù)據(jù)集 請(qǐng)問
    發(fā)表于 05-28 06:46

    海思SD3403邊緣計(jì)算AI數(shù)據(jù)訓(xùn)練概述

    AI數(shù)據(jù)訓(xùn)練:基于用戶特定應(yīng)用場景,用戶采集照片或視頻,通過AI數(shù)據(jù)訓(xùn)練工程師**(用戶公司****員工)** ,進(jìn)行特征標(biāo)定后,將標(biāo)定好的訓(xùn)練樣本,通過AI
    發(fā)表于 04-28 11:11

    請(qǐng)問STM32WBA65如何進(jìn)行matter的學(xué)習(xí)?

    STM32WBA65如何進(jìn)行matter的學(xué)習(xí)?相關(guān)的支持都有哪些?有一個(gè)X-CUBE-MATTER,可是這個(gè)沒有集成在STM32CubeMX中
    發(fā)表于 04-24 07:22

    使用CAN以及CANIF配置了S32K310的CAN驅(qū)動(dòng)模塊,如何進(jìn)行報(bào)文的接收呢?

    我使用CAN以及CANIF配置了S32K310的CAN驅(qū)動(dòng)模塊。我知道調(diào)用CAN_Write()函數(shù)進(jìn)行報(bào)文的發(fā)送,但我存有以下的一些問題: 1.我該如何進(jìn)行報(bào)文的接收呢?我看到有一些文章說能夠通過
    發(fā)表于 03-21 07:24