package com.spbsu.exp.multiclass.spoc;
import com.spbsu.commons.func.Computable;
import com.spbsu.commons.io.StreamTools;
import com.spbsu.commons.math.Func;
import com.spbsu.commons.math.MathTools;
import com.spbsu.commons.math.Trans;
import com.spbsu.commons.math.vectors.Mx;
import com.spbsu.commons.math.vectors.Vec;
import com.spbsu.commons.math.vectors.VecTools;
import com.spbsu.commons.seq.IntSeq;
import com.spbsu.commons.util.ArrayTools;
import com.spbsu.ml.*;
import com.spbsu.ml.data.set.VecDataSet;
import com.spbsu.ml.data.tools.DataTools;
import com.spbsu.ml.data.tools.MCTools;
import com.spbsu.ml.data.tools.Pool;
import com.spbsu.ml.func.Ensemble;
import com.spbsu.ml.func.FuncEnsemble;
import com.spbsu.ml.loss.L2;
import com.spbsu.ml.loss.LLLogit;
import com.spbsu.ml.loss.SatL2;
import com.spbsu.ml.loss.blockwise.BlockwiseMLLLogit;
import com.spbsu.ml.loss.multiclass.util.ConfusionMatrix;
import com.spbsu.ml.meta.FeatureMeta;
import com.spbsu.ml.meta.impl.fake.FakeTargetMeta;
import com.spbsu.ml.meta.items.QURLItem;
import com.spbsu.ml.methods.GradientBoosting;
import com.spbsu.ml.methods.VecOptimization;
import com.spbsu.ml.methods.multiclass.spoc.SPOCMethodClassic;
import com.spbsu.ml.methods.multiclass.spoc.impl.CodingMatrixLearningGreedy;
import com.spbsu.ml.methods.trees.GreedyObliviousTree;
import com.spbsu.ml.models.multiclass.MCModel;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import org.apache.commons.cli.MissingArgumentException;
import java.io.*;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.*;
/**
* User: qdeee
* Date: 18.08.14
*/
public class SearchMCGreedyAllParams {
private final static int UNITS = Runtime.getRuntime().availableProcessors();
public static void main(String[] args) throws MissingArgumentException, IOException {
final Properties properties = new Properties();
properties.load(new FileInputStream(new File(args[0])));
final String mxPath = properties.getProperty("sim_mx_path");
final CharSequence mxStr = StreamTools.readStream(new BufferedInputStream(new FileInputStream(mxPath)));
final Mx S = MathTools.CONVERSION.convert(mxStr, Mx.class);
final String[] strBorders = properties.getProperty("borders").split(";");
final TDoubleList borders = new TDoubleArrayList();
for (String strBorder : strBorders) {
borders.add(Double.valueOf(strBorder));
}
final String learnPath = properties.getProperty("learn_path");
final Pool<QURLItem> learn = DataTools.loadFromFeaturesTxt(learnPath);
final IntSeq learnTarget = MCTools.transformRegressionToMC(learn.target(L2.class).target, borders.size(), borders);
learn.addTarget(new FakeTargetMeta(learn.vecData(), FeatureMeta.ValueType.INTS), learnTarget);
final String testPath = properties.getProperty("test_path");
final Pool<QURLItem> test = DataTools.loadFromFeaturesTxt(testPath);
final IntSeq testTarget = MCTools.transformRegressionToMC(test.target(L2.class).target, borders.size(), borders);
test.addTarget(new FakeTargetMeta(test.vecData(), FeatureMeta.ValueType.INTS), testTarget);
final String[] strBaselineScores = properties.getProperty("baseline_scores").split(";");
final TDoubleList baselineScores = new TDoubleArrayList();
for (String strBaselineScore : strBaselineScores) {
baselineScores.add(Double.valueOf(strBaselineScore));
}
BaselineComparator.init(baselineScores.toArray());
final int itersFrom = Integer.valueOf(properties.getProperty("iters_from"));
final int itersTo = Integer.valueOf(properties.getProperty("iters_to"));
final int itersDelta = Integer.valueOf(properties.getProperty("iters_delta"));
final double stepFrom = Double.valueOf(properties.getProperty("step_from"));
final double stepTo = Double.valueOf(properties.getProperty("step_to"));
final double stepDelta = Double.valueOf(properties.getProperty("step_delta"));
final PrintWriter printWriter = new PrintWriter(new FileWriter(properties.getProperty("log_filename")), true);
final BFGrid grid = GridTools.medianGrid(learn.vecData(), 32);
learn.vecData().cache().cache(Binarize.class, VecDataSet.class).binarize(grid);
// final ThreadPoolExecutor executor = new ThreadPoolExecutor(UNITS, UNITS, 120, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>());
final List<Callable<String>> tasks = new LinkedList<>();
final int k = borders.size();
for (double lambdaC = 1.0; lambdaC < 1.5 * k; lambdaC += 1.0) {
for (double lambdaR = 0.5; lambdaR < 3.0; lambdaR += 0.5) {
for (double lambda1 = 1.0; lambda1 < 1.5 * k; lambda1 += 1.0) {
for (int iters = itersFrom; iters < itersTo; iters += itersDelta) {
for (double step = stepFrom; step < stepTo; step += stepDelta) {
tasks.add(new Task(learn, test, grid, S, k, iters, step, lambdaC, lambdaR, lambda1));
}
}
}
}
}
Collections.shuffle(tasks);
final ExecutorService pool = Executors.newFixedThreadPool(UNITS);
final ExecutorCompletionService<String> completionService = new ExecutorCompletionService<>(pool);
double time = System.currentTimeMillis();
for (Callable<String> task : tasks) {
completionService.submit(task);
}
pool.shutdown();
while (!pool.isTerminated()) {
final String result;
try {
result = completionService.take().get();
printWriter.append(result);
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
System.out.println((System.currentTimeMillis() - time) / 1000);
printWriter.close();
}
private static final class Task implements Callable<String> {
private final Pool<?> learn;
private final Pool<?> test;
private final BFGrid grid;
private final Mx S;
private final int classCount;
private final int iters;
private final double step;
private final double lac;
private final double lar;
private final double la1;
private Task(final Pool<?> learn, final Pool<?> test, final BFGrid grid, final Mx s, final int classCount,
final int iters, final double step, final double lac, final double lar, final double la1) {
this.learn = learn;
this.test = test;
this.grid = grid;
this.S = s;
this.classCount = classCount;
this.iters = iters;
this.step = step;
this.lac = lac;
this.lar = lar;
this.la1 = la1;
}
@Override
public String call() throws Exception {
final VecOptimization<LLLogit> weak = new VecOptimization<LLLogit>() {
@Override
public Func fit(final VecDataSet learn, final LLLogit llLogit) {
final GradientBoosting<LLLogit> boosting = new GradientBoosting<>(new GreedyObliviousTree<L2>(grid, 5), SatL2.class, iters, step);
final Ensemble ensemble = boosting.fit(learn, llLogit);
return new FuncEnsemble(ArrayTools.map(ensemble.models, Func.class, new Computable<Trans, Func>() {
@Override
public Func compute(final Trans argument) {
return (Func)argument;
}
}), ensemble.weights);
}
};
final CodingMatrixLearningGreedy codingMatrixLearning = new CodingMatrixLearningGreedy(classCount, 5, lac, lar, la1);
final Mx mxB = codingMatrixLearning.findMatrixB(S);
final SPOCMethodClassic spocMethodClassic = new SPOCMethodClassic(mxB, weak);
final MCModel model = (MCModel) spocMethodClassic.fit(learn.vecData(), learn.target(BlockwiseMLLLogit.class));
final Vec predict = model.bestClassAll(test.vecData().data());
final ConfusionMatrix cm = new ConfusionMatrix(
test.target(BlockwiseMLLLogit.class).labels(),
VecTools.toIntSeq(predict));
final double[] scores = {cm.getMicroPrecision(), cm.getMacroPrecision(), cm.getMacroRecall(), cm.getMacroF1Measure()};
if (BaselineComparator.getInstance().isBetterThanBaseline(scores)) {
return String.format("i=%d, s=%f, lac=%f, lar=%f, la1=%f : %.6f | %.6f | %.6f | %.6f\n",
iters, step, lac, lar, la1, scores[0], scores[1], scores[2], scores[3]);
}
else
return String.format("i=%d, s=%f, lac=%f, lar=%f, la1=%f : fail\n",
iters, step, lac, lar, la1);
}
}
private static class BaselineComparator {
private static BaselineComparator instance = null;
private double[] baselineScores;
private BaselineComparator(double[] scores){
this.baselineScores = scores;
}
public static void init(final double[] baselineScores) {
instance = new BaselineComparator(baselineScores);
}
public static BaselineComparator getInstance() {
return instance;
}
public boolean isBetterThanBaseline(final double[] scores) {
return scores[0] > baselineScores[0] ||
scores[1] > baselineScores[1] ||
scores[2] > baselineScores[2] |
scores[3] > baselineScores[3];
}
}
}