package com.cse10.classifier; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer; import weka.core.Instance; import weka.core.Instances; import weka.core.SerializationHelper; import org.apache.log4j.Logger; import java.util.Random; /** * super class for ensemble classifiers * Created by Chamath on 2/3/2015. */ public abstract class EnsembleClassifierHandler extends ClassifierHandler { protected RandomizableIteratedSingleClassifierEnhancer randomizableIteratedSingleClassifierEnhancer; private Logger log; public EnsembleClassifierHandler() { } /** * @param numOfIterations * @param classifier */ public void configure(int numOfIterations, Classifier classifier) { randomizableIteratedSingleClassifierEnhancer.setNumIterations(numOfIterations); randomizableIteratedSingleClassifierEnhancer.setClassifier(classifier); } /** * @param filteredTrainingData * @param numOfFolds * @return */ public Evaluation crossValidateClassifier(Instances filteredTrainingData, int numOfFolds) { //perform cross validation Evaluation evaluation = null; try { evaluation = new Evaluation(filteredTrainingData); evaluation.crossValidateModel(randomizableIteratedSingleClassifierEnhancer, filteredTrainingData, numOfFolds, new Random(1)); log.info(evaluation.toSummaryString()); log.info(evaluation.weightedAreaUnderROC()); double[][] confusionMatrix = evaluation.confusionMatrix(); for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { log.info(confusionMatrix[i][j] + " "); } log.info("\n"); } log.info("accuracy for crime class= " + (confusionMatrix[0][0] / (confusionMatrix[0][1] + confusionMatrix[0][0])) * 100 + "%"); log.info("accuracy for other class= " + (confusionMatrix[1][1] / (confusionMatrix[1][1] + confusionMatrix[1][0])) * 100 + "%"); } catch (Exception e) { e.printStackTrace(); } return evaluation; } /** * build ensemble classifier with given training data and save it * * @param filteredTrainingData * @param isSaving * @return */ public void buildEnsemble(Instances filteredTrainingData, boolean isSaving) { try { randomizableIteratedSingleClassifierEnhancer.buildClassifier(filteredTrainingData); //save classifier if (isSaving) { SerializationHelper.write("Classifier\\src\\main\\resources\\models\\adaBoost.model", randomizableIteratedSingleClassifierEnhancer); } } catch (Exception e) { e.printStackTrace(); } } /** * classify the given news article * * @param filteredTestInstance * @return double */ public double classifyInstance(Instance filteredTestInstance) { double result = -1.0; try { result = randomizableIteratedSingleClassifierEnhancer.classifyInstance(filteredTestInstance); } catch (Exception e) { e.printStackTrace(); } return result; } public RandomizableIteratedSingleClassifierEnhancer getModel() { return randomizableIteratedSingleClassifierEnhancer; } }