本篇文章我们将拆分Bert,细究Bert的结构以及每一层的参数个数

我们以bert-base为例(768维):

普通bert:

bert的模型如下(省略多层):

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
Input-Token (InputLayer)        [(None, None)]       0
__________________________________________________________________________________________________
Input-Segment (InputLayer)      [(None, None)]       0
__________________________________________________________________________________________________
Embedding-Token (Embedding)     multiple             16226304    Input-Token[0][0]
                                                                 MLM-Norm[0][0]
__________________________________________________________________________________________________
Embedding-Segment (Embedding)   (None, None, 768)    1536        Input-Segment[0][0]
__________________________________________________________________________________________________
Embedding-Token-Segment (Add)   (None, None, 768)    0           Embedding-Token[0][0]
                                                                 Embedding-Segment[0][0]
__________________________________________________________________________________________________
Embedding-Position (PositionEmb (None, None, 768)    393216      Embedding-Token-Segment[0][0]
__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, None, 768)    1536        Embedding-Position[0][0]
__________________________________________________________________________________________________
Embedding-Dropout (Dropout)     (None, None, 768)    0           Embedding-Norm[0][0]
__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768)    2362368     Embedding-Dropout[0][0]
                                                                 Embedding-Dropout[0][0]
                                                                 Embedding-Dropout[0][0]
