PyTorch: попытка создать объединенный набор данных с разными преобразованиями приводит к тому, что оба набора данных имеPython

Программы на Python
Ответить
Anonymous
 PyTorch: попытка создать объединенный набор данных с разными преобразованиями приводит к тому, что оба набора данных име

Сообщение Anonymous »

Я новичок в PyTorch и пытаюсь создать набор данных, для которого данный образец имеет как немаскированные, так и замаскированные данные, связанные с ним, или, другими словами, первый фрагмент данных является просто исходным образцом, а ко второму фрагменту данных применено некоторое преобразование, маскирующее некоторые записи.
Я создал собственный класс JointDataset, который затем передается в DataLoader для использования в обучении. Ниже я включил MWE, который является сокращенным вариантом руководства, которому я следовал и на котором основывался. У меня возникла одна проблема, которую я не могу отладить: при распаковке двух наборов данных с помощью загрузчика данных немаскированные и замаскированные данные идентичны. Мне удалось сузить это до того факта, что train_set и val_set каким-то образом перезаписываются своими замаскированными аналогами, когда вызывается процедура data.Subset для их создания с использованием исходных индексов. Может как-то объяснить, почему это происходит и как я могу заставить этот MWE дать желаемый результат?
# standard libraries
import sys
import numpy as np
from typing import Callable

# PyTorch data loading
import torch
import torch.utils.data as data
from torchvision import transforms

# custom masking function which masks an element of the sample with probability 0.25
def add_masking(sample):
noisy_sample = sample
prob = 0.25
for i in range(sample.shape[0]):
if (np.random.uniform(0.0, 1.0) < prob):
noisy_sample = -1.0
return noisy_sample

# very basic custom dataset
class MyDataset(data.Dataset):
def __init__(self, array: np.array, transform: Callable = None):
self.data = array
self.transform = transform

def __len__(self):
return self.data.shape[1]

def __getitem__(self, idx):
# idx sample is located in the idx column of the data
sample = self.data[:,idx]
if self.transform:
sample = self.transform(sample)
return sample

# generate dataset
num_examples = 100
num_features = 10
X = np.random.rand(num_features, num_examples)
train_dataset = MyDataset(array=X, transform=None)

# split dataset into training and validation
train_set, val_set = data.random_split(train_dataset, [80, 20], generator=torch.Generator().manual_seed(42))
train_indices = train_set.indices
val_indices = val_set.indices

# mask dataset and split it according to the split described by train_indices and val_indices
train_masked_dataset = MyDataset(array=X, transform=transforms.Lambda(lambda x: add_masking(x)))
# these calls seem to replace val_set and train_set with the masked data...
train_masked_set = data.Subset(train_masked_dataset, train_indices)
val_masked_set = data.Subset(train_masked_dataset, val_indices)

# concatenate masked and unmasked data
class JointDataset(data.Dataset):
def __init__(self, dataset1, dataset2):
self.dataset1 = dataset1
self.dataset2 = dataset2
assert len(self.dataset1) == len(self.dataset2)

def __len__(self):
return len(self.dataset1)

def __getitem__(self, index):
data1 = self.dataset1[index]
data2 = self.dataset2[index]
return data1, data2

# combine unmasked and masked data
train_set_combined = JointDataset(train_set, train_masked_set)
val_set_combined = JointDataset(val_set, val_masked_set)

# define data loaders
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple,list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)

train_loader = data.DataLoader(train_set_combined, batch_size=16, shuffle=True, drop_last=True, pin_memory=True, num_workers=32, collate_fn=numpy_collate, persistent_workers=True)
val_loader = data.DataLoader(val_set_combined, batch_size=16, shuffle=False, drop_last=False, num_workers=32, collate_fn=numpy_collate)

data_iter = iter(train_loader)
batch_unmasked, batch_masked = next(data_iter)

# checking the first batch. the unmasked data is somehow masked...
print(batch_unmasked[0:2,:].T)
print(batch_masked[0:2,:].T)


Подробнее здесь: https://stackoverflow.com/questions/798 ... sults-in-b
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»