Применение `vmap` к модели с` torch.utils.checkpoint.checkpoint`Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Применение `vmap` к модели с` torch.utils.checkpoint.checkpoint`

Сообщение Anonymous »

Мне нужно использовать vmap для вычисления градиента по отношению к модели, в которой используется torch.utils.checkpoint.checkpoint .
Я получил следующую ошибку < /p>

runtimeerror: вы пытались пройти VMAP через _noopsaveinputs, но это не
не имеет поддержки VMAP. Пожалуйста, переопределите и реализуйте vmap staticmethod
или установите generate_vmap_rule = true. Для получения дополнительной информации, пожалуйста, см. Примечание, приведенное в ошибке, я попытался добавить Generate_vmap_rule = true в классе _noopsaveinputs , и я получил следующую ошибку Вместо этого < /p>

runtimeerror: torch.func преобразования еще не поддерживают сохраненные тензорные крючки. Пожалуйста, откройте проблему с вашим вариантом использования. /p>
Я нашел проблему GitHub, задающий тот же вопрос. Но через год до сих пор нет обновления. br /> Вот минимальный код для воспроизведения моей проблемы < /p>
import torch
import torch.nn as nn
import torch.optim as optim
import random
from torch.func import vmap
from torch.utils.checkpoint import checkpoint
# from torch.autograd.function import once_differentiable

# Set a fixed seed for reproducibility
seed = 42
torch.manual_seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # For multi-GPU setups
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define a simple Transformer model
class SimpleTransformer(nn.Module):
def __init__(self, input_dim, model_dim, num_heads, num_layers, output_dim):
super(SimpleTransformer, self).__init__()
self.embedding = nn.Embedding(input_dim, model_dim)
self.transformer = nn.Transformer(d_model=model_dim, nhead=num_heads, num_encoder_layers=num_layers)
self.fc_out = nn.Linear(model_dim, output_dim)

def forward(self, src, tgt):
src = self.embedding(src)
tgt = self.embedding(tgt)
# output = self.transformer(src, tgt)
output = checkpoint(self.transformer, src, tgt, use_reentrant=False)
return self.fc_out(output)

def compute_loss(model,
weights,
buffers,
src,
tgt,
):

# Forward pass
output = torch.func.functional_call(
model, (weights, buffers), args=(src.unsqueeze(1), tgt[:-1].unsqueeze(1))
)

# Reshape output and target for loss calculation
output = output.view(-1, output_dim)
tgt = tgt[1:].view(-1) # Shift target sequence by one
loss = criterion_mean(output, tgt)
print(loss.shape)
return loss

# Hyperparameters
input_dim = 1000 # Vocabulary size
model_dim = 512 # Embedding dimension
num_heads = 8 # Number of attention heads
num_layers = 6 # Number of transformer layers
output_dim = 1000 # Output vocabulary size
seq_length = 10 # Sequence length

# Initialize model, loss function, and optimizer
model = SimpleTransformer(input_dim, model_dim, num_heads, num_layers, output_dim)
criterion_mean = nn.CrossEntropyLoss()
criterion_raw = nn.CrossEntropyLoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Dummy data (batch_size = 2, seq_length = 10)
src = torch.randint(0, input_dim, (seq_length, 2)) # Source sequence
tgt_raw = torch.randint(0, output_dim, (seq_length, 2)) # Target sequence

# Print src and tgt to verify they are the same each time
print("Source sequence (src):")
print("Target sequence (tgt):")

# Forward pass
output = model(src, tgt_raw[:-1, :]) # Exclude the last token in the target sequence

# Reshape output and target for loss calculation
output = output.view(-1, output_dim)
tgt = tgt_raw[1:, :].view(-1) # Shift target sequence by one

weights = dict(model.named_parameters())
buffers = dict(model.named_buffers())

grads_fn = torch.func.grad(compute_loss, has_aux=False, argnums=1)
gs = vmap(grads_fn,
in_dims=(None, None, None, 1, 1),
randomness='different')(
model,
weights,
buffers,
src,
tgt_raw
)


Подробнее здесь: https://stackoverflow.com/questions/794 ... checkpoint
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение
  • Нет модуля с именем utils.utils, utils не является пакетом.
    Anonymous » » в форуме Python
    0 Ответы
    70 Просмотры
    Последнее сообщение Anonymous
  • Нет модуля с именем utils.utils, Utils не является пакетом
    Anonymous » » в форуме Python
    0 Ответы
    37 Просмотры
    Последнее сообщение Anonymous
  • Нет модуля с именем utils.utils, Utils не является пакетом
    Anonymous » » в форуме Python
    0 Ответы
    5 Просмотры
    Последнее сообщение Anonymous
  • Ошибка из Torch.utils.data Импорт UTILS
    Anonymous » » в форуме Python
    0 Ответы
    12 Просмотры
    Последнее сообщение Anonymous
  • Ошибка из Torch.utils.data Импорт UTILS
    Anonymous » » в форуме Python
    0 Ответы
    2 Просмотры
    Последнее сообщение Anonymous

Вернуться в «Python»