/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.experiment.ml.regression.nnet;
import rapaio.core.RandomSource;
import rapaio.data.Frame;
import rapaio.data.Var;
import rapaio.data.VarType;
import rapaio.ml.common.Capabilities;
import rapaio.ml.regression.AbstractRegression;
import rapaio.ml.regression.RFit;
import rapaio.ml.regression.Regression;
import java.util.Arrays;
import static rapaio.sys.WS.formatFlex;
/**
* User: Aurelian Tutuianu <padreati@yahoo.com>
*/
@Deprecated
public class MultiLayerPerceptronRegression extends AbstractRegression {
private final int[] layerSizes;
private final NetNode[][] net;
int runs = 0;
private TFunction function = TFunction.SIGMOID;
private double learningRate = 1.0;
public MultiLayerPerceptronRegression(int... layerSizes) {
this.layerSizes = layerSizes;
if (layerSizes.length < 2) {
throw new IllegalArgumentException("neural net must have at least 2 layers (including input layer)");
}
// build design
net = new NetNode[layerSizes.length][];
for (int i = 0; i < layerSizes.length; i++) {
int add = (i != net.length - 1) ? 1 : 0;
net[i] = new NetNode[layerSizes[i] + add];
for (int j = 0; j < net[i].length; j++) {
net[i][j] = new NetNode();
}
}
// wire-up nodes
for (int i = 0; i < net.length; i++) {
if (i == 0) {
for (int j = 0; j < net[i].length; j++) {
net[i][j].setInputs(null);
if (j == 0) {
net[i][j].value = 1.;
}
}
continue;
}
if (i == net.length - 1) {
for (int j = 0; j < net[i].length; j++) {
net[i][j].setInputs(net[i - 1]);
}
continue;
}
for (int j = 0; j < net[i].length; j++) {
if (j == 0) {
net[i][j].setInputs(null);
net[i][j].value = 1.;
continue;
}
net[i][j].setInputs(net[i - 1]);
}
}
}
@Override
public Regression newInstance() {
return new MultiLayerPerceptronRegression(layerSizes);
}
@Override
public String name() {
return "MultiLayerPerceptronRegression";
}
@Override
public String fullName() {
StringBuilder sb = new StringBuilder();
sb.append(name()).append("{");
sb.append("function=").append(function.name()).append(", ");
sb.append("learningRate=").append(formatFlex(learningRate)).append(", ");
sb.append("layerSizes=").append(Arrays.deepToString(Arrays.stream(layerSizes).mapToObj(i -> i).toArray()));
sb.append("}");
return sb.toString();
}
@Override
public Capabilities capabilities() {
return new Capabilities()
.withInputTypes(VarType.NUMERIC, VarType.INDEX, VarType.BINARY, VarType.ORDINAL)
.withTargetTypes(VarType.NUMERIC)
.withInputCount(1, 1_000_000)
.withTargetCount(1, 1_000_000)
.withAllowMissingInputValues(false)
.withAllowMissingTargetValues(false);
}
public MultiLayerPerceptronRegression withFunction(TFunction function) {
this.function = function;
return this;
}
public MultiLayerPerceptronRegression withLearningRate(double learningRate) {
this.learningRate = learningRate;
return this;
}
public MultiLayerPerceptronRegression withRuns(int runs) {
this.runs = runs;
return this;
}
@Override
protected boolean coreTrain(Frame df, Var weights) {
for (String varName : df.varNames()) {
if (df.var(varName).type().isNominal()) {
throw new IllegalArgumentException("perceptrons can't train nominal features");
}
}
// validate
if (this.targetNames().length != net[net.length - 1].length) {
throw new IllegalArgumentException("target var names does not fit output nodes");
}
if (inputNames().length != net[0].length - 1) {
throw new IllegalArgumentException("input var names does not fit input nodes");
}
// learn network
int pos;
for (int kk = 0; kk < runs; kk++) {
pos = RandomSource.nextInt(df.rowCount());
// set inputs
for (int i = 0; i < inputNames().length; i++) {
if (df.missing(pos, inputName(i))) {
throw new RuntimeException("detected NaN in input values");
}
net[0][i + 1].value = df.value(pos, inputName(i));
}
// feed forward
for (int i = 1; i < net.length; i++) {
for (int j = 0; j < net[i].length; j++) {
if (net[i][j].inputs != null) {
double t = 0;
for (int k = 0; k < net[i][j].inputs.length; k++) {
t += net[i][j].inputs[k].value * net[i][j].weights[k];
}
net[i][j].value = function.compute(t);
}
}
}
// back propagate
for (NetNode[] layer : net) {
for (NetNode node : layer) {
node.gamma = 0;
}
}
int last = net.length - 1;
for (int i = 0; i < net[last].length; i++) {
double expected = df.value(pos, targetName(i));
double actual = net[last][i].value;
net[last][i].gamma = function.differential(actual) * (expected - actual);
}
for (int i = last - 1; i > 0; i--) {
for (int j = 0; j < net[i].length; j++) {
double sum = 0;
for (int k = 0; k < net[i + 1].length; k++) {
if (net[i + 1][k].weights == null) continue;
sum += net[i + 1][k].weights[j] * net[i + 1][k].gamma;
}
net[i][j].gamma = function.differential(net[i][j].value) * sum;
}
}
for (int i = net.length - 1; i > 0; i--) {
for (int j = 0; j < net[i].length; j++) {
if (net[i][j].weights != null) {
for (int k = 0; k < net[i][j].weights.length; k++) {
net[i][j].weights[k] +=
learningRate * net[i][j].inputs[k].value * net[i][j].gamma;
}
}
}
}
}
return true;
}
@Override
protected RFit coreFit(final Frame df, final boolean withResiduals) {
RFit pred = RFit.build(this, df, withResiduals);
for (int pos = 0; pos < df.rowCount(); pos++) {
// set inputs
for (int i = 0; i < inputNames().length; i++) {
net[0][i + 1].value = df.value(pos, inputName(i));
}
// feed forward
for (int i = 1; i < net.length; i++) {
for (int j = 0; j < net[i].length; j++) {
if (net[i][j].inputs != null) {
double t = 0;
for (int k = 0; k < net[i][j].inputs.length; k++) {
t += net[i][j].inputs[k].value * net[i][j].weights[k];
}
net[i][j].value = function.compute(t);
}
}
}
for (int i = 0; i < targetNames().length; i++) {
pred.fit(targetName(i)).setValue(pos, net[net.length - 1][i].value);
}
}
pred.buildComplete();
return pred;
}
@Override
public String summary() {
throw new IllegalArgumentException("not implemented");
}
}
@Deprecated
class NetNode {
double value = RandomSource.nextDouble() / 10.;
NetNode[] inputs;
double[] weights;
double gamma;
public void setInputs(NetNode[] inputs) {
this.inputs = inputs;
if (inputs == null) {
this.weights = null;
return;
}
this.weights = new double[inputs.length];
for (int i = 0; i < weights.length; i++) {
weights[i] = RandomSource.nextDouble() / 10.;
}
}
}