Проблемы при логическом индексации в JAX, получение NoncroteBreteanIndexerrorPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Проблемы при логическом индексации в JAX, получение NoncroteBreteanIndexerror

Сообщение Anonymous »

В настоящее время я пытаюсь создать CustomProblem , наследуя от класса BaseProblem в Tensorneat, которая является библиотекой на основе JAX. Пытаясь реализовать функцию оценки этого класса, я использую логическую маску, но у меня возникают проблемы с тем, чтобы она работала. Мой код приводит к jax.errors.nonconcretebooleanindexerr: массивовые логические индексы должны быть конкретными; Получил CHAPEDARRAY (BOOL [N, N]) , что, я думаю, связано с тем, что некоторые из моих массивов не имеют определенной формы. Как мне это обойти?import numpy as np

ran_int = np.random.randint(1, 5, size=(2, 2))
print(ran_int)

ran_bool = np.random.randint(0,2, size=(2,2), dtype=bool)
print(ran_bool)

a = (ran_int[ran_bool]>0).astype(int)
print(a)
< /code>
Это может дать выход, подобный этим: < /p>
[[2 2]
[3 4]]
[[ True False]
[ True True]]
[1 1 1] #Is 1D and has less elements than before boolean mask was applied!

Но в JAX такой же способ мышления приводит к ошибке nonconcreteanIndexError , которую я получил. переопределить ">#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (n, 1)

#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (n, n)
pairwise_predictions = predict - predict.T # shape (n, n)

#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
print(pairs_to_keep.shape) #this prints (n, n)

pairwise_labels = pairwise_labels[pairs_to_keep] #ERROR HAPPENS HERE
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask

pairwise_predictions = pairwise_predictions[pairs_to_keep] #WOULD HAPPEN HERE TOO IF THIS PART WAS FIRST
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape) #want this to print a 1D array that potentially has less elements than n*n depending on the boolean mask

# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels) # shape (n)

# reduce loss to a scalar
loss = jnp.mean(loss)

# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss

Я рассматривал возможность использования jnp.where для решения проблемы, но полученная пара pairwise_labels и pairwise_predictions имеет другую форму, чем я ожидаю ( а именно (n, n) ), как видно в коде ниже:
#NB! len(labels) = len(inputs) = n
def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (n, 1)

#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (n, n)
pairwise_predictions = predict - predict.T # shape (n, n)

#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold
print(pairs_to_keep.shape) #this prints (n, n)

pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape) # shape (n, n)

pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #one problem is that now I have -inf instead of discarding the element entirely
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape) # shape (n, n)

# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels) # shape (n ,n)

# reduce loss to a scalar
loss = jnp.mean(loss)

# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss

