D2L学习笔记-注意力机制

  1. 1. 注意力提示
  2. 2. 注意力汇聚:Nadaraya-Watson核回归
  3. 3. 注意力评分函数
  4. 4. Bahdanau注意力(使用注意力的seq2seq)
  5. 5. 多头注意力
  6. 6. 自注意力和位置编码
  7. 7. Transformer

李沐动手学深度学习(PyTorch)课程学习笔记第十章:注意力机制。

1. 注意力提示

注意力机制(Attention Mechanism)是人们在机器学习模型中嵌入的一种特殊结构,用来自动学习和计算输入数据对输出数据的贡献大小。

非自主性提示是基于环境中物体的突出性和易见性。想象一下,假如我们面前有五个物品:一份报纸、一篇研究论文、一杯咖啡、一本笔记本和一本书,所有纸制品都是黑白印刷的,但咖啡杯是红色的。换句话说,这个咖啡杯在这种视觉环境中是突出和显眼的,不由自主地引起人们的注意,所以我们会把视力最敏锐的地方放到咖啡上。喝咖啡后,我们会变得兴奋并想读书,所以转过头,重新聚焦眼睛,然后看看书,与咖啡杯是由于突出性导致的选择不同,此时选择书是受到了认知和意识的控制,因此注意力在基于自主性提示去辅助选择时将更为谨慎。受试者的主观意愿推动,选择的力量也就更强大。自主性的与非自主性的注意力提示解释了人类的注意力的方式,下面来看看如何通过这两种注意力提示,用神经网络来设计注意力机制的框架。

首先,考虑一个相对简单的状况,即只使用非自主性提示。要想将选择偏向于感官输入,则可以简单地使用参数化的全连接层,甚至是非参数化的最大汇聚层或平均汇聚层。

因此,“是否包含自主性提示”将注意力机制与全连接层或汇聚层区别开来。在注意力机制的背景下,自主性提示被称为查询(query)。给定任何查询,注意力机制通过注意力汇聚(attention pooling)将选择引导至感官输入(sensory inputs,例如中间特征表示)。在注意力机制中,这些感官输入被称为(value)。更通俗的解释,每个值都与一个(key)配对,这可以想象为感官输入的非自主提示。可以通过设计注意力汇聚的方式,便于给定的查询(自主性提示)与键(非自主性提示)进行匹配,这将引导得出最匹配的值(感官输入)。

平均汇聚层可以被视为输入的加权平均值,其中各输入的权重是一样的。实际上,注意力汇聚得到的是加权平均的总和值,其中权重是在给定的查询和不同的键之间计算得出的。为了可视化注意力权重,需要定义一个 show_heatmaps 函数,其输入 matrices 的形状是 (要显示的行数, 要显示的列数, 查询的数目, 键的数目)。下面使用一个简单的例子进行演示,在本例子中,仅当查询和键相同时,注意力权重为1,否则为0:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import matplotlib.pyplot as plt
from d2l import torch as d2l

def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(8, 6), cmap='Reds'):
"""显示矩阵热图"""
d2l.use_svg_display()
num_rows, num_cols = matrices.shape[0], matrices.shape[1]
fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize, sharex=True, sharey=True, squeeze=False)
for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
if i == num_rows - 1:
ax.set_xlabel(xlabel)
if j == 0:
ax.set_ylabel(ylabel)
if titles:
ax.set_title(titles[j])
fig.colorbar(pcm, ax=axes, shrink=0.6)
plt.show()

attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')

此外可以使用 Plotly 绘制热力图:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import plotly.graph_objects as go

def show_plotly_heatmaps(x=None, y=None, z=None, colorscale='reds', width=600, height=600, title=None, xtitle=None, ytitle=None):
heatmap_fig = go.Figure(
data=[
go.Heatmap(x=x, y=y, z=z, colorscale=colorscale)
]
)

heatmap_fig.update_layout(
autosize=False, width=width, height=height,
title=title,
xaxis=dict(title=xtitle), yaxis=dict(title=ytitle),
showlegend=True
)

heatmap_fig.show()

attention_weights = torch.eye(10)
show_plotly_heatmaps(z=attention_weights, title='Attention Weights Heatmap', xtitle='Keys', ytitle='Queries')

2. 注意力汇聚:Nadaraya-Watson核回归

上节介绍了框架下的注意力机制的主要成分:查询(自主提示)和键(非自主提示)之间的交互形成了注意力汇聚;注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。本节将介绍注意力汇聚的更多细节,以便从宏观上了解注意力机制在实践中的运作方式。具体来说,1964年提出的 Nadaraya-Watson 核回归模型是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习,其理论介绍可见:注意力汇聚:Nadaraya-Watson核回归

