91欧美超碰AV自拍|国产成年人性爱视频免费看|亚洲 日韩 欧美一厂二区入|人人看人人爽人人操aV|丝袜美腿视频一区二区在线看|人人操人人爽人人爱|婷婷五月天超碰|97色色欧美亚州A√|另类A√无码精品一级av|欧美特级日韩特级

電子發(fā)燒友App

硬聲App

掃碼添加小助手

加入工程師交流群

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示
創(chuàng)作
電子發(fā)燒友網(wǎng)>電子資料下載>電子資料>PyTorch教程15.10之預(yù)訓(xùn)練BERT

PyTorch教程15.10之預(yù)訓(xùn)練BERT

2023-06-05 | pdf | 0.15 MB | 次下載 | 免費(fèi)

資料介紹

借助15.8 節(jié)中實現(xiàn)的 BERT 模型和15.9 節(jié)中從 WikiText-2 數(shù)據(jù)集生成的預(yù)訓(xùn)練示例 ,我們將在本節(jié)中在 WikiText-2 數(shù)據(jù)集上預(yù)訓(xùn)練 BERT。

import torch
from torch import nn
from d2l import torch as d2l
from mxnet import autograd, gluon, init, np, npx
from d2l import mxnet as d2l

npx.set_np()

首先,我們將 WikiText-2 數(shù)據(jù)集加載為用于屏蔽語言建模和下一句預(yù)測的小批量預(yù)訓(xùn)練示例。批量大小為 512,BERT 輸入序列的最大長度為 64。請注意,在原始 BERT 模型中,最大長度為 512。

batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)
batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)
Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip...

15.10.1。預(yù)訓(xùn)練 BERT

原始 BERT 有兩個不同模型大小的版本 Devlin et al. , 2018。基礎(chǔ)模型(BERTBASE) 使用 12 層(Transformer 編碼器塊),具有 768 個隱藏單元(隱藏大?。┖?12 個自注意力頭。大模型(BERTLARGE) 使用 24 層,有 1024 個隱藏單元和 16 個自注意力頭。值得注意的是,前者有 1.1 億個參數(shù),而后者有 3.4 億個參數(shù)。為了便于演示,我們定義了一個小型 BERT,使用 2 層、128 個隱藏單元和 2 個自注意力頭。

net = d2l.BERTModel(len(vocab), num_hiddens=128,
          ffn_num_hiddens=256, num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
loss = nn.CrossEntropyLoss()
net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
          num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
net.initialize(init.Xavier(), ctx=devices)
loss = gluon.loss.SoftmaxCELoss()

在定義訓(xùn)練循環(huán)之前,我們定義了一個輔助函數(shù) _get_batch_loss_bert給定訓(xùn)練示例的碎片,此函數(shù)計算掩碼語言建模和下一句預(yù)測任務(wù)的損失。請注意,BERT 預(yù)訓(xùn)練的最終損失只是掩碼語言建模損失和下一句預(yù)測損失的總和。

#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
             segments_X, valid_lens_x,
             pred_positions_X, mlm_weights_X,
             mlm_Y, nsp_y):
  # Forward pass
  _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
                 valid_lens_x.reshape(-1),
                 pred_positions_X)
  # Compute masked language model loss
  mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
  mlm_weights_X.reshape(-1, 1)
  mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
  # Compute next sentence prediction loss
  nsp_l = loss(nsp_Y_hat, nsp_y)
  l = mlm_l + nsp_l
  return mlm_l, nsp_l, l
#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
             segments_X_shards, valid_lens_x_shards,
             pred_positions_X_shards, mlm_weights_X_shards,
             mlm_Y_shards, nsp_y_shards):
  mlm_ls, nsp_ls, ls = [], [], []
  for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
     pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
     nsp_y_shard) in zip(
    tokens_X_shards, segments_X_shards, valid_lens_x_shards,
    pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards,
    nsp_y_shards):
    # Forward pass
    _, mlm_Y_hat, nsp_Y_hat = net(
      tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1),
      pred_positions_X_shard)
    # Compute masked language model loss
    mlm_l = loss(
      mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1),
      mlm_weights_X_shard.reshape((-1, 1)))
    mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
    # Compute next sentence prediction loss
    nsp_l = loss(nsp_Y_hat, nsp_y_shard)
    nsp_l = nsp_l.mean()
    mlm_ls.append(mlm_l)
    nsp_ls.append(nsp_l)
    ls.append(mlm_l + nsp_l)
    npx.waitall()
  return mlm_ls, nsp_ls, ls

調(diào)用上述兩個輔助函數(shù),以下 函數(shù)定義了在 WikiText-2 ( ) 數(shù)據(jù)集上train_bert預(yù)訓(xùn)練 BERT ( ) 的過程。訓(xùn)練 BERT 可能需要很長時間。與在函數(shù)中指定訓(xùn)練的時期數(shù)不同 (參見第 14.1 節(jié)),以下函數(shù)的輸入指定訓(xùn)練的迭代步數(shù)。nettrain_itertrain_ch13num_steps

