多头自注意力机制详解:手把手计算现代AI的基石

1. 引言

在现代人工智能领域,多头自注意力机制(Multi-Headed Self Attention, MHSA)可以说是最重要的架构范式之一。它是Transformer模型的核心组件,而Transformer又是当前最先进的大型语言模型的基础架构。本文将深入浅出地解析多头自注意力机制的工作原理,通过手动计算的方式,让读者对其内部运作有一个直观而全面的理解。

2. 多头自注意力机制的背景

在深入MHSA之前,我们先简要回顾一下自然语言处理(NLP)领域的相关发展历程。早期的NLP模型主要依赖于循环神经网络(RNN)和长短期记忆网络(LSTM)等序列模型。这些模型虽然能够处理序列数据,但在处理长序列时存在长期依赖问题。

2017年,Google提出了Transformer模型,其核心就是多头自注意力机制。MHSA能够并行处理输入序列,捕捉序列中的长距离依赖关系,大大提高了模型的性能和效率。自此,MHSA成为了NLP领域的主流技术,被广泛应用于各种大型语言模型中。

3. 多头自注意力机制的工作原理

让我们通过一个具体的例子,step by step地计算多头自注意力机制的输出。我们将遵循以下步骤:

  1. 定义输入
  2. 定义可学习参数
  3. 计算查询(Query)、键(Key)和值(Value)
  4. 划分多个注意力头
  5. 计算Z矩阵
  6. 掩码操作
  7. 计算注意力矩阵
  8. 计算注意力头的输出
  9. 拼接多个注意力头的输出

3.1 定义输入

MHSA可以应用于各种类型的数据,但通常情况下,输入是一个向量序列。在自然语言处理中,这通常是词嵌入(word embedding)与位置编码(positional encoding)的组合。

假设我们有一个简单的输入序列,包含3个词,每个词用4维向量表示:

Input = [
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12]
]

这个4×3的矩阵代表了我们的输入序列。

3.2 定义可学习参数

MHSA主要学习三个权重矩阵,用于构造”查询”(Query)、”键”(Key)和”值”(Value)。在本例中,我们假设模型已经学习到了以下权重矩阵:

W_Q = [
    [0.1, 0.2],
    [0.3, 0.4],
    [0.5, 0.6],
    [0.7, 0.8]
]

W_K = [
    [0.1, 0.2],
    [0.3, 0.4],
    [0.5, 0.6],
    [0.7, 0.8]
]

W_V = [
    [0.1, 0.2],
    [0.3, 0.4],
    [0.5, 0.6],
    [0.7, 0.8]
]

这些4×2的矩阵代表了模型的可学习参数。

3.3 计算查询、键和值

接下来,我们将输入与权重矩阵相乘,得到查询、键和值:

Query = Input * W_Q
Key = Input * W_K
Value = Input * W_V

让我们计算Query:

Query = [
    [1*0.1 + 2*0.3 + 3*0.5 + 4*0.7, 1*0.2 + 2*0.4 + 3*0.6 + 4*0.8],
    [5*0.1 + 6*0.3 + 7*0.5 + 8*0.7, 5*0.2 + 6*0.4 + 7*0.6 + 8*0.8],
    [9*0.1 + 10*0.3 + 11*0.5 + 12*0.7, 9*0.2 + 10*0.4 + 11*0.6 + 12*0.8]
]

Query = [
    [5.0, 6.0],
    [13.0, 15.0],
    [21.0, 24.0]
]

同理可得Key和Value:

Key = [
    [5.0, 6.0],
    [13.0, 15.0],
    [21.0, 24.0]
]

Value = [
    [5.0, 6.0],
    [13.0, 15.0],
    [21.0, 24.0]
]

3.4 划分多个注意力头

多头自注意力机制的”多头”体现在这一步。我们将Query、Key和Value划分为多个子矩阵,每个子矩阵对应一个注意力头。在本例中,我们使用两个注意力头:

Query_1 = [
    [5.0],
    [13.0],
    [21.0]
]

Query_2 = [
    [6.0],
    [15.0],
    [24.0]
]

Key_1 = [
    [5.0],
    [13.0],
    [21.0]
]

Key_2 = [
    [6.0],
    [15.0],
    [24.0]
]

Value_1 = [
    [5.0],
    [13.0],
    [21.0]
]

Value_2 = [
    [6.0],
    [15.0],
    [24.0]
]

这样,我们就得到了两组Query、Key和Value,分别用于两个注意力头的计算。

3.5 计算Z矩阵

接下来,我们需要计算Z矩阵,这是构造注意力矩阵的中间步骤。Z矩阵由Query和Key的矩阵乘法得到。我们以第一个注意力头为例:

Z_1 = Query_1 * Key_1^T

Z_1 = [
    [5.0],    [5.0, 13.0, 21.0]
    [13.0], *
    [21.0]
]