__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768)    0           Transformer-0-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768)    0           Embedding-Dropout[0][0]
                                                                 Transformer-0-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-0-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-0-FeedForward (Feed (None, None, 768)    4722432     Transformer-0-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-0-FeedForward-Dropo (None, None, 768)    0           Transformer-0-FeedForward[0][0]
__________________________________________________________________________________________________
Transformer-0-FeedForward-Add ( (None, None, 768)    0           Transformer-0-MultiHeadSelfAttent
                                                                 Transformer-0-FeedForward-Dropout
__________________________________________________________________________________________________
Transformer-0-FeedForward-Norm  (None, None, 768)    1536        Transformer-0-FeedForward-Add[0][
__________________________________________________________________________________________________
Transformer-1-MultiHeadSelfAtte (None, None, 768)    2362368     Transformer-0-FeedForward-Norm[0]
                                                                 Transformer-0-FeedForward-Norm[0]
                                                                 Transformer-0-FeedForward-Norm[0]
__________________________________________________________________________________________________
Transformer-1-MultiHeadSelfAtte (None, None, 768)    0           Transformer-1-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-1-MultiHeadSelfAtte (None, None, 768)    0           Transformer-0-FeedForward-Norm[0]
                                                                 Transformer-1-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-1-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-1-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-1-FeedForward (Feed (None, None, 768)    4722432     Transformer-1-MultiHeadSelfAttent
__________________________________________________________________________________________________
Transformer-1-FeedForward-Dropo (None, None, 768)    0           Transformer-1-FeedForward[0][0]
__________________________________________________________________________________________________
Transformer-1-FeedForward-Add ( (None, None, 768)    0           Transformer-1-MultiHeadSelfAttent
                                                                 Transformer-1-FeedForward-Dropout
__________________________________________________________________________________________________
Transformer-1-FeedForward-Norm  (None, None, 768)    1536        Transformer-1-FeedForward-Add[0][
__________________________________________________________________________________________________

下面我们简单剖析一下各部分参数:

首先是输入:

embedding部分,bert使用了embedding、token type(用来区分两个句子)和position embedding三部分。

embedding就是 (这里以词典大小21128为例) :

voab size * embedding size = 21128*768=16226304。

__________________________________________________________________________________________________
Embedding-Token (Embedding)     multiple             16226304    Input-Token[0][0]
                                                                 MLM-Norm[0][0]   

token type:

使用0和1标记句子(比如NSP任务时区分两个句子):

768*2=1536。

__________________________________________________________________________________________________
Embedding-Segment (Embedding)   (None, None, 768)    1536        Input-Segment[0][0]

position embedding:

max length * embedding size = 512*768=393216

__________________________________________________________________________________________________
Embedding-Position (PositionEmb (None, None, 768)    393216      Embedding-Token-Segment[0][0]
_

然后Bert在embedding部分还有一个layer Normalization,因此还要有768*2个参数( α 和 β )

__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, None, 768)    1536        Embedding-Position[0][0]               

搞懂了embedding的参数,下面就是Transformer的参数,为了简便,这里只介绍一层。

首先是多头注意力:

bert base使用了12头注意力机制,QKV维度为64维度,同时最后还需要一个O矩阵,将12头注意力结合。

因此总参数就是: embedding size* head nub * qkv size * len(qkv)[三个映射矩阵] + (head nub* qkv size)* embedding size[多头结果拼接后处理] + qkvo bias= 768*12*64*3 + 12*64*768+ 768*4 =2362368

(这里最后的768*4 分别为Q矩阵、K矩阵、V矩阵的偏置以及最后的O矩阵的偏置。)

__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768)    2362368     Embedding-Dropout[0][0]
                                                                 Embedding-Dropout[0][0]
                                                                 Embedding-Dropout[0][0]

随后是多头注意力机制的LN:768*2=1536

__________________________________________________________________________________________________
Transformer-0-MultiHeadSelfAtte (None, None, 768)    1536        Transformer-0-MultiHeadSelfAttent

随后,是全连接层:

Bert中隐藏层个数采用了传统的4*input的大小,因此为4*768=3072。

因此,这部分参数为:embedding size * hidden size + bias + hidden size * embedding size + bias = 768*3072+3072+3072*768+768=4722432

__________________________________________________________________________________________________
Transformer-0-FeedForward (Feed (None, None, 768)    4722432     Transformer-0-MultiHeadSelfAttent

然后是LN:768*2=1536

__________________________________________________________________________________________________
Transformer-0-FeedForward-Norm  (None, None, 768)    1536        Transformer-0-FeedForward-Add[0][  

然后就是下一层Transformer 以此类推。

bert的Conditional Layer Normalization:

使用了Conditional Layer Normalization后,bert的LayerNormalizattion变为198144个。

__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, None, 768)    198144      Embedding-Position[0][0]
                                                                 reshape[0][0]

由于 β、γ 没有任何变化,还是1536个参数,我们可以分析一下这多出来的196608是从哪来的。

由于我们需要对 β、γ 进行相同的变换,因此参数个数也是相同的,我们分析的参数格式可以进一步缩小为98304个。

而我们之前提到,需要将c的128维度升到768维,如果不考虑bias偏置,只做矩阵变换(没有偏置的单层神经网络实际上就是矩阵变换),恰好是768*128=98304个参数。

详见 Conditional Layer Normalization

bert的mlm任务:

__________________________________________________________________________________________________
Transformer-11-FeedForward-Norm (None, None, 768)    1536        Transformer-11-FeedForward-Add[0]
__________________________________________________________________________________________________
MLM-Dense (Dense)               (None, None, 768)    590592      Transformer-11-FeedForward-Norm[0
__________________________________________________________________________________________________
MLM-Norm (LayerNormalization)   (None, None, 768)    1536        MLM-Dense[0][0]
__________________________________________________________________________________________________
MLM-Bias (BiasAdd)              (None, None, 21128)  21128       Embedding-Token[1][0]
__________________________________________________________________________________________________
MLM-Activation (Activation)     (None, None, 21128)  0           MLM-Bias[0][0]
__________________________________________________________________________________________________
cross_entropy (CrossEntropy)    (None, None, 21128)  0           Input-Token[0][0]
                                                                 MLM-Activation[0][0]
==================================================================================================     

这里主要多了MLM-Dense、MLM-Norm 、MLM-Bias几处。

MLM-Dense为768*768+768=590592。

其他两个就没啥可说的了。

bert的mlm任务实现方式


0 条评论

发表回复

Avatar placeholder

您的电子邮箱地址不会被公开。 必填项已用*标注