package com.cse10.classifier; import com.cse10.article.TrainingArticle; import com.cse10.database.DatabaseHandler; import weka.core.Attribute; import weka.core.FastVector; import weka.core.Instance; import weka.core.Instances; import org.apache.log4j.Logger; import java.util.List; /** * load complete training data set * Created by Chamath on 12/18/2014. */ public class GenericDataHandler extends DataHandler { private Logger log; public GenericDataHandler() { fileName = "generic"; log = Logger.getLogger(this.getClass()); } @Override protected String printDescription() { String description = "Data Handler -> This Data Handler will Load all of the Training Data"; log.info(description); return description; } /** * fetch training data * * @param featureVectorTransformer * @return Instances * @throws Exception */ public Instances loadTrainingData(FeatureVectorTransformer featureVectorTransformer) { printDescription(); FastVector attributeList = new FastVector(2); Attribute content = new Attribute("text", (FastVector) null); FastVector classVal = new FastVector(); classVal.addElement("crime"); classVal.addElement("other"); Attribute classValue = new Attribute("@@class@@", classVal); //add class attribute and news text attributeList.addElement(content); attributeList.addElement(classValue); Instances trainingData = new Instances("TrainingNews", attributeList, 0); if (trainingData.classIndex() == -1) { trainingData.setClassIndex(trainingData.numAttributes() - 1); } //load training data using database handler List<TrainingArticle> trainingArticles = DatabaseHandler.fetchTrainingArticles(); for (TrainingArticle trainingArticle : trainingArticles) { Instance inst = new Instance(trainingData.numAttributes()); inst.setValue(content, trainingArticle.getContent()); inst.setValue(classValue, trainingArticle.getLabel()); inst.setDataset(trainingData); trainingData.add(inst); } trainingData.setClassIndex(trainingData.numAttributes() - 1); return trainingData; } }