Отслеживание потерь теста/значения при обучении модели с помощью JAXPython

Программы на Python
Ответить
Anonymous
 Отслеживание потерь теста/значения при обучении модели с помощью JAX

Сообщение Anonymous »

При использовании JAX для обучения модели машинного обучения мы лишь пытаемся минимизировать потери при обучении.
В то время как в моем требовании, чтобы оценить количество эпох или избежать Из-за перетренированности мне также нужно знать потери при тестировании на каждом этапе обновления параметров. Но опция обратного вызова или отладки, доступная в JAX, явно предполагает, что мне не следует выполнять какие-либо задачи, требующие больших вычислительных ресурсов, например определение потерь и точности тестов.
Приведенная ниже оптимизация работает безупречно

Код: Выделить всё

import pennylane as qml
from pennylane import numpy as np
import jax
from jax import numpy as jnp
import optax
from itertools import combinations
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import log_loss
import matplotlib.pyplot as plt
import matplotlib.colors
import warnings
warnings.filterwarnings("ignore")
np.random.seed(42)

# Load the digits dataset with features (X_digits) and labels (y_digits)
X_digits, y_digits = load_digits(return_X_y=True)

# Create a boolean mask to filter out only the samples where the label is 2 or 6
filter_mask = np.isin(y_digits, [2, 6])

# Apply the filter mask to the features and labels to keep only the selected digits
X_digits = X_digits[filter_mask]
y_digits = y_digits[filter_mask]

# Split the filtered dataset into training and testing sets with 10% of data reserved for testing
X_train, X_test, y_train, y_test = train_test_split(
X_digits, y_digits, test_size=0.1, random_state=42
)

# Normalize the pixel values in the training and testing data
# Convert each image from a 1D array to an 8x8 2D array, normalize pixel values, and scale them
X_train = np.array([thing.reshape([8, 8]) / 16 * 2 * np.pi for thing in X_train])
X_test = np.array([thing.reshape([8, 8]) / 16 * 2 * np.pi for thing in X_test])

# Adjust the labels to be centered around 0 and scaled to be in the range -1 to 1
# The original labels (2 and 6) are mapped to -1 and 1 respectively
y_train = (y_train - 4) / 2
y_test = (y_test - 4) / 2

def feature_map(features):
# Apply Hadamard gates to all qubits to create an equal superposition state
for i in range(len(features[0])):
qml.Hadamard(i)

# Apply angle embeddings based on the feature values
for i in range(len(features)):
# For odd-indexed features, use Z-rotation in the angle embedding
if i % 2:
qml.AngleEmbedding(features=features[i], wires=range(8), rotation="Z")
# For even-indexed features, use X-rotation in the angle embedding
else:
qml.AngleEmbedding(features=features[i], wires=range(8), rotation="X")

# Define the ansatz (quantum circuit ansatz) for parameterized quantum operations
def ansatz(params):
# Apply RY rotations with the first set of parameters
for i in range(8):
qml.RY(params[i], wires=i)

# Apply CNOT gates with adjacent qubits (cyclically connected) to create entanglement
for i in range(8):
qml.CNOT(wires=[(i - 1) % 8, (i) % 8])

# Apply RY rotations with the second set of parameters
for i in range(8):
qml.RY(params[i + 8], wires=i)

# Apply CNOT gates with qubits in reverse order (cyclically connected)
# to create additional entanglement
for i in range(8):
qml.CNOT(wires=[(8 - 2 - i) % 8, (8 - i - 1) % 8])

dev = qml.device("default.qubit", wires=8)

@qml.qnode(dev)
def circuit(params, features):
feature_map(features)
ansatz(params)
return qml.expval(qml.PauliZ(0))

def variational_classifier(weights, bias, x):
return circuit(weights, x) + bias

def square_loss(labels, predictions):
return np.mean((labels - qml.math.stack(predictions)) ** 2)

def accuracy(labels, predictions):
acc = sum([np.sign(l) == np.sign(p) for l, p in zip(labels, predictions)])
acc = acc / len(labels)
return acc

def cost(params, X, Y):
predictions = [variational_classifier(params["weights"], params["bias"], x) for x in X]
return square_loss(Y, predictions)

def acc(params, X, Y):
predictions = [variational_classifier(params["weights"], params["bias"], x) for x in X]
return accuracy(Y, predictions)

np.random.seed(0)
weights = 0.01 * np.random.randn(16)
bias = jnp.array(0.0)
params = {"weights": weights, "bias":  bias}
opt = optax.adam(0.05)
batch_size = 7
num_batch = X_train.shape[0] // batch_size
opt_state = opt.init(params)
X_batched = X_train.reshape([-1, batch_size, 8, 8])
y_batched = y_train.reshape([-1, batch_size])

