/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.importance; import org.encog.EncogError; import org.encog.ml.data.MLDataSet; import org.encog.neural.networks.BasicNetwork; import java.util.ArrayList; import java.util.Comparator; import java.util.List; /** * A feature ranking algorithm based on the weights of a neural network. This algorithm can only be used for neural * networks, and it cannot calculate importance relative to a new dataset. * * Sources: * Garson, D. G. (1991). Interpreting neural network connection weights. * Goh, A. (1995). Back-propagation neural networks for modeling complex systems. Artificial Intelligence in * Engineering, 9(3), 143-151. * */ public class NeuralFeatureImportanceCalc extends AbstractFeatureImportance { /** * {@inheritDoc} */ @Override public void performRanking() { // Reset rankings for (FeatureRank rank : getFeatures()) { rank.setImportancePercent(0); rank.setTotalWeight(0); } if( ! (getModel() instanceof BasicNetwork) ) { throw new EncogError("This algorithm only works for classes of type BasicNetwork"); } BasicNetwork network = (BasicNetwork)getModel(); // Sum weights for each input neuron for (int inputNueron = 0; inputNueron < network.getInputCount(); inputNueron++) { FeatureRank ranking = getFeatures().get(inputNueron); for (int nextNeuron = 0; nextNeuron < network.getLayerNeuronCount(1); nextNeuron++) { double i_h = network.getWeight(0, inputNueron, nextNeuron); double h_o = network.getWeight(1, nextNeuron, 0); ranking.addWeight(i_h * h_o); } } // sum total weight to input neurons. double max = 0; for (FeatureRank rank : getFeatures() ) { max = Math.max(max, Math.abs(rank.getTotalWeight())); } // calculate each feature's importance percent for (FeatureRank rank : getFeatures()) { rank.setImportancePercent(Math.abs(rank.getTotalWeight()) / max); } } /** * {@inheritDoc} */ @Override public void performRanking(MLDataSet theDataset) { performRanking(); } }