「项目复现」S3net的训练数据集读取

​ 本博客是作者复现《S3Net: A Single Stream Structure for Depth Guided Image Relighting》的训练数据集读取代码的笔记。

一、函数test_trainSet()

函数功能:测试类trainDataSetFromTrack2的功能。

1、给出输入原始图像的路径和引导图像路径

1
2
origin_img_path = '../datasets/alltrain/*.png'# 输入的原始图像的路径
guide_img_path = origin_img_path # 引导图像路径

2、根据图片路径和想获取的图片数量获取数据集。

1
dataset = trainDataSetFromTrack2(origin_img_path, guide_img_path,10)

3、用DataLoader获取可以输入神经网络中的数据集

1
trainloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)

4、得到一组样本图片,iter函数将可序列化的对象序列化,next按顺序取序列化后对象的数据。

1
batchdict = next(iter(trainloader))

5、获取原始图像及其深度图、引导图像及其深度图。

1
ori_image, guide_image, ori_depth, guide_depth = batchdict['x']

6、将图片保存到对象路径中

1
save_img(ori_image,'./1.png')

函数代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def test_trainSet():
# 创建数据集
origin_img_path = '../datasets/alltrain/*.png'# 输入的原始图像的路径
guide_img_path = origin_img_path # 引导图像路径
dataset = trainDataSetFromTrack2(origin_img_path, guide_img_path,10)# 根据图片路径读取数据集
trainloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
# 输出信息
print("训练集一共有{}/{}={}个的批次,其中{}是mini-batch".format(len(dataset), 1, len(trainloader), 1))
batchdict = next(iter(trainloader))# 得到一组样本数据
ori_image, guide_image, ori_depth, guide_depth = batchdict['x']
img_name = batchdict['img_name']
print(ori_image.shape)
print(guide_image.shape)
print(ori_depth.shape)
print(guide_depth.shape)
print('img_name', img_name)
save_img(ori_image,'./1.png')

二、类trainDataSetFromTrack2

类trainDataSetFromTrack2的功能:实现加载数据集所需的各个函数。

1、类头

该类继承自类Dataset,需要重载函数__init__()、getitem(self, index)、len(self)(这三个函数开头结尾都有两个下划线,typora文档里没显示出来)。

1
class trainDataSetFromTrack2(Dataset):

2、成员函数__init__()

函数代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def __init__(self,
origin_img_path: str, # 输入文件所在的路径
guide_img_path: str, # 输出文件所在的路径
num:int,# 读取的图片数量
):
super(trainDataSetFromTrack2, self).__init__()
# 获取所有图片的路径
self.origin_img_paths, self.guide_img_paths = self._get_dataset_path(origin_img_path, guide_img_path)
self.len = len(self.origin_img_paths)
# 选取指定数量的图片
if num > 0 and num < self.len:
self.origin_img_paths = self.origin_img_paths[:num]
self.guide_img_paths =self.guide_img_paths[:num]
self.len = num
# 获取图像预处理函数
self.preprocess_fn = data_transform
print(f'含有{self.len} 个样本的数据集已被创建')

函数功能

1、获取所有输入的原始图像和引导图像的路径

1
self.origin_img_paths, self.guide_img_paths = self._get_dataset_path(origin_img_path, guide_img_path)

2、获取读取整个数据集的大小

1
self.len = len(self.origin_img_paths)

3、获取指定数量的图片

1
2
3
4
if num > 0 and num < self.len:
self.origin_img_paths = self.origin_img_paths[:num]
self.guide_img_paths =self.guide_img_paths[:num]
self.len = num

4、获取图像预处理函数

1
self.preprocess_fn = data_transform

3、成员函数__getitem__()

函数代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 获取一组图片数据
def __getitem__(self, index):
# 获取一组样本的路径
origin_img_path, guide_img_path = self.origin_img_paths[index % self.len], self.guide_img_paths[index % self.len]
origin_depth_name = origin_img_path.split('_')[0]+'.npy' # 拼接出原始图像对应深度图的路径:Image000+.npy
guide_depth_name = guide_img_path.split('_')[0]+'.npy' # 拼接出指导图像对应深度图的路径: Image001+.npy
truth_img_name = origin_img_path.split('_')[0]+'_'+guide_img_path.split('_')[1]+'_'+guide_img_path.split('_')[2]# 拼接出真实图像的路径:原始图像的前缀Image000+指导图像的后缀
# 读取该组样本的RGB图片
ori_image, guide_image,truth_img = map(self._read_rgb_img, (origin_img_path, guide_img_path,truth_img_name))

