Мне нужно использовать vmap для вычисления градиента по отношению к модели, в которой используется torch.utils.checkpoint.checkpoint .
Я получил следующую ошибку < /p>
runtimeerror: вы пытались пройти VMAP через _noopsaveinputs, но это не
не имеет поддержки VMAP. Пожалуйста, переопределите и реализуйте vmap staticmethod
или установите generate_vmap_rule = true. Для получения дополнительной информации, пожалуйста, см. Примечание, приведенное в ошибке, я попытался добавить Generate_vmap_rule = true в классе _noopsaveinputs , и я получил следующую ошибку Вместо этого < /p>
runtimeerror: torch.func преобразования еще не поддерживают сохраненные тензорные крючки. Пожалуйста, откройте проблему с вашим вариантом использования. /p>
Я нашел проблему GitHub, задающий тот же вопрос. Но через год до сих пор нет обновления. br /> Вот минимальный код для воспроизведения моей проблемы < /p>
import torch
import torch.nn as nn
import torch.optim as optim
import random
from torch.func import vmap
from torch.utils.checkpoint import checkpoint
# from torch.autograd.function import once_differentiable
# Set a fixed seed for reproducibility
seed = 42
torch.manual_seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # For multi-GPU setups
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Define a simple Transformer model
class SimpleTransformer(nn.Module):
def __init__(self, input_dim, model_dim, num_heads, num_layers, output_dim):
super(SimpleTransformer, self).__init__()
self.embedding = nn.Embedding(input_dim, model_dim)
self.transformer = nn.Transformer(d_model=model_dim, nhead=num_heads, num_encoder_layers=num_layers)
self.fc_out = nn.Linear(model_dim, output_dim)
def forward(self, src, tgt):
src = self.embedding(src)
tgt = self.embedding(tgt)
# output = self.transformer(src, tgt)
output = checkpoint(self.transformer, src, tgt, use_reentrant=False)
return self.fc_out(output)
def compute_loss(model,
weights,
buffers,
src,
tgt,
):
# Forward pass
output = torch.func.functional_call(
model, (weights, buffers), args=(src.unsqueeze(1), tgt[:-1].unsqueeze(1))
)
# Reshape output and target for loss calculation
output = output.view(-1, output_dim)
tgt = tgt[1:].view(-1) # Shift target sequence by one
loss = criterion_mean(output, tgt)
print(loss.shape)
return loss
# Hyperparameters
input_dim = 1000 # Vocabulary size
model_dim = 512 # Embedding dimension
num_heads = 8 # Number of attention heads
num_layers = 6 # Number of transformer layers
output_dim = 1000 # Output vocabulary size
seq_length = 10 # Sequence length
# Initialize model, loss function, and optimizer
model = SimpleTransformer(input_dim, model_dim, num_heads, num_layers, output_dim)
criterion_mean = nn.CrossEntropyLoss()
criterion_raw = nn.CrossEntropyLoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Dummy data (batch_size = 2, seq_length = 10)
src = torch.randint(0, input_dim, (seq_length, 2)) # Source sequence
tgt_raw = torch.randint(0, output_dim, (seq_length, 2)) # Target sequence
# Print src and tgt to verify they are the same each time
print("Source sequence (src):")
print("Target sequence (tgt):")
# Forward pass
output = model(src, tgt_raw[:-1, :]) # Exclude the last token in the target sequence
# Reshape output and target for loss calculation
output = output.view(-1, output_dim)
tgt = tgt_raw[1:, :].view(-1) # Shift target sequence by one
weights = dict(model.named_parameters())
buffers = dict(model.named_buffers())
grads_fn = torch.func.grad(compute_loss, has_aux=False, argnums=1)
gs = vmap(grads_fn,
in_dims=(None, None, None, 1, 1),
randomness='different')(
model,
weights,
buffers,
src,
tgt_raw
)
Подробнее здесь: https://stackoverflow.com/questions/794 ... checkpoint
Применение `vmap` к модели с` torch.utils.checkpoint.checkpoint` ⇐ Python
-
- Похожие темы
- Ответы
- Просмотры
- Последнее сообщение