「项目复现」S3net的损失函数实现

《S3Net: A Single Stream Structure for Depth Guided Image Relighting》是来自中国台湾的Hao-Hsiang Yang等人发表在CVPR 2021(CCF推荐的A类会议)上的一篇WorkShip论文,本文是其项目训练代码的关键复现流程。

​ 本博客是复现《S3Net: A Single Stream Structure for Depth Guided Image Relighting》的损失函数实现。该项目的S3Net一共使用了三个损失函数,其整体损失如下:
$$
L_{\text {Total }}=\lambda_{1} L_{\text {cha }}+\lambda_{2} L_{W-S S I M}+\lambda_{3} L_{P e r}
$$
​ 其中 $\lambda_{1}$、$\lambda_{2}$ 和 $\lambda_{3}$ 是缩放系数,用于调整三个分量的相对权重。

一、Charbonnier 损失

​ 该损失函数来自于《A general and adaptive robust loss function》,其可以看做是一个高鲁棒性的L1损失函数,该损失函数可以还原全局结构并且可以更鲁棒地处理异常值,其公式如下:
$$
L_{C h a}(I, \hat{I})=\frac{1}{T} \sum_{i}^{T} \sqrt{\left(I_{i}-\hat{I}_{i}\right)^{2}+\epsilon^{2}}
$$
​ 其中$I$ 和$\hat{I}$ 分别代表目标图像和该文网络输出的预测图像, $\epsilon$被视为一个微小的常数(例如$10^{-6}$​),用来实现稳定和鲁棒的收敛。根据这篇超分辨领域的论文《Fast and Accurate Image Super-Resolution with Deep Laplacian Pyramid Networks》,采用该函数可以使得模型的收敛速度加快。其实现代码相对简单“

1
2
3
4
5
6
7
8
9
10
11
class L1_Charbonnier_loss(torch.nn.Module):
"""L1 Charbonnierloss."""
def __init__(self):
super(L1_Charbonnier_loss, self).__init__()
self.eps = 1e-6

def forward(self, X, Y):
diff = torch.add(X, -Y)
error = torch.sqrt(diff * diff + self.eps)
loss = torch.mean(error)
return loss

二、SSIM 损失

​ 该损失函数来自于《Loss functions for image restoration with neural networks》 ,其能够重建局部纹理和细节。 可以表示为:
$$
L_{S S I M}(I, \hat{I})=-\frac{\left(2 \mu_{I} \mu_{\hat{I}}+C_{1}\right)\left(2 \sigma_{I \hat{I}}+C_{2}\right)}{\left(\mu_{I}^{2}+\mu_{\hat{I}}^{2}+C_{1}\right)\left(\sigma_{I}^{2}+\sigma_{\hat{I}}^{2}+C_{2}\right)}
$$
​ 其中 σ 和 µ 表示图像的标准偏差、协方差和均值。

​ 在图像重照明任务中,为了从原始图像中去除阴影,该文扩展了 SSIM 损失函数,以便使网络可以恢复更详细的部分。

​ 该文使用《Y-net: Multiscale feature aggregation network with wavelet structure similarity loss function for single image dehazing》 中的方法将 DWT 组合到 SSIM 损失中,这有利于重建重光照图像的清晰细节。最初,DWT 将预测图像分解为四个不同的小sub-band图像。 操作可以表示为:
$$
\hat{I}^{L L}, \hat{I}^{L H}, \hat{I}^{H L}, \hat{I}^{H H}=\operatorname{DWT}(\hat{I})
$$

​ 其中上标表示来自各个过滤器的输出(例如,$$\hat{I}^{L L}, \hat{I}^{L H}, \hat{I}^{H L}, \hat{I}^{H H}$$)。

​ $$\hat{I}^{H L}, \hat{I}^{L H}, \hat{I}^{H H}$$分别是水平边缘、垂直边缘和角点检测的高通滤波器。 fLL 被视为下采样操作。 此外,DWT 可以不断分解$$\hat{I}^{L L}$$ 以生成具有不同尺度和频率信息的图像。 这一步写成:
$$
\hat{I}{i+1}^{L L}, \hat{I}{i+1}^{L H}, \hat{I}{i+1}^{H L}, \hat{I}{i+1}^{H H}=\operatorname{DWT}\left(\hat{I}{i}^{L L}\right)
$$
​ 其中下标 i 表示第 i 次 DWT 迭代的输出。 上述 SSIM 损失项是根据原始图像对和各种子带图像对计算得出的。 SSIM损失和DWT的融合整合为:
$$
\begin{array}{l}
L
{W-S S I M}(I, \hat{I})=\sum_{0}^{r} \gamma_{i} L_{\mathrm{SSIM}}\left(I_{i}^{w}, \hat{I}{i}^{w}\right) \
w \in{L L, H L, L H, H H}
\end{array}
$$
其中$\gamma
{i}$​ 基于原文来控制不同补丁的重要性。

​ 这里的实现我们参考了wavelet_ssim的实现。

三、感知损失

​ 该损失函数来自于2016年代ECCV会议的《Perceptual losses for real-time style transfer and super-resolution》,该论文在图像转换问题中使用感知损失(perceptual loss)函数代替之前的逐像素(per-pixel)损失函数,结果在速度和图片质量上均得到了大幅度提升。

​ 感知损失定义为
$$
L_{P e r}(I, \hat{I})=\mid(\operatorname{VGG}(I)-\operatorname{VGG}(\hat{I}) \mid
$$

​ 其中$\mid·\mid$ 是绝对值。

​ 该损失函数利用从预训练的深度神经网络(例如 VGG19 (《Very deep convolutional networks for large-scale image recognition》))中获得的多尺度特征,然后使用L1损失来测量预测图像和目标图像之间的视觉特征差异,从而使得训练的图像尽可能地逼近目标图像。该项目使用在ImageNet 上预训练的 VGG19 被用作损失函数网络。

​ 首先使用代码获取vgg19模型:

1
2
3
4
5
import torch
import torchvision.models as models
# 加载预训练的模型
vgg_model = models.vgg19(pretrained=True)
print(vgg_model.features)

​ vgg19整体结构分为’features’, ‘avgpool’, 和 ‘classifier’三大部分,而计算损失函数只需要用到’features’部分,打印其结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

这里我们使用[DRN项目中的vggloss](DeepRelight/networks.py at master · WangLiwen1994/DeepRelight (github.com))的实现来获取多尺度特征,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Vgg19(nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
# vgg_pretrained_features = models.vgg19(pretrained=True).features # #pretrained是true,导入预训练模型
# 加载预训练模型
vgg19_model = models.vgg19(pretrained=True)
# 获取中间层特征
vgg_pretrained_features = vgg19_model.features

self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
# 获取中间层特征,把不同层的特征分别加入不同的模块
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 28):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
# 设置所有参数都不需要计算梯度,使得之后不进行反向传播及权重更新
if not requires_grad:
for param in self.parameters():
param.requires_grad = False

def forward(self, X):

h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
#h_relu4 = self.slice4(h_relu3)
#h_relu5 = self.slice5(h_relu4)
#out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
out = [h_relu1, h_relu2, h_relu3]#, h_relu4, h_relu5]
return out
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = Vgg19().cuda()
self.criterion = nn.L1Loss() # L1损失,平均绝对值损失
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0.0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss