Очень медленные квантовые нейронные сетиPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Очень медленные квантовые нейронные сети

Сообщение Anonymous »

следующий код очень медленный...
Как я могу его ускорить?
Ниже вы увидите весь код, который я написал, и то, как я использовал JAX и ПенниЛейн. Вы увидите схему, которую я создал, и модель, которую я построил, и я считаю, что они верны, но любые предложения будут очень признательны. Код работает, но он слишком медленный, и я не могу понять, почему он не может завершить даже первую эпоху обучения из-за своей медлительности.

Код: Выделить всё

import pennylane as qml
from pennylane import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from tensorflow import keras
import tensorflow as tf
from jax.lib import xla_bridge

# Check if JAX is using the GPU
print(xla_bridge.get_backend().platform)  # It should return 'gpu' if CUDA is correctly configured.

# Parameters
n_qubits = 4    # Number of qubits (equivalent to the number of pixels for a 3x3 kernel)
n_layers = 2    # Number of layers in the circuit
n_train = 100    # Number of training samples
n_test = 50      # Number of test samples

# Load and preprocess MNIST data
mnist_dataset = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist_dataset.load_data()

# Add an extra dimension for convolution channels using jax numpy
train_images = jnp.array(train_images[..., None])
test_images = jnp.array(test_images[..., None])

# Reduce dataset size for testing
train_images = train_images[:n_train]
train_labels = train_labels[:n_train]
test_images = test_images[:n_test]
test_labels = test_labels[:n_test]

# Normalize pixels between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0

print(train_images.shape)
print(train_labels.shape)

# Quantum device
dev = qml.device("default.qubit", wires=n_qubits)

# Quantum circuit
@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(inputs, weights):
# Data embedding
qml.AngleEmbedding(inputs, wires=range(n_qubits), rotation="Y")
# Entanglement layers
qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
# PauliZ operator measurements
return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]

# Class for quantum convolutional layer
class QuantumConv2D:
def __init__(self, kernel_size, stride, n_qubits, n_layers):
self.kernel_size = kernel_size
self.stride = stride
self.n_qubits = n_qubits
self.n_layers = n_layers
self.weight_shapes = {"weights": (n_layers, n_qubits)}
#print("Initialized QuantumConv2D")

def __call__(self, X, weights):
#print("Starting QuantumConv2D call")
batch_size, height, width, channels = X.shape
output_height = (height - self.kernel_size) // self.stride + 1
output_width = (width - self.kernel_size) // self.stride + 1

outputs = jnp.empty((batch_size, output_height, output_width, self.n_qubits), dtype=jnp.float32)

# Iterate over image blocks
for i in range(0, height - self.kernel_size + 1, self.stride):
#print(f"Processing row: {i}")
for j in range(0, width - self.kernel_size + 1, self.stride):
#print(f"Processing column:  {j}")
patch = X[:, i:i+self.kernel_size, j:j+self.kernel_size, :]
patch = patch.reshape((batch_size, -1))

# Calculate output for each example in the batch
for b in range(batch_size):
#print(f"Calling circuit batch {b}, row {i}, column {j}")
outputs = outputs.at[b, i // self.stride, j // self.stride, :].set(
circuit(patch[b], weights)
)

#print("Ending QuantumConv2D call")
return outputs

# Class for the full model
class QuantumCNN:
def __init__(self, input_shape, n_qubits, n_layers, kernel_size=2, stride=1, output_size=10):
#print("Initializing QuantumCNN")
self.qconv = QuantumConv2D(kernel_size, stride, n_qubits, n_layers)
self.n_qubits = n_qubits
self.output_size = output_size
height, width = input_shape[1:3]
output_height = (height - kernel_size) // stride + 1
output_width = (width - kernel_size) // stride + 1
self.flatten_size = output_height * output_width * n_qubits
self.fc_weights = random.normal(random.PRNGKey(42), (self.flatten_size, output_size))

def forward(self, X, q_weights):
#print("Starting forward QuantumCNN")
qconv_output = self.qconv(X, q_weights)
#print("QuantumConv2D output calculated")
qconv_output_flattened = qconv_output.reshape(X.shape[0], -1)
logits = jnp.dot(qconv_output_flattened, self.fc_weights)
#print("Logits calculated")
return logits

# Loss function
#@jax.jit
def cross_entropy_loss(logits, labels):
# Convert labels to one-hot
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
return -jnp.mean(jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1))

# Parameter update function
@jax.jit
def update(params, X, y, learning_rate):
# Compute the gradient of the loss with respect to the parameters
#print("Measuring Loss")
grads = jax.grad(loss_fn)(params, X, y)
# Update parameters with SGD
return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)

# Generic loss function
@jax.jit
def loss_fn(params, X, y):
q_weights, fc_weights = params
input_shape = X.shape
model = QuantumCNN(input_shape=input_shape, n_qubits=n_qubits, n_layers=n_layers)
logits = model.forward(X, q_weights)
return cross_entropy_loss(logits, y)

# Initialization parameters
q_weights = random.normal(random.PRNGKey(0), (n_layers, n_qubits))
params = (q_weights, random.normal(random.PRNGKey(42), (n_qubits * 24 * 24, 10)))

# Training
learning_rate = 0.01
num_epochs = 1
batch_size = 10

@jax.jit
def train_step(params, X_batch, y_batch, learning_rate):
# Compute gradients of the loss function
grads = jax.grad(loss_fn)(params, X_batch, y_batch)

# Print gradients to monitor them (optional, can be very detailed)
print("Gradients:", grads)

# Update parameters
params = jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)

# Print updated parameters (optional)
print("Updated Parameters:", params)

return params

# Run training with optimized step
for epoch in range(num_epochs):
print(f"Epoch {epoch + 1}/{num_epochs}")
for i in range(0, len(train_images), batch_size):
X_batch = train_images[i:i+batch_size]
y_batch = train_labels[i:i+batch_size]

# Calculate updated parameters
params = train_step(params, X_batch, y_batch, learning_rate)

# You can also print the loss after each batch (if loss_fn is well defined)
if i % 100 == 0:  # Print every 100 batches
loss_value = loss_fn(params, X_batch, y_batch)
print(f"Batch {i}/{len(train_images)} - Loss:  {loss_value}")

# print("Starting training")
# for epoch in range(num_epochs):
#     print(f"Epoch {epoch+1}/{num_epochs}")
#     for i in range(0, len(train_images), batch_size):
#         #print(f"Processing batch {i // batch_size + 1}")
#         X_batch = train_images[i:i+batch_size]
#         y_batch = train_labels[i:i+batch_size]

#         params = update(params, X_batch, y_batch, learning_rate)

#         batch_loss = loss_fn(params, X_batch, y_batch)
#         print(f"Batch {i // batch_size + 1} Loss: {batch_loss:.4f}")

#     train_loss = loss_fn(params, train_images, train_labels)
#     print(f"Epoch {epoch+1} Loss: {train_loss:.4f}")

# Testing
q_weights, fc_weights = params

# Initialize the model with input shape
input_shape = train_images.shape  # Assume train_images has shape (batch_size, height, width, channels)
print(f"Input shape: {input_shape}")
model = QuantumCNN(input_shape=input_shape, n_qubits=n_qubits, n_layers=n_layers)

# Calculate logits for test examples
logits = model.forward(test_images, q_weights)

predictions = jnp.argmax(logits, axis=1)

accuracy = jnp.mean(predictions == test_labels)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# Show the correct and predicted result for the first N examples
N = 10  # Number of examples to display
for i in range(N):
print(f"Example {i+1}: True Label: {test_labels[i]}, Predicted: {predictions[i]}")

Мне хотелось бы понять причины низкой производительности и высоких вычислительных затрат, с которыми я сталкиваюсь, несмотря на использование графических процессоров мощных серверов для запуска кода. Кажется, что процесс занимает неожиданно много времени, и я ищу информацию о потенциальной неэффективности или узких местах. Я был бы очень признателен за любые рекомендации или стратегии по оптимизации кода или внесению изменений для повышения скорости и снижения вычислительной нагрузки на протяжении всего процесса.


Подробнее здесь: https://stackoverflow.com/questions/792 ... l-networks
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»