/* * WEKAClassifier.java * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) * @author FracPete (fracpete at waikato dot ac dot nz) * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ package tr.gov.ulakbim.jDenetX.classifiers; import tr.gov.ulakbim.jDenetX.core.Measurement; import tr.gov.ulakbim.jDenetX.core.SizeOf; import tr.gov.ulakbim.jDenetX.options.IntOption; import tr.gov.ulakbim.jDenetX.options.WEKAClassOption; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.classifiers.UpdateableClassifier; import weka.core.Instance; import weka.core.Instances; public class WEKAClassifier extends AbstractClassifier { private static final long serialVersionUID = 1L; public WEKAClassOption baseLearnerOption = new WEKAClassOption("baseLearner", 'l', "Classifier to train.", weka.classifiers.Classifier.class, "weka.classifiers.bayes.NaiveBayesUpdateable"); public IntOption widthOption = new IntOption("width", 'w', "Size of Window for training learner.", 0, 0, Integer.MAX_VALUE); public IntOption widthInitOption = new IntOption("widthInit", 'i', "Size of first Window for training learner.", 1000, 0, Integer.MAX_VALUE); public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", 'f', "How many instances between samples of the learning performance.", 0, 0, Integer.MAX_VALUE); protected Classifier classifier; protected int numberInstances; protected Instances instancesBuffer; protected boolean isClassificationEnabled; protected boolean isBufferStoring; public int measureByteSize() { int size = (int) SizeOf.sizeOf(this); //size += classifier.measureByteSize(); return size; } public void resetLearningImpl() { try { //System.out.println(baseLearnerOption.getValue()); String[] options = weka.core.Utils.splitOptions(baseLearnerOption.getValueAsCLIString()); createWekaClassifier(options); } catch (Exception e) { System.err.println("Creating a new classifier: " + e.getMessage()); } numberInstances = 0; isClassificationEnabled = false; this.isBufferStoring = true; } public void trainOnInstanceImpl(Instance inst) { try { if (numberInstances == 0) { this.instancesBuffer = new Instances(inst.dataset()); if (classifier instanceof UpdateableClassifier) { classifier.buildClassifier(instancesBuffer); this.isClassificationEnabled = true; } else { this.isBufferStoring = true; } } numberInstances++; if (classifier instanceof UpdateableClassifier) { if (numberInstances > 0) { ((UpdateableClassifier) classifier).updateClassifier(inst); } } else { if (numberInstances == widthInitOption.getValue()) { //Build first time Classifier buildClassifier(); isClassificationEnabled = true; //Continue to store instances if (sampleFrequencyOption.getValue() != 0) { isBufferStoring = true; } } if (widthOption.getValue() == 0) { //Used from SingleClassifierDrift if (isBufferStoring == true) { instancesBuffer.add(inst); } } else { //Used form WekaClassifier without using SingleClassifierDrift int numInstances = numberInstances % sampleFrequencyOption.getValue(); if (sampleFrequencyOption.getValue() == 0) { numInstances = numberInstances; } if (numInstances == 0) { //Begin to store instances isBufferStoring = true; } if (isBufferStoring == true && numInstances <= widthOption.getValue()) { //Store instances instancesBuffer.add(inst); } if (numInstances == widthOption.getValue()) { //Build Classifier buildClassifier(); isClassificationEnabled = true; this.instancesBuffer = new Instances(inst.dataset()); } } } } catch (Exception e) { System.err.println("Training: " + e.getMessage()); } } public void buildClassifier() { try { if ((classifier instanceof UpdateableClassifier) == false) { Classifier auxclassifier = AbstractClassifier.makeCopy(classifier); auxclassifier.buildClassifier(instancesBuffer); classifier = auxclassifier; isBufferStoring = false; } } catch (Exception e) { System.err.println("Building WEKA Classifier: " + e.getMessage()); } } public double[] getVotesForInstance(Instance inst) { double[] votes = new double[inst.numClasses()]; if (isClassificationEnabled == false) { for (int i = 0; i < inst.numClasses(); i++) { votes[i] = 1.0 / inst.numClasses(); } } else { try { votes = this.classifier.distributionForInstance(inst); } catch (Exception e) { System.err.println(e.getMessage()); } } return votes; } public boolean isRandomizable() { return false; } public void getModelDescription(StringBuilder out, int indent) { if (classifier != null) { out.append(classifier.toString()); } } protected Measurement[] getModelMeasurementsImpl() { Measurement[] m = new Measurement[0]; return m; } public void createWekaClassifier(String[] options) throws Exception { String classifierName = options[0]; String[] newoptions = options.clone(); newoptions[0] = ""; this.classifier = AbstractClassifier.forName(classifierName, newoptions); } @Override public void buildClassifier(Instances data) throws Exception { //To change body of implemented methods use File | Settings | File Templates. } }