Код: Выделить всё
class LoadDataset(Dataset):
def __init__(self, img_dir, mask_dir, apply_transforms = None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transforms = apply_transforms
self.img_paths, self.mask_paths = self.__get_all_paths()
self.__pil_to_tensor = transforms.PILToTensor()
self.__float_tensor = transforms.ToDtype(torch.float32, scale = True)
self.__grayscale = transforms.Grayscale()
def __get_all_paths(self):
img_paths = [os.path.join(self.img_dir, img_name.name) for img_name in os.scandir(self.img_dir) if os.path.isfile(img_name)]
mask_paths = [os.path.join(self.mask_dir, mask_name.name) for mask_name in os.scandir(self.mask_dir) if os.path.isfile(mask_name)]
img_paths = sorted(img_paths)
mask_paths = sorted(mask_paths)
return img_paths, mask_paths
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index):
img_path, mask_path = self.img_paths[index], self.mask_paths[index]
img_PIL = Image.open(img_path)
mask_PIL = Image.open(mask_path)
img_tensor = self.__pil_to_tensor(img_PIL)
img_tensor = self.__float_tensor(img_tensor)
mask_tensor = self.__pil_to_tensor(mask_PIL)
mask_tensor = self.__float_tensor(mask_tensor)
mask_tensor = self.__grayscale(mask_tensor)
if self.transforms:
img_tensor, mask_tensor = self.transforms(img_tensor, mask_tensor)
return img_tensor, mask_tensor
Код: Выделить всё
transforms.RandomHorizontalFlip()
Код: Выделить всё
def __getitem__(self, index):
img_path, mask_path = self.img_paths[index], self.mask_paths[index]
img_PIL = Image.open(img_path)
mask_PIL = Image.open(mask_path)
if self.transforms:
img_PIL, mask_PIL = self.transforms(img_PIL, mask_PIL)
img_tensor = self.__pil_to_tensor(img_PIL)
mask_tensor = self.__pil_to_tensor(mask_PIL)
img_tensor = self.__float_tensor(img_tensor)
mask_tensor = self.__float_tensor(mask_tensor)
mask_tensor = self.__grayscale(mask_tensor)
return img_tensor, mask_tensor
Подробнее здесь: https://stackoverflow.com/questions/793 ... vision-tra
Мобильная версия