Как передать входные данные dim из метода подгонки в обертку скорча?Python

Программы на Python
Ответить
Anonymous
 Как передать входные данные dim из метода подгонки в обертку скорча?

Сообщение Anonymous »

Я пытаюсь включить функциональные возможности PyTorch в среду scikit-learn (в частности, Pipelines и GridSearchCV) и поэтому изучаю skorch. Стандартный пример документации для нейронных сетей выглядит так:

Код: Выделить всё

import torch.nn.functional as F
from torch import nn
from skorch import NeuralNetClassifier

class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=F.relu):
super(MyModule, self).__init__()

self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
...
...
self.output = nn.Linear(10, 2)
...
...

где вы явно передаете входные и выходные измерения, жестко закодировав их в конструкторе. Однако на самом деле это не так, как работают интерфейсы scikit-learn, где входные и выходные измерения получаются с помощью метода fit, а не передаются явно конструкторам. В качестве практического примера рассмотрим

Код: Выделить всё

# copied from the documentation
net = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=0.1,
# Shuffle training data on each epoch
iterator_train__shuffle=True,
)

# any general Pipeline interface
pipeline = Pipeline([
('transformation', AnyTransformer()),
('net', net)
])

gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')
gs.fit(X, y)
помимо того, что нигде в преобразователях нельзя указывать входные и выходные размерности, преобразователи, применяемые перед моделью, могут изменить размерность обучающего набора (подумайте об уменьшении размерности и тому подобном), поэтому жесткое кодирование ввода и вывода в конструкторе нейронной сети просто не подойдет.

Я неправильно понял, как это происходит? должно работать или иначе, что было бы предложенным решением (я думал об указании конструкторов в методе вперед, где у вас уже есть X, доступный для подгонки, но я не уверен, что это хорошая практика )?

Подробнее здесь: https://stackoverflow.com/questions/600 ... ch-wrapper
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»