本篇文章我们将拆分Bert,细究Bert的结构以及每一层的参数个数
我们以bert-base为例(768维):
普通bert:
bert的模型如下(省略多层):
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Input-Token (InputLayer) [(None, None)] 0
__________________________________________________________________________________________________
Input-Segment (InputLayer) [(None, None)] 0
__________________________________________________________________________________________________
Embedding-Token (Embedding) multiple 16226304 Input-Token[0][0]
MLM-Norm[0][0]
__________________________________________________________________________________________________
Embedding-Segment (Embedding) (None, None, 768) 1536 Input-Segment[0][0]
__________________________________________________________________________________________________
Embedding-Token-Segment (Add) (None, None, 768) 0 Embedding-Token[0][0]
Embedding-Segment[0][0]
__________________________________________________________________________________________________
Embedding-Position (PositionEmb (None, None, 768) 393216 Embedding-Token-Segment[0][0]
__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, None, 768) 1536 Embedding-Position[0][0]
__________________________________________________________________________________________________
Embedding-Dropout (Dropout) (None, None, 768) 0 Embedding-Norm[0][0]
__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768) 2362368 Embedding-Dropout[0][0]
Embedding-Dropout[0][0]
Embedding-Dropout[0][0]
__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768) 0 Transformer-0-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768) 0 Embedding-Dropout[0][0]
Transformer-0-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768) 1536 Transformer-0-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-0-FeedForward (Feed (None, None, 768) 4722432 Transformer-0-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-0-FeedForward-Dropo (None, None, 768) 0 Transformer-0-FeedForward[0][0]
__________________________________________________________________________________________________
Transformer-0-FeedForward-Add ( (None, None, 768) 0 Transformer-0-MultiHeadSelfAttent
Transformer-0-FeedForward-Dropout
__________________________________________________________________________________________________
Transformer-0-FeedForward-Norm (None, None, 768) 1536 Transformer-0-FeedForward-Add[0][
__________________________________________________________________________________________________
Transformer-1-MultiHeadSelfAtte (None, None, 768) 2362368 Transformer-0-FeedForward-Norm[0]
Transformer-0-FeedForward-Norm[0]
Transformer-0-FeedForward-Norm[0]
__________________________________________________________________________________________________
Transformer-1-MultiHeadSelfAtte (None, None, 768) 0 Transformer-1-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-1-MultiHeadSelfAtte (None, None, 768) 0 Transformer-0-FeedForward-Norm[0]
Transformer-1-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-1-MultiHeadSelfAtte (None, None, 768) 1536 Transformer-1-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-1-FeedForward (Feed (None, None, 768) 4722432 Transformer-1-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-1-FeedForward-Dropo (None, None, 768) 0 Transformer-1-FeedForward[0][0]
__________________________________________________________________________________________________
Transformer-1-FeedForward-Add ( (None, None, 768) 0 Transformer-1-MultiHeadSelfAttent
Transformer-1-FeedForward-Dropout
__________________________________________________________________________________________________
Transformer-1-FeedForward-Norm (None, None, 768) 1536 Transformer-1-FeedForward-Add[0][
__________________________________________________________________________________________________
下面我们简单剖析一下各部分参数:
首先是输入:
embedding部分,bert使用了embedding、token type(用来区分两个句子)和position embedding三部分。
embedding就是 (这里以词典大小21128为例) :
voab size * embedding size = 21128*768=16226304。
__________________________________________________________________________________________________
Embedding-Token (Embedding) multiple 16226304 Input-Token[0][0]
MLM-Norm[0][0]
token type:
使用0和1标记句子(比如NSP任务时区分两个句子):
768*2=1536。
__________________________________________________________________________________________________
Embedding-Segment (Embedding) (None, None, 768) 1536 Input-Segment[0][0]
position embedding:
max length * embedding size = 512*768=393216
__________________________________________________________________________________________________
Embedding-Position (PositionEmb (None, None, 768) 393216 Embedding-Token-Segment[0][0]
_
然后Bert在embedding部分还有一个layer Normalization,因此还要有768*2个参数( α 和 β )
__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, None, 768) 1536 Embedding-Position[0][0]
搞懂了embedding的参数,下面就是Transformer的参数,为了简便,这里只介绍一层。
首先是多头注意力:
bert base使用了12头注意力机制,QKV维度为64维度,同时最后还需要一个O矩阵,将12头注意力结合。
因此总参数就是: embedding size* head nub * qkv size * len(qkv)[三个映射矩阵] + (head nub* qkv size)* embedding size[多头结果拼接后处理] + qkvo bias= 768*12*64*3 + 12*64*768+ 768*4 =2362368
(这里最后的768*4 分别为Q矩阵、K矩阵、V矩阵的偏置以及最后的O矩阵的偏置。)
__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768) 2362368 Embedding-Dropout[0][0]
Embedding-Dropout[0][0]
Embedding-Dropout[0][0]
随后是多头注意力机制的LN:768*2=1536
__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768) 1536 Transformer-0-MultiHeadSelfAttent
随后,是全连接层:
Bert中隐藏层个数采用了传统的4*input的大小,因此为4*768=3072。
因此,这部分参数为:embedding size * hidden size + bias + hidden size * embedding size + bias = 768*3072+3072+3072*768+768=4722432
__________________________________________________________________________________________________
Transformer-0-FeedForward (Feed (None, None, 768) 4722432 Transformer-0-MultiHeadSelfAttent
然后是LN:768*2=1536
__________________________________________________________________________________________________
Transformer-0-FeedForward-Norm (None, None, 768) 1536 Transformer-0-FeedForward-Add[0][
然后就是下一层Transformer 以此类推。
bert的Conditional Layer Normalization:
使用了Conditional Layer Normalization后,bert的LayerNormalizattion变为198144个。
__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, None, 768) 198144 Embedding-Position[0][0]
reshape[0][0]
由于 β、γ 没有任何变化,还是1536个参数,我们可以分析一下这多出来的196608是从哪来的。
由于我们需要对 β、γ 进行相同的变换,因此参数个数也是相同的,我们分析的参数格式可以进一步缩小为98304个。
而我们之前提到,需要将c的128维度升到768维,如果不考虑bias偏置,只做矩阵变换(没有偏置的单层神经网络实际上就是矩阵变换),恰好是768*128=98304个参数。
详见 Conditional Layer Normalization。
bert的mlm任务:
__________________________________________________________________________________________________
Transformer-11-FeedForward-Norm (None, None, 768) 1536 Transformer-11-FeedForward-Add[0]
__________________________________________________________________________________________________
MLM-Dense (Dense) (None, None, 768) 590592 Transformer-11-FeedForward-Norm[0
__________________________________________________________________________________________________
MLM-Norm (LayerNormalization) (None, None, 768) 1536 MLM-Dense[0][0]
__________________________________________________________________________________________________
MLM-Bias (BiasAdd) (None, None, 21128) 21128 Embedding-Token[1][0]
__________________________________________________________________________________________________
MLM-Activation (Activation) (None, None, 21128) 0 MLM-Bias[0][0]
__________________________________________________________________________________________________
cross_entropy (CrossEntropy) (None, None, 21128) 0 Input-Token[0][0]
MLM-Activation[0][0]
==================================================================================================
这里主要多了MLM-Dense、MLM-Norm 、MLM-Bias几处。
MLM-Dense为768*768+768=590592。
其他两个就没啥可说的了。
bert的mlm任务实现方式
0 条评论