/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.networks.BasicNetwork;
import org.encog.util.concurrency.EngineTask;
/**
* Mean standard error worker, for parallel compute sub-error then summing
*/
public class MSEWorker implements EngineTask {
private final BasicNetwork network;
private final MLDataSet dataSet;
private final int low;
private final int high;
private final MLDataPair pair;
private double totalError;
public MSEWorker(BasicNetwork network,
MLDataSet dataSet,
int low, int high) {
this.network = network;
this.dataSet = dataSet;
this.low = low;
this.high = high;
this.pair = BasicMLDataPair.createPair(network.getInputCount(), network.getOutputCount());
this.totalError = 0.0;
}
public void run() {
for (int i = this.low; i <= this.high; i++) {
this.dataSet.getRecord(i, pair);
MLData result = this.network.compute(this.pair.getInput());
double tmp = result.getData()[0] - this.pair.getIdeal().getData()[0];
double mse = tmp * tmp;
this.totalError += mse;
}
}
public double getTotalError() {
return this.totalError;
}
}