本博客是作者复现《S3Net: A Single Stream Structure for Depth Guided Image Relighting》的训练代码的笔记。
S3net项目的程序结构
一、搭建网络模型
二、训练网络模型
1、获取数据集dataloader、获取模型model、获取优化器optimizer、获取学习率调整器scheduler
2、使用数据集跑n个epoch
(1)跑1个eopch (遍历一遍数据集)
获取x,y
正向传播得到y’(model.forward)
计算损失(get_loss)
反向传播(optimizer.zero_grad,loss.backward,optimizer.step)
(2)动态调整学习率(scheduler.step)
(3)定期保存模型(torch.load,model.load_state_dict)
(4)打印日志到控制台(tqdn进度条技术)
3、保存实验数据到磁盘(MetricRecorder类)
(1)保存损失、PSNR、SSIM等到.csv文件
(2)保存输入图片、预测图片、目标图片为.png
三、测试网络模型
1、如何加载模型和保存模型
函数:保存模型:
torch.save({‘state_dict’:network.state_dict()}, save_path)是下述代码中最重要的API
1 2 3 4 5 6 7 8
| def save_model(self, save_dir, network, epoch): save_filename = '%s_net.pth' % (epoch)//模型文件名 save_path = os.path.join(save_dir, save_filename) torch.save({'state_dict':network.state_dict()}, save_path)
|
用以上save_model函数定期保存模型:
技巧:在每个epoch保存模型时,同时保存latest模型,万一中断训练,方便加载模型、继续训练。
1 2 3 4 5 6 7 8
| # 定期保存模型 # if self.metric_recorder.update_best_model('PSNR'): # self.model.save(self.option.model_path, 'best') if epoch % self.option.save_freq == 0 and epoch != 0: self.save_model(self.option.model_path,self.model, 'latest') self.save_model(self.option.model_path,self.model, epoch) np.savetxt(self.iter_path, (epoch, self.n_total_iter), delimiter=',', fmt='%d') print('成功保存模型:epoch %d, iters %d' % (epoch, self.n_total_iter))
|
函数:加载模型:
torch.load(save_path)和model.load_state_dict(checkpoint[‘state_dict’])是下述代码中最重要的API
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| from network.res2net2 import Dehaze3
def get_model(self): model = Dehaze3().to(self.option.device) if self.option.is_pretrain_model: save_path = os.path.join(self.option.model_path, 'latest_net.pth') checkpoint = torch.load(save_path) self.start_epoch, self.n_total_iter = np.loadtxt(self.iter_path, delimiter=',', dtype=int)//iter.txt保存之前训练保存的最后的模型的epoch和iter model.load_state_dict(checkpoint['state_dict']) print('成功预加载网络模型!') else: self.start_epoch = 0 self.n_total_iter = 0 print('成功创建网络模型!') return model
|
2、如何输出实验数据到.csv
MetricRecorder类:用于记录数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| class MetricRecorder(): def add_current_scalar(self, log_dict:dict): for tag, value in log_dict.items(): self.scalarDict[tag].append(value) def _write_to_csv(self,epoch_num,validation): if self.save_to_csv: csv_name = 'val_result.csv' if validation else 'train_result.csv'//得到csv文件的名字 self.csv_helper.save_one_epoch(epoch_num, log_dict=self.scalarDict,csv_name=csv_name) def write_one_epoch(self, epoch_num, validation=False): if self.use_tb_log: self._write_to_tensorboard(epoch_num,validation) if self.save_to_csv: self._write_to_csv(epoch_num,validation) if self.save_to_png: self._write_to_png(epoch_num)
|
使用MetricRecorder类得到数据,并保存数据:
初始化MetricRecorder类
1 2 3 4 5 6 7 8 9 10 11 12
| class Trainer(): def __init__(self,option:argparse.Namespace): self.metric_recorder = MetricRecorder(self.option.output_path, use_tb_log=False, save_to_csv=True, save_to_png=True, csv_name=None, write_header=self.option.is_pretrain_model )
|
得到数据字典logDict,用self.metric_recorder.add_current_scalar函数获取到数据字典logDict,使MetricRecorder类里的函数_write_to_csv能使用logDict数据。
1 2 3 4 5
| logDict = {'loss': losses['loss'].item(), "loss_chaL1": losses['loss_chaL1'].item(), "loss_wssim": losses['loss_wssim'].item(),"loss_pre": losses['loss_pre'].item(), "PSNR": curr_psnr, "SSIM": curr_ssim} self.metric_recorder.add_current_scalar(logDict)
|
调用metric_recorder.write_one_epoch,保存每个回合的数据:
1 2
| # 保存该回合的数据 self.metric_recorder.write_one_epoch(epoch, validation=False)
|
3、如何保存预测图片
调用metric_recorder.add_current_imgs获取图片名称和图片的字典,使metric_recorder里的_write_to_png函数能使用图片,并保存
1 2 3 4 5
| # 输出图片 if i == epoch_size-1:# 当i是最后一个批次时保存图片 # 添加本回合的生成的图片 imgDict = {'ori_image': ori_image, 'guide_image': guide_image, 'pre_image': pre_image,'truth_img': truth_img} self.metric_recorder.add_current_imgs(imgDict) # 记录图片
|
调用metric_recorder.write_one_epoch,保存每个回合的数据:
1 2 3
| # 保存该回合的数据 self.metric_recorder.write_one_epoch(epoch, validation=False)
|
4、如何使用进度条功能