Как в PyTorch заставить __call__() из nn.Module автоматически копировать подсказки типов и строку документации forward()Python

Программы на Python
Ответить
Anonymous
 Как в PyTorch заставить __call__() из nn.Module автоматически копировать подсказки типов и строку документации forward()

Сообщение Anonymous »

Обычно шаги вперед реализуются в методе front() и вызываются через __call__, как в этом примере

Код: Выделить всё

from torch import nn, FloatTensor, IntTensor

class MyModule(nn.Module):
def __init__(self, ...) -> None:
nn.Module.__init__(self)
...
def forward(self, x: FloatTensor, y: FloatTensor) -> tuple[FloatTensor, IntTensor]:
"""
Args:
x (FloatTensor): in shape of BxTxE
y (FloatTensor): in shape of BxE

Returns:
tuple[FloatTensor, IntTensor]: (sth. in shape of BxT, sth. in shape of B)
"""
... # implementation of forward steps

model = MyModule(...)
...
a, b = model(x, y) # call it through __call__
Однако такие среды IDE, как VSCode, не могут распознавать подсказки типов или строку документации __call__, поскольку это совершенно другой метод без перегрузки.
Хотя в принципе это разумно для Python , это по-прежнему недружелюбно к таким обстоятельствам, как сотрудничество, требующее удобных подсказок по кодированию.
Возможное, но неуклюжее решение — скопировать эту информацию для перегрузки __call__() в каждый nn.Module:

Код: Выделить всё

from torch import nn, FloatTensor, IntTensor

class MyModule(nn.Module):
def __init__(self, ...) -> None:
nn.Module.__init__(self)
...
def forward(self, x: FloatTensor, y: FloatTensor) -> tuple[FloatTensor, IntTensor]:
"""
Args:
x (FloatTensor): in shape of BxTxE
y (FloatTensor): in shape of BxE

Returns:
tuple[FloatTensor, IntTensor]: (sth. in shape of BxT, sth. in shape of B)
"""
... # implementation of forward steps
def __call__(self, x: FloatTensor, y: FloatTensor) -> tuple[FloatTensor, IntTensor]:
"""
Args:
x (FloatTensor): in shape of BxTxE
y (FloatTensor): in shape of BxE

Returns:
tuple[FloatTensor, IntTensor]: (sth. in shape of BxT, sth. in shape of B)
"""
return nn.Module.__call__(self, x, y)

model = MyModule(...)
...
a, b = model(x, y) # call through __call__
Итак, как я могу сообщить Python или VSCode, что __call__() и front() имеют одинаковые типы ввода/вывода и строку документации в любом подклассе nn.Module, не записывая их снова в перегрузку __call__() каждого подкласса?
(Я думаю, возможное решение для строк документации может быть декораторами? А у меня нет? идея копирования подсказок типов)

Подробнее здесь: https://stackoverflow.com/questions/792 ... -type-hint
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»