package com.spbsu.exp.multiclass.spoc;
import com.spbsu.commons.func.Computable;
import com.spbsu.commons.func.types.TypeConverter;
import com.spbsu.commons.math.io.Mx2CharSequenceConversionPack;
import com.spbsu.commons.math.vectors.Mx;
import com.spbsu.commons.util.ArrayTools;
import com.spbsu.commons.util.logging.Logger;
import com.spbsu.ml.methods.multiclass.spoc.AbstractCodingMatrixLearning;
import com.spbsu.ml.methods.multiclass.spoc.CMLHelper;
import com.spbsu.ml.methods.multiclass.spoc.impl.CodingMatrixLearning;
import org.apache.commons.cli.MissingArgumentException;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* User: qdeee
* Date: 22.05.14
*/
public class SearchAvaliableMxMath {
public static void main(String[] args) throws Exception {
if (args.length < 1) {
throw new MissingArgumentException("Enter the path to mx S");
}
final String path = args[0];
final int l = Integer.parseInt(args[1]);
final String[] stepsStr = Arrays.copyOfRange(args, 2, args.length);
final Double[] steps = ArrayTools.map(stepsStr, Double.class, new Computable<String, Double>() {
@Override
public Double compute(final String argument) {
return Double.valueOf(argument);
}
});
final Mx S = loadMxFromFile(path);
findParameters(S, l, steps);
}
public static void findParameters(final Mx S, final int l, Double[] steps) throws Exception{
final Logger logger = Logger.create(SearchAvaliableMxMath.class);
final int k = S.rows();
// final CodingMatrixLearning codingMatrixLearning = new CodingMatrixLearning(k, l, 0.5);
final int units = Runtime.getRuntime().availableProcessors() - 2;
logger.info("Units: " + units);
final ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(units, units, 30, TimeUnit.MINUTES, new LinkedBlockingDeque<Runnable>());
for (double step : steps) {
final double stepCopy = step;
for (double lambdaC = 1.0; lambdaC < 1.5 * k; lambdaC += 1.0) {
final double lambdaCCopy = lambdaC;
threadPoolExecutor.execute(new Runnable() {
@Override
public void run() {
for (double lambdaR = 0.5; lambdaR < 3.0; lambdaR += 0.5) {
for (double lambda1 = 1.0; lambda1 < 1.5 * k; lambda1 += 1.0) {
final AbstractCodingMatrixLearning cml = new CodingMatrixLearning(k, l, 0.5, lambdaCCopy, lambdaR, lambda1);
final Mx matrixB = cml.trainCodingMatrix(S);
if (CMLHelper.checkConstraints(matrixB)) {
synchronized (logger) {
logger.info(stepCopy + " " + lambdaCCopy + " " + lambdaR + " " + lambda1 + "\n" + matrixB.toString() + "\n");
}
}
}
}
logger.info("step" + stepCopy + " is finished");
}
});
}
}
threadPoolExecutor.awaitTermination(24, TimeUnit.HOURS);
}
private static Mx loadMxFromFile(final String filename) throws IOException {
final TypeConverter<CharSequence, Mx> converter = new Mx2CharSequenceConversionPack.CharSequence2MxConverter();
final BufferedReader reader = new BufferedReader(new FileReader(new File(filename)));
final StringBuilder builder = new StringBuilder();
String s;
while ((s = reader.readLine()) != null) {
builder.append(s);
builder.append("\n");
}
return converter.convert(builder.toString());
}
}