Как эффективно реализовать заполнение вперед в pytorchPython

Программы на Python
Ответить
Anonymous
 Как эффективно реализовать заполнение вперед в pytorch

Сообщение Anonymous »

Как я могу эффективно реализовать логику прямого заполнения (вдохновленную пандами ffill) для NxLxC векторной формы (пакет, измерение последовательности, канал). Поскольку каждая последовательность каналов независима, это может быть эквивалентно работе с тензорной формой (N*C)xL.
При вычислении должна сохраняться переменная горелки, чтобы фактический выходной сигнал был дифференцируемым.< /p>
Мне удалось сделать что-то с расширенной индексацией, но это L**2 по памяти и количеству операций, поэтому не очень хорошо и удобно для графического процессора.

Пример:
Предполагая, что у вас есть последовательность [0,1,2,0,0,3,0,4,0,0,0,5,6,0] в тензорной форме 1x14 при прямом заполнении вы получите последовательность [0,1,2,2,2,3,3,4,4,4,4,5,6,6].
Еще один пример в форме 2x4 — это [[0, 1, 0, 3], [1, 2, 0, 3]], который следует заполнить вперед в [[0, 1, 1, 3], [1 , 2, 2, 3]].

Метод, используемый сегодня:
Мы используем следующий код, который сильно неоптимизирован, но все же быстрее, чем неоптимизированный векторизованные циклы:
def last_zero_sequence_start_indices(t: torch.Tensor) -> torch.Tensor:
"""
Given a 3D tensor `t`, this function returns a two-dimensional tensor where each entry represents
the starting index of the last contiguous sequence of zeros up to and including the current index.
If there's no zero at the current position, the value is the tensor's length.

In essence, for each position in `t`, the function pinpoints the beginning of the last contiguous
sequence of zeros up to that position.

Args:
- t (torch.Tensor): Input tensor with shape [Batch, Channel, Time].

Returns:
- torch.Tensor: Three-dimensional tensor with shape [Batch, Channel, Time] indicating the starting position of
the last sequence of zeros up to each index in `t`.
"""

# Create a mask indicating the start of each zero sequence
start_of_zero_sequence = (t == 0) & torch.cat([
torch.full(t.shape[:-1] + (1,), True, device=t.device),
t[..., :-1] != 0,
], dim=2)

# Duplicate this mask into a TxT matrix
duplicated_mask = start_of_zero_sequence.unsqueeze(2).repeat(1, 1, t.size(-1), 1)

# Extract the lower triangular part of this matrix (including the diagonal)
lower_triangular = torch.tril(duplicated_mask)

# For each row, identify the index of the rightmost '1' (start of the last zero sequence up to that row)
indices = t.size(-1) - 1 - lower_triangular.int().flip(dims=[3]).argmax(dim=3)

return indices


Подробнее здесь: https://stackoverflow.com/questions/772 ... in-pytorch
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»