首先生成一个非线性函数的人工数据集:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torch import nn
from d2l import torch as d2l

n_train = 50 # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本

def f(x):
return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # 训练样本的输出
x_test = torch.arange(0, 5, 0.1) # 测试样本
y_truth = f(x_test) # 测试样本的真实输出
n_test = len(x_test) # 测试样本数
print(n_test) # 50

函数 plot_kernel_reg 将绘制所有的训练样本(样本由圆圈表示),不带噪声项的真实数据生成函数(标记为 Truth),以及学习得到的预测函数(标记为 Pred)。先使用最简单的估计器来解决回归问题,即基于平均汇聚来计算所有训练样本输出值的平均值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def plot_kernel_reg(y_hat):
fig = go.Figure(
data=[
go.Scatter(x=x_test, y=y_truth, mode='lines', name='Truth'),
go.Scatter(x=x_test, y=y_hat, mode='lines', name='Pred'),
go.Scatter(x=x_train, y=y_train, mode='markers', name='Sample', opacity=0.5)
]
)
fig.update_layout(
autosize=False, width=1200, height=800,
xaxis=dict(title='x'), yaxis=dict(title='y'),
showlegend=True
)
fig.show()
# d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
# xlim=[0, 5], ylim=[-1, 5])
# d2l.plt.plot(x_train, y_train, 'o', alpha=0.5)

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

显然,平均汇聚忽略了输入,Nadaraya-Watson 核回归根据输入的位置对输出进行加权,是一个非参数模型。接下来,我们将基于这个非参数的注意力汇聚模型来绘制预测结果。从绘制的结果会发现新的模型预测线是平滑的,并且比平均汇聚的预测更接近真实。

1
2
3
4
5
6
7
8
9
# X_repeat.shape: (n_test, n_train)
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键,attention_weights.shape: (n_test, n_train)
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train) # y_hat.shape: torch.Size([50])
plot_kernel_reg(y_hat)

非参数的 Nadaraya-Watson 核回归具有一致性(consistency)的优点:如果有足够的数据,此模型会收敛到最优结果。尽管如此,我们还是可以轻松地将可学习的参数集成到注意力汇聚中。

为了更有效地计算小批量数据的注意力,我们可以利用深度学习开发框架中提供的批量矩阵乘法:

1
2
3
X = torch.ones((2, 1, 4))
Y = torch.ones((2, 4, 6))
print(torch.bmm(X, Y).shape) # torch.Size([2, 1, 6])

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值:

1
2
3
4
5
weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
print(torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1)))
# tensor([[[ 4.5000]],
# [[14.5000]]])

定义 Nadaraya-Watson 核回归的带参数版本为:

1
2
3
4
5
6
7
8
9
10
11
class NWKernelRegression(nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

def forward(self, queries, keys, values):
# queries和attention_weights的形状为: (查询个数, “键-值”对个数)
queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w)**2 / 2, dim=1)
# values的形状为: (查询个数, “键-值”对个数)
return torch.bmm(self.attention_weights.unsqueeze(1), values.unsqueeze(-1)).reshape(-1)

接下来,将训练数据集变换为键和值用于训练注意力模型。在带参数的注意力汇聚模型中,任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算,从而得到其对应的预测输出:

1
2
3
4
5
6
7
8
# X_tile.shape: (n_train, n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat(n_train, 1)
# Y_tile.shape: (n_train, n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat(n_train, 1)
# keys.shape: (n_train, n_train-1),将对角线元素筛去
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values.shape: (n_train, n_train-1),将对角线元素筛去
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降:

1
2
3
4
5
6
7
8
9
10
net = NWKernelRegression()
loss_function = nn.MSELoss(reduction='none')
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

for epoch in range(20):
optimizer.zero_grad()
loss = loss_function(net(x_train, keys, values), y_train)
loss.sum().backward()
optimizer.step()
print(f'epoch {epoch + 1}, loss {float(loss.sum()):.6f}')

训练完带参数的注意力汇聚模型后可以发现:在尝试拟合带噪声的训练数据时,预测结果绘制的线不如之前非参数模型的平滑,因为与非参数的注意力汇聚模型相比,带参数的模型加入可学习的参数后,曲线在注意力权重较大的区域变得更不平滑。

1
2
3
4
5
6
# keys.shape: (n_test, n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat(n_test, 1)
# value.shape: (n_test, n_train)
values = y_train.repeat(n_test, 1)
y_hat = net(x_test, keys, values).detach()
plot_kernel_reg(y_hat)

3. 注意力评分函数

