梯度反传直通图解

有时我们想在层的输出端放置一个阈值函数。这可能出于多种原因。其中之一是我们想将激活总结为二进制值。这种激活的二值化在自编码器中很有用。

然而,阈值化在反向传播过程中会带来问题:阈值函数的导数为零。这种梯度的缺乏导致我们的网络无法学习任何东西。为了解决这个问题,我们可以使用直通估计器 (STE:Straight Through Estimator)。

1、什么是直通估计器?

假设我们想使用以下函数将层的激活二值化:

此函数将为每个大于 0 的值返回 1,否则将返回 0。

如前所述,此函数的问题在于其梯度为零。为了解决这个问题,我们将在反向传递中使用直通估计器。

直通估计器顾名思义就是它估计函数的梯度。具体来说,它忽略阈值函数的导数,并将传入的梯度传递,就好像该函数是恒等函数一样。下图有助于更好地解释它:

你可以看到在反向传递中如何绕过阈值函数。就是这样,这就是直通式估计器的作用。它使阈值函数的梯度看起来像恒等函数的梯度。

2、直通估计器的PyTorch 实现

截至目前,PyTorch 的 API 中尚未包含 STE 的实现。因此,我们必须自己实现它。为此,我们需要创建一个 Function 类和一个 Module 类。Function 类将包含 STE 的前向和后向功能。Module 类是创建和使用 STE Function 对象的地方。我们将在我们的神经网络中使用 STE Module。

以下是 STE Function 类的实现:

class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

PyTorch 让我们可以定义具有前向和后向功能的自定义自动求导函数。这里我们为直通式估算器定义了一个自动求导函数。在前向传递中,我们希望将输入张量中的所有值从浮点转换为二进制。在后向传递中,我们希望传递传入的梯度而不对其进行修改。这是为了模仿恒等函数。不过,这里我们对传入的梯度执行 F.hardtanh 操作。此操作将梯度限制在 -1 和 1 之间。我们这样做是为了让梯度不会变得太大。

现在,让我们实现 STE 模块类:

class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
            x = STEFunction.apply(x)
            return x

你可以看到,我们在 forward 函数中使用了我们定义的 STE 函数类。要使用 autograd 函数,我们必须将输入传递给 apply 方法。现在,我们可以在神经网络中使用此模块。

使用 STE 的常见方法是在自编码器的瓶颈层内。以下是此类自编码器的实现:

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.ReLU(),
            
            nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            
            StraightThroughEstimator(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.Tanh(),
        )
        
    def forward(self, x, encode=False, decode=False):
        if encode:
            x = self.encoder(x)
        elif decode:
            x = self.decoder(x)
        else:
            encoding = self.encoder(x)
            x = self.decoder(encoding)
        return x

这个自编码器是为 MNIST 数据集制作的。它将 28x28 图像压缩为具有 512 个通道的 1x1 图像。然后将其解码回 28x28 图像。

我将 STE 放在编码器的末尾。它将把接收到的张量的所有值转换为二进制。你可能已经注意到我使用了一个非常规的前向函数。我添加了两个新参数 encode 和 decrypt,它们要么是 True,要么是 False。如果 encode 设置为 True,网络将返回编码器的输出。同样,如果 decrypt 设置为 True,网络需要有效的编码并将其解码回图像。

我在 MNIST 数据集上对自动编码器进行了 5 个 epoch 的训练,并带有 MSE 损失。以下是测试集上的重建:

如你所见,重建效果非常好。STE 可用于神经网络,且性能不会有太大损失。

完整代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# dataset preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
])
trainset = datasets.MNIST('dataset/', train=True, download=True, transform=transform)
testset = datasets.MNIST('dataset/', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
# defining networks
class STEFunction(autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()
    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)
class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()
    def forward(self, x):
        x = STEFunction.apply(x)
        return x
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.ReLU(),
            
            nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            
            StraightThroughEstimator(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.Tanh(),
        )
        
    def forward(self, x, encode=False, decode=False):
        if encode:
            x = self.encoder(x)
        elif decode:
            x = self.decoder(x)
        else:
            encoding = self.encoder(x)
            x = self.decoder(encoding)
        return x
net = Autoencoder().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.5, 0.999))
criterion_MSE = nn.MSELoss().to(device)
# train loop
epoch = 5
for e in range(epoch):
    print(f'Starting epoch {e} of {epoch}')
    for X, y in tqdm(trainloader):
        optimizer.zero_grad()
        X = X.to(device)
        reconstruction = net(X)
        loss = criterion_MSE(reconstruction, X)
        loss.backward()
        optimizer.step()
    print(f'Loss: {loss.item()}')
# test loop
i = 1
fig = plt.figure(figsize=(10, 10))
for X, y in testloader:
    X_in = X.to(device)
    recon = net(X_in).detach().cpu().numpy()
    if i >= 10:
      break
    fig.add_subplot(5, 2, i).set_title('Original')
    plt.imshow(X[0].reshape((28, 28)), cmap="gray")
    fig.add_subplot(5, 2, i+1).set_title('Reconstruction')
    plt.imshow(recon[0].reshape((28, 28)), cmap="gray")
    i += 2
fig.tight_layout()
plt.show()

原文链接:Intuitive Explanation of Straight-Through Estimators with PyTorch Implementation

BimAnt翻译整理,转载请标明出处