Чем в этом случае неправильно использовать session.run()?
Код: Выделить всё
public void Train(NDArray currentState, int action, float reward, NDArray nextState, bool done)
{
var graph = session.graph.as_default();
simMenu.LogMessage($"Training started with currentState: {currentState}, action: {action}, reward: {reward}, nextState: {nextState}, done: {done}");
try
{
currentState = np.reshape(currentState, new int[] { 1, stateSize });
currentState = currentState.astype(np.float32);
// Get the Q-values for the current state
var qValuesCurrent = session.run(output, new FeedItem(state, currentState));
float currentQValue = qValuesCurrent.ToArray()[action];
// Get the Q-values for the next state
nextState = np.reshape(nextState, new int[] { 1, stateSize });
nextState = nextState.astype(np.float32);
var qValuesNext = session.run(output, new FeedItem(state, nextState)); //here is where the error occurs
float maxFutureQValue = np.max(qValuesNext.ToArray());
// Calculate the target Q-value
float targetQValue = reward;
if (!done)
{
targetQValue += discountFactor * maxFutureQValue;
}
// Update the Q-value using the Q-learning formula
var qValuesArray = qValuesCurrent.ToArray();
qValuesArray[action] = currentQValue + learningRate * (targetQValue - currentQValue);
var updatedQValues = np.array(qValuesArray).reshape(1, actionSize);
// Train the model
session.run(optimizer, new FeedItem(state, currentState), new FeedItem(qTarget, updatedQValues));
}
catch (Exception ex)
{
simMenu.LogMessage($"An error occurred during training: {ex.Message}");
throw;
}
Сеанс и график инициализированы успешно.Обучение началось с текущего состояния: [140.0723, 30.92329, 30.92329, 51.61395, 40.36087], действие: 2, награда: 1, следующее состояние: [140.1722, 30.92329, 30.92329, 51.61395, 40.50 206], сделано: False Во время обучения произошла ошибка. :
отслеживание переменных при ошибке
Здесь указана вся функция -
Код: Выделить всё
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;
using System;
using Coursework;
public class QLearningAgent
{
private readonly int stateSize; // Size of the state
private readonly int actionSize; // Number of actions
private readonly float learningRate; // Learning rate
private readonly float discountFactor = 0.99f; // Discount factor for future rewards
private readonly float explorationProbability; // Epsilon for exploration
private readonly Session session; // TensorFlow session
private readonly Tensor state; // Placeholder for the current state
private readonly Tensor qTarget; // Placeholder for the target Q-values
private readonly SimMenu simMenu;
private Tensor output; // Output layer for Q-values
private Operation optimizer;
public QLearningAgent(int stateSize, int actionSize, float learningRate, SimMenu simMenu, float explorationProbability = 0.1f)
{
this.stateSize = stateSize;
this.actionSize = actionSize;
this.learningRate = learningRate;
this.explorationProbability = explorationProbability;
this.simMenu = simMenu;
// TensorFlow graph construction
var graph = tf.Graph().as_default();
// Initialize session
session = tf.Session(graph);
session.run(tf.global_variables_initializer()); //session.run() works fine here
this.state = tf.placeholder(tf.float32, shape: new int[] { -1, stateSize });
this.qTarget = tf.placeholder(tf.float32, shape: new int[] { -1, actionSize });
// Define the neural network structurek
var hidden = tf.keras.layers.Dense(24, activation: "relu").Apply(state); // Hidden layer
this.output = tf.keras.layers.Dense(actionSize, activation: "linear").Apply(hidden); // Output layer for Q-values
// Loss and optimizer
var loss = tf.reduce_mean(tf.square(output - qTarget));
this.optimizer = tf.train.AdamOptimizer(learningRate).minimize(loss);
simMenu.LogMessage("Session and graph initialized successfully.");
}
public int GetAction(NDArray currentState)
{
// Epsilon-greedy action selection
if (new Random().NextDouble() < explorationProbability)
{
// Explore: select a random action
return new Random().Next(actionSize);
}
else
{
// Exploit: select the action with the highest Q-value
var graph = session.graph.as_default();
// Get the Q-values for the current state
var feedDict = new FeedItem(state, currentState);
var qValues = session.run(output, feedDict); // Session.run() doesn't work here either however it goes to Train() before GetAction()
var qValuesArray = qValues.ToArray(); // Convert Tensorflow.NumPy.NDArray to a float array
return np.argmax(np.array(qValuesArray)); // Return the index of the action with the highest Q-value
}
}
public void Train(NDArray currentState, int action, float reward, NDArray nextState, bool done)
{
var graph = session.graph.as_default();
simMenu.LogMessage($"Training started with currentState: {currentState}, action: {action}, reward: {reward}, nextState: {nextState}, done: {done}");
try
{
currentState = np.reshape(currentState, new int[] { 1, stateSize });
currentState = currentState.astype(np.float32);
// Get the Q-values for the current state
var qValuesCurrent = session.run(output, new FeedItem(state, currentState));
float currentQValue = qValuesCurrent.ToArray()[action];
// Get the Q-values for the next state
nextState = np.reshape(nextState, new int[] { 1, stateSize });
nextState = nextState.astype(np.float32);
var qValuesNext = session.run(output, new FeedItem(state, nextState));
float maxFutureQValue = np.max(qValuesNext.ToArray());
// Calculate the target Q-value
float targetQValue = reward;
if (!done)
{
targetQValue += discountFactor * maxFutureQValue;
}
// Update the Q-value using the Q-learning formula
var qValuesArray = qValuesCurrent.ToArray();
qValuesArray[action] = currentQValue + learningRate * (targetQValue - currentQValue);
var updatedQValues = np.array(qValuesArray).reshape(1, actionSize);
// Train the model
session.run(optimizer, new FeedItem(state, currentState), new FeedItem(qTarget, updatedQValues));
}
catch (Exception ex)
{
simMenu.LogMessage($"An error occurred during training: {ex.Message}");
throw;
}
}
}
Мое описание проблемы, возможно, не самое лучшее и полное, поскольку это мой первый пост, спасибо, что нашли время прочитать. Если вам нужна дополнительная информация, просто спросите!
Подробнее здесь: https://stackoverflow.com/questions/791 ... -run-usage
Мобильная версия