/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.importance; import org.encog.EncogError; import org.encog.mathutil.error.ErrorCalculation; import org.encog.mathutil.randomize.generate.GenerateRandom; import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom; import org.encog.ml.MLContext; import org.encog.ml.MLRegression; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.ml.data.basic.BasicMLData; import org.encog.ml.data.basic.BasicMLDataPair; import org.encog.ml.data.basic.BasicMLDataSet; import org.encog.util.EngineArray; import org.encog.util.simple.EncogUtility; import java.util.ArrayList; import java.util.List; import java.util.Random; /** * perturbation feature encoding can be used to determine the importance of features for any type of regression or * classification model, with any compatible dataset. This method works by evaluating the performance of the model * when each of the input's corrisponding data is scrambled. Features that are more important will result in worse * errors when their data are scrambled. * * Source: * Breiman, L. (2001). Random forests. Machine learning, 45(1), 5-32. */ public class PerturbationFeatureImportanceCalc extends AbstractFeatureImportance { /** * Random number generator. */ private GenerateRandom rnd = new MersenneTwisterGenerateRandom(); private double[] shuffleColumn; /** * {@inheritDoc} */ @Override public void performRanking() { throw new EncogError("This algorithm requires a dataset to measure performance against, please call performRanking with a dataset."); } private double calculateRegressionError(MLDataSet dataset, int perturbFeature) { // init as needed final ErrorCalculation errorCalculation = new ErrorCalculation(); if( getModel() instanceof MLContext) ((MLContext)getModel()).clearContext(); // copy the perturb column for(int i=0;i<dataset.size();i++) { this.shuffleColumn[i] = dataset.get(i).getInput().getData(perturbFeature); } // evaluate MLData featureVector = new BasicMLData(dataset.getInputSize()); try { int n = dataset.size(); for(int i=0;i<n;i++) { // Get training element MLDataPair pair = dataset.get(i); EngineArray.arrayCopy(pair.getInput().getData(),featureVector.getData()); // Shuffle if( i!=(n-1)) { int j = rnd.nextInt(dataset.size() - i); double t = this.shuffleColumn[i]; this.shuffleColumn[i] = this.shuffleColumn[j]; this.shuffleColumn[j] = t; featureVector.setData(perturbFeature, this.shuffleColumn[i]); } // Evaluate final MLData actual = getModel().compute(featureVector); errorCalculation.updateError(actual.getData(), pair.getIdeal() .getData(), pair.getSignificance()); } } catch(EncogError e) { return Double.NaN; } return errorCalculation.calculate(); } /** * {@inheritDoc} */ @Override public void performRanking(MLDataSet theDataset) { this.shuffleColumn = new double[theDataset.size()]; double max = 0; for(int i=0;i<getModel().getInputCount();i++) { FeatureRank fr = getFeatures().get(i); //MLDataSet p = generatePermutation(theDataset,i); double e = calculateRegressionError(theDataset,i); fr.setTotalWeight(e); max = Math.max(max,e); } for(FeatureRank fr:getFeatures()) { fr.setImportancePercent(fr.getTotalWeight()/max); } } /** * @return The random number generator. */ public GenerateRandom getRnd() { return rnd; } /** * Set the random number generator. * @param rnd The random number generator. */ public void setRnd(GenerateRandom rnd) { this.rnd = rnd; } }