package func; import dist.*; import dist.Distribution; import dist.DiscreteDistribution; import shared.ConvergenceTrainer; import shared.DataSet; import shared.DataSetDescription; import shared.GradientErrorMeasure; import shared.Instance; import shared.SumOfSquaresError; import shared.Trainer; import func.nn.activation.DifferentiableActivationFunction; import func.nn.activation.HyperbolicTangentSigmoid; import func.nn.backprop.BackPropagationNetwork; import func.nn.backprop.BackPropagationNetworkFactory; import func.nn.backprop.BatchBackPropagationTrainer; import func.nn.backprop.RPROPUpdateRule; import func.nn.backprop.WeightUpdateRule; /** * A neural network classifier * @author Andrew Guillory gtg008g@mail.gatech.edu * @version 1.0 */ public class NeuralNetworkClassifier extends AbstractConditionalDistribution implements FunctionApproximater { /** * The transfer function */ private DifferentiableActivationFunction activationFunction; /** * The hidden node count */ private int hiddenNodeCount; /** * The training weight update rule */ private WeightUpdateRule updateRule; /** * The network */ private BackPropagationNetwork network; /** * Make a new nn classifier * @param hiddenNodeCount the hidden node count * @param activationFunction the activation function * @param updateRule the update rule */ public NeuralNetworkClassifier(int hiddenNodeCount, DifferentiableActivationFunction activationFunction, WeightUpdateRule updateRule) { this.hiddenNodeCount = hiddenNodeCount; this.activationFunction = activationFunction; this.updateRule = updateRule; } /** * Make a new classifier */ public NeuralNetworkClassifier() { this(3, new HyperbolicTangentSigmoid(), new RPROPUpdateRule()); } /** * @see func.FunctionApproximater#estimate(shared.DataSet) */ public void estimate(DataSet set) { if (set.getDescription() == null) { set.setDescription(new DataSetDescription(set)); } int[] topology; if (hiddenNodeCount != 0) { topology = new int[3]; topology[1] = hiddenNodeCount; } else { topology = new int[2]; } topology[0] = set.getDescription().getAttributeTypes().length; if (set.getDescription().getLabelDescription().getDiscreteRange() == 2) { topology[topology.length - 1] = 1; } else { topology[topology.length - 1] = set.getDescription().getLabelDescription().getDiscreteRange(); } network = (new BackPropagationNetworkFactory()) .createClassificationNetwork(topology, activationFunction); GradientErrorMeasure errorMeasure = new SumOfSquaresError(); Trainer trainer = new ConvergenceTrainer( new BatchBackPropagationTrainer( set, network, errorMeasure, updateRule)); trainer.train(); } /** * Get a distribution for the given input * @param input the input * @return the distribution */ public Distribution distributionFor(Instance input) { network.setInputValues(input.getData()); network.run(); if (network.getOutputLayer().getNodeCount() > 1) { return new DiscreteDistribution( network.getOutputValues()); } else { double[] p = new double[2]; p[1] = network.getOutputValues().get(0); p[0] = 1 - p[1]; return new DiscreteDistribution( p); } } /** * @see func.FunctionApproximater#value(shared.Instance) */ public Instance value(Instance i) { return distributionFor(i).mode(); } }