@jax.jit
def update_step_jit(i, args):
params, opt_state, data, targets, batch_no = args
_data = data[batch_no % num_batch]
_targets = targets[batch_no % num_batch]
_, grads = jax.value_and_grad(cost)(params, _data, _targets)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return (params, opt_state, data, targets, batch_no + 1)

@jax.jit
def optimization_jit(params, data, targets):
opt_state = opt.init(params)
args = (params, opt_state, data, targets, 0)
(params, opt_state, _, _, _) = jax.lax.fori_loop(0, 200, update_step_jit, args)
return params

params = optimization_jit(params, X_batched, y_batched, print_training = True)
Хотя это и неэффективно, я попытался включить больше вычислений в функцию print_fn() следующим образом:

Код: Выделить всё

@jax.jit
def update_step_jit(i, args):
params, opt_state, data, targets, batch_no, print_training = args
_data = data[batch_no % num_batch]
_targets = targets[batch_no % num_batch]
loss_val, grads = jax.value_and_grad(cost)(params, _data, _targets)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)

def print_fn():
jax.debug.print("Step: {i}, Train Loss: {loss_val}", i=i, loss_val=loss_val)

# Calculate accuracy and loss for training and test sets
train_accuracy = acc(params, X_train_ndarray, y_train_ndarray)
test_predictions = [variational_classifier(params["weights"], params["bias"], x) for x in X_test_ndarray]
test_loss = square_loss(y_test_ndarray, test_predictions)
test_accuracy = accuracy(y_test_ndarray, test_predictions)

jax.debug.print("Step: {i}, Train Accuracy {train_accuracy}", i=i, train_accuracy = train_accuracy)
jax.debug.print("Step: {i}, Test Accuracy {test_accuracy}", i=i, test_accuracy = test_accuracy)
jax.debug.print("Step: {i}, Test Loss {test_loss}", i=i, test_loss = test_loss)

# if print_training=True, print the loss every 5 steps
jax.lax.cond((jnp.mod(i, 1) == 0) & print_training, print_fn, lambda: None)
return (params, opt_state, data, targets, batch_no + 1, print_training)

@jax.jit
def optimization_jit(params, data, targets, print_training = False):
opt_state = opt.init(params)
args = (params, opt_state, data, targets, 0, print_training)
(params, opt_state, _, _, _, _) = jax.lax.fori_loop(0, 30, update_step_jit, args)
return params

params = optimization_jit(params, X_batched, y_batched, print_training = True)
Выдает странную ошибку типа:

Код: Выделить всё

TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function print_fn at ...\AppData\Local\Temp\ipykernel_37368\1395230403.py:10 for cond. This value became a tracer due to JAX operations on these lines:

operation a:i32[] b:i32[] = pjit[
name=divmod
jaxpr={ lambda ; c:i32[] d:i32[]. let
e:i32[] = pjit[
name=floor_divide
jaxpr={ lambda ; f:i32[] g:i32[]. let
h:i32[] = div f g
i:i32[] = sign f
j:i32[] = sign g
k:bool[] = ne i j
l:i32[] = rem f g
m:bool[] = ne l 0
n:bool[] = convert_element_type[new_dtype=bool weak_type=False] k
o:bool[] = convert_element_type[new_dtype=bool weak_type=False] m
p:bool[] = and n o
q:i32[] = sub h 1
r:i32[] = pjit[
name=_where
jaxpr={ lambda ; s:bool[] t:i32[] u:i32[]. let
v:i32[] = select_n s u t
in (v,) }
] p q h
in (r,) }
] c d
w:i32[] = pjit[
name=remainder
jaxpr={ lambda ; x:i32[] y:i32[]. let
z:bool[] = eq y 0
ba:i32[] = pjit[
name=_where
jaxpr={ lambda ; bb:bool[] bc:i32[] bd:i32[].  let
be:i32[] = select_n bb bd bc
in (be,) }
] z 1 y
bf:i32[] = rem x ba
bg:bool[] = ne bf 0
bh:bool[] = lt bf 0
bi:bool[] = lt ba 0
bj:bool[] = ne bh bi
bk:bool[] = and bj bg
bl:i32[] = add bf ba
bm:i32[] = select_n bk bf bl
in (bm,) }
] c d
in (e, w) }
] bn bo
from line ...\AppData\Local\Temp\ipykernel_37368\1395230403.py:12:27 (update_step_jit..print_fn)
Мое время обучения резко увеличится, если я не буду использовать JAX.
Таким образом, с JAX я могу получить производительность на тестовом наборе только в конце обучения , и я никак не могу получить тестовые потери в середине обучения или есть обходной путь?
Для минимально воспроизводимого примера вы можете попробовать запустить приведенный пример в этом демонстрационном коде.

Подробнее здесь: https://stackoverflow.com/questions/791 ... l-with-jax
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»