「项目复现」S3net的损失函数实现
《S3Net: A Single Stream Structure for Depth Guided Image Relighting》是来自中国台湾的Hao-Hsiang Yang等人发表在CVPR 2021(CCF推荐的A类会议)上的一篇WorkShip论文,本文是其项目训练代码的关键复现流程。
本博客是复现《S3Net: A Single Stream Structure for Depth Guided Image Relighting》的损失函数实现。该项目的S3Net一共使用了三个损失函数,其整体损失如下:
$$
L_{\text {Total }}=\lambda_{1} L_{\text {cha }}+\lambda_{2} L_{W-S S I M}+\lambda_{3} L_{P e r}
$$
其中 $\lambda_{1}$、$\lambda_{2}$ 和 $\lambda_{3}$ 是缩放系数,用于调整三个分量的相对权重。
一、Charbonnier 损失
该损失函数来自于《A general and adaptive robust loss function》,其可以看做是一个高鲁棒性的L1损失函数,该损失函数可以还原全局结构并且可以更鲁棒地处理异常值,其公式如下:
$$
L_{C h a}(I, \hat{I})=\frac{1}{T} \sum_{i}^{T} \sqrt{\left(I_{i}-\hat{I}_{i}\right)^{2}+\epsilon^{2}}
$$
其中$I$ 和$\hat{I}$ 分别代表目标图像和该文网络输出的预测图像, $\epsilon$被视为一个微小的常数(例如$10^{-6}$),用来实现稳定和鲁棒的收敛。根据这篇超分辨领域的论文《Fast and Accurate Image Super-Resolution with Deep Laplacian Pyramid Networks》,采用该函数可以使得模型的收敛速度加快。其实现代码相对简单“
1 | class L1_Charbonnier_loss(torch.nn.Module): |
二、SSIM 损失
该损失函数来自于《Loss functions for image restoration with neural networks》 ,其能够重建局部纹理和细节。 可以表示为:
$$
L_{S S I M}(I, \hat{I})=-\frac{\left(2 \mu_{I} \mu_{\hat{I}}+C_{1}\right)\left(2 \sigma_{I \hat{I}}+C_{2}\right)}{\left(\mu_{I}^{2}+\mu_{\hat{I}}^{2}+C_{1}\right)\left(\sigma_{I}^{2}+\sigma_{\hat{I}}^{2}+C_{2}\right)}
$$
其中 σ 和 µ 表示图像的标准偏差、协方差和均值。
在图像重照明任务中,为了从原始图像中去除阴影,该文扩展了 SSIM 损失函数,以便使网络可以恢复更详细的部分。
该文使用《Y-net: Multiscale feature aggregation network with wavelet structure similarity loss function for single image dehazing》 中的方法将 DWT 组合到 SSIM 损失中,这有利于重建重光照图像的清晰细节。最初,DWT 将预测图像分解为四个不同的小sub-band图像。 操作可以表示为:
$$
\hat{I}^{L L}, \hat{I}^{L H}, \hat{I}^{H L}, \hat{I}^{H H}=\operatorname{DWT}(\hat{I})
$$
其中上标表示来自各个过滤器的输出(例如,$$\hat{I}^{L L}, \hat{I}^{L H}, \hat{I}^{H L}, \hat{I}^{H H}$$)。
$$\hat{I}^{H L}, \hat{I}^{L H}, \hat{I}^{H H}$$分别是水平边缘、垂直边缘和角点检测的高通滤波器。 fLL 被视为下采样操作。 此外,DWT 可以不断分解$$\hat{I}^{L L}$$ 以生成具有不同尺度和频率信息的图像。 这一步写成:
$$
\hat{I}{i+1}^{L L}, \hat{I}{i+1}^{L H}, \hat{I}{i+1}^{H L}, \hat{I}{i+1}^{H H}=\operatorname{DWT}\left(\hat{I}{i}^{L L}\right)
$$
其中下标 i 表示第 i 次 DWT 迭代的输出。 上述 SSIM 损失项是根据原始图像对和各种子带图像对计算得出的。 SSIM损失和DWT的融合整合为:
$$
\begin{array}{l}
L{W-S S I M}(I, \hat{I})=\sum_{0}^{r} \gamma_{i} L_{\mathrm{SSIM}}\left(I_{i}^{w}, \hat{I}{i}^{w}\right) \
w \in{L L, H L, L H, H H}
\end{array}
$$
其中$\gamma{i}$ 基于原文来控制不同补丁的重要性。
这里的实现我们参考了wavelet_ssim的实现。
三、感知损失
该损失函数来自于2016年代ECCV会议的《Perceptual losses for real-time style transfer and super-resolution》,该论文在图像转换问题中使用感知损失(perceptual loss)函数代替之前的逐像素(per-pixel)损失函数,结果在速度和图片质量上均得到了大幅度提升。
感知损失定义为
$$
L_{P e r}(I, \hat{I})=\mid(\operatorname{VGG}(I)-\operatorname{VGG}(\hat{I}) \mid
$$
其中$\mid·\mid$ 是绝对值。
该损失函数利用从预训练的深度神经网络(例如 VGG19 (《Very deep convolutional networks for large-scale image recognition》))中获得的多尺度特征,然后使用L1损失来测量预测图像和目标图像之间的视觉特征差异,从而使得训练的图像尽可能地逼近目标图像。该项目使用在ImageNet 上预训练的 VGG19 被用作损失函数网络。
首先使用代码获取vgg19模型:
1 | import torch |
vgg19整体结构分为’features’, ‘avgpool’, 和 ‘classifier’三大部分,而计算损失函数只需要用到’features’部分,打印其结果如下:
1 | Sequential( |
这里我们使用[DRN项目中的vggloss](DeepRelight/networks.py at master · WangLiwen1994/DeepRelight (github.com))的实现来获取多尺度特征,如下:
1 | class Vgg19(nn.Module): |