在上一节中使用了高斯核来对查询和键之间的关系建模。高斯核的指数部分可以视为注意力评分函数(attention scoring function),简称评分函数(scoring function),然后把这个函数的输出结果输入到 Softmax 函数中进行运算。通过上述步骤,将得到与键对应的值的概率分布(即注意力权重)。最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。

选择不同的注意力评分函数会导致不同的注意力汇聚操作。本节将介绍两个流行的评分函数,稍后将用他们来实现更复杂的注意力机制。

正如上面提到的,Softmax 操作用于输出一个概率分布作为注意力权重。在某些情况下,并非所有的值都应该被纳入到注意力汇聚中。例如,为了在机器翻译中高效处理小批量数据集,某些文本序列被填充了没有意义的特殊词元。为了仅将有意义的词元作为值来获取注意力汇聚,可以指定一个有效序列长度(即词元的个数),以便在计算 Softmax 时过滤掉超出指定范围的位置。下面的 masked_softmax 函数实现了这样的掩蔽 Softmax 操作(masked softmax operation),其中任何超出有效长度的位置都被掩蔽并置为0。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import math
import torch
from torch import nn
from d2l import torch as d2l

def masked_softmax(X, valid_lens):
"""通过在最后一个轴上掩蔽元素来执行softmax操作"""
# X: 3D张量,valid_lens: 1D或2D张量
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1]) # [a, b] -> [a, a, ..., b, b, ...]
else:
valid_lens = valid_lens.reshape(-1)
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)

为了演示此函数是如何工作的,考虑由两个2*4矩阵表示的样本,这两个样本的有效长度分别为2和3。经过掩蔽 Softmax 操作,超出有效长度的值都被掩蔽为0:

1
2
3
4
5
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))
# tensor([[[0.3292, 0.6708, 0.0000, 0.0000],
# [0.5249, 0.4751, 0.0000, 0.0000]],
# [[0.3104, 0.4577, 0.2318, 0.0000],
# [0.3227, 0.3408, 0.3365, 0.0000]]])

同样,也可以使用二维张量,为矩阵样本中的每一行指定有效长度:

1
2
3
4
5
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]])))
# tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
# [0.4203, 0.2752, 0.3045, 0.0000]],
# [[0.4234, 0.5766, 0.0000, 0.0000],
# [0.2979, 0.1618, 0.2246, 0.3157]]])

接下来将介绍加性注意力与缩放点积注意力,其理论分析可见:注意力评分函数

一般来说,当查询和键是不同长度的矢量时,可以使用加性注意力作为评分函数。将查询和键连结起来后输入到一个多层感知机(MLP)中,感知机包含一个隐藏层,其隐藏单元数是一个超参数。通过使用 tanh 作为激活函数,并且禁用偏置项:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class AdditiveAttention(nn.Module):
"""加性注意力"""
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
self.w_v = nn.Linear(num_hiddens, 1, bias=False)
self.dropout = nn.Dropout(dropout)

def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys) # (2, 1, 8), (2, 10, 8)
# 维度扩展后使用广播方式进行求和
# 扩展后的queries.shape: (batch_size, 查询的个数, 1, num_hidden)
# 扩展后的key.shape: (batch_size, 1, “键-值”对的个数, num_hiddens)
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features) # (2, 1, 10, 8)
# self.w_v仅有一个输出,因此从形状中移除最后那个大小为1的维度
# scores.shape: (batch_size, 查询的个数, “键-值”对的个数)
scores = self.w_v(features).squeeze(-1) # (2, 1, 10, 1) -> (2, 1, 10)
self.attention_weights = masked_softmax(scores, valid_lens) # (2, 1, 10)
# values.shape: (batch_size, “键-值”对的个数, 值的维度)
return torch.bmm(self.dropout(self.attention_weights), values)

queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# values的小批量,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])

attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
print(attention(queries, keys, values, valid_lens))
# tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
# [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)

尽管加性注意力包含了可学习的参数,但由于本例子中每个键都是相同的,所以注意力权重是均匀的,由指定的有效长度决定:

1
2
# show_plotly_heatmaps函数在第一节中定义
show_plotly_heatmaps(z=attention.attention_weights.detach().reshape((2, 10)), height=300, xtitle='Keys', ytitle='Queries')

使用点积可以得到计算效率更高的评分函数,但是点积操作要求查询和键具有相同的长度,为了演示 DotProductAttention 类,我们使用与先前加性注意力例子中相同的键、值和有效长度。对于点积操作,我们令查询的特征维度与键的特征维度大小相同:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class DotProductAttention(nn.Module):
"""缩放点积注意力"""
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)

