「项目复现」S3net的训练代码实现

​ 本博客是作者复现《S3Net: A Single Stream Structure for Depth Guided Image Relighting》的训练代码的笔记。

S3net项目的程序结构

一、搭建网络模型

二、训练网络模型

1、获取数据集dataloader获取模型model获取优化器optimizer获取学习率调整器scheduler

2、使用数据集跑n个epoch

​ (1)跑1个eopch (遍历一遍数据集)

​ 获取x,y

​ 正向传播得到y’(model.forward)

计算损失(get_loss)

​ 反向传播(optimizer.zero_grad,loss.backward,optimizer.step)

​ (2)动态调整学习率(scheduler.step)

​ (3)定期保存模型(torch.load,model.load_state_dict)

​ (4)打印日志到控制台(tqdn进度条技术)

3、保存实验数据到磁盘(MetricRecorder类)

​ (1)保存损失、PSNR、SSIM等到.csv文件

​ (2)保存输入图片、预测图片、目标图片为.png

三、测试网络模型

1、如何加载模型和保存模型

函数:保存模型:

torch.save({‘state_dict’:network.state_dict()}, save_path)是下述代码中最重要的API

1
2
3
4
5
6
7
8
# 保存模型
def save_model(self, save_dir, network, epoch):
# 获取保存文件路径
save_filename = '%s_net.pth' % (epoch)//模型文件名
save_path = os.path.join(save_dir, save_filename)
# 保存网络模型
torch.save({'state_dict':network.state_dict()}, save_path)

用以上save_model函数定期保存模型:

技巧:在每个epoch保存模型时,同时保存latest模型,万一中断训练,方便加载模型、继续训练。

1
2
3
4
5
6
7
8
# 定期保存模型
# if self.metric_recorder.update_best_model('PSNR'):
# self.model.save(self.option.model_path, 'best')
if epoch % self.option.save_freq == 0 and epoch != 0:
self.save_model(self.option.model_path,self.model, 'latest')
self.save_model(self.option.model_path,self.model, epoch)
np.savetxt(self.iter_path, (epoch, self.n_total_iter), delimiter=',', fmt='%d')
print('成功保存模型:epoch %d, iters %d' % (epoch, self.n_total_iter))

函数:加载模型:

torch.load(save_path)和model.load_state_dict(checkpoint[‘state_dict’])是下述代码中最重要的API

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 导入网络模型
from network.res2net2 import Dehaze3
# 获取模型
def get_model(self):
# 加载网络模型
model = Dehaze3().to(self.option.device)
# 是否加载预训练好的模型
if self.option.is_pretrain_model:
# 得到保存模型的路径
save_path = os.path.join(self.option.model_path, 'latest_net.pth')
# 加载之前保存好的模型
checkpoint = torch.load(save_path)
self.start_epoch, self.n_total_iter = np.loadtxt(self.iter_path, delimiter=',', dtype=int)//iter.txt保存之前训练保存的最后的模型的epoch和iter
model.load_state_dict(checkpoint['state_dict'])
print('成功预加载网络模型!')
else:
self.start_epoch = 0
self.n_total_iter = 0 # 训练的总迭代次数
print('成功创建网络模型!')
return model

2、如何输出实验数据到.csv

MetricRecorder类:用于记录数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class MetricRecorder():
# 把实验数据添加到scalarDict字典中
def add_current_scalar(self, log_dict:dict):
for tag, value in log_dict.items():
self.scalarDict[tag].append(value)
#把scalarDict字典中的每条实验数据添加到csv文件中
def _write_to_csv(self,epoch_num,validation):
if self.save_to_csv:
csv_name = 'val_result.csv' if validation else 'train_result.csv'//得到csv文件的名字
self.csv_helper.save_one_epoch(epoch_num, log_dict=self.scalarDict,csv_name=csv_name)
# self.save_to_csv为True,则调用 self._write_to_csv()把实验数据添加到csv中
def write_one_epoch(self, epoch_num, validation=False):
if self.use_tb_log:
self._write_to_tensorboard(epoch_num,validation)
if self.save_to_csv:
self._write_to_csv(epoch_num,validation)
if self.save_to_png:
self._write_to_png(epoch_num)

使用MetricRecorder类得到数据,并保存数据:

初始化MetricRecorder类

1
2
3
4
5
6
7
8
9
10
11
12
# save_to_csv=True,  # 是否保存到.csv确认把实验数据保存到csv文件中
class Trainer():
def __init__(self,option:argparse.Namespace):
# 初始化数据记录器
self.metric_recorder = MetricRecorder(self.option.output_path,
use_tb_log=False, # 是否使用tb日志
save_to_csv=True, # 是否保存到.csv
save_to_png=True, # 是否保存到.png
csv_name=None,
write_header=self.option.is_pretrain_model
) # 实验描述

得到数据字典logDict,用self.metric_recorder.add_current_scalar函数获取到数据字典logDict,使MetricRecorder类里的函数_write_to_csv能使用logDict数据。

1
2
3
4
5
logDict = {'loss': losses['loss'].item(), "loss_chaL1": losses['loss_chaL1'].item(),
"loss_wssim": losses['loss_wssim'].item(),"loss_pre": losses['loss_pre'].item(),
"PSNR": curr_psnr, "SSIM": curr_ssim}
self.metric_recorder.add_current_scalar(logDict)

调用metric_recorder.write_one_epoch,保存每个回合的数据:

1
2
# 保存该回合的数据
self.metric_recorder.write_one_epoch(epoch, validation=False)

3、如何保存预测图片

调用metric_recorder.add_current_imgs获取图片名称和图片的字典,使metric_recorder里的_write_to_png函数能使用图片,并保存

1
2
3
4
5
# 输出图片
if i == epoch_size-1:# 当i是最后一个批次时保存图片
# 添加本回合的生成的图片
imgDict = {'ori_image': ori_image, 'guide_image': guide_image, 'pre_image': pre_image,'truth_img': truth_img}
self.metric_recorder.add_current_imgs(imgDict) # 记录图片

调用metric_recorder.write_one_epoch,保存每个回合的数据:

1
2
3

# 保存该回合的数据
self.metric_recorder.write_one_epoch(epoch, validation=False)

4、如何使用进度条功能