Я боюсь, что различные формы pairwise_predictions и pairwise_labels после использования Jnp.where приведут к другой потере, чем если бы я только что использовал Логическая маска, как и в NP . Существует также тот факт, что я получаю еще одну ошибку, которая произойдет позже в трубопроводе с помощью выходного значения. >. Это любопытно обходит путем изменения pairs_to_pee = jnp.abs (pairwise_labels)> self.threshold to pairs_to_peer = jnp.abs (pairwise_labels - неверно. < /p>
Ниже приведен какой-то код, который должен быть достаточно для настройки минимального примера выполнения, который похож на мою настройку: < /p>
from tensorneat import algorithm, genome, common
from tensorneat.pipeline import Pipeline
from tensorneat.genome.gene.node import DefaultNode
from tensorneat.genome.gene.conn import DefaultConn
from tensorneat.genome.operations import mutation
import jax, jax.numpy as jnp
from tensorneat.problem import BaseProblem

def binary_cross_entropy(prediction, target):
return -(target * jnp.log(prediction) + (1 - target) * jnp.log(1 - prediction))

# Define the custom Problem
class CustomProblem(BaseProblem):

jitable = True # necessary

def __init__(self, inputs, labels, threshold):
self.inputs = jnp.array(inputs) #nb! already has shape (n, 768)
self.labels = jnp.array(labels).reshape((-1,1)) #nb! has shape (n), must be transformed to have shape (n, 1)
self.threshold = threshold

def evaluate(self, state, randkey, act_func, params):
# do batch forward for all inputs (using jax.vamp).
predict = jax.vmap(act_func, in_axes=(None, None, 0))(
state, params, self.inputs
) # should be shape (len(labels), 1)

#calculating pairwise labels and predictions
pairwise_labels = self.labels - self.labels.T # shape (len(labels), len(labels))
pairwise_predictions = predict - predict.T # shape (len(inputs), len(inputs))

#finding which pairs to keep
pairs_to_keep = jnp.abs(pairwise_labels) > self.threshold #this is the thing I actually want
#pairs_to_keep = jnp.abs(pairwise_labels - pairwise_predictions) > self.threshold #weird fix to circumvent ValueError: max() iterable argument is empty when using jnp.where for pairwise_labels and pairwise_predictions
print(pairs_to_keep.shape)

pairwise_labels = pairwise_labels[pairs_to_keep] #normal boolean mask that doesnt work
#pairwise_labels = jnp.where(pairs_to_keep, pairwise_labels, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
pairwise_labels = jnp.where(pairwise_labels > 0, True, False)
print(pairwise_labels.shape)

pairwise_predictions = pairwise_predictions[pairs_to_keep] #normal boolean mask that doesnt work
#pairwise_predictions = jnp.where(pairs_to_keep, pairwise_predictions, -jnp.inf) #using jnp.where to circumvent NonConcreteBooleanIndexError, but gives different shape than I want
pairwise_predictions = jax.nn.sigmoid(pairwise_predictions)
print(pairwise_predictions.shape)

# calculate loss
loss = binary_cross_entropy(pairwise_predictions, pairwise_labels) # shape (len(labels), len(labels))

# reduce loss to a scalar
loss = jnp.mean(loss)

# return negative loss as fitness
# TensorNEAT maximizes fitness, equivalent to minimizing loss
return -loss

@property
def input_shape(self):
# the input shape that the act_func expects
return (self.inputs.shape[1],)

@property
def output_shape(self):
# the output shape that the act_func returns
return (1,)

def show(self, state, randkey, act_func, params, *args, **kwargs):
# showcase the performance of one individual
predict = jax.vmap(act_func, in_axes=(None, None, 0))(state, params, self.inputs)

loss = jnp.mean(jnp.square(predict - self.labels))

n_elements = 5
if n_elements > len(self.inputs):
n_elements = len(self.inputs)

msg = f"Looking at {n_elements} first elements of input\n"
for i in range(n_elements):
msg += f"for input i: {i}, target: {self.labels}, predict: {predict}\n"
msg += f"total loss: {loss}\n"
print(msg)

algorithm = algorithm.NEAT(
pop_size=10,
survival_threshold=0.2,
min_species_size=2,
compatibility_threshold=3.0,
species_elitism=2,
genome=genome.DefaultGenome(
num_inputs=768,
num_outputs=1,
max_nodes=769, # must at least be same as inputs and outputs
max_conns=768, # must be 768 connections for the network to be fully connected
output_transform=common.ACT.sigmoid,
mutation=mutation.DefaultMutation(
# no allowing adding or deleting nodes
node_add=0.0,
node_delete=0.0,
# set mutation rates for edges to 0.5
conn_add=0.5,
conn_delete=0.5,
),
node_gene=DefaultNode(),
conn_gene=DefaultConn(),
),
)

INPUTS = jax.random.uniform(jax.random.PRNGKey(0), (100, 768)) #the input data x
LABELS = jax.random.uniform(jax.random.PRNGKey(0), (100)) #the annotated labels y

problem = CustomProblem(INPUTS, LABELS, 0.25)

print("Setting up pipeline and running it")
print("-----------------------------------------------------------------------")
pipeline = Pipeline(
algorithm,
problem,
generation_limit=1,
fitness_target=1,
seed=42,
)

state = pipeline.setup()
# run until termination
state, best = pipeline.auto_run(state)
# show results
pipeline.show(state, best)


Подробнее здесь: https://stackoverflow.com/questions/794 ... indexerror
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение
  • Проблемы при логическом индексации в JAX, получение NonconcretebooleanIndexerror
    Anonymous » » в форуме Python
    0 Ответы
    9 Просмотры
    Последнее сообщение Anonymous
  • Проблемы при логическом индексации в JAX, получение NonconcretebooleanIndexerror
    Anonymous » » в форуме Python
    0 Ответы
    12 Просмотры
    Последнее сообщение Anonymous
  • Проблемы при логическом индексации в JAX, получение NonconcretebooleanIndexerror
    Anonymous » » в форуме Python
    0 Ответы
    10 Просмотры
    Последнее сообщение Anonymous
  • Проблемы при логическом индексации в JAX, получение NonconcretebooleanIndexerror
    Anonymous » » в форуме Python
    0 Ответы
    7 Просмотры
    Последнее сообщение Anonymous
  • Перегрузка оператора индексации индексации C ++ [] таким образом, чтобы разрешить ответы на обновления
    Anonymous » » в форуме C++
    0 Ответы
    1 Просмотры
    Последнее сообщение Anonymous

Вернуться в «Python»