# queries.shape: (batch_size, 查询的个数, d)
# keys.shape: (batch_size, “键-值”对的个数, d)
# values.shape: (batch_size, “键-值”对的个数, 值的维度)
# valid_lens.shape: (batch_size,)或者(batch_size, 查询的个数)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)

queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
print(attention(queries, keys, values, valid_lens))
# tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
# [[10.0000, 11.0000, 12.0000, 13.0000]]])

与加性注意力演示相同,由于键包含的是相同的元素,而这些元素无法通过任何查询进行区分,因此获得了均匀的注意力权重。

4. Bahdanau注意力(使用注意力的seq2seq)

Bahdanau 注意力模型的原理可见:Bahdanau 注意力

下面看看如何定义 Bahdanau 注意力,实现循环神经网络编码器-解码器。其实,我们只需重新定义解码器即可。为了更方便地显示学习的注意力权重,以下 AttentionDecoder 类定义了带有注意力机制解码器的基本接口:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torch import nn
from d2l import torch as d2l
import sys
sys.path.append("..")
from util.functions import train_seq2seq, predict_seq2seq, bleu, show_plotly_heatmaps

class AttentionDecoder(d2l.Decoder):
"""带有注意力机制解码器的基本接口"""
def __init__(self, **kwargs):
super(AttentionDecoder, self).__init__(**kwargs)

@property
def attention_weights(self):
raise NotImplementedError

接下来,让我们在接下来的 Seq2SeqAttentionDecoder 类中实现带有 Bahdanau 注意力的循环神经网络解码器。首先,初始化解码器的状态,需要下面的输入:

  • 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;
  • 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;
  • 编码器有效长度(排除在注意力池中填充词元)。

在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Seq2SeqAttentionDecoder(AttentionDecoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
self.dense = nn.Linear(num_hiddens, vocab_size)

def init_state(self, enc_outputs, enc_valid_lens, *args):
# outputs.shape: (batch_size, num_steps, num_hiddens)
# hidden_state.shape: (num_layers, batch_size, num_hiddens)
outputs, hidden_state = enc_outputs
return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)

def forward(self, X, state):
# enc_outputs.shape: (batch_size, num_steps, num_hiddens)
# hidden_state.shape: (num_layers, batch_size, num_hiddens)
enc_outputs, hidden_state, enc_valid_lens = state
X = self.embedding(X).permute(1, 0, 2) # X.shape: (num_steps, batch_size, embed_size)
outputs, self._attention_weights = [], []
for x in X:
# query.shape: (batch_size, 1, num_hiddens)
query = torch.unsqueeze(hidden_state[-1], dim=1)
# context.shape: (batch_size, 1, num_hiddens)
context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
# 在特征维度上连结
x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
# 将x变形为(1, batch_size, embed_size + num_hiddens)
out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
outputs.append(out)
self._attention_weights.append(self.attention.attention_weights)
# 全连接层变换后,outputs的形状为(num_steps, batch_size, vocab_size)
outputs = self.dense(torch.cat(outputs, dim=0))
return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]

@property
def attention_weights(self):
return self._attention_weights

接下来,使用包含7个时间步的4个序列输入的小批量测试 Bahdanau 注意力解码器:

1
2
3
4
5
6
7
8
9
encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long) # (batch_size, num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
print(output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape)
# torch.Size([4, 7, 10]) 3 torch.Size([4, 7, 16]) 2 torch.Size([4, 16])

我们在这里指定超参数,实例化一个带有 Bahdanau 注意力的编码器和解码器,并对这个模型进行机器翻译训练:

1
2
3
4
5
6
7
8
9
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps, lr, num_epochs = 64, 10, 0.005, 300
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device, '../logs/Bahdanau_seq2seq_train_log')

模型训练后,我们用它将几个英语句子翻译成法语并计算它们的 BLEU 分数:

1
2
3
4
5
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
translation, dec_attention_weight_seq = predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True)
print(f'{eng} => {translation}, bleu {bleu(translation, fra, k=2):.3f}')

训练结束后,下面通过可视化注意力权重会发现每个查询都会在键值对上分配不同的权重,这说明在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中:

1
2
3
4
attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((-1, num_steps))

# 加上一个包含序列结束词元
show_plotly_heatmaps(z=attention_weights[:, :len(engs[-1].split()) + 1].cpu().detach(), xtitle='Keys', ytitle='Queries')

5. 多头注意力

在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同子空间表示(representation subspaces)可能是有益的。

