Я' Я пытаюсь оценить параметр w в функции x -> sin(w*x)^2, чтобы классифицировать целые числа как четные или нечетные, используя pytorch в качестве самостоятельного упражнения. Конечно, существует несколько возможных значений правильного w, включая w = pi/2. Я инициализировал свою сеть (линейную без смещения, за которой следовала активация греха, а затем возводила в квадрат) с w = 1,5, близким к pi/2, в надежде, что она сходится к pi/1 = 1,507... но модель не обучается, независимо от того, как я настраиваю скорость обучения или какой оптимизатор использую.
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
# linear without bias
self.fc1 = torch.nn.Linear(1, 1, bias=False)
# Initialize close to the theoretical solution
with torch.no_grad():
self.fc1.weight.data.fill_(1.5) # Close to π/2 ≈ 1.57
def forward(self, x):
x = self.fc1(x)
return torch.sin(x)**2
net = Net()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.0004)
weights = []
for epoch in range(100):
net.train()
optimizer.zero_grad()
output = net(train_x)
loss = criterion(output, train_y)
loss.backward()
optimizer.step()
weights.append(w)
На графике веса видно, что нет тенденции сходиться к какой-либо точке.
Хотелось бы верить, что я избежал общепринятых подводные камни: я позволил входным и выходным значениям float 32, целевая функция прекрасно изучается с помощью модели, я также пробовал это с другими функциями потерь, но не получилось.
Пожалуйста, помогите мне найти, где я допустил ошибку, вот полный код (экспортированный из блокнота Jupyter):
# %%
import torch
import numpy as np
import pandas as pd
# %%
# Generate data and scale inputs
def generate_data(size):
x = np.random.randint(0, size, size) # Smaller range for better visualization
return x.astype(float), (x % 2).astype(float)
# %%
# Generate datasets
train_x, train_y = generate_data(1000)
val_x, val_y = generate_data(1000)
# Convert to tensors
train_x = torch.tensor(train_x, dtype=torch.float32).reshape(-1, 1)
train_y = torch.tensor(train_y, dtype=torch.float32).reshape(-1, 1)
val_x = torch.tensor(val_x, dtype=torch.float32).reshape(-1, 1)
val_y = torch.tensor(val_y, dtype=torch.float32).reshape(-1, 1)
# %%
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
# linear without bias
self.fc1 = torch.nn.Linear(1, 1, bias=False)
# Initialize close to the theoretical solution
with torch.no_grad():
self.fc1.weight.data.fill_(1.5) # Close to π/2 ≈ 1.57
def forward(self, x):
x = self.fc1(x)
return torch.sin(x)**2
# %%
net = Net()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.0004)
# %%
weights = []
for epoch in range(100):
net.train()
optimizer.zero_grad()
output = net(train_x)
loss = criterion(output, train_y)
loss.backward()
optimizer.step()
net.eval()
with torch.no_grad():
val_output = net(val_x)
val_loss = criterion(val_output, val_y)
if epoch % 1 == 0:
print(f"Epoch {epoch}")
print(f"Loss: {loss.item():.8f} Val Loss: {val_loss.item():.8f}")
w = net.fc1.weight.item()
print(f"Weight: {w:.8f} (target: {np.pi/2:.8f})")
print("---")
weights.append(w)
# %%
# plot weights
import matplotlib.pyplot as plt
plt.plot(weights)
plt.plot([np.pi/2]*len(weights))
# %%
# Test the model
w = net.fc1.weight.item()
print("\nFinal parameters:")
print(f"Weight: {w:.8f} (target: {np.pi/2:.8f})")
# Test on even and odd numbers
test_numbers = np.arange(0, 1500, 1)
net.eval()
with torch.no_grad():
for x in test_numbers:
test_input = torch.tensor([[float(x)]], dtype=torch.float32)
pred = net(test_input).item()
print("
if (x+1) % 60 == 0:
print()
Подробнее здесь: https://stackoverflow.com/questions/791 ... classifier