/*********************************************************************** This file is part of KEEL-software, the Data Mining tool for regression, classification, clustering, pattern mining and so on. Copyright (C) 2004-2010 F. Herrera (herrera@decsai.ugr.es) L. S�nchez (luciano@uniovi.es) J. Alcal�-Fdez (jalcala@decsai.ugr.es) S. Garc�a (sglopez@ujaen.es) A. Fern�ndez (alberto.fernandez@ujaen.es) J. Luengo (julianlm@decsai.ugr.es) This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see http://www.gnu.org/licenses/ **********************************************************************/ package keel.Algorithms.Instance_Generation.LVQ; import keel.Algorithms.Instance_Generation.Basic.PrototypeSet; import keel.Algorithms.Instance_Generation.Basic.Prototype; import keel.Algorithms.Instance_Generation.Basic.PrototypeGenerationAlgorithm; import keel.Algorithms.Instance_Generation.*; import keel.Algorithms.Instance_Generation.utilities.*; import org.core.*; import keel.Algorithms.Instance_Generation.utilities.KNN.*; import java.util.ArrayList; import java.util.*; /** * Implements LVQTC algorithm * @author diegoj */ public class LVQTC extends LVQ1 { /** Alpha R parameter */ private double alpha_r = LVQTC.ALPHA_DEFAULT_VALUE; /** Alpha W parameter */ private double alpha_w = LVQTC.ALPHA_DEFAULT_VALUE; //private double alphaReductionFactor = 0.001; /** Epoches of the algorithm */ private int epoches = 4; //private static double DEFAULT_REDUCTION_SIZE=95.0; /** Threshold below that the prototype will be removed. */ private int retentionThreshold = 3; private static ArrayList<Double> posibleClasses = null; /** Counter of times were class_i is the winner, one for prototype. */ private HashMap<Prototype, HashMap<Double,Integer> > counter = null; /** Sum of all ocurrences */ private HashMap<Prototype,Integer> sumCounter = null; private HashMap<Prototype,PrototypeSet> wrong = null; /** Constructor based on the training dataset and the parameters */ public LVQTC(PrototypeSet traDataSet, Parameters parameters) { super(traDataSet, parameters); algorithmName="LVQTC"; this.alpha_r = this.alpha_0; this.alpha_w = parameters.getNextAsDouble(); this.retentionThreshold = parameters.getNextAsInt(); this.epoches = parameters.getNextAsInt(); posibleClasses = traDataSet.getPosibleValuesOfOutput(); //int numClasses = classes.size(); counter = new HashMap<Prototype,HashMap<Double,Integer> >(); sumCounter = new HashMap<Prototype,Integer>(); wrong = new HashMap<Prototype,PrototypeSet>(); } /** Constructor based on the training dataset and the parameters * @param traDataSet Training data prototypes * @param it Number of iterations that will execute the algorithm. * @param percProts New size of the set (% of training data set). * @param alpha_r Alpha algorithm parameter. * @param alpha_w Alpha algorithm parameter. * @param T Retention threshold of the algorithm. */ public LVQTC(PrototypeSet traDataSet, int it, double percProts, double alpha_r, double alpha_w, int T, int epoches) { super(traDataSet, it, percProts, alpha_r); algorithmName="LVQTC"; this.alpha_r = alpha_r; this.alpha_w = alpha_w; this.retentionThreshold = T; this.epoches = epoches; posibleClasses = traDataSet.getPosibleValuesOfOutput(); //int numClasses = classes.size(); counter = new HashMap<Prototype,HashMap<Double,Integer> >(); sumCounter = new HashMap<Prototype,Integer>(); wrong = new HashMap<Prototype,PrototypeSet>(); //veremos a ver qué pollas hacemos } //Inicializa el contador para el prototypo i protected void initCounterOf(Prototype i) { counter.put(i, new HashMap<Double,Integer>()); //for each posible class set the counter = 0 for(Double d : posibleClasses) counter.get(i).put(d,0);//Todos a 0 sumCounter.put(i, -1);//Sum of all counter for i is -1 } private void reset(PrototypeSet data) { for(Prototype p : data) { initCounterOf(p); wrong.put(p, new PrototypeSet());//no idea } } private int sum(HashMap<Double,Integer> v) { ArrayList<Integer> values = new ArrayList<Integer>(v.values()); int acc = 0; for(Integer i : values) acc += i; return acc; } private int sumOfCounterOf(Prototype p) { int value = 0; Debug.force(sumCounter.containsKey(p), "ERROR en sumOfCounter"); if(sumCounter.get(p)==-1) { int _sum = sum(counter.get(p)); sumCounter.put(p, _sum); value = _sum; } else { value = sumCounter.get(p); } return value; } private Pair<Boolean,Double> maximumWrongClassCounter(Prototype p) { HashMap<Double,Integer> h = counter.get(p); ArrayList<Double> list = new ArrayList<Double>(h.keySet()); double classWrong = p.label(); int max = retentionThreshold; boolean found = false; for(Double klass : list) if(klass != p.assignedClass() && h.get(klass)>max) { classWrong = klass; max = h.get(klass); found = true; } return new Pair<Boolean,Double>(found,classWrong); } /** * Increment the counter of a prototype for a selected class * @param i Prototype which class-ocurrences-counter will be modified. It should to be nearest prototype to the training prototype (in Kohonen's notation m_c). * @param _class Class which ocurrences will be incremented. It should to be the training prototype (in Kohonen's notation x). */ private void incrementCounterOf(Prototype i, double _class) { Debug.force(counter.containsKey(i), "No contiene la clave"); int oldValue = counter.get(i).get(_class); counter.get(i).put(_class, oldValue+1); } /*void updateCounterSum() { ArrayList<Prototype> list = new ArrayList<Prototype>(counter.keySet()); for(Prototype p : list) counterSum.put(p, sum(counter.get(p)));//optimizable?? }*/ /** * Applies the LVQTC reward to prototype m * @param m Rewarded prototype (nearest to x). IT IS MODIFIED. * @param x Original prototype. */ @Override protected void reward(Prototype m, Prototype x) { int q_i = sumOfCounterOf(m); Debug.force(q_i>0,"CERAPIO en reward"); //System.out.println("ExisteR: " + counter.get(m)); //System.out.println("SumR: " + q_i); //int q_i = 1; m.set(m.add((x.sub(m)).mul(alpha_r/q_i))); } /** * Applies LVQTC penalization to prototype m * @param m Penalized prototype (nearest to x). IT IS MODIFIED. * @param x Original prototype. */ @Override protected void penalize(Prototype m, Prototype x) { int q_i = sumOfCounterOf(m); Debug.force(q_i>0,"CERAPIO en penalize"); //System.out.println("ExisteP: " + counter.get(m)); //System.out.println("SumP: " + q_i); //int q_i = 1; m.set(m.sub((x.sub(m)).mul(alpha_w/q_i))); } void updateCentroidOfWrongClass(Prototype p, Prototype newWrong) { PrototypeSet oldSet = wrong.get(p); oldSet.add(newWrong); wrong.put(p, oldSet);//updates wrong centroid } /** * Corrects the instance using a particular method * @param i is a instance of the instance set. * @param tData is the training data set. IS MODIFIED. */ @Override protected void correct(Prototype i, PrototypeSet tData) { Prototype nearest = KNN._1nn(i, tData); double i_label = i.label(); incrementCounterOf(nearest, i_label); /*if(nearest==null) { System.out.println("La correccion ha petao"); System.exit(-1); }*/ double nearest_prot_label = nearest.label(); //Incrementa el contador del training vector (de nearest) if(i_label != nearest_prot_label) { penalize(nearest,i); updateCentroidOfWrongClass(nearest,i); //wrong.put(nearest,i);//añadimos a wrong } else { reward(nearest,i); } } protected PrototypeSet neuronPruning(PrototypeSet data) { //Eliminamos las neuronas que tengan la suma de sus contadores menor que //el retentionThreshold PrototypeSet edited = new PrototypeSet(); Prototype pMC = null; int mc = 0; for(Prototype p : data) { int currentCounter = sum(counter.get(p)); //Debug.println("Counter " + currentCounter); if(currentCounter>=retentionThreshold) edited.add(p); if(mc<currentCounter) { mc = currentCounter; pMC = p; } } //System.out.println("Data tenía " + data.size()); //System.out.println("Edited tiene " + edited.size()); if(edited.size() == 0)// edited.add(pMC); // return edited; //return data; } protected PrototypeSet neuronCreation(PrototypeSet data) { PrototypeSet newPrototypes = new PrototypeSet(); for(Prototype p : data) { Pair<Boolean, Double> isWrong = maximumWrongClassCounter(p); if(isWrong.first()) { Prototype w = (wrong.get(p)).avg();//TO DO make wrong w.setLabel(isWrong.second()); newPrototypes.add(w); } } for(Prototype newP : newPrototypes) data.add(newP); return data; } protected PrototypeSet doEpoche(PrototypeSet outputDataSet) { int it=0; while(it<iterations) { Prototype instance = extract(trainingDataSet); correct(instance, outputDataSet); //Debug.println("Iteration " + it); ++it; } return outputDataSet; } /** * Execute the method and returns the output instance set * @return a instance set modified from the training instance set by a LVQ method */ @Override public PrototypeSet reduceSet() { PrototypeSet outputDataSet = initDataSet(); //for(Prototype p : outputDataSet) // initCounterOf(p); int e=0; while(e < epoches) { reset(outputDataSet); outputDataSet = doEpoche(outputDataSet); outputDataSet = neuronPruning(outputDataSet);//eliminamos algunas neuronas outputDataSet = neuronCreation(outputDataSet);//añadimos algunas neuronas // Debug.println("Epoch number " + e); ++e; //reset(outputDataSet); } //outputDataSet.applyThresholds(); return outputDataSet; //return initDataSet() } /** * General main for all the prototoype generators * Arguments: * 0: Filename with the training data set to be condensed. * 1: Filename wich will contain the test data set * 3: k Number of neighbors used in the KNN function * @param args Arguments of the main function. */ public static void main(String[] args) { Parameters.setUse("LVQTC", "<seed> <iterations per epoch> <% of prots> <alpha_r> <alpha_w> <retention threshold> <number of epoches>"); Parameters.assertBasicArgs(args); Debug.setStdDebugMode(false); PrototypeSet training = PrototypeGenerationAlgorithm.readPrototypeSet(args[0]); PrototypeSet test = PrototypeGenerationAlgorithm.readPrototypeSet(args[1]); long seed = Parameters.assertExtendedArgAsInt(args,2,"seed",0,Long.MAX_VALUE); int iter = Parameters.assertExtendedArgAsInt(args,3,"number of iterations per epoch", 1, Integer.MAX_VALUE); double pcProt = Parameters.assertExtendedArgAsDouble(args,4,"% of prototypes", 0, 100); double alphaR = Parameters.assertExtendedArgAsDouble(args,5,"alpha_r", 0, 1); double alphaW = Parameters.assertExtendedArgAsDouble(args,6,"alpha_w", 0, 1); int Q = Parameters.assertExtendedArgAsInt(args,7,"retention threshold (Q)", 1, Integer.MAX_VALUE); int epoches = Parameters.assertExtendedArgAsInt(args,8,"number of epoches of the algorithm",1,Integer.MAX_VALUE); //PrototypeSet trainingDataSet, int iterations, double alpha_0, double windowWidth, double epsilon) LVQTC.setSeed(seed); LVQTC generator = new LVQTC(training, iter, pcProt, alphaR, alphaW, Q, epoches); PrototypeSet resultingSet = generator.execute(); //resultingSet.save(args[1]); //int accuracyKNN = KNN.classficationAccuracy(resultingSet, test, k); int accuracy1NN = KNN.classficationAccuracy(resultingSet, test); generator.showResultsOfAccuracy(Parameters.getFileName(), accuracy1NN, test); //generator.showResultsOfAccuracy(accuracyKNN, accuracy1NN, k, test); } }