Z_1 = [
    [5.0*5.0, 5.0*13.0, 5.0*21.0],
    [13.0*5.0, 13.0*13.0, 13.0*21.0],
    [21.0*5.0, 21.0*13.0, 21.0*21.0]
]

Z_1 = [
    [25, 65, 105],
    [65, 169, 273],
    [105, 273, 441]
]

为了防止Z矩阵的值随着序列长度的增加而过大,我们通常会将Z矩阵除以序列长度的平方根。在本例中,序列长度为3,所以我们将Z_1除以$\sqrt{3}$:

Z_1 = [
    [14.43, 37.53, 60.62],
    [37.53, 97.58, 157.62],
    [60.62, 157.62, 254.62]
]

同理可得Z_2。

3.6 掩码操作

在某些应用场景中,如语言模型预测下一个词时,我们需要进行掩码操作,以确保模型在预测时不会”看到”未来的信息。这通常通过将Z矩阵中的某些位置设置为负无穷来实现。在本例中,我们假设不需要掩码操作。

3.7 计算注意力矩阵

注意力矩阵是通过对Z矩阵的每一行进行softmax运算得到的。softmax函数的定义如下:

$softmax(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$

让我们以Z_1的第一行为例计算softmax:

row = [14.43, 37.53, 60.62]
e_row = [1,850,752, 20,215,372,293, 1,800,537,936,918]
sum_e_row = 1,822,604,060,963

softmax(row) = [
    1,850,752 / 1,822,604,060,963,
    20,215,372,293 / 1,822,604,060,963,
    1,800,537,936,918 / 1,822,604,060,963
]

softmax(row) ≈ [0.001, 0.011, 0.988]

对Z_1的每一行都进行这样的计算,我们就得到了注意力矩阵Attention_1:

Attention_1 ≈ [
    [0.001, 0.011, 0.988],
    [0.000, 0.000, 1.000],
    [0.000, 0.000, 1.000]
]

同理可得Attention_2。

3.8 计算注意力头的输出

得到注意力矩阵后,我们将其与Value相乘,得到每个注意力头的输出:

Output_1 = Attention_1 * Value_1

Output_1 ≈ [
    [0.001*5.0 + 0.011*13.0 + 0.988*21.0],
    [0.000*5.0 + 0.000*13.0 + 1.000*21.0],
    [0.000*5.0 + 0.000*13.0 + 1.000*21.0]
]

Output_1 ≈ [
    [20.86],
    [21.00],
    [21.00]
]

同理可得Output_2。

3.9 拼接多个注意力头的输出

最后,我们将所有注意力头的输出拼接起来,得到多头自注意力机制的最终输出:

Final_Output = [Output_1 | Output_2]

Final_Output ≈ [
    [20.86, 24.00],
    [21.00, 24.00],
    [21.00, 24.00]
]

这个3×2的矩阵就是多头自注意力机制的输出结果。

4. 多头自注意力机制的优势

通过上述计算过程,我们可以看出多头自注意力机制具有以下优势:

  1. 并行计算: MHSA可以并行处理输入序列中的所有元素,大大提高了计算效率。
  2. 捕捉多种关系: 通过使用多个注意力头,模型可以同时关注输入序列中的不同特征和关系。
  3. 长距离依赖: MHSA可以有效捕捉序列中的长距离依赖关系,克服了RNN等传统模型的局限性。
  4. 灵活性: MHSA可以应用于各种类型的序列数据,不仅限于自然语言处理。
  5. 可解释性: 注意力权重可以提供模型决策过程的一定解释性,有助于理解模型的工作原理。

5. 多头自注意力机制的应用

MHSA在自然语言处理领域有广泛的应用,包括但不限于:

  1. 机器翻译: Transformer模型在机器翻译任务中取得了突破性的成果。
  2. 文本生成: GPT系列模型使用了基于MHSA的架构,能够生成高质量的文本。
  3. 问答系统: BERT等模型在问答任务中表现出色,为智能问答系统提供了强大的支持。
  4. 文本分类: MHSA可以有效捕捉文本的语义特征,提高分类准确率。
  5. 语音识别: 在语音识别任务中,MHSA也展现出了优秀的性能。
  6. 图像处理: 虽然最初设计用于NLP任务,但MHSA也被成功应用于计算机视觉领域。

6. 结论

多头自注意力机制是现代人工智能,特别是自然语言处理领域的核心技术之一。通过本文的详细计算过程,我们深入了解了MHSA的工作原理。尽管实际应用中的计算规模要大得多,但基本原理是相同的。

理解MHSA的工作原理对于深入学习和应用先进的AI技术至关重要。随着技术的不断发展,我们可以期待MHSA在更多领域发挥重要作用,推动人工智能技术的进步。

参考文献

  1. Warfield, D. (2024). Multi-Headed Self Attention — By Hand. Intuitively and Exhaustively Explained. https://iaee.substack.com/p/multi-headed-self-attention-by-hand
  2. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.

发表评论