package edu.usc.cssl.tacit.classify.naivebayes.services; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Date; import java.util.HashMap; import java.util.List; import org.apache.commons.io.FileUtils; import org.eclipse.core.runtime.IProgressMonitor; import bsh.EvalError; import edu.usc.cssl.tacit.common.ui.views.ConsoleView; public class CrossValidator { public HashMap<Integer, String> doCross(NaiveBayesClassifier nbc, HashMap<String, List<String>> classPaths, int kValue, IProgressMonitor monitor, String outputDir, Date dateObj) throws IOException, EvalError{ HashMap<Integer, String> performance = new HashMap<Integer, String>(); int[] index = new int[classPaths.size()]; String tmpLocation = nbc.getTmpLocation(); if(!new File(tmpLocation).exists()) { new File(tmpLocation).mkdir(); } String trainDir = tmpLocation + File.separator + "Train"; if(!new File(trainDir).exists()) { new File(trainDir).mkdir(); } String testDir = tmpLocation + File.separator + "Test"; if(!new File(testDir).exists()) { new File(testDir).mkdir(); } ConsoleView.printlInConsoleln("---------- Cross Validation Starts ------------"); for (int i=1; i<=kValue; i++) { if(monitor.isCanceled()) { monitor.subTask("Cancelling.. "); return null; } int count = 0; ConsoleView.printlInConsoleln ("------ Fold "+ i +"------"); ArrayList<String> trainingDataPaths = new ArrayList<String>(); ArrayList<String> testingDataPaths = new ArrayList<String>(); for(String path : classPaths.keySet()) { List<String> selectedFiles = classPaths.get(path); int numFiles = classPaths.get(path).size(); int trainingSetSize = (int)Math.floor(0.90 * numFiles); int testingSetSize = numFiles - trainingSetSize; File[] trainFiles = new File[trainingSetSize]; File[] testFiles = new File[testingSetSize]; int currIndex = index[count]; String className = new File(path).getName(); String tempTrainDir = trainDir + File.separator + className; System.out.println("Training data dir :"+ tempTrainDir); if(new File(tempTrainDir).exists()) { nbc.purgeDirectory(new File(tempTrainDir)); } new File(tempTrainDir).mkdir(); for (int num = 0; num < trainingSetSize; num++) { trainFiles[num]= new File(selectedFiles.get(currIndex)); //new File(trainFiles[num].getAbsolutePath(), new File(tempTrainDir + File.separator + trainFiles[num].getName()).getAbsolutePath()); FileUtils.copyFileToDirectory(trainFiles[num], new File(tempTrainDir)); //Files.copy(trainFiles[num].toPath(), new File(tempTrainDir + File.separator + trainFiles[num].getName()).toPath(), new CopyOption[] { REPLACE_EXISTING }); currIndex++; if(currIndex >= numFiles) currIndex = 0; } String tempTestDir = testDir + File.separator + className; System.out.println("Testing data dir :"+ tempTestDir); if(new File(tempTestDir).exists()) { nbc.purgeDirectory(new File(tempTestDir)); } new File(tempTestDir).mkdir(); for (int num = 0; num < testingSetSize; num++) { testFiles[num] = new File(selectedFiles.get(currIndex)); //new File(testFiles[num].getAbsolutePath(), new File(tempTestDir + File.separator + testFiles[num].getName()).getAbsolutePath()); FileUtils.copyFileToDirectory(testFiles[num], new File(tempTestDir)); //Files.copy(testFiles[num].toPath(), new File(tempTestDir+ File.separator + testFiles[num].getName()).toPath(), new CopyOption[] { REPLACE_EXISTING }); currIndex++; if(currIndex >= numFiles) currIndex=0; } //set training and testing paths trainingDataPaths.add(tempTrainDir); testingDataPaths.add(tempTestDir); // Clear required globals like dfmap? index[count] = index[count] + numFiles-trainingSetSize; if (index[count] >= numFiles){ index[count] = index[count] - numFiles; } count++; } System.out.println("Training data paths .."); for(String s : trainingDataPaths) System.out.println(s); System.out.println("Testing data paths .."); for(String s : testingDataPaths) System.out.println(s); if(monitor.isCanceled()) { monitor.subTask("Cancelling.. "); return null; } // Perform classification String result = nbc.predict(trainingDataPaths, testingDataPaths, outputDir, false, false,dateObj); performance.put(i, result); monitor.worked(7); // for each trial if(monitor.isCanceled()) { monitor.subTask("Cancelling.. "); return null; } } ConsoleView.printlInConsoleln("---------- Cross Validation Finished ------------"); nbc.purgeDirectory(new File(trainDir)); nbc.purgeDirectory(new File(testDir)); return performance; } }