package com.spbsu.ml.methods.greedyRegion.cherry; import com.spbsu.commons.func.AdditiveStatistics; import com.spbsu.ml.BFGrid; import com.spbsu.ml.data.cherry.CherryLoss; import com.spbsu.ml.data.cherry.CherryPointsHolder; import com.spbsu.ml.loss.StatBasedLoss; import gnu.trove.set.hash.TIntHashSet; import static com.spbsu.ml.methods.greedyRegion.AdditiveStatisticsExtractors.weight; import static com.spbsu.ml.methods.greedyRegion.GreedyTDWeakRegionMTA.sum; public class OutLoss3<Subset extends CherryPointsHolder, Loss extends StatBasedLoss<AdditiveStatistics>> extends CherryLoss { private Subset subset; private Loss loss; private int complexity = 1; private int minBinSize = 50; private TIntHashSet used = new TIntHashSet(); OutLoss3(Subset subset, Loss loss) { this.subset = subset; this.loss = loss; } @Override public double score(BFGrid.BFRow feature, int start, int end, AdditiveStatistics added, AdditiveStatistics out) { if (start == 0 && end == feature.size()) return Double.NEGATIVE_INFINITY; int newsize = used.contains(feature.origFIndex) ? used.size() : used.size()+1; if (newsize > 7) return Double.NEGATIVE_INFINITY; AdditiveStatistics inside = subset.inside(); AdditiveStatistics total = subset.inside().append(added); final double R1 = -sum(total) * sum(total) / weight(total); total.append(out); final double R2 = Math.min(weight(inside) > 1 ? -sum(inside) * sum(inside) / weight(inside) : 0,-sum(total) * sum(total) / weight(total)) ; final int borders = borders(feature, start, end); final double score = (R2-R1) / (borders); return score >= 0 ? score : -1000000;//score(total, out, complexity + borders); } private int borders(BFGrid.BFRow feature, int start, int end) { return start != 0 && end != feature.size() ? 4 : 1; } private double score(AdditiveStatistics inside, AdditiveStatistics outside, int complexity) { final double wIn = weight(inside); if (used.size() > 6) return Double.NEGATIVE_INFINITY; if (wIn > 0 && wIn < minBinSize) return -1000000; final double wOut = weight(outside); if (wOut > 0 && wOut < minBinSize) return -1000000; double s = sum(inside) + sum(outside); double w = wIn + wOut; final double score = weight(inside) > 0 ? sum(inside) * sum(inside) / weight(inside) : 0; return score; } @Override public double score() { return score(subset.inside(), subset.outside(), complexity); } @Override public double insideIncrement() { return loss.bestIncrement(subset.inside()); } @Override public void endClause() { subset.endClause(); complexity ++; } public void addCondition(BFGrid.BFRow feature, int start, int end) { subset().addCondition(feature, start, end); complexity += borders(feature, start, end); used.add(feature.origFIndex); complexity ++; } @Override public CherryPointsHolder subset() { return subset; } }