本博客是作者复现《S3Net: A Single Stream Structure for Depth Guided Image Relighting》的训练数据集读取代码的笔记。
一、函数test_trainSet()
函数功能:测试类trainDataSetFromTrack2的功能。
1、给出输入原始图像的路径和引导图像路径
1 2
| origin_img_path = '../datasets/alltrain/*.png' guide_img_path = origin_img_path
|
2、根据图片路径和想获取的图片数量获取数据集。
1
| dataset = trainDataSetFromTrack2(origin_img_path, guide_img_path,10)
|
3、用DataLoader获取可以输入神经网络中的数据集
1
| trainloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
|
4、得到一组样本图片,iter函数将可序列化的对象序列化,next按顺序取序列化后对象的数据。
1
| batchdict = next(iter(trainloader))
|
5、获取原始图像及其深度图、引导图像及其深度图。
1
| ori_image, guide_image, ori_depth, guide_depth = batchdict['x']
|
6、将图片保存到对象路径中
1
| save_img(ori_image,'./1.png')
|
函数代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| def test_trainSet(): # 创建数据集 origin_img_path = '../datasets/alltrain/*.png'# 输入的原始图像的路径 guide_img_path = origin_img_path # 引导图像路径 dataset = trainDataSetFromTrack2(origin_img_path, guide_img_path,10)# 根据图片路径读取数据集 trainloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0) # 输出信息 print("训练集一共有{}/{}={}个的批次,其中{}是mini-batch".format(len(dataset), 1, len(trainloader), 1)) batchdict = next(iter(trainloader))# 得到一组样本数据 ori_image, guide_image, ori_depth, guide_depth = batchdict['x'] img_name = batchdict['img_name'] print(ori_image.shape) print(guide_image.shape) print(ori_depth.shape) print(guide_depth.shape) print('img_name', img_name) save_img(ori_image,'./1.png')
|
二、类trainDataSetFromTrack2
类trainDataSetFromTrack2的功能:实现加载数据集所需的各个函数。
1、类头
该类继承自类Dataset,需要重载函数__init__()、getitem(self, index)、len(self)(这三个函数开头结尾都有两个下划线,typora文档里没显示出来)。
1
| class trainDataSetFromTrack2(Dataset):
|
2、成员函数__init__()
函数代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| def __init__(self, origin_img_path: str, guide_img_path: str, num:int, ): super(trainDataSetFromTrack2, self).__init__() self.origin_img_paths, self.guide_img_paths = self._get_dataset_path(origin_img_path, guide_img_path) self.len = len(self.origin_img_paths) if num > 0 and num < self.len: self.origin_img_paths = self.origin_img_paths[:num] self.guide_img_paths =self.guide_img_paths[:num] self.len = num self.preprocess_fn = data_transform print(f'含有{self.len} 个样本的数据集已被创建')
|
函数功能:
1、获取所有输入的原始图像和引导图像的路径
1
| self.origin_img_paths, self.guide_img_paths = self._get_dataset_path(origin_img_path, guide_img_path)
|
2、获取读取整个数据集的大小
1
| self.len = len(self.origin_img_paths)
|
3、获取指定数量的图片
1 2 3 4
| if num > 0 and num < self.len: self.origin_img_paths = self.origin_img_paths[:num] self.guide_img_paths =self.guide_img_paths[:num] self.len = num
|
4、获取图像预处理函数
1
| self.preprocess_fn = data_transform
|
3、成员函数__getitem__()
函数代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| # 获取一组图片数据 def __getitem__(self, index): # 获取一组样本的路径 origin_img_path, guide_img_path = self.origin_img_paths[index % self.len], self.guide_img_paths[index % self.len] origin_depth_name = origin_img_path.split('_')[0]+'.npy' # 拼接出原始图像对应深度图的路径:Image000+.npy guide_depth_name = guide_img_path.split('_')[0]+'.npy' # 拼接出指导图像对应深度图的路径: Image001+.npy truth_img_name = origin_img_path.split('_')[0]+'_'+guide_img_path.split('_')[1]+'_'+guide_img_path.split('_')[2]# 拼接出真实图像的路径:原始图像的前缀Image000+指导图像的后缀 # 读取该组样本的RGB图片 ori_image, guide_image,truth_img = map(self._read_rgb_img, (origin_img_path, guide_img_path,truth_img_name))
# 读取该组样本的depth图片 ori_depth, guide_depth = map(self._read_depth_img, (origin_depth_name, guide_depth_name)) # 获取该组样本对应的名称 img_name = origin_img_path.split('\\')[1] return {'x':(ori_image, guide_image, ori_depth, guide_depth), 'y':truth_img, 'img_name':img_name}
|
函数功能:根据序号index,获取一组样本图片。
1、获取原始图像及其深度图、引导图像及其深度图、真实图像的路径
1 2 3 4 5 6 7 8
| # 根据序号index,获取原始图像、引导图像的路径 origin_img_path, guide_img_path = self.origin_img_paths[index % self.len], self.guide_img_paths[index % self.len] # 拼接出原始图像对应深度图的路径:Image000+.npy origin_depth_name = origin_img_path.split('_')[0]+'.npy' # 拼接出指导图像对应深度图的路径: Image001+.npy guide_depth_name = guide_img_path.split('_')[0]+'.npy' # 拼接出真实图像的路径:原始图像的前缀Image000+指导图像的后缀 truth_img_name = origin_img_path.split('_')[0]+'_'+guide_img_path.split('_')[1]+'_'+guide_img_path.split('_')[2]
|
2、# 读取该组样本的RGB图片
1
| ori_image, guide_image,truth_img = map(self._read_rgb_img, (origin_img_path, guide_img_path,truth_img_name))
|
map()相当于调用了函数self._read_rgb_img三次,以上代码还可以写为
1 2 3
| ori_image = self._read_rgb_img(origin_img_path) guide_image = self._read_rgb_img(guide_img_path) truth_img = self._read_rgb_img(truth_img_name)
|
3、读取该组样本的depth图片
1
| ori_depth, guide_depth = map(self._read_depth_img, (origin_depth_name, guide_depth_name))
|
4、返回读取的这组样本图片
1 2 3
| return {'x':(ori_image, guide_image, ori_depth, guide_depth), 'y':truth_img, 'img_name':img_name}
|
4、成员函数 __len__()
函数功能:返回读取图片的数量。
函数代码:
1 2
| def __len__(self): return self.len
|
5、成员函数_read_rgb_img()
类中的成员函数加上一个下划线_,这样类外就不能访问该函数。
函数功能:根据给定的图片路径,获取图片张量。
函数代码:
1 2 3 4 5
| def _read_rgb_img(self,img_path): img = Image.open(str(img_path)) image_tensor = self.preprocess_fn(img) image_tensor = image_tensor[:3, :, :] return image_tensor
|
6、成员函数_read_depth_img()
函数功能:根据给定的图片路径,获取深度图片张量。
函数代码:
1 2 3 4 5 6
| def _read_depth_img(self,depth_path): depth = np.load(depth_path, allow_pickle=True).item()['normalized_depth'] ori_depth = torch.unsqueeze(torch.from_numpy(depth), 0) # 升维(1,1024,1024) #ori_depth = torch.unsqueeze(ori_depth, 0) # 升维(1,1,1024,1024) return ori_depth
|
7、成员函数_get_dataset_path()
函数功能:根据给定的图片文件夹的路径,获取图片文件夹中所有图片的路径。
glob.glob函数:搜索所有满足条件的项。
函数代码:
1 2 3 4 5 6
| def _get_dataset_path(self, input_file_path, target_file_path): origin_img_paths = sorted(glob.glob(input_file_path, recursive=True)) guide_img_paths = glob.glob(target_file_path, recursive=True) random.shuffle(guide_img_paths) #assert len(origin_img_paths) == len(guide_img_paths) return origin_img_paths, guide_img_paths
|
三、数据增强手段
代码
1 2 3
| data_transform = transforms.Compose([ transforms.ToTensor(), ])
|
四、函数save_img()
函数功能:把图片张量tensor_img保存到输出文件夹output_dir中。
函数代码:
1 2 3
| def save_img(tensor_img,output_dir): torchvision.utils.save_image(tensor_img, output_dir)
|