/* * SingleClassifierDrift.java * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand * @author Manuel Baena (mbaena@lcc.uma.es) * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ package tr.gov.ulakbim.jDenetX.classifiers; import tr.gov.ulakbim.jDenetX.AbstractMOAObject; import tr.gov.ulakbim.jDenetX.core.Measurement; import tr.gov.ulakbim.jDenetX.core.SizeOf; import tr.gov.ulakbim.jDenetX.options.ClassOption; import tr.gov.ulakbim.jDenetX.options.MultiChoiceOption; import weka.core.Instance; import weka.core.Utils; /** * Class for handling concept drift datasets with a wrapper on a * classifier.<p>data * <p/> * Valid options are:<p> * <p/> * -l classname <br> * Specify the full class name of a classifier as the basis for * the concept drift classifier.<p> * * @author Manuel Baena (mbaena@lcc.uma.es) * @version 1.1 */ public class SingleClassifierDrift extends AbstractClassifier { public class DriftDetectionMethod extends AbstractMOAObject { private static final long serialVersionUID = 1L; public static final int DDM_INCONTROL_LEVEL = 0; public static final int DDM_WARNING_LEVEL = 1; public static final int DDM_OUTCONTROL_LEVEL = 2; public int computeNextVal(boolean prediction) { return 0; } //@Override public void getModelDescription(StringBuilder out, int indent) { } public void getDescription(StringBuilder sb, int indent) { // TODO Auto-generated method stub } } public class JGamaMethod extends DriftDetectionMethod { /** * */ private static final long serialVersionUID = -3518369648142099719L; private static final int JGAMAMETHOD_MINNUMINST = 30; private int m_n; private double m_p; private double m_s; private double m_psmin; private double m_pmin; private double m_smin; public JGamaMethod() { initialize(); } private void initialize() { m_n = 1; m_p = 1; m_s = 0; m_psmin = Double.MAX_VALUE; m_pmin = Double.MAX_VALUE; m_smin = Double.MAX_VALUE; } @Override public int computeNextVal(boolean prediction) { if (prediction == false) { m_p = m_p + (1.0 - m_p) / (double) m_n; } else { m_p = m_p - (m_p) / (double) m_n; } m_s = Math.sqrt(m_p * (1 - m_p) / (double) m_n); m_n++; //System.out.print(prediction + " " + m_n + " " + (m_p+m_s) + " "); if (m_n < JGAMAMETHOD_MINNUMINST) { return DDM_INCONTROL_LEVEL; } if (m_p + m_s <= m_psmin) { m_pmin = m_p; m_smin = m_s; m_psmin = m_p + m_s; } if (m_n > JGAMAMETHOD_MINNUMINST && m_p + m_s > m_pmin + 3 * m_smin) { initialize(); return DDM_OUTCONTROL_LEVEL; } else if (m_p + m_s > m_pmin + 2 * m_smin) { return DDM_WARNING_LEVEL; } else { return DDM_INCONTROL_LEVEL; } } } public class EDDM extends DriftDetectionMethod { /** * */ private static final long serialVersionUID = 140980267062162000L; private static final double FDDM_OUTCONTROL = 0.9; private static final double FDDM_WARNING = 0.95; private static final double FDDM_MINNUMINSTANCES = 30; private double m_numErrors; private int m_minNumErrors = 30; private int m_n; private int m_d; private int m_lastd; private double m_mean; private double m_stdTemp; private double m_m2smax; private int m_lastLevel; public EDDM() { initialize(); } private void initialize() { m_n = 1; m_numErrors = 0; m_d = 0; m_lastd = 0; m_mean = 0.0; m_stdTemp = 0.0; m_m2smax = 0.0; m_lastLevel = DDM_INCONTROL_LEVEL; } @Override public int computeNextVal(boolean prediction) { //System.out.print(prediction + " " + m_n + " " + probability + " "); m_n++; if (prediction == false) { m_numErrors += 1; m_lastd = m_d; m_d = m_n - 1; int distance = m_d - m_lastd; double oldmean = m_mean; m_mean = m_mean + ((double) distance - m_mean) / m_numErrors; m_stdTemp = m_stdTemp + (distance - m_mean) * (distance - oldmean); double std = Math.sqrt(m_stdTemp / m_numErrors); double m2s = m_mean + 2 * std; //System.out.print(m_numErrors + " " + m_mean + " " + std + " " + m2s + " " + m_m2smax + " "); if (m2s > m_m2smax) { if (m_n > FDDM_MINNUMINSTANCES) { m_m2smax = m2s; } m_lastLevel = DDM_INCONTROL_LEVEL; //System.out.print(1 + " "); } else { double p = m2s / m_m2smax; //System.out.print(p + " "); if (m_n > FDDM_MINNUMINSTANCES && m_numErrors > m_minNumErrors && p < FDDM_OUTCONTROL) { initialize(); return DDM_OUTCONTROL_LEVEL; } else if (m_n > FDDM_MINNUMINSTANCES && m_numErrors > m_minNumErrors && p < FDDM_WARNING) { m_lastLevel = DDM_WARNING_LEVEL; return DDM_WARNING_LEVEL; } else { m_lastLevel = DDM_INCONTROL_LEVEL; return DDM_INCONTROL_LEVEL; } } } else { //System.out.print(m_numErrors + " " + m_mean + " " + Math.sqrt(m_stdTemp/m_numErrors) + " " + (m_mean + 2*Math.sqrt(m_stdTemp/m_numErrors)) + " " + m_m2smax + " "); //System.out.print(((m_mean + 2*Math.sqrt(m_stdTemp/m_numErrors))/m_m2smax) + " "); } return m_lastLevel; } } private static final long serialVersionUID = 1L; public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', "Classifier to train.", Classifier.class, "NaiveBayes"); public MultiChoiceOption driftDetectionMethodOption = new MultiChoiceOption( "driftDetectionMethod", 'd', "Drift detection method to use.", new String[]{ "DDM", "EDDM"}, new String[]{ "DDM: Joao Gama Drift Detection Method", "EDDM: Early Drift Detection Method"}, 0); public Classifier classifier; protected Classifier newclassifier; protected DriftDetectionMethod driftDetectionMethod; protected boolean newClassifierReset; @Override public int measureByteSize() { int size = (int) SizeOf.sizeOf(this); size += classifier.measureByteSize(); size += newclassifier.measureByteSize(); return size; } @Override public void resetLearningImpl() { this.classifier = (Classifier) getPreparedClassOption(this.baseLearnerOption); this.newclassifier = (Classifier) getPreparedClassOption(this.baseLearnerOption); this.classifier.resetLearning(); this.newclassifier.resetLearning(); this.driftDetectionMethod = newDriftDetectionMethod(); newClassifierReset = false; } @Override public void trainOnInstanceImpl(Instance inst) { int trueClass = (int) inst.classValue(); boolean prediction; if (Utils.maxIndex(this.classifier.getVotesForInstance(inst)) == trueClass) { prediction = true; } else { prediction = false; } switch (this.driftDetectionMethod.computeNextVal(prediction)) { case DriftDetectionMethod.DDM_WARNING_LEVEL: //System.out.println("1 0 W"); if (newClassifierReset == true) { this.newclassifier.resetLearning(); newClassifierReset = false; } this.newclassifier.trainOnInstance(inst); break; case DriftDetectionMethod.DDM_OUTCONTROL_LEVEL: //System.out.println("0 1 O"); this.classifier = null; this.classifier = this.newclassifier; if (this.classifier instanceof WEKAClassifier) { ((WEKAClassifier) this.classifier).buildClassifier(); } this.newclassifier = (Classifier) getPreparedClassOption(this.baseLearnerOption); this.newclassifier.resetLearning(); break; case DriftDetectionMethod.DDM_INCONTROL_LEVEL: //System.out.println("0 0 I"); newClassifierReset = true; break; default: //System.out.println("ERROR!"); } this.classifier.trainOnInstance(inst); } public double[] getVotesForInstance(Instance inst) { return this.classifier.getVotesForInstance(inst); } public boolean isRandomizable() { return true; } @Override public void getModelDescription(StringBuilder out, int indent) { ((AbstractClassifier) this.classifier).getModelDescription(out, indent); } @Override protected Measurement[] getModelMeasurementsImpl() { return ((AbstractClassifier) this.classifier).getModelMeasurementsImpl(); } protected DriftDetectionMethod newDriftDetectionMethod() { switch (this.driftDetectionMethodOption.getChosenIndex()) { case 0: return new JGamaMethod(); case 1: return new EDDM(); default: break; } return new DriftDetectionMethod(); } }