package com.spbsu.exp.multiclass.weak;
import com.spbsu.commons.math.Func;
import com.spbsu.commons.math.vectors.Vec;
import com.spbsu.commons.math.vectors.VecTools;
import com.spbsu.commons.seq.IntSeq;
import com.spbsu.commons.math.Trans;
import com.spbsu.ml.data.set.VecDataSet;
import com.spbsu.ml.data.tools.MCTools;
import com.spbsu.ml.loss.LLLogit;
import com.spbsu.ml.loss.blockwise.BlockwiseMLLLogit;
import com.spbsu.ml.methods.VecOptimization;
import com.spbsu.ml.models.MultiClassModel;
import gnu.trove.map.hash.TIntIntHashMap;
/**
* User: qdeee
* Date: 16.08.14
*/
public class CustomWeakBinClass extends VecOptimization.Stub<LLLogit> {
private final int iters;
private final double step;
public CustomWeakBinClass(final int iters, final double step) {
this.iters = iters;
this.step = step;
}
@Override
public Trans fit(final VecDataSet learn, final LLLogit targetFunc) {
final Vec binClassTarget = targetFunc.labels();
final IntSeq intBinClassTarget = VecTools.toIntSeq(binClassTarget);
final IntSeq mcTarget = MCTools.normalizeTarget(intBinClassTarget, new TIntIntHashMap());
final CustomWeakMultiClass customWeakMultiClass = new CustomWeakMultiClass(iters, step);
final MultiClassModel mcm = (MultiClassModel) customWeakMultiClass.fit(learn, new BlockwiseMLLLogit(mcTarget, learn));
return new Func.Stub() {
@Override
public double value(Vec x) {
return mcm.getInternModel().trans(x).get(0);
}
@Override
public int dim() {
return mcm.getInternModel().xdim();
}
};
}
}