transformers的TFBertForPreTraining

这个类暂时还没找到什么可以应用的具体任务,应该属于一个基类,下游任务调用的,等看到能用的下游任务回来补全。

暂时看一下源码。

这个库应该时Bert预训练的模型,不同于TFBertModel,这个模型带了两个结构,一个用于NSP(next sentence predict)任务,另一个用于MLM(masked language model)任务。

TFBertForPreTraining

下面让我们来看看TFBertForPreTraining的源码。

源码

首先我们可以看到init中的定义

对比TFBertModel的init:

我们可以看到只是多了一个nsp和mlm。

我们再看call:

TFBertForPreTraining的(由于call中状态判断较多,比如返不返回attention state等,因此只截取部分,返回的参数前面点了个红点):

对比TFBertModel:

因此我们可以明显的看出来差别:

TFBertForPreTraining相比TFBertModel,多了mlm和nsp,分别处理的是隐藏层状态和FC的输出(TFBertMainLayer的返回详见TFBertModel)。

那么是怎么处理的呢?

继续查看mlm的源码:

先不看call,只看init

继续看 :

继续:

我们可以看到,在这里就只是一个Dense+激活+LayerNorm,而call中也只是线性的调用而已。

而上一层TFBertLMpredictionHead的call中,还有一个input embedding,继续再上一层的call就也是个简单的调用。

继续查看nsp的源码:

就只是一个dense层,最后返回2个参数。

演示

由于只是个上游任务模型,没有找到什么下游任务和数据,暂时不做演示,只给一个官方demo,以后找到任务回来补。

整代码


0 条评论

发表回复

Avatar placeholder

您的电子邮箱地址不会被公开。 必填项已用 * 标注