def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
  net(*next(iter(train_iter))[:4])
  net = nn.DataParallel(net, device_ids=devices).to(devices[0])
  trainer = torch.optim.Adam(net.parameters(), lr=0.01)
  step, timer = 0, d2l.Timer()
  animator = d2l.Animator(xlabel='step', ylabel='loss',
              xlim=[1, num_steps], legend=['mlm', 'nsp'])
  # Sum of masked language modeling losses, sum of next sentence prediction
  # losses, no. of sentence pairs, count
  metric = d2l.Accumulator(4)
  num_steps_reached = False
  while step < num_steps and not num_steps_reached:
    for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
      mlm_weights_X, mlm_Y, nsp_y in train_iter:
      tokens_X = tokens_X.to(devices[0])
      segments_X = segments_X.to(devices[0])
      valid_lens_x = valid_lens_x.to(devices[0])
      pred_positions_X = pred_positions_X.to(devices[0])
      mlm_weights_X = mlm_weights_X.to(devices[0])
      mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
      trainer.zero_grad()
      timer.start()
      mlm_l, nsp_l, l = _get_batch_loss_bert(
        net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
        pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
      l.backward()
      trainer.step()
      metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
      timer.stop()
      animator.add(step + 1,
             (metric[0] / metric[3], metric[1] / metric[3]))
      step += 1
      if step == num_steps:
        num_steps_reached = True
        break

  print(f'MLM loss {metric[0] / metric[3]:.3f}, '
     f'NSP loss {metric[1] / metric[3]:.3f}')
  print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
     f'{str(devices)}')
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
  trainer = gluon.Trainer(net.collect_params(), 'adam',
              {'learning_rate': 0.01})
  step, timer = 0, d2l.Timer()
  animator = d2l.Animator(xlabel='step', ylabel='loss',
              xlim=[1, num_steps], legend=['mlm', 'nsp'])
  # Sum of masked language modeling losses, sum of next sentence prediction
  # losses, no. of sentence pairs, count
  metric = d2l.

函數(shù) 數(shù)據(jù)集 pytorch
加入交流群
微信小助手二維碼

掃碼添加小助手

加入工程師交流群

下載該資料的人也在下載 下載該資料的人還在閱讀
更多 >

評論

查看更多

下載排行

本周

  1. 1新一代網(wǎng)絡(luò)可視化(NPB 2.0)
  2. 3.40 MB  |  1次下載  |  免費(fèi)
  3. 2冷柜-電氣控制系統(tǒng)講解
  4. 13.68 MB   |  1次下載  |  10 積分
  5. 3MDD品牌三極管MMBT3906數(shù)據(jù)手冊
  6. 2.33 MB  |  次下載  |  免費(fèi)
  7. 4MDD品牌三極管S9012數(shù)據(jù)手冊
  8. 2.62 MB  |  次下載  |  免費(fèi)
  9. 5LAT1218 如何選擇和設(shè)置外部晶體適配 BlueNRG-X
  10. 0.60 MB   |  次下載  |  3 積分
  11. 6LAT1216 Blue NRG-1/2 系列芯片 Flash 操作與 BLE 事件的互斥處理
  12. 0.89 MB   |  次下載  |  3 積分
  13. 7收音環(huán)繞擴(kuò)音機(jī) AVR-1507手冊
  14. 2.50 MB   |  次下載  |  免費(fèi)
  15. 8MS1000TA 超聲波測量模擬前端芯片技術(shù)手冊
  16. 0.60 MB   |  次下載  |  免費(fèi)

本月

  1. 1愛華AIWA HS-J202維修手冊
  2. 3.34 MB   |  37次下載  |  免費(fèi)
  3. 2PC5502負(fù)載均流控制電路數(shù)據(jù)手冊
  4. 1.63 MB   |  23次下載  |  免費(fèi)
  5. 3NB-IoT芯片廠商的資料說明
  6. 0.31 MB   |  22次下載  |  1 積分
  7. 4UWB653Pro USB口測距通信定位模塊規(guī)格書
  8. 838.47 KB  |  5次下載  |  免費(fèi)
  9. 5蘇泊爾DCL6907(即CHK-S007)單芯片電磁爐原理圖資料
  10. 0.04 MB   |  4次下載  |  1 積分
  11. 6蘇泊爾DCL6909(即CHK-S009)單芯片電磁爐原理圖資料
  12. 0.08 MB   |  2次下載  |  1 積分
  13. 7100W準(zhǔn)諧振反激式恒流電源電路圖資料
  14. 0.09 MB   |  2次下載  |  1 積分
  15. 8FS8025B USB的PD和OC快充協(xié)議電壓誘騙控制器IC技術(shù)手冊
  16. 1.81 MB   |  1次下載  |  免費(fèi)

總榜

  1. 1matlab軟件下載入口
  2. 未知  |  935137次下載  |  10 積分
  3. 2開源硬件-PMP21529.1-4 開關(guān)降壓/升壓雙向直流/直流轉(zhuǎn)換器 PCB layout 設(shè)計
  4. 1.48MB  |  420064次下載  |  10 積分
  5. 3Altium DXP2002下載入口
  6. 未知  |  233089次下載  |  10 積分
  7. 4電路仿真軟件multisim 10.0免費(fèi)下載
  8. 340992  |  191439次下載  |  10 積分
  9. 5十天學(xué)會AVR單片機(jī)與C語言視頻教程 下載
  10. 158M  |  183353次下載  |  10 積分
  11. 6labview8.5下載
  12. 未知  |  81602次下載  |  10 積分
  13. 7Keil工具M(jìn)DK-Arm免費(fèi)下載
  14. 0.02 MB  |  73822次下載  |  10 積分
  15. 8LabVIEW 8.6下載
  16. 未知  |  65991次下載  |  10 積分