Post

对比attention实现

对比attention实现

Self-Attention 与 Cross-Attention 的区别及实现细节

应用场景

类型应用场景示例
Self-Attention单序列内部关系建模Transformer 编码器、文本分类
Cross-Attention两个序列之间的交互建模Transformer 解码器、机器翻译/文本生成

核心区别

特性Self-AttentionCross-Attention
输入来源Q/K/V 来自同一输入Q 来自输入 A,K/V 来自输入 B
序列关系单序列内部关系跨序列关系(如编码器-解码器交互)
典型位置Transformer 编码器Transformer 解码器

Python 实现代码 (PyTorch)

1. Self-Attention 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)  # 拆分为 Q/K/V
        
        # 线性投影 + 多头拆分
        q, k, v = [ 
            t.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            for t in qkv
        ]

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(scores, dim=-1)
        
        # 加权和 + 合并多头
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        return self.out(out)

Cross-Attention 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q = nn.Linear(embed_dim, embed_dim)  # Q 来自输入 x
        self.kv = nn.Linear(embed_dim, 2 * embed_dim)  # K/V 来自 encoder_output

    def forward(self, x, encoder_output):
        batch_size, seq_len, _ = x.shape
        
        # 生成 Q/K/V(注意来源不同)
        q = self.q(x)
        k, v = self.kv(encoder_output).chunk(2, dim=-1)

        # 多头拆分
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(scores, dim=-1)
        
        # 加权和 + 合并多头
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        return out

Attention机制实现方式全面对比

对比维度flash_attention_2flex_attentionsdpa
核心算法基于分块计算和重计算的IO优化算法动态稀疏注意力机制PyTorch内置的scaled_dot_product_attention实现
计算复杂度O(N²)理论复杂度,但通过分块减少实际计算量支持稀疏模式时可达O(N√N)O(N²)标准实现,但通过硬件加速优化
内存占用最低(不存储完整attention矩阵)中等(支持稀疏存储)较高(需存储完整attention矩阵)
硬件加速强CUDA优化支持多平台适配深度集成CUDA/cuDNN
序列长度支持最优(支持超长序列)中等(依赖稀疏模式)标准长度(受显存限制)
反向传播支持需重计算前向结果原生支持原生支持
扩展性固定分块策略支持自定义稀疏模式固定标准实现
实现复杂度高(需手工CUDA优化)中等(需定义稀疏策略)低(直接调用API)
适用场景1. 超长序列处理
2. 内存敏感场景
3. 训练场景
1. 稀疏注意力需求
2. 动态模式切换
3. 研究性场景
1. 标准Transformer
2. 推理场景
3. 快速原型开发
精度控制使用FP16/FP32混合精度原生支持自动混合精度依赖框架自动混合精度
框架依赖需要定制CUDA扩展需要特定框架支持深度集成PyTorch
典型应用案例1. LLM训练
2. 长文本处理
1. 视觉Transformer
2. 图神经网络
1. 标准BERT/GPT
2. 移动端部署

关键结论:

  1. 训练场景优先flash_attention_2在内存效率和长序列处理上表现最优
  2. 动态稀疏需求flex_attention提供最灵活的注意力模式配置
  3. 快速开发推荐sdpa凭借PyTorch深度集成实现最佳开发效率
  4. 硬件适配性sdpa > flash_attention_2 > flex_attention
  5. 内存敏感场景flash_attention_2的内存优化策略可节省30-50%显存

注:实际性能表现需结合具体硬件配置(如A100/H100对flash_attention_2有特殊优化)和任务特性

This post is licensed under CC BY 4.0 by the author.

Trending Tags