package com.spbsu.exp.multiclass.spoc;
import com.spbsu.commons.func.Action;
import com.spbsu.commons.func.Computable;
import com.spbsu.commons.io.StreamTools;
import com.spbsu.commons.math.MathTools;
import com.spbsu.commons.math.vectors.Mx;
import com.spbsu.commons.math.vectors.Vec;
import com.spbsu.commons.math.vectors.VecTools;
import com.spbsu.commons.math.vectors.impl.mx.VecBasedMx;
import com.spbsu.commons.util.ArrayTools;
import com.spbsu.commons.util.Pair;
import com.spbsu.ml.BFGrid;
import com.spbsu.commons.math.Func;
import com.spbsu.ml.GridTools;
import com.spbsu.commons.math.Trans;
import com.spbsu.ml.cli.builders.data.impl.DataBuilderClassic;
import com.spbsu.ml.data.set.VecDataSet;
import com.spbsu.ml.data.tools.MCTools;
import com.spbsu.ml.data.tools.Pool;
import com.spbsu.ml.func.Ensemble;
import com.spbsu.ml.func.FuncEnsemble;
import com.spbsu.ml.loss.L2;
import com.spbsu.ml.loss.LLLogit;
import com.spbsu.ml.loss.blockwise.BlockwiseMLLLogit;
import com.spbsu.ml.methods.GradientBoosting;
import com.spbsu.ml.methods.VecOptimization;
import com.spbsu.ml.methods.multiclass.spoc.ECOCCombo;
import com.spbsu.ml.methods.trees.GreedyObliviousTree;
import com.spbsu.ml.models.multiclass.MCModel;
import com.spbsu.ml.models.multiclass.MulticlassCodingMatrixModel;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Properties;
/**
* User: qdeee
* Date: 24.11.14
*/
public class RunnerECOC {
public static void main(String[] args) throws IOException {
final Properties properties = new Properties();
properties.load(new FileInputStream(args[0]));
final boolean isJsonFormat = properties.getProperty("is_json").equals("true");
final String learnPath = properties.getProperty("learn_path");
final String testPath = properties.getProperty("test_path");
final String mxPath = properties.getProperty("sim_mx_path");
final int l = Integer.valueOf(properties.getProperty("L", String.valueOf(5)));
final double lambdaC = Double.valueOf(properties.getProperty("lac", String.valueOf(5.0)));
final double lambdaR = Double.valueOf(properties.getProperty("lar", String.valueOf(2.5)));
final double lambda1 = Double.valueOf(properties.getProperty("la1", String.valueOf(3.0)));
final int iters = Integer.valueOf(properties.getProperty("iters", String.valueOf(100)));
final double step = Double.valueOf(properties.getProperty("step", String.valueOf(0.3)));
final boolean updatePrior = properties.getProperty("update_prior").equals("false");
final boolean targetBasedUpdate = Boolean.valueOf(properties.getProperty("target_based_update", "false"));
final int firstColumnForUpdate = Integer.valueOf(properties.getProperty("first_column_for_update", "5"));
final double lambdaPrior = Double.valueOf(properties.getProperty("laprior", String.valueOf(0.8)));
properties.store(System.out, "[PROPERTIES VALUES] ");
final CharSequence mxStr = StreamTools.readStream(new FileInputStream(mxPath));
final Mx S = MathTools.CONVERSION.convert(mxStr, Mx.class);
final DataBuilderClassic dataBuilder = new DataBuilderClassic();
dataBuilder.setJsonFormat(isJsonFormat);
dataBuilder.setLearnPath(learnPath);
dataBuilder.setTestPath(testPath);
final Pair<Pool, Pool> poolsPair = dataBuilder.create();
final Pool<?> learn = poolsPair.getFirst();
final Pool<?> test = poolsPair.getSecond();
final VecDataSet vecDataSet = learn.vecData();
final BFGrid grid = GridTools.medianGrid(vecDataSet, 32);
final BlockwiseMLLLogit mllLogit = learn.target(BlockwiseMLLLogit.class);
final int k = MCTools.countClasses(mllLogit.labels());
final ECOCCombo ecocComboMethod = new ECOCCombo(k, l, lambdaC, lambdaR, lambda1, S, createWeak(grid, iters, step));
final Action<MulticlassCodingMatrixModel> listener = new Action<MulticlassCodingMatrixModel>() {
@Override
public void invoke(final MulticlassCodingMatrixModel model) {
if (updatePrior && model.getCodingMatrix().columns() >= firstColumnForUpdate) {
final Mx mx = getPairwiseInteractions(model, learn, targetBasedUpdate);
VecTools.scale(S, lambdaPrior);
VecTools.scale(mx, 1 - lambdaPrior);
VecTools.append(S, mx);
}
System.out.println("L == " + model.getInternalModel().ydim());
System.out.println(MCTools.evalModel(model, learn, "[LEARN] ", true));
System.out.println(MCTools.evalModel(model, test, "[TEST] ", true));
System.out.println();
}
};
ecocComboMethod.addListener(listener);
final MulticlassCodingMatrixModel model = (MulticlassCodingMatrixModel) ecocComboMethod.fit(vecDataSet, mllLogit);
System.out.println("\n\n\n");
System.out.println(MCTools.evalModel(model, learn, "[LEARN] ", false));
System.out.println(MCTools.evalModel(model, test, "[TEST] ", false));
}
private static VecOptimization<LLLogit> createWeak(final BFGrid grid, final int iters, final double step) {
return new VecOptimization<LLLogit>() {
@Override
public Trans fit(final VecDataSet learn, final LLLogit llLogit) {
final GradientBoosting<LLLogit> boosting = new GradientBoosting<>(new GreedyObliviousTree<L2>(grid, 5), iters, step);
final Ensemble ensemble = boosting.fit(learn, llLogit);
return new FuncEnsemble(
ArrayTools.map(ensemble.models, Func.class, new Computable<Trans, Func>() {
@Override
public Func compute(final Trans argument) {
return (Func) argument;
}
}),
ensemble.weights
);
}
};
}
private static Mx getPairwiseInteractions(final MCModel model, final Pool<?> pool, final boolean targetBasedUpdate) {
final BlockwiseMLLLogit mllLogit = pool.target(BlockwiseMLLLogit.class);
final VecDataSet ds = pool.vecData();
final Mx result = new VecBasedMx(mllLogit.classesCount(), mllLogit.classesCount());
final Mx features = ds.data();
final int[] counts = new int[features.rows()];
for (int i = 0; i < ds.length(); i++) {
final Vec probs = model.probs(features.row(i));
final int bestClass = targetBasedUpdate ? mllLogit.label(i) : VecTools.argmax(probs);
VecTools.append(result.row(bestClass), probs);
counts[bestClass]++;
}
for (int c = 0; c < result.rows(); c++) {
VecTools.scale(result.row(c), 1.0 / counts[c]);
}
for (int c1 = 0; c1 < result.rows(); c1++) {
for (int c2 = c1 + 1; c2 < result.columns(); c2++) {
final double val = 0.5 * (result.get(c1, c2) + result.get(c2, c1));
result.set(c1, c2, val);
result.set(c2, c1, val);
}
}
return result;
}
}