package ml.humaning.algorithm; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import ml.humaning.util.Point; import ml.humaning.util.Reader; import weka.core.Instances; import weka.core.converters.LibSVMLoader; import weka.filters.Filter; import weka.filters.unsupervised.attribute.NumericToNominal; public class SMO { Instances data; ArrayList <weka.classifiers.functions.SMO> smoList; public void train(String trainFile) throws Exception{ Point [] allData = Reader.readPoints(trainFile); int maxDimension = Reader.getMaxDimension(allData); for(int maskRegion = 1;maskRegion <= 4;maskRegion++){// 4 mask regions String maskTrainFile = "mask"+maskRegion+"_smo"; BufferedWriter trainFileWriter = new BufferedWriter(new FileWriter(maskTrainFile)); for(Point p : allData){ p.setMaskRegion(maskRegion); trainFileWriter.write(p.toLIBSVMString(maxDimension)+"\n"); } trainFileWriter.close(); } smoList = new ArrayList <weka.classifiers.functions.SMO>(); LibSVMLoader libsvmLoader = new LibSVMLoader(); libsvmLoader.setSource(new File(trainFile)); data = libsvmLoader.getDataSet(); NumericToNominal filter = new NumericToNominal(); filter.setInputFormat(data); data = Filter.useFilter(data, filter); for(int maskRegion = 1;maskRegion <= 4;maskRegion++){// 4 mask regions LibSVMLoader tempLibsvmLoader = new LibSVMLoader(); tempLibsvmLoader.setSource(new File("mask"+maskRegion+"_smo")); Instances tempData = tempLibsvmLoader.getDataSet(); NumericToNominal tempFilter = new NumericToNominal(); tempFilter.setInputFormat(tempData); tempData = Filter.useFilter(tempData, tempFilter); weka.classifiers.functions.SMO tempSmo = new weka.classifiers.functions.SMO(); tempSmo.buildClassifier(tempData); smoList.add(tempSmo); } } public void predict(String testFile, String outputFile) throws Exception{ BufferedReader testReader = new BufferedReader(new FileReader(testFile)); String testPointsWithMask1 = "mask1.in"; String testPointsWithMask2 = "mask2.in"; String testPointsWithMask3 = "mask3.in"; String testPointsWithMask4 = "mask4.in"; ArrayList <BufferedWriter> maskPointsWriter = new ArrayList <BufferedWriter>(); maskPointsWriter.add(new BufferedWriter(new FileWriter(testPointsWithMask1))); maskPointsWriter.add(new BufferedWriter(new FileWriter(testPointsWithMask2))); maskPointsWriter.add(new BufferedWriter(new FileWriter(testPointsWithMask3))); maskPointsWriter.add(new BufferedWriter(new FileWriter(testPointsWithMask4))); ArrayList <ArrayList <Integer> > lineMapping = new ArrayList <ArrayList <Integer> >(); lineMapping.add(new ArrayList <Integer>()); lineMapping.add(new ArrayList <Integer>()); lineMapping.add(new ArrayList <Integer>()); lineMapping.add(new ArrayList <Integer>()); int lineNumber = 0; String line = null; while((line = testReader.readLine()) != null){ Point p = new Point(line); int maskRegion = p.getEmptyRegion(); p.setMaskRegion(maskRegion); maskPointsWriter.get(maskRegion-1).write(line+"\n"); lineMapping.get(maskRegion-1).add(lineNumber); lineNumber++; } testReader.close(); for(BufferedWriter bw : maskPointsWriter){ bw.close(); } String mask1Output = "mask1.out"; String mask2Output = "mask2.out"; String mask3Output = "mask3.out"; String mask4Output = "mask4.out"; for(int maskRegion = 1;maskRegion <= 4; maskRegion++){ LibSVMLoader libsvmLoader = new LibSVMLoader(); libsvmLoader.setSource(new File("mask"+maskRegion+".in")); Instances test = libsvmLoader.getDataSet(); NumericToNominal filter = new NumericToNominal(); filter.setInputFormat(test); test = Filter.useFilter(test, filter); BufferedWriter bw = new BufferedWriter(new FileWriter("mask"+maskRegion+".out")); for (int i = 0; i < test.numInstances(); i++) { double pred = smoList.get(maskRegion-1).classifyInstance(test.instance(i)); bw.write(data.classAttribute().value((int)pred)+"\n"); } bw.close(); } merge(lineMapping, mask1Output, mask2Output, mask3Output, mask4Output, outputFile); } private void merge(ArrayList <ArrayList <Integer> > lineMapping,String mask1Input, String mask2Input, String mask3Input, String mask4Input,String outputFile) throws IOException{ BufferedWriter writer = new BufferedWriter(new FileWriter(outputFile)); BufferedReader mask1Reader = new BufferedReader(new FileReader(mask1Input)); BufferedReader mask2Reader = new BufferedReader(new FileReader(mask2Input)); BufferedReader mask3Reader = new BufferedReader(new FileReader(mask3Input)); BufferedReader mask4Reader = new BufferedReader(new FileReader(mask4Input)); int lineNumber = 0; int index1 = 0; int index2 = 0; int index3 = 0; int index4 = 0; while(index1 < lineMapping.get(0).size() || index2 < lineMapping.get(1).size() || index3 < lineMapping.get(2).size() || index4 < lineMapping.get(3).size()){ if(index1 < lineMapping.get(0).size() && lineMapping.get(0).get(index1) == lineNumber){ writer.write(mask1Reader.readLine()+"\n"); index1++; }else if(index2 < lineMapping.get(1).size() && lineMapping.get(1).get(index2) == lineNumber){ writer.write(mask2Reader.readLine()+"\n"); index2++; }else if(index3 < lineMapping.get(2).size() && lineMapping.get(2).get(index3) == lineNumber){ writer.write(mask3Reader.readLine()+"\n"); index3++; }else if(index4 < lineMapping.get(3).size() && lineMapping.get(3).get(index4) == lineNumber){ writer.write(mask4Reader.readLine()+"\n"); index4++; } lineNumber++; } writer.close(); mask1Reader.close(); mask2Reader.close(); mask3Reader.close(); mask4Reader.close(); } }