为此,与其只使用单独一个注意力汇聚,我们可以用独立学习得到的 h 组不同的线性投影(linear projections)来变换查询、键和值。然后,这 h 组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 h 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为多头注意力(multihead attention)。对于 h 个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。基于这种设计,每个头都可能会关注输入的不同部分,可以表示比简单加权平均值更复杂的函数。

多头注意力模型的原理可见:多头注意力

在实现过程中通常选择缩放点积注意力作为每一个注意力头:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from torch import nn
from d2l import torch as d2l

class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

def forward(self, queries, keys, values, valid_lens):
# queries/keys/values的形状: (batch_size, 查询或者“键-值”对的个数, query_size/key_size/value_size)
# valid_lens的形状: (batch_size,)或(batch_size, 查询的个数)
# 经过变换后,输出的queries/keys/values的形状: (batch_size * num_heads, 查询或者“键-值”对的个数, num_hiddens / num_heads)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)

if valid_lens is not None:
# 在轴0,将每一项(标量或者矢量)复制num_heads次
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

# output的形状: (batch_size * num_heads, 查询的个数, num_hiddens / num_heads)
output = self.attention(queries, keys, values, valid_lens)
# output_concat的形状: (batch_size, 查询的个数, num_hiddens)
output_concat = transpose_output(output, self.num_heads)

return self.W_o(output_concat)

为了能够使多个头并行计算,上面的 MultiHeadAttention 类将使用下面定义的两个转置函数。具体来说,transpose_output 函数反转了 transpose_qkv 函数的操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
def transpose_qkv(X, num_heads):
"""为了多注意力头的并行计算而变换形状"""
# 输入X的形状: (batch_size, 查询或者“键-值”对的个数, num_hiddens)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) # (batch_size, 查询或者“键-值”对的个数, num_heads, num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3) # (batch_size, num_heads, 查询或者“键-值”对的个数, num_hiddens/num_heads)
return X.reshape(-1, X.shape[2], X.shape[3]) # (batch_size*num_heads, 查询或者“键-值”对的个数, num_hiddens/num_heads)

def transpose_output(X, num_heads):
"""逆转transpose_qkv函数的操作"""
# 输入X的形状: (batch_size*num_heads, 查询或者“键-值”对的个数, num_hiddens/num_heads)
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) # (batch_size, num_heads, 查询或者“键-值”对的个数, num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3) # (batch_size, 查询或者“键-值”对的个数, num_heads, num_hiddens/num_heads)
return X.reshape(X.shape[0], X.shape[1], -1) # (batch_size, 查询或者“键-值”对的个数, num_hiddens)

下面使用键和值相同的小例子来测试我们编写的 MultiHeadAttention 类。多头注意力输出的形状是 (batch_size, num_queries, num_hiddens)

1
2
3
4
5
6
7
8
9
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
attention.eval()

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
print(attention(X, Y, Y, valid_lens).shape) # torch.Size([2, 4, 100])

6. 自注意力和位置编码

在深度学习中,经常使用卷积神经网络(CNN)或循环神经网络(RNN)对序列进行编码。想象一下,有了注意力机制之后,我们将词元序列输入注意力池化中,以便同一组词元同时充当查询、键和值。具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。当查询、键和值来自同一组输入时被称为自注意力(self-attention),也被称为内部注意力(intra-attention)。本节将使用自注意力进行序列编码,以及如何使用序列的顺序作为补充信息。

自注意力模型的原理可见:自注意力和位置编码

下面的代码片段是基于多头注意力对一个张量完成自注意力的计算,张量的形状为 (批量大小, 时间步的数目或词元序列的长度, h),输出与输入的张量形状相同:

1
2
3
4
5
6
7
8
9
10
11
12
13
import matplotlib.pyplot as plt
import torch
from torch import nn
from d2l import torch as d2l
from util.functions import show_plotly_heatmaps

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
attention.eval()

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
print(attention(X, X, X, valid_lens).shape) # torch.Size([2, 4, 100])

在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,通过在输入表示中添加位置编码(positional encoding)来注入绝对的或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。接下来描述的是基于正弦函数和余弦函数的固定位置编码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# 创建一个足够长的P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) /\
torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)

def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)

在位置嵌入矩阵 P 中,行代表词元在序列中的位置,列代表位置编码的不同维度。从下面的例子中可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替:

1
2
3
4
5
6
7
8
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)', figsize=(8, 4),
legend=["Col %d" % d for d in torch.arange(6, 10)])
plt.show()

通过绘制热力图可以看到,位置编码通过使用三角函数在编码维度上降低频率:

1
2
P = P[0, :, :]
show_plotly_heatmaps(z=P, xtitle='Column (encoding dimension)', ytitle='Row (position)', colorscale='Blues')

7. Transformer