package ml.humaning.algorithm;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import ml.humaning.util.Command;
import ml.humaning.util.Point;
import ml.humaning.util.Reader;
import ml.humaning.util.Svm_train;
public class SVM {
private Point [] allData;
private ArrayList <BufferedWriter> maskPointsWriter;
private ArrayList <ArrayList <Integer> > lineMapping;
private String trainFile;
private Executor pool;
public SVM(String trainFile) throws IOException{
allData = Reader.readPoints(trainFile);
this.trainFile = trainFile;
}
public void predict(int svmType, int kernelType, String testFile, String outputFile) throws IOException, InterruptedException{
BufferedReader testReader = new BufferedReader(new FileReader(testFile));
String testPointsWithMask1 = "mask1.in";
String testPointsWithMask2 = "mask2.in";
String testPointsWithMask3 = "mask3.in";
String testPointsWithMask4 = "mask4.in";
maskPointsWriter = new ArrayList <BufferedWriter>();
maskPointsWriter.add(new BufferedWriter(new FileWriter(testPointsWithMask1)));
maskPointsWriter.add(new BufferedWriter(new FileWriter(testPointsWithMask2)));
maskPointsWriter.add(new BufferedWriter(new FileWriter(testPointsWithMask3)));
maskPointsWriter.add(new BufferedWriter(new FileWriter(testPointsWithMask4)));
lineMapping = new ArrayList <ArrayList <Integer> >();
lineMapping.add(new ArrayList <Integer>());
lineMapping.add(new ArrayList <Integer>());
lineMapping.add(new ArrayList <Integer>());
lineMapping.add(new ArrayList <Integer>());
int lineNumber = 0;
String line = null;
while((line = testReader.readLine()) != null){
Point p = new Point(line);
int maskRegion = p.getEmptyRegion();
p.setMaskRegion(maskRegion);
maskPointsWriter.get(maskRegion-1).write(line+"\n");
lineMapping.get(maskRegion-1).add(lineNumber);
lineNumber++;
}
testReader.close();
for(BufferedWriter bw : maskPointsWriter){
bw.flush();
bw.close();
}
String mask1Output = "mask1.out";
String mask2Output = "mask2.out";
String mask3Output = "mask3.out";
String mask4Output = "mask4.out";
Command command = new Command();
command.call("svm-predict " + testPointsWithMask1+ " mask1"+"_svm"+svmType+"_kernel"+kernelType+".model "+ mask1Output);
command.call("svm-predict " + testPointsWithMask2+ " mask2"+"_svm"+svmType+"_kernel"+kernelType+".model "+ mask2Output);
command.call("svm-predict " + testPointsWithMask3+ " mask3"+"_svm"+svmType+"_kernel"+kernelType+".model "+ mask3Output);
command.call("svm-predict " + testPointsWithMask4+ " mask4"+"_svm"+svmType+"_kernel"+kernelType+".model "+ mask4Output);
merge(mask1Output, mask2Output, mask3Output, mask4Output, outputFile);
}
private void merge(String mask1Input, String mask2Input, String mask3Input, String mask4Input,String outputFile) throws IOException{
BufferedWriter writer = new BufferedWriter(new FileWriter(outputFile));
BufferedReader mask1Reader = new BufferedReader(new FileReader(mask1Input));
BufferedReader mask2Reader = new BufferedReader(new FileReader(mask2Input));
BufferedReader mask3Reader = new BufferedReader(new FileReader(mask3Input));
BufferedReader mask4Reader = new BufferedReader(new FileReader(mask4Input));
int lineNumber = 0;
int index1 = 0;
int index2 = 0;
int index3 = 0;
int index4 = 0;
while(index1 < lineMapping.get(0).size() || index2 < lineMapping.get(1).size()
|| index3 < lineMapping.get(2).size() || index4 < lineMapping.get(3).size()){
if(index1 < lineMapping.get(0).size() && lineMapping.get(0).get(index1) == lineNumber){
writer.write(mask1Reader.readLine()+"\n");
index1++;
}else if(index2 < lineMapping.get(1).size() && lineMapping.get(1).get(index2) == lineNumber){
writer.write(mask2Reader.readLine()+"\n");
index2++;
}else if(index3 < lineMapping.get(2).size() && lineMapping.get(2).get(index3) == lineNumber){
writer.write(mask3Reader.readLine()+"\n");
index3++;
}else if(index4 < lineMapping.get(3).size() && lineMapping.get(3).get(index4) == lineNumber){
writer.write(mask4Reader.readLine()+"\n");
index4++;
}
lineNumber++;
}
writer.flush();
writer.close();
mask1Reader.close();
mask2Reader.close();
mask3Reader.close();
mask4Reader.close();
}
private String processTrainCommand(int svmType, int kernelType, int degree, double gamma, double coef, double cost , double nu, double epsilon){
String commandString = "";
if(svmType == 0){// C-SVC
commandString +="-c ";// cost
commandString +=String.valueOf(cost)+" ";
}else if(svmType == 1){// nu-SVC
commandString +="-n ";// nu
commandString +=String.valueOf(nu)+" ";
}else if(svmType == 2){// one-class SVM
commandString +="-n ";// nu
commandString +=String.valueOf(nu)+" ";
}else if(svmType == 3){// epsilon-SVR
commandString +="-c ";// cost
commandString +=String.valueOf(cost)+" ";
commandString +="-p ";// epsilon
commandString +=String.valueOf(epsilon)+" ";
}else if(svmType == 4){// nu-SVR
commandString +="-c ";// cost
commandString +=String.valueOf(cost)+" ";
commandString +="-n ";// nu
commandString +=String.valueOf(nu)+" ";
}
if(kernelType == 0){// linear
}else if(kernelType == 1){// polynomial
commandString +="-d ";// degree
commandString +=String.valueOf(degree)+" ";
commandString +="-g ";// gamma
commandString +=String.valueOf(gamma)+" ";
commandString +="-r ";// coef0
commandString +=String.valueOf(coef)+" ";
}else if(kernelType == 2){// radial basis
commandString +="-g ";// gamma
commandString +=String.valueOf(gamma)+" ";
}else if(kernelType == 3){// sigmoid
commandString +="-g ";// gamma
commandString +=String.valueOf(gamma)+" ";
commandString +="-r ";// coef0
commandString +=String.valueOf(coef)+" ";
}
commandString += ("-s " + svmType + " -t "+ kernelType + " ");
return commandString;
}
private class Trainer implements Runnable{
private SVM svm;
private String command;
private String accuracyRecord;
public Trainer(SVM svm, String command, String ar){
this.svm = svm;
this.command = command;
this.accuracyRecord = ar;
}
@Override
public void run() {
Svm_train trainer = new Svm_train();
try {
double accuracy = trainer.run(command.split("\\s+"));
svm.updateAccuracyRecord(accuracyRecord, accuracy, command);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
public synchronized void updateAccuracyRecord(String accuracyRecordPath, double accuracy, String command) throws IOException{
File accuracyRecord = new File(accuracyRecordPath);
if(accuracyRecord.exists()) {
BufferedReader accuracyReader = new BufferedReader(new FileReader(accuracyRecord));
double recordAccuracy = Double.parseDouble(accuracyReader.readLine());
if(accuracy > recordAccuracy){
accuracyReader.close();
accuracyRecord.delete();
BufferedWriter accuracyWriter = new BufferedWriter(new FileWriter(accuracyRecord));
accuracyWriter.write(String.valueOf(accuracy)+"\n");
accuracyWriter.write(command+"\n");
accuracyWriter.close();
}
}else {
BufferedWriter accuracyWriter = new BufferedWriter(new FileWriter(accuracyRecord));
accuracyWriter.write(String.valueOf(accuracy)+"\n");
accuracyWriter.write(command+"\n");
accuracyWriter.close();
}
}
public void parallelCrossValidationSVM(int svmType, int kernelType) throws IOException, InterruptedException{
String configurationFile = "svm"+svmType+"_kernel"+kernelType;
int threadNumber = 10;
pool = Executors.newFixedThreadPool(threadNumber);
for(double c = 4.5 ;c <= 5.5; c += 0.1){
for(double gamma = 0.000165;gamma <= 0.000175;gamma += 0.000001){
String commandString = processTrainCommand(svmType, kernelType, 3, gamma, -1, c, -1, -1);
pool.execute(new Trainer(this, commandString+" -v 5 "+trainFile, configurationFile+".record"));
}
}
}
public void train(int svmType, int kernelType, int degree, double gamma, double coef, double cost , double nu, double epsilon) throws IOException, InterruptedException{
String commandString = processTrainCommand(svmType, kernelType, degree, gamma, coef, cost, nu, epsilon);
Command command = new Command();
for(int maskRegion = 1;maskRegion <= 4;maskRegion++){// 4 mask regions
String maskTrainFile = "mask"+maskRegion+"_svm"+svmType+"_kernel"+kernelType;
BufferedWriter trainFileWriter = new BufferedWriter(new FileWriter(maskTrainFile));
for(Point p : allData){
p.setMaskRegion(maskRegion);
trainFileWriter.write(p.toLIBSVMString()+"\n");
}
command.call("svm-train "+commandString+maskTrainFile);
trainFileWriter.close();
}
}
}