package com.cse10.classifier; import com.cse10.article.*; import com.cse10.database.DatabaseConstants; import com.cse10.util.ArticleConverter; import weka.core.Instances; import org.apache.log4j.Logger; import java.util.*; /** * Combine all of the functionality, used by ui handlers * Created by Chamath on 12/20/2014 */ public class ClassifierConfigurator extends Observable { private DataHandler dataHandler; private Instances trainingData; private Instances filteredTrainingData; private SVMClassifierHandler svmClassifierHandler; private GridSearch gridSearch; private boolean isModelBuild; // no need if we use DataHandlerWithSampling data handler, it converts training data into feature vector private FeatureVectorTransformer featureVectorTransformer; //singleton private static ClassifierConfigurator classifierConfigurator; private Logger log; private ClassifierConfigurator() { dataHandler = new GenericDataHandler(); gridSearch = new GridSearch(); svmClassifierHandler = new SVMClassifierHandler(); featureVectorTransformer = new FeatureVectorTransformer(); isModelBuild = false; log = Logger.getLogger(this.getClass()); } /** * to get singleton instance * * @return */ public synchronized static ClassifierConfigurator getInstance() { if (classifierConfigurator == null) { classifierConfigurator = new ClassifierConfigurator(); } return classifierConfigurator; } /** * load training data from the database */ private void loadTrainingData() throws InterruptedException { //check if interrupted checkInterruption(); try { log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Start Data Loading"); trainingData = dataHandler.loadTrainingData(featureVectorTransformer); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Training Data Details"); log.info(Thread.currentThread().getName() + " Number of Articles= " + trainingData.numInstances()); int crimeCount = 0; int otherCount = 0; for (int i = 0; i < trainingData.numInstances(); i++) { if (trainingData.instance(i).classValue() == 0.0) crimeCount++; else otherCount++; } log.info(Thread.currentThread().getName() + " Number of Crime Articles= " + crimeCount); log.info(Thread.currentThread().getName() + " Number of Other Articles= " + otherCount); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> End of Data Loading"); } catch (Exception e) { e.printStackTrace(); } } /** * filter data, transform training articles to feature vectors * If data handler return filtered training data ( i.e. in DataHandlerWithSampling) then no need to * filter data. */ private void filterData() throws InterruptedException { //check if interrupted checkInterruption(); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Start Data Filtering"); featureVectorTransformer.configure(1, 1, true); featureVectorTransformer.setInputFormat(trainingData); filteredTrainingData = featureVectorTransformer.getTransformedArticles(trainingData, dataHandler.getFileName()); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> End of Data Filtering"); } /** * perform grid search to find best cost and gamma values */ private void performGridSearch() throws InterruptedException { //check if interrupted checkInterruption(); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Start Grid Search"); gridSearch.gridSearch(svmClassifierHandler.getSvm(), filteredTrainingData); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> End of Grid Search"); } /** * cross validate the svm model using * if we use different weights, we need to normalize data */ private void crossValidateModel() throws InterruptedException { //check if interrupted checkInterruption(); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Start cross validation"); // using parameters found during the grid search svmClassifierHandler.configure(8.0, 0.001953125, "10 1", true); svmClassifierHandler.crossValidateClassifier(filteredTrainingData, 10); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> End of cross validation"); } /** * buildClassifier the model using training data and save model */ private void buildModel() throws InterruptedException { //check if interrupted checkInterruption(); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Start building model"); svmClassifierHandler.buildSVM(filteredTrainingData, true); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> End of building model"); } /** * load training data, filter data and perform grid search, cross validate model,buildClassifier and save * model, this function is used by GUI */ private synchronized void buildClassifier(Class tableName) throws InterruptedException { log.info("\n--------------------------------------------------------------"); int progress = 0; //if interrupted checkInterruption(); if (!isModelBuild) { log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Building Classifier"); loadTrainingData(); progress = 20; notify(new DatabaseConstants().classToTableName.get(tableName), progress); //check if interrupted checkInterruption(); //check whether feature vector transform is required if (dataHandler.isFeatureVectorTransformerRequired()) { filterData(); } else { filteredTrainingData = trainingData; } progress = 40; notify(new DatabaseConstants().classToTableName.get(tableName), progress); //check if interrupted checkInterruption(); //only perform this if we need to change cost and gamma values //performGridSearch(); crossValidateModel(); progress = 60; notify(new DatabaseConstants().classToTableName.get(tableName), progress); //check if interrupted checkInterruption(); buildModel(); progress = 80; notify(new DatabaseConstants().classToTableName.get(tableName), progress); //check if interrupted checkInterruption(); isModelBuild = true; //check if interrupted checkInterruption(); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> End of Building Classifier"); } else { progress = 80; notify(new DatabaseConstants().classToTableName.get(tableName), progress); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Classifier is Already Existing"); } log.info("---------------------------------------------------------------------------"); } /** * classify news articles * * @param tableName */ private synchronized void classifyNewsArticles(Class tableName, Date endDate) throws InterruptedException { //check if interrupted checkInterruption(); log.info(Thread.currentThread().getName() + "------------------------------------------------------------------------"); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Start Classifying Articles"); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Start Loading Test Data"); //convert util.Date to sql.Date log.info(Thread.currentThread().getName() + " Classifier UI Handler ->" + endDate); java.sql.Date sqlEndDate = new java.sql.Date(endDate.getTime()); //get only unclassified data using weka loading log.info(Thread.currentThread().getName() + " Classifier UI Handler ->" + sqlEndDate); Instances testData = dataHandler.loadTestData(tableName, "WHERE label IS NULL and `created_date` <= '" + sqlEndDate + "'", true); //`created_date`<'2013-06-01' log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Finish Loading Test Data"); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Size of Test Data= " + testData.numInstances()); //check if interrupted checkInterruption(); if (testData.numInstances() != 0) { List<Article> testDataArticles = dataHandler.fetchArticlesWithNullLabels(tableName, endDate); HashMap<Integer, Integer> articleIds = dataHandler.getArticleIds(); Instances filteredTestData = featureVectorTransformer.getTransformedArticles(testData); List<Integer> crimeArticleIdList = new ArrayList<Integer>(); for (int instNumber = 0; instNumber < filteredTestData.numInstances(); instNumber++) { double category = svmClassifierHandler.classifyInstance(filteredTestData.instance(instNumber)); if (category == 0) { // if crime crimeArticleIdList.add(articleIds.get(instNumber)); } } //check if interrupted checkInterruption(); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> size of ID list = " + crimeArticleIdList.size()); ListIterator iter; List<Article> articles = dataHandler.fetchArticlesByIdList(tableName, crimeArticleIdList); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> size of article list= " + articles.size()); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Article Titles"); log.info(" {"); iter = articles.listIterator(); while (iter.hasNext()) { Article a = (Article) iter.next(); log.info(" " + a.getTitle()); } log.info(" }"); //check if interrupted checkInterruption(); // to prepare them as crime articles List<CrimeArticle> crimeArticles = ArticleConverter.convertToCrimeArticle(articles, tableName); iter = crimeArticles.listIterator(); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> size of crime article list= " + crimeArticles.size()); log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Crime Article Titles"); log.info(" {"); while (iter.hasNext()) { Article a = (Article) iter.next(); log.info(" " + a.getTitle()); } log.info(" }"); checkInterruption(); //transaction log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Transaction"); for (Article article : testDataArticles) { if (crimeArticleIdList.contains(article.getId())) { article.setLabel("crime"); for (CrimeArticle crimeArticle : crimeArticles) { if (article.getId() == crimeArticle.getNewspaperId()) { dataHandler.insertCrimeArticleAndUpdatePprArticle(crimeArticle, article); } } checkInterruption(); } else { article.setLabel("other"); dataHandler.updateArticle(article); checkInterruption(); } } } else { log.info(Thread.currentThread().getName() + " Classifier UI Handler -> No New Articles to Classify"); } log.info(Thread.currentThread().getName() + " Classifier UI Handler -> End of Classifying Articles"); int progress = 100; notify(new DatabaseConstants().classToTableName.get(tableName), progress); //to finish hybernate session and close database. other wise JVM will run continuously dataHandler.closeDatabase(); log.info(Thread.currentThread().getName() + "---------------------------------------------------------------------------------"); } /** * start classification process */ public void startClassification(Class tableName, Date endDate) { log.info(Thread.currentThread().getName() + "Classifier UI Handler -> Start Classification"); try { buildClassifier(tableName); classifyNewsArticles(tableName, endDate); } catch (InterruptedException e) { System.out.println("#############"); } } /** * stop classification process */ public void stopClassification() { log.info(Thread.currentThread().getName() + "Classifier UI Handler -> Stop Classification"); } /** * helper function to handle interruption * * @return */ private void checkInterruption() throws InterruptedException { if (Thread.currentThread().isInterrupted()) { log.info(Thread.currentThread().getName() + " Classifier UI Handler -> Interrupted "); dataHandler.closeDatabase(); throw new InterruptedException(); } } /** * helper function to notify observers * * @param name * @param progress */ private void notify(String name, int progress) throws InterruptedException { checkInterruption(); setChanged(); notifyObservers(name + " " + Integer.toString(progress)); } }