Post

RQ-VAE方法详解

RQ-VAE方法详解

RQ-VAE(Residual Quantized Variational Autoencoder)方法详解

一、 方法概述

Residual-Quantized Variational AutoEncoder(RQ-VAE)是一种结合了残差连接(Residual Connection)和量化技术(Quantization)的变分自编码器(VAE)。 它旨在通过引入这些技术,提高模型的生成效率和效果,同时减少潜在空间的维度,使模型更加紧凑和高效。

二、 具体步骤及作用

步骤1:编码器处理输入

  • 操作:输入数据x通过编码器E生成连续潜在向量z = E(x)
  • 作用:提取高层特征表示,维度通常低于输入数据
  • 数学表达:\(z ∈ R^{H×W×C}\)

步骤2:多阶段残差量化

循环执行N次(N=量化层数):

  1. 当前层量化
    • 查找码本B_n中最接近的向量: \(q_n = argmin_{b∈B_n} ||r_{n-1} - b||²\)
    • 作用:离散化当前残差信号
  2. 残差计算
    • \(r_n = r_{n-1} - q_n\)(初始化r_0 = z)
    • 作用:保留未被当前层捕获的细节信息
  3. 结果聚合
    • 累积量化结果:\(Q = concat(q_1, q_2, ..., q_n)\)
    • 作用:构建分层离散表示

步骤3:解码器重建

  • 操作:聚合结果Q通过解码器D生成重建数据x’ = D(Q)
  • 作用:将离散表示映射回数据空间
  • 损失函数

    \[L = ||x - x'||² + ∑_{n=1}^N ||sg[r_{n-1}] - q_n||² + β∑_{n=1}^N ||r_{n-1} - sg[q_n]||²\]

三、核心优势

特点说明
渐进式量化分层捕获不同粒度特征
码本高效利用每层码本专注特定残差
低重建误差残差机制保留细节信息
可扩展性通过增加层数提升质量

四、实现注意事项

  1. 码本设计:建议不同层使用独立码本,尺寸逐层递减(如[512,256,128])
  2. 梯度传导:采用Straight-Through Estimator处理量化不可导问题
  3. 残差归一化:推荐对每层残差进行Layer Normalization
  4. 渐进训练:可先训练单层,逐步添加新层加速收敛

五、VAE代码实现

为了实现RQ-VAE,我们需要首先了解VAE的实现

VAE主要包含三个部分:编码器,重参数化采样器和解码器

1. 编码器

1
2
3
4
5
6
7
8
9
10
11
12
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x)) # 非线性映射到隐藏层
        mu = self.fc_mu(h) # 映射到潜在空间均值
        logvar = self.fc_logvar(h) # 映射到潜在空间方差log
        return mu, logvar # 返回均值和方差log

2. 重参数化采样器

1
2
3
4
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar) # 得到标准差
    eps = torch.randn_like(std) # 随意采样得到标准正态分布 
    return mu + eps * std # 根据均值和标准差,恢复潜在空间的对应的正态分布  

3. 解码器

1
2
3
4
5
6
7
8
9
10
11
# 定义解码器
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        x_recon = torch.sigmoid(self.fc2(h)) # 经过两次非线性映射,重建原始输入
        return x_recon

4.完整的VAE模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
        
    def reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar) # 得到标准差
        eps = torch.randn_like(std) # 随意采样得到标准正态分布 
        return mu + eps * std # 根据均值和标准差,恢复潜在空间的对应的正态分布   

    def forward(self, x):
        mu, logvar = self.encoder(x) # 编码器decode得到均值和方差log
        z = self.reparameterize(mu, logvar) # 重参数化采样器,得到潜在空间的z
        x_recon = self.decoder(z) # 解码器decode得到重建的输入x_recon
        return x_recon, mu, logvar

5. 损失函数

1
2
3
4
5
# 定义损失函数,包括重构损失和 KL 散度损失。
def vae_loss(x_recon, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss
This post is licensed under CC BY 4.0 by the author.

Trending Tags