# 读取该组样本的depth图片
ori_depth, guide_depth = map(self._read_depth_img, (origin_depth_name, guide_depth_name))
# 获取该组样本对应的名称
img_name = origin_img_path.split('\\')[1]
return {'x':(ori_image, guide_image, ori_depth, guide_depth),
'y':truth_img,
'img_name':img_name}

函数功能:根据序号index,获取一组样本图片。

1、获取原始图像及其深度图、引导图像及其深度图、真实图像的路径

1
2
3
4
5
6
7
8
# 根据序号index,获取原始图像、引导图像的路径
origin_img_path, guide_img_path = self.origin_img_paths[index % self.len], self.guide_img_paths[index % self.len]
# 拼接出原始图像对应深度图的路径:Image000+.npy
origin_depth_name = origin_img_path.split('_')[0]+'.npy'
# 拼接出指导图像对应深度图的路径: Image001+.npy
guide_depth_name = guide_img_path.split('_')[0]+'.npy'
# 拼接出真实图像的路径:原始图像的前缀Image000+指导图像的后缀
truth_img_name = origin_img_path.split('_')[0]+'_'+guide_img_path.split('_')[1]+'_'+guide_img_path.split('_')[2]

2、# 读取该组样本的RGB图片

1
ori_image, guide_image,truth_img = map(self._read_rgb_img, (origin_img_path, guide_img_path,truth_img_name))

map()相当于调用了函数self._read_rgb_img三次,以上代码还可以写为

1
2
3
ori_image = self._read_rgb_img(origin_img_path)
guide_image = self._read_rgb_img(guide_img_path)
truth_img = self._read_rgb_img(truth_img_name)

3、读取该组样本的depth图片

1
ori_depth, guide_depth = map(self._read_depth_img, (origin_depth_name, guide_depth_name))

4、返回读取的这组样本图片

1
2
3
return {'x':(ori_image, guide_image, ori_depth, guide_depth),
'y':truth_img,
'img_name':img_name}

4、成员函数 __len__()

函数功能:返回读取图片的数量。

函数代码:

1
2
def __len__(self):
return self.len

5、成员函数_read_rgb_img()

类中的成员函数加上一个下划线_,这样类外就不能访问该函数。

函数功能:根据给定的图片路径,获取图片张量。

函数代码:

1
2
3
4
5
def _read_rgb_img(self,img_path):
img = Image.open(str(img_path)) # (1024,1024,4)
image_tensor = self.preprocess_fn(img) # tensor,size=(4,1024,1024)
image_tensor = image_tensor[:3, :, :] # tensor,size=(3,1024,1024)
return image_tensor

6、成员函数_read_depth_img()

函数功能:根据给定的图片路径,获取深度图片张量。

函数代码:

1
2
3
4
5
6
def _read_depth_img(self,depth_path):
depth = np.load(depth_path, allow_pickle=True).item()['normalized_depth']
ori_depth = torch.unsqueeze(torch.from_numpy(depth), 0) # 升维(1,1024,1024)
#ori_depth = torch.unsqueeze(ori_depth, 0) # 升维(1,1,1024,1024)
return ori_depth

7、成员函数_get_dataset_path()

函数功能:根据给定的图片文件夹的路径,获取图片文件夹中所有图片的路径。

glob.glob函数:搜索所有满足条件的项。

函数代码:

1
2
3
4
5
6
def _get_dataset_path(self, input_file_path, target_file_path):
origin_img_paths = sorted(glob.glob(input_file_path, recursive=True))
guide_img_paths = glob.glob(target_file_path, recursive=True)
random.shuffle(guide_img_paths)
#assert len(origin_img_paths) == len(guide_img_paths)
return origin_img_paths, guide_img_paths

三、数据增强手段

代码

1
2
3
data_transform = transforms.Compose([
transforms.ToTensor(),
])

四、函数save_img()

函数功能:把图片张量tensor_img保存到输出文件夹output_dir中。

函数代码:

1
2
3
def save_img(tensor_img,output_dir):
# 保存图像
torchvision.utils.save_image(tensor_img, output_dir)