package org.seqcode.motifs;
import java.util.*;
import java.io.*;
import java.sql.*;
import java.text.DecimalFormat;
import org.seqcode.data.connections.DatabaseException;
import org.seqcode.data.connections.DatabaseFactory;
import org.seqcode.data.connections.UnknownRoleException;
import org.seqcode.data.io.parsing.FASTAStream;
import org.seqcode.data.motifdb.*;
import org.seqcode.genome.Genome;
import org.seqcode.gseutils.*;
import org.seqcode.motifs.*;
import libsvm.*;
//import org.seqcode.math.probability.Binomial;
public class SVMCombinatorial extends CombinatorialEnrichment {
private double trainfrac = .3;
private double[] trainY, testY;
private svm_node[][] trainX, testX;
private double[] matrixMaxScores;
private List<String> trainKeys, testKeys;
private svm_parameter param; // set by parse_command_line
private svm_problem prob; // set by read_problem
private svm_model model;
private int traini, testi;
private PrintWriter saveCalls;
public SVMCombinatorial() {
super();
trainKeys = new ArrayList<String>();
testKeys = new ArrayList<String>();
}
public void parseArgs(String args[]) throws Exception {
super.parseArgs(args);
trainfrac = Args.parseDouble(args,"trainfrac",trainfrac);
String calls = Args.parseString(args,"savecalls",null);
if (calls == null) {
saveCalls = new PrintWriter("/dev/null");
} else {
saveCalls = new PrintWriter(calls);
}
}
private void fillsvm(Map<String, WMHit[]> hits, double val) {
for (String s : hits.keySet()) {
WMHit[] list = hits.get(s);
if (traini < trainY.length &&
(Math.random() < trainfrac || testi >= testY.length)) {
trainY[traini] = val;
for (int j = 0; j < list.length; j++) {
trainX[traini][j] = new svm_node();
trainX[traini][j].index = j;
trainX[traini][j].value = list[j] == null ? 0 : list[j].getScore()/matrixMaxScores[j];
}
trainKeys.add(s);
traini++;
} else {
testY[testi] = val;
for (int j = 0; j < list.length; j++) {
testX[testi][j] = new svm_node();
testX[testi][j].index = j;
testX[testi][j].value = list[j] == null ? 0 : list[j].getScore()/matrixMaxScores[j];
}
testKeys.add(s);
testi++;
}
}
}
public void setupSVM() {
int totalexamples = fghits.size() + bghits.size();
int trainsize = (int)(totalexamples * trainfrac);
int testsize = totalexamples - trainsize;
trainY = new double[trainsize];
testY = new double[testsize];
trainX = new svm_node[trainsize][matrices.size()];
testX = new svm_node[testsize][matrices.size()];
traini = 0;
testi = 0;
matrixMaxScores = new double[matrices.size()];
for (int i = 0; i < matrices.size(); i++) {
matrixMaxScores[i] = matrices.get(i).getMaxScore();
}
fillsvm(fghits,1.0);
System.err.println("Used " + traini + " of " + trainY.length + " from fg dataset for training");
fillsvm(bghits,-1.0);
param = new svm_parameter();
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.LINEAR;
param.degree = 1;
param.gamma = 0; // 1/num_features
param.coef0 = 0;
param.nu = 0.5;
param.cache_size = 100;
param.C = 1;
param.eps = 1e-3;
param.p = 0.1;
param.shrinking = 1;
param.probability = 0;
param.nr_weight = 0;
param.weight_label = new int[0];
param.weight = new double[0];
prob = new svm_problem();
prob.l = trainY.length;
prob.x = trainX;
prob.y = trainY;
}
public void trainSVM() {
model = svm.svm_train(prob,param);
}
public void testSVM() {
int pospos = 0, posneg = 0, negpos = 0, negneg = 0;
for (int i = 0; i < testY.length; i++) {
double pred = svm.svm_predict(model, testX[i]);
if (testY[i] > 0) {
if (pred > 0) {
saveCalls.println(testKeys.get(i) + " ++");
pospos++;
} else {
saveCalls.println(testKeys.get(i) + " +-");
posneg++;
}
} else {
if (pred > 0) {
saveCalls.println(testKeys.get(i) + " -+");
negpos++;
} else {
saveCalls.println(testKeys.get(i) + " --");
negneg++;
}
}
}
System.out.println(String.format("++ %d, +- %d, -+ %d, -- %d", pospos, posneg, negpos, negneg));
}
public void report() {
}
public static void main(String args[]) throws Exception {
SVMCombinatorial ce = new SVMCombinatorial();
ce.parseArgs(args);
System.err.println("Masking and saving");
ce.maskSequence();
ce.saveSequences();
System.err.println("Doing weight matrix scanning");
ce.doScans();
System.err.println("Translating to SVM Format");
ce.setupSVM();
System.err.println("Training");
ce.trainSVM();
System.err.println("Testing");
ce.testSVM();
ce.report();
}
}