следующий код очень медленный...
Как я могу его ускорить?
Ниже вы увидите весь код, который я написал, и то, как я использовал JAX и ПенниЛейн. Вы увидите схему, которую я создал, и модель, которую я построил, и я считаю, что они верны, но любые предложения будут очень признательны. Код работает, но он слишком медленный, и я не могу понять, почему он не может завершить даже первую эпоху обучения из-за своей медлительности.
import pennylane as qml
from pennylane import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from tensorflow import keras
import tensorflow as tf
from jax.lib import xla_bridge
# Check if JAX is using the GPU
print(xla_bridge.get_backend().platform) # It should return 'gpu' if CUDA is correctly configured.
# Parameters
n_qubits = 4 # Number of qubits (equivalent to the number of pixels for a 3x3 kernel)
n_layers = 2 # Number of layers in the circuit
n_train = 100 # Number of training samples
n_test = 50 # Number of test samples
# Load and preprocess MNIST data
mnist_dataset = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist_dataset.load_data()
# Add an extra dimension for convolution channels using jax numpy
train_images = jnp.array(train_images[..., None])
test_images = jnp.array(test_images[..., None])
# Reduce dataset size for testing
train_images = train_images[:n_train]
train_labels = train_labels[:n_train]
test_images = test_images[:n_test]
test_labels = test_labels[:n_test]
# Normalize pixels between 0 and 1
train_images = train_images / 255.0
test_images = test_images / 255.0
print(train_images.shape)
print(train_labels.shape)
# Quantum device
dev = qml.device("default.qubit", wires=n_qubits)
# Quantum circuit
@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(inputs, weights):
# Data embedding
qml.AngleEmbedding(inputs, wires=range(n_qubits), rotation="Y")
# Entanglement layers
qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
# PauliZ operator measurements
return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]
# Class for quantum convolutional layer
class QuantumConv2D:
def __init__(self, kernel_size, stride, n_qubits, n_layers):
self.kernel_size = kernel_size
self.stride = stride
self.n_qubits = n_qubits
self.n_layers = n_layers
self.weight_shapes = {"weights": (n_layers, n_qubits)}
#print("Initialized QuantumConv2D")
def __call__(self, X, weights):
#print("Starting QuantumConv2D call")
batch_size, height, width, channels = X.shape
output_height = (height - self.kernel_size) // self.stride + 1
output_width = (width - self.kernel_size) // self.stride + 1
outputs = jnp.empty((batch_size, output_height, output_width, self.n_qubits), dtype=jnp.float32)
# Iterate over image blocks
for i in range(0, height - self.kernel_size + 1, self.stride):
#print(f"Processing row: {i}")
for j in range(0, width - self.kernel_size + 1, self.stride):
#print(f"Processing column: {j}")
patch = X[:, i:i+self.kernel_size, j:j+self.kernel_size, :]
patch = patch.reshape((batch_size, -1))
# Calculate output for each example in the batch
for b in range(batch_size):
#print(f"Calling circuit batch {b}, row {i}, column {j}")
outputs = outputs.at[b, i // self.stride, j // self.stride, :].set(
circuit(patch[b], weights)
)
#print("Ending QuantumConv2D call")
return outputs
# Class for the full model
class QuantumCNN:
def __init__(self, input_shape, n_qubits, n_layers, kernel_size=2, stride=1, output_size=10):
#print("Initializing QuantumCNN")
self.qconv = QuantumConv2D(kernel_size, stride, n_qubits, n_layers)
self.n_qubits = n_qubits
self.output_size = output_size
height, width = input_shape[1:3]
output_height = (height - kernel_size) // stride + 1
output_width = (width - kernel_size) // stride + 1
self.flatten_size = output_height * output_width * n_qubits
self.fc_weights = random.normal(random.PRNGKey(42), (self.flatten_size, output_size))
def forward(self, X, q_weights):
#print("Starting forward QuantumCNN")
qconv_output = self.qconv(X, q_weights)
#print("QuantumConv2D output calculated")
qconv_output_flattened = qconv_output.reshape(X.shape[0], -1)
logits = jnp.dot(qconv_output_flattened, self.fc_weights)
#print("Logits calculated")
return logits
# Loss function
#@jax.jit
def cross_entropy_loss(logits, labels):
# Convert labels to one-hot
one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
return -jnp.mean(jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1))
# Parameter update function
@jax.jit
def update(params, X, y, learning_rate):
# Compute the gradient of the loss with respect to the parameters
#print("Measuring Loss")
grads = jax.grad(loss_fn)(params, X, y)
# Update parameters with SGD
return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
# Generic loss function
@jax.jit
def loss_fn(params, X, y):
q_weights, fc_weights = params
input_shape = X.shape
model = QuantumCNN(input_shape=input_shape, n_qubits=n_qubits, n_layers=n_layers)
logits = model.forward(X, q_weights)
return cross_entropy_loss(logits, y)
# Initialization parameters
q_weights = random.normal(random.PRNGKey(0), (n_layers, n_qubits))
params = (q_weights, random.normal(random.PRNGKey(42), (n_qubits * 24 * 24, 10)))
# Training
learning_rate = 0.01
num_epochs = 1
batch_size = 10
@jax.jit
def train_step(params, X_batch, y_batch, learning_rate):
# Compute gradients of the loss function
grads = jax.grad(loss_fn)(params, X_batch, y_batch)
# Print gradients to monitor them (optional, can be very detailed)
print("Gradients:", grads)
# Update parameters
params = jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
# Print updated parameters (optional)
print("Updated Parameters:", params)
return params
# Run training with optimized step
for epoch in range(num_epochs):
print(f"Epoch {epoch + 1}/{num_epochs}")
for i in range(0, len(train_images), batch_size):
X_batch = train_images[i:i+batch_size]
y_batch = train_labels[i:i+batch_size]
# Calculate updated parameters
params = train_step(params, X_batch, y_batch, learning_rate)
# You can also print the loss after each batch (if loss_fn is well defined)
if i % 100 == 0: # Print every 100 batches
loss_value = loss_fn(params, X_batch, y_batch)
print(f"Batch {i}/{len(train_images)} - Loss: {loss_value}")
# print("Starting training")
# for epoch in range(num_epochs):
# print(f"Epoch {epoch+1}/{num_epochs}")
# for i in range(0, len(train_images), batch_size):
# #print(f"Processing batch {i // batch_size + 1}")
# X_batch = train_images[i:i+batch_size]
# y_batch = train_labels[i:i+batch_size]
# params = update(params, X_batch, y_batch, learning_rate)
# batch_loss = loss_fn(params, X_batch, y_batch)
# print(f"Batch {i // batch_size + 1} Loss: {batch_loss:.4f}")
# train_loss = loss_fn(params, train_images, train_labels)
# print(f"Epoch {epoch+1} Loss: {train_loss:.4f}")
# Testing
q_weights, fc_weights = params
# Initialize the model with input shape
input_shape = train_images.shape # Assume train_images has shape (batch_size, height, width, channels)
print(f"Input shape: {input_shape}")
model = QuantumCNN(input_shape=input_shape, n_qubits=n_qubits, n_layers=n_layers)
# Calculate logits for test examples
logits = model.forward(test_images, q_weights)
predictions = jnp.argmax(logits, axis=1)
accuracy = jnp.mean(predictions == test_labels)
print(f"Test Accuracy: {accuracy * 100:.2f}%")
# Show the correct and predicted result for the first N examples
N = 10 # Number of examples to display
for i in range(N):
print(f"Example {i+1}: True Label: {test_labels[i]}, Predicted: {predictions[i]}")
Мне хотелось бы понять причины низкой производительности и высоких вычислительных затрат, с которыми я сталкиваюсь, несмотря на использование графических процессоров мощных серверов для запуска кода. Кажется, что процесс занимает неожиданно много времени, и я ищу информацию о потенциальной неэффективности или узких местах. Я был бы очень признателен за любые рекомендации или стратегии по оптимизации кода или внесению изменений для повышения скорости и снижения вычислительной нагрузки на протяжении всего процесса.
следующий код очень медленный... Как я могу его ускорить? Ниже вы увидите весь код, который я написал, и то, как я использовал JAX и ПенниЛейн. Вы увидите схему, которую я создал, и модель, которую я построил, и я считаю, что они верны, но любые предложения будут очень признательны. Код работает, но он слишком медленный, и я не могу понять, почему он не может завершить даже первую эпоху обучения из-за своей медлительности. [code]import pennylane as qml from pennylane import numpy as np import jax import jax.numpy as jnp from jax import random from tensorflow import keras import tensorflow as tf from jax.lib import xla_bridge
# Check if JAX is using the GPU print(xla_bridge.get_backend().platform) # It should return 'gpu' if CUDA is correctly configured.
# Parameters n_qubits = 4 # Number of qubits (equivalent to the number of pixels for a 3x3 kernel) n_layers = 2 # Number of layers in the circuit n_train = 100 # Number of training samples n_test = 50 # Number of test samples
# Load and preprocess MNIST data mnist_dataset = keras.datasets.mnist (train_images, train_labels), (test_images, test_labels) = mnist_dataset.load_data()
# Add an extra dimension for convolution channels using jax numpy train_images = jnp.array(train_images[..., None]) test_images = jnp.array(test_images[..., None])
# Iterate over image blocks for i in range(0, height - self.kernel_size + 1, self.stride): #print(f"Processing row: {i}") for j in range(0, width - self.kernel_size + 1, self.stride): #print(f"Processing column: {j}") patch = X[:, i:i+self.kernel_size, j:j+self.kernel_size, :] patch = patch.reshape((batch_size, -1))
# Calculate output for each example in the batch for b in range(batch_size): #print(f"Calling circuit batch {b}, row {i}, column {j}") outputs = outputs.at[b, i // self.stride, j // self.stride, :].set( circuit(patch[b], weights) )
# Loss function #@jax.jit def cross_entropy_loss(logits, labels): # Convert labels to one-hot one_hot_labels = jax.nn.one_hot(labels, num_classes=10) return -jnp.mean(jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1))
# Parameter update function @jax.jit def update(params, X, y, learning_rate): # Compute the gradient of the loss with respect to the parameters #print("Measuring Loss") grads = jax.grad(loss_fn)(params, X, y) # Update parameters with SGD return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
# Generic loss function @jax.jit def loss_fn(params, X, y): q_weights, fc_weights = params input_shape = X.shape model = QuantumCNN(input_shape=input_shape, n_qubits=n_qubits, n_layers=n_layers) logits = model.forward(X, q_weights) return cross_entropy_loss(logits, y)
@jax.jit def train_step(params, X_batch, y_batch, learning_rate): # Compute gradients of the loss function grads = jax.grad(loss_fn)(params, X_batch, y_batch)
# Print gradients to monitor them (optional, can be very detailed) print("Gradients:", grads)
# Run training with optimized step for epoch in range(num_epochs): print(f"Epoch {epoch + 1}/{num_epochs}") for i in range(0, len(train_images), batch_size): X_batch = train_images[i:i+batch_size] y_batch = train_labels[i:i+batch_size]
# You can also print the loss after each batch (if loss_fn is well defined) if i % 100 == 0: # Print every 100 batches loss_value = loss_fn(params, X_batch, y_batch) print(f"Batch {i}/{len(train_images)} - Loss: {loss_value}")
# print("Starting training") # for epoch in range(num_epochs): # print(f"Epoch {epoch+1}/{num_epochs}") # for i in range(0, len(train_images), batch_size): # #print(f"Processing batch {i // batch_size + 1}") # X_batch = train_images[i:i+batch_size] # y_batch = train_labels[i:i+batch_size]
# Show the correct and predicted result for the first N examples N = 10 # Number of examples to display for i in range(N): print(f"Example {i+1}: True Label: {test_labels[i]}, Predicted: {predictions[i]}")
[/code] Мне хотелось бы понять причины низкой производительности и высоких вычислительных затрат, с которыми я сталкиваюсь, несмотря на использование графических процессоров мощных серверов для запуска кода. Кажется, что процесс занимает неожиданно много времени, и я ищу информацию о потенциальной неэффективности или узких местах. Я был бы очень признателен за любые рекомендации или стратегии по оптимизации кода или внесению изменений для повышения скорости и снижения вычислительной нагрузки на протяжении всего процесса.
Я использую конвейер в scikit-learn, чтобы объединить масштабирование функций с классификатором. Это хорошо работает для логистической регрессии, но мне любопытно, будет ли этот подход эффективно обобщаться на более сложные модели, такие как...
Я использую конвейер в scikit-learn, чтобы объединить масштабирование функций с классификатором. Это хорошо работает для логистической регрессии, но мне любопытно, будет ли этот подход эффективно обобщаться на более сложные модели, такие как...
Я пытаюсь запустить keras-tutorial Вероятностные байесовские нейронные сети, чтобы получить представление о байесовских нейронных сетях (BNN). Учебное пособие содержит блокнот Google-Colab, поэтому вы можете запустить его прямо в браузере. Однако...
Я пытаюсь запустить вероятностные байесовские нейронные сети Керас, чтобы получить понимание байесовских нейронных сетей (BNN). Учебное пособие содержит ноутбук Google-Colab, поэтому вы можете запустить его непосредственно в браузере. Однако, когда...
Это коды, которые я использовал через Python Чтобы попробовать нейронные сети, чтобы найти прогнозируемый доход на основе этих предикторов. Когда я попытался установить модель, я получил следующую ошибку:
valueError : Исключение встречается при...