При использовании JAX для обучения модели машинного обучения мы лишь пытаемся минимизировать потери при обучении.
В то время как в моем требовании, чтобы оценить количество эпох или избежать Из-за перетренированности мне также нужно знать потери при тестировании на каждом этапе обновления параметров. Но опция обратного вызова или отладки, доступная в JAX, явно предполагает, что мне не следует выполнять какие-либо задачи, требующие больших вычислительных ресурсов, например определение потерь и точности тестов.
Приведенная ниже оптимизация работает безупречно
import pennylane as qml
from pennylane import numpy as np
import jax
from jax import numpy as jnp
import optax
from itertools import combinations
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import log_loss
import matplotlib.pyplot as plt
import matplotlib.colors
import warnings
warnings.filterwarnings("ignore")
np.random.seed(42)
import time
# Load the digits dataset with features (X_digits) and labels (y_digits)
X_digits, y_digits = load_digits(return_X_y=True)
# Create a boolean mask to filter out only the samples where the label is 2 or 6
filter_mask = np.isin(y_digits, [2, 6])
# Apply the filter mask to the features and labels to keep only the selected digits
X_digits = X_digits[filter_mask]
y_digits = y_digits[filter_mask]
# Split the filtered dataset into training and testing sets with 10% of data reserved for testing
X_train, X_test, y_train, y_test = train_test_split(
X_digits, y_digits, test_size=0.1, random_state=42
)
# Normalize the pixel values in the training and testing data
# Convert each image from a 1D array to an 8x8 2D array, normalize pixel values, and scale them
X_train = np.array([thing.reshape([8, 8]) / 16 * 2 * np.pi for thing in X_train])
X_test = np.array([thing.reshape([8, 8]) / 16 * 2 * np.pi for thing in X_test])
# Adjust the labels to be centered around 0 and scaled to be in the range -1 to 1
# The original labels (2 and 6) are mapped to -1 and 1 respectively
y_train = (y_train - 4) / 2
y_test = (y_test - 4) / 2
def feature_map(features):
# Apply Hadamard gates to all qubits to create an equal superposition state
for i in range(len(features[0])):
qml.Hadamard(i)
# Apply angle embeddings based on the feature values
for i in range(len(features)):
# For odd-indexed features, use Z-rotation in the angle embedding
if i % 2:
qml.AngleEmbedding(features=features[i], wires=range(8), rotation="Z")
# For even-indexed features, use X-rotation in the angle embedding
else:
qml.AngleEmbedding(features=features[i], wires=range(8), rotation="X")
# Define the ansatz (quantum circuit ansatz) for parameterized quantum operations
def ansatz(params):
# Apply RY rotations with the first set of parameters
for i in range(8):
qml.RY(params[i], wires=i)
# Apply CNOT gates with adjacent qubits (cyclically connected) to create entanglement
for i in range(8):
qml.CNOT(wires=[(i - 1) % 8, (i) % 8])
# Apply RY rotations with the second set of parameters
for i in range(8):
qml.RY(params[i + 8], wires=i)
# Apply CNOT gates with qubits in reverse order (cyclically connected)
# to create additional entanglement
for i in range(8):
qml.CNOT(wires=[(8 - 2 - i) % 8, (8 - i - 1) % 8])
dev = qml.device("default.qubit", wires=8)
@qml.qnode(dev)
def circuit(params, features):
feature_map(features)
ansatz(params)
return qml.expval(qml.PauliZ(0))
def variational_classifier(weights, bias, x):
return circuit(weights, x) + bias
def square_loss(labels, predictions):
return np.mean((labels - qml.math.stack(predictions)) ** 2)
def accuracy(labels, predictions):
acc = sum([np.sign(l) == np.sign(p) for l, p in zip(labels, predictions)])
acc = acc / len(labels)
return acc
def cost(params, X, Y):
predictions = [variational_classifier(params["weights"], params["bias"], x) for x in X]
return square_loss(Y, predictions)
def acc(params, X, Y):
predictions = [variational_classifier(params["weights"], params["bias"], x) for x in X]
return accuracy(Y, predictions)
np.random.seed(0)
weights = 0.01 * np.random.randn(16)
bias = jnp.array(0.0)
params = {"weights": weights, "bias": bias}
opt = optax.adam(0.05)
batch_size = 7
num_batch = X_train.shape[0] // batch_size
opt_state = opt.init(params)
X_batched = X_train.reshape([-1, batch_size, 8, 8])
y_batched = y_train.reshape([-1, batch_size])
@jax.jit
def update_step_jit(i, args):
params, opt_state, data, targets, batch_no, print_training = args
_data = data[batch_no % num_batch]
_targets = targets[batch_no % num_batch]
loss_val, grads = jax.value_and_grad(cost)(params, _data, _targets)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
# Print training loss every 5 steps if print_training is True
def print_fn():
jax.debug.print("Step: {i}, Train Loss: {loss_val}", i=i, loss_val=loss_val)
jax.lax.cond((jnp.mod(i, 1) == 0) & print_training, print_fn, lambda: None)
return (params, opt_state, data, targets, batch_no + 1, print_training)
@jax.jit
def optimization_jit(params, data, targets, print_training = True):
opt_state = opt.init(params)
args = (params, opt_state, data, targets, 0, print_training)
(params, opt_state, _, _, _, _) = jax.lax.fori_loop(0, 10, update_step_jit, args)
return params
start_time = time.time()
params = optimization_jit(params, X_batched, y_batched)
print("Training Done! \nTime taken:",time.time() - start_time)
var_train_acc = acc(params, X_train, y_train)
print("Training accuracy: ", var_train_acc)
var_test_acc = acc(params, X_test, y_test)
print("Testing accuracy: ", var_test_acc)
Хотя это и неэффективно, я попытался включить больше вычислений в функцию print_fn() следующим образом:
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[]
The error occurred while tracing the function print_fn at C:\Users\...\AppData\Local\Temp\ipykernel_43468\2623796165.py:10 for cond. This value became a tracer due to JAX operations on these lines:
operation a:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] b
from line C:\Users\...\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] b
from line C:\Users\...\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] b
from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(4,) start_indices=(3,) strides=None] b
from line C:\Users\...\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(5,) start_indices=(4,) strides=None] b
from line C:\Users\...\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
(Additional originating lines are not shown.)
Мое время обучения резко увеличится, если я не буду использовать JAX.
Таким образом, с JAX я могу получить производительность на тестовом наборе только в конце обучения , и я никак не могу получить тестовые потери в середине обучения или есть обходной путь для этого?
Для минимально воспроизводимого примера вы можете попробовать запустить приведенный пример в этом демонстрационном коде.
При использовании JAX для обучения модели машинного обучения мы лишь пытаемся минимизировать потери при обучении. В то время как в моем требовании, чтобы оценить количество эпох или избежать Из-за перетренированности мне также нужно знать потери при тестировании на каждом этапе обновления параметров. Но опция обратного вызова или отладки, доступная в JAX, явно предполагает, что мне не следует выполнять какие-либо задачи, требующие больших вычислительных ресурсов, например определение потерь и точности тестов. Приведенная ниже оптимизация работает безупречно [code]import pennylane as qml from pennylane import numpy as np import jax from jax import numpy as jnp import optax from itertools import combinations from sklearn.datasets import load_digits from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPClassifier from sklearn.metrics import log_loss import matplotlib.pyplot as plt import matplotlib.colors import warnings warnings.filterwarnings("ignore") np.random.seed(42) import time
# Load the digits dataset with features (X_digits) and labels (y_digits) X_digits, y_digits = load_digits(return_X_y=True)
# Create a boolean mask to filter out only the samples where the label is 2 or 6 filter_mask = np.isin(y_digits, [2, 6])
# Apply the filter mask to the features and labels to keep only the selected digits X_digits = X_digits[filter_mask] y_digits = y_digits[filter_mask]
# Split the filtered dataset into training and testing sets with 10% of data reserved for testing X_train, X_test, y_train, y_test = train_test_split( X_digits, y_digits, test_size=0.1, random_state=42 )
# Normalize the pixel values in the training and testing data # Convert each image from a 1D array to an 8x8 2D array, normalize pixel values, and scale them X_train = np.array([thing.reshape([8, 8]) / 16 * 2 * np.pi for thing in X_train]) X_test = np.array([thing.reshape([8, 8]) / 16 * 2 * np.pi for thing in X_test])
# Adjust the labels to be centered around 0 and scaled to be in the range -1 to 1 # The original labels (2 and 6) are mapped to -1 and 1 respectively y_train = (y_train - 4) / 2 y_test = (y_test - 4) / 2
def feature_map(features): # Apply Hadamard gates to all qubits to create an equal superposition state for i in range(len(features[0])): qml.Hadamard(i)
# Apply angle embeddings based on the feature values for i in range(len(features)): # For odd-indexed features, use Z-rotation in the angle embedding if i % 2: qml.AngleEmbedding(features=features[i], wires=range(8), rotation="Z") # For even-indexed features, use X-rotation in the angle embedding else: qml.AngleEmbedding(features=features[i], wires=range(8), rotation="X")
# Define the ansatz (quantum circuit ansatz) for parameterized quantum operations def ansatz(params): # Apply RY rotations with the first set of parameters for i in range(8): qml.RY(params[i], wires=i)
# Apply CNOT gates with adjacent qubits (cyclically connected) to create entanglement for i in range(8): qml.CNOT(wires=[(i - 1) % 8, (i) % 8])
# Apply RY rotations with the second set of parameters for i in range(8): qml.RY(params[i + 8], wires=i)
# Apply CNOT gates with qubits in reverse order (cyclically connected) # to create additional entanglement for i in range(8): qml.CNOT(wires=[(8 - 2 - i) % 8, (8 - i - 1) % 8])
# Print training loss every 5 steps if print_training is True def print_fn(): jax.debug.print("Step: {i}, Train Loss: {loss_val}", i=i, loss_val=loss_val)
# Calculate accuracy and loss for training and test sets train_accuracy = acc(params, X_train, y_train) test_predictions = jnp.array([variational_classifier(params["weights"], params["bias"], x) for x in X_test]) test_loss = square_loss(y_test, test_predictions) test_accuracy = accuracy(y_test, test_predictions)
jax.debug.print("Step: {i}, Train Accuracy {train_accuracy}", i=i, train_accuracy = train_accuracy) jax.debug.print("Step: {i}, Test Accuracy {test_accuracy}", i=i, test_accuracy = test_accuracy) jax.debug.print("Step: {i}, Test Loss {test_loss}", i=i, test_loss = test_loss)
# if print_training=True, print the loss every 5 steps jax.lax.cond((jnp.mod(i, 1) == 0) & print_training, print_fn, lambda: None) return (params, opt_state, data, targets, batch_no + 1, print_training)
print("Training accuracy: ", var_train_acc) print("Testing accuracy: ", var_test_acc) [/code] Выдает странную ошибку типа: [code]TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[] The error occurred while tracing the function print_fn at C:\Users\...\AppData\Local\Temp\ipykernel_43468\2623796165.py:10 for cond. This value became a tracer due to JAX operations on these lines:
operation a:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] b from line C:\Users\...\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] b from line C:\Users\...\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] b from line C:\Users\mysore\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(4,) start_indices=(3,) strides=None] b from line C:\Users\...\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
operation a:f32[1] = slice[limit_indices=(5,) start_indices=(4,) strides=None] b from line C:\Users\...\AppData\Local\Temp\ipykernel_43468\1231864253.py:61:15 (ansatz)
(Additional originating lines are not shown.) [/code] Мое время обучения резко увеличится, если я не буду использовать JAX. Таким образом, с JAX я могу получить производительность на тестовом наборе только в конце обучения , и я никак не могу получить тестовые потери в середине обучения или есть обходной путь для этого? Для минимально воспроизводимого примера вы можете попробовать запустить приведенный пример в этом демонстрационном коде.