众所周知,Bert只能处理最长512长度的文本,那么如何处理超长文本呢?

我们可以使用nezha,因为它是基于相对位置编码的。

我们还可以将Bert的绝对位置编码处理一下,使其能够处理超长文本。

思路如下:

我们将超长文本映射到多层(k为文本索引),通过位置编码pos(i)和pos(j)来确定k。

通过一个[0,1]的超参 α 来定义i和j的比例( α ≠0.5,否则无法区分(i,j)和(j,i))。

因此我们可以将新的位置编码设为:

其中,u为embedding的原始数据(暂时这么理解),q为新的位置编码,(i-1)*n+j=k。

到这里比较好理解,但是我们还有个问题:

由于我们希望在长度不足512的时候,u和原始位置编码(<512)一致的,也就是说,当i=1时,q(新的embedding)和p(原始512的embedding)是相等的。因此我们可以算出u为:

因此,当我们取u为上式时,当长度小于512和原始embedding一致。当长度大于512时,可以通过变换获得新的embedding。

具体代码实现为:

首先变换p为u:

embeddings = self.embeddings - alpha * self.embeddings[:1]
embeddings = embeddings / (1 - alpha)

随后通过u计算q:

embeddings_x = K.gather(embeddings, position_ids // self.input_dim)
embeddings_y = K.gather(embeddings, position_ids % self.input_dim)
embeddings = alpha * embeddings_x + (1 - alpha) * embeddings_y

通过实验表明, α 取0.4时效果最好。

而实验效果表明,在很小的代价下,模型就能够理解这种变换,达到与之前(长度为512)相近的结果。

当然,代价也是有的,由于attention的长度增加,attention又是O(n2)的空间复杂度,因此显存占用也会有一些增加。

完整的代码:

alpha = 0.4 if self.hierarchical is True else self.hierarchical
embeddings = self.embeddings - alpha * self.embeddings[:1]
embeddings = embeddings / (1 - alpha)
embeddings_x = K.gather(embeddings, position_ids // self.input_dim)
embeddings_y = K.gather(embeddings, position_ids % self.input_dim)
embeddings = alpha * embeddings_x + (1 - alpha) * embeddings_y

感谢苏建林大佬的博客,本文为阅读原文时的笔记。


0 条评论

发表回复

Avatar placeholder

您的电子邮箱地址不会被公开。 必填项已用 * 标注