Точность нейронной сети меняется при сохранении и загрузке из файла.JAVA

Программисты JAVA общаются здесь
Ответить Пред. темаСлед. тема
Anonymous
 Точность нейронной сети меняется при сохранении и загрузке из файла.

Сообщение Anonymous »

Я считаю, что проблема в конструкторе MultNnet(String path, Activation[] активациях, Activation[] производных)
и/или методе void save(String path). Когда я вычисляю точность несколько раз в одном и том же экземпляре программы, она не меняется, происходит только при сохранении и перезагрузке модели:
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.concurrent.ThreadLocalRandom;

import org.ejml.simple.SimpleMatrix;

interface Activation {
double activate(double x);
};

interface Interpretor {
double interpretor(SimpleMatrix output);
}

public class MultNnet {
SimpleMatrix[] weights;
SimpleMatrix[] biases;
Activation[] activations;
Activation[] derivatives; // the derivative functions will recieve the output of activation as their input

private static final double EPSILON = 1e-3;

public MultNnet(int inputSize, int[] neuronSizes, Activation[] activations, Activation[] derivatives) {
weights = new SimpleMatrix[neuronSizes.length];
biases = new SimpleMatrix[neuronSizes.length];
weights[0] = new SimpleMatrix(neuronSizes[0], inputSize);
biases[0] = new SimpleMatrix(neuronSizes[0], 1);
initWeights(weights[0]);
initWeights(biases[0]);
for (int i = 1; i < weights.length; i++) {
weights = new SimpleMatrix(neuronSizes, neuronSizes[i-1]);
biases = new SimpleMatrix(neuronSizes, 1);
initWeights(weights);
initWeights(biases);
}

this.activations = activations;
this.derivatives = derivatives;
}

private static void initWeights(SimpleMatrix weights) {
double limit = Math.sqrt(3.0 / weights.getNumCols());
for (int i = 0; i < weights.getNumRows(); i++)
for (int j = 0; j < weights.getNumCols(); j++)
weights.set(i, j, ThreadLocalRandom.current().nextDouble(-limit, limit));
}

public void save(String path) {
try (PrintWriter out = new PrintWriter(new FileWriter(path))) {
out.println("LayerNo:" + weights.length);
for (int i = 0; i < weights.length; i++) {
String line = weights.getNumRows() + "," + weights.getNumCols() + ":";
for (int j = 0; j < weights.getNumRows(); j++)
for (int k = 0; k < weights.getNumCols(); k++)
line += weights[i].get(j, k) + ",";
for (int j = 0; j < weights[i].getNumRows(); j++)
line += biases[i].get(j, 0) + (j < biases[i].getNumElements() - 1 ? "," : "");
out.println(line);
}
} catch (Exception e) {
e.printStackTrace();
}
}

public MultNnet(String path, Activation[] activations, Activation[] derivatives) {
this.activations = activations;
this.derivatives = derivatives;
try (BufferedReader in = new BufferedReader(new FileReader(path))) {
int layerNo = Integer.parseInt(in.readLine().split(":")[1]);
weights = new SimpleMatrix[layerNo];
biases = new SimpleMatrix[layerNo];
for (int i = 0; i < layerNo; i++) {
String[] line = in.readLine().split(":");
String[] info = line[0].split(",");
String[] entries = line[1].split(",");
weights[i] = new SimpleMatrix(Integer.parseInt(info[0]), Integer.parseInt(info[1]));
biases[i] = new SimpleMatrix(weights[i].getNumRows(), 1);
for (int j = 0; j < weights[i].getNumElements(); j++)
weights[i].set(j, Double.parseDouble(entries[j]));
for (int j = 0; j < biases[i].getNumElements(); j++)
biases[i].set(j, Double.parseDouble(entries[weights[i].getNumElements()] + j));
}
} catch (Exception e) {
e.printStackTrace();
}
}

public SimpleMatrix predict(SimpleMatrix input) {
SimpleMatrix output = input;
for (int i = 0; i < weights.length; i++) {
final int fi = i; // needed to use it in lambda expression
output = weights[fi].mult(output).plus(biases[fi]).elementOp((int j, int k, double x) -> activations[fi].activate(x));
}
return output;
}

public void train(SimpleMatrix[] inputs, SimpleMatrix[] targets, int epoches, double learningRate) {
for (int i = 0; i < epoches; i++) {
SimpleMatrix[] gradients = new SimpleMatrix[weights.length];
SimpleMatrix[] biasGradients = new SimpleMatrix[biases.length];

for (int j = 0; j < gradients.length; j++) {
gradients[j] = new SimpleMatrix(weights[j].getNumRows(), weights[j].getNumCols());
biasGradients[j] = new SimpleMatrix(biases[j].getNumRows(), 1);
}

double[] losses = new double[weights.length];

SimpleMatrix[] outputs = new SimpleMatrix[weights.length];

for (int j = 0; j < inputs.length; j++) {
SimpleMatrix input = inputs[j];

for (int k = 0; k < weights.length; k++) {
final int fk = k; // needed to use it in lambda expression
outputs[k] = weights[k].mult(input).plus(biases[k]).elementOp((int m, int n, double x) -> activations[fk].activate(x));
input = outputs[k];
}

SimpleMatrix error = outputs[outputs.length - 1].minus(targets[j]);
for (int k = weights.length - 1; k > 0; k--) {
double norm = error.normF();
losses[k] += norm * norm;
final int fk = k; // needed to use it in lambda expression
SimpleMatrix biasDerivative = error.elementMult(outputs[k].elementOp((int m, int n, double x) -> derivatives[fk].activate(x))).scale(2);
SimpleMatrix weightDerivative = biasDerivative.mult(outputs[k - 1].transpose());
biasGradients[k] = biasGradients[k].plus(biasDerivative.divide(inputs.length));
gradients[k] = gradients[k].plus(weightDerivative.divide(inputs.length));
error = weights[k].transpose().mult(biasDerivative);
}

double norm = error.normF();
losses[0] += norm * norm;
SimpleMatrix biasDerivative = error.elementMult(outputs[0].elementOp((int m, int n, double x) -> derivatives[0].activate(x))).scale(2);
SimpleMatrix weightDerivative = biasDerivative.mult(inputs[j].transpose());
biasGradients[0] = biasGradients[0].plus(biasDerivative.divide(inputs.length));
gradients[0] = gradients[0].plus(weightDerivative.divide(inputs.length));
}

for (int j = 0; j < weights.length; j++) {
double gradientNorm = gradients[j].normF();
double biasGradientNorm = biasGradients[j].normF();
double coefficient = losses[j] / inputs.length / (gradientNorm * gradientNorm + biasGradientNorm * biasGradientNorm + EPSILON);
weights[j] = weights[j].minus(gradients[j].scale(coefficient * learningRate));
biases[j] = biases[j].minus(biasGradients[j].scale(coefficient * learningRate));
}
}
}

private void shuffle(T[] v) {
for (int i = 0; i < v.length; i++) {
int randomIndex = ThreadLocalRandom.current().nextInt(v.length);
T tmp = v[i];
v[i] = v[randomIndex];
v[randomIndex] = tmp;
}
}

public void stochasticTrain(SimpleMatrix[][] inputBatches, SimpleMatrix[][] outputBatches, int epoches, double learningRate) {
Integer[] batchIndecies = new Integer[inputBatches.length];
for (int i = 0; i < batchIndecies.length; i++)
batchIndecies[i] = i;
while (epoches-- > 0) {
shuffle(batchIndecies);
for (int i = 0 ; i < batchIndecies.length; i++)
train(inputBatches[batchIndecies[i]], outputBatches[batchIndecies[i]], 1, learningRate);
}
}

public double accuracy(SimpleMatrix[] inputs, SimpleMatrix[] targets, Interpretor interpretor) {
int correct = 0;
for (int i = 0; i < inputs.length; i++) {
SimpleMatrix prediction = predict(inputs[i]);
if (interpretor.interpretor(prediction) == interpretor.interpretor(targets[i]))
correct++;
}
return correct * 100.0 / inputs.length;
}

public double accuracy(SimpleMatrix[][] inputs, SimpleMatrix[][] targets, Interpretor interpretor) {
int correct = 0;
for (int i = 0; i < inputs.length; i++)
for (int j = 0; j < inputs[i].length; j++) {
SimpleMatrix prediction = predict(inputs[i][j]);
if (interpretor.interpretor(prediction) == interpretor.interpretor(targets[i][j]))
correct++;
}
return correct * 100.0 / inputs.length / inputs[0].length;
}
}


Подробнее здесь: https://stackoverflow.com/questions/787 ... -from-file
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «JAVA»