package org.shanbo.feluca.distribute.classification.lr;
import java.util.ArrayList;
import java.util.Map.Entry;
import java.util.concurrent.ExecutionException;
import org.shanbo.feluca.data2.DataStatistic;
import org.shanbo.feluca.data2.Vector;
import org.shanbo.feluca.distribute.launch.LoopingBase;
import org.shanbo.feluca.paddle.GlobalConfig;
import org.shanbo.feluca.util.JSONUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.alibaba.fastjson.JSONObject;
public class SGDL2LR extends LoopingBase{
final static int LABELRANGEBASE = 32768;
public final static double DEFAULT_STOP = 0.001;
public final static int DEFAULT_LOOPS = 30;
final static int BATCH_COMPUTE_SIZE = 100;
protected int[][] dataInfo = null;
float[] featureWeights = null; //ref of float[] in ModelLocal
protected Double alpha = null; // learning speed
protected Double lambda = null;// regularization
protected Double convergence = null;
protected boolean alphaSetted = false;
protected boolean lambdaSetted = false;
protected double minAlpha = 0.001;
protected double minLambda = 0.01;
protected int fold = Integer.MAX_VALUE;
protected int samples = 0;
protected int maxFeatureId = -1;
protected int biasLabel = 0; // original label
protected int biasWeightRound = 1;
// for accuracy stop
protected int minSamples = 0; // #
protected int maxSamples = 0; // #
double lastCorrects = -1;
double avge = 999999999;
double lastAVGE = avge;
double error = 0;
double sume = 0.0, corrects = 0;
int cc = 0;
int vcount = 0; //for cross validation
int remain = 0; // for cross validation
long tStart, tEnd;
double multi ;
public SGDL2LR(GlobalConfig conf) throws Exception {
super(conf);
initParams();
estimateParameter();
}
static Logger log = LoggerFactory.getLogger(SGDL2LR.class);
private void initParams(){
String infoString = conf.getDataStatistic().getString(DataStatistic.LABEL_INFO);
String[] ll = infoString.split("\\s+");
String[] classInfo1 = ll[0].split(":"); // orginal_label:converted_label:#num
String[] classInfo2 = ll[1].split(":");
int[] classInfo1Ints = new int[]{Integer.parseInt(classInfo1[0]), Integer.parseInt(classInfo1[1]), Integer.parseInt(classInfo1[2])};
int[] classInfo2Ints = new int[]{Integer.parseInt(classInfo2[0]), Integer.parseInt(classInfo2[1]), Integer.parseInt(classInfo2[2])};
this.dataInfo = new int[LABELRANGEBASE * 2][]; // original_LABEL -> index, #sample
this.dataInfo[LABELRANGEBASE + classInfo1Ints[0]] = new int[]{classInfo1Ints[1], classInfo1Ints[2]};
this.dataInfo[LABELRANGEBASE + classInfo2Ints[0]] = new int[]{classInfo2Ints[1], classInfo2Ints[2]};
// set bias automatically
float ratio = classInfo2Ints[2] /(classInfo1Ints[2] + 0.0f);
this.biasLabel = classInfo1Ints[0];
this.minSamples = classInfo1Ints[2];
this.maxSamples = classInfo2Ints[2];
if (classInfo1Ints[2] > classInfo2Ints[2]){ // #(label 0) > #(label 1)
this.biasLabel = classInfo2Ints[0];
ratio = classInfo1Ints[2] /(classInfo2Ints[2] + 0.0f);
this.minSamples = classInfo2Ints[2];
this.maxSamples = classInfo1Ints[2];
}else{ //default
;
}
this.biasWeightRound = Math.round(ratio);
this.setProperties(conf.getAlgorithmConf());
}
protected void estimateParameter() throws NullPointerException{
this.samples = conf.getAlgorithmConf().getIntValue(DataStatistic.NUM_VECTORS);
double rate = Math.log(2 + samples /((1 + biasWeightRound)/(biasWeightRound * 2.0)) /( this.maxFeatureId + 0.0));
if (rate < 0.5)
rate = 0.5;
if (alpha == null){
alpha = 0.5 / rate;
minAlpha = alpha / Math.pow(1 + rate, 1.8);
}
if (this.lambda == null){
lambda = 0.002 / rate;
minLambda = 0.01;
}
}
private void setProperties(JSONObject algoConf) {
if (algoConf.containsKey("loops")){
loops = algoConf.getInteger("loops");
}
if (algoConf.containsKey("alpha")){
alpha = algoConf.getDouble("alpha");
alphaSetted = true;
}
if (algoConf.containsKey("lambda")){
this.lambda = algoConf.getDouble("lambda");
lambdaSetted = true;
}
if (algoConf.containsKey("convergence")){
convergence = algoConf.getDouble("convergence");
}
if (this.convergence == null){
convergence = DEFAULT_STOP;
}
for(Entry<String, Object> entry : algoConf.entrySet()){
String key= entry.getKey();
if (key.startsWith("-w")){
biasLabel = Integer.parseInt(key.substring(2));
biasWeightRound = Integer.parseInt(entry.getValue().toString());
}
}
}
protected void startup() {
modelClient.createVector("weights", JSONUtil.getConf(conf.getDataStatistic(), DataStatistic.MAX_FEATURE_ID, 1) + 1, 0, 0);
featureWeights = modelClient.getVector("weights");
}
protected void computeLoopBegin(){
lastAVGE = avge;
lastCorrects = corrects;
tStart = System.currentTimeMillis();
vcount =1; //for n-fold cv
error = 0;
sume = 0;
corrects = 0;
cc = 0;
vcount = 0;
multi = (biasWeightRound * minSamples + maxSamples)/(minSamples + maxSamples + 0.0);
}
protected void computeLoopEnd(){
avge = sume / cc;
tEnd = System.currentTimeMillis();
double acc = corrects / (cc * multi) * 100;
if (corrects < lastCorrects ){ //
if (!alphaSetted){
this.alpha *= 0.5;
if (alpha < minAlpha)
alpha = minAlpha;
}
if (!lambdaSetted){
this.lambda *= 0.9;
if (lambda < minLambda)
lambda = minLambda;
}
}
System.out.println(String.format("#%d loop%d\ttime:%d(ms)\tacc: %.3f(approx)\tavg_error:%.6f", cc, looping, (tEnd - tStart), acc , avge));
}
@Override
protected void computeLoop() throws Exception {
computeLoopBegin();
ArrayList<Vector> batchVectors = new ArrayList<Vector>();
for(Vector v = dataEntry.getNextVector(); v!= null; v=dataEntry.getNextVector()){
batchVectors.add(v);
if (batchVectors.size() >= BATCH_COMPUTE_SIZE){
doCompute(batchVectors);
batchVectors.clear();
}
}
if (batchVectors.size() > 0){
doCompute(batchVectors);
}
computeLoopEnd();
}
protected void doCompute(ArrayList<Vector> batchVectors) throws InterruptedException, ExecutionException{
float[] weightSums = new float[batchVectors.size()];
for(int i = 0 ; i < batchVectors.size(); i++){
Vector vector = batchVectors.get(i);
float weightSum = 0;
for(int f = 0 ; f < vector.getSize(); f++){
weightSum += vector.getWeight(f) * featureWeights[vector.getFId(f)];
}
weightSums[i] = weightSum;
}
//--------
float[] merged = reducerClient.sum(weightSums);
for(int i = 0 ; i < batchVectors.size(); i++){
Vector v = batchVectors.get(i);
error = gradientDescend(v, ((Float)merged[i]).floatValue());
if (Math.abs(error) < 0.49)//accuracy
if ( v.getIntHeader() == this.biasLabel)
corrects += this.biasWeightRound;
else
corrects += 1;
cc += 1;
sume += Math.abs(error);
}
}
private double gradientDescend(Vector v, float weightSum){
int label = v.getIntHeader();
double tmp = Math.pow(Math.E, -weightSum); //e^-sigma(x)
double error = dataInfo[LABELRANGEBASE + label][0] - (1/ (1+tmp));
double partialDerivation = tmp / (tmp * tmp + 2 * tmp + 1) ;
for(int i = 0 ; i < v.getSize(); i++){
// w <- w + alpha * (error * partial_derivation - lambda * w)
featureWeights[v.getFId(i)] +=
alpha * (error * v.getWeight(i) * partialDerivation - lambda * featureWeights[v.getFId(i)]) ;
}
return error;
}
@Override
public boolean earlyStop() {
double errorRatio = 1- avge/ lastAVGE;
double accRatio = 1- corrects/ lastCorrects;
System.out.print(String.format("errorRatio:[%.4f] accRatio:[%.4f] ", errorRatio, accRatio));
if ( (Math.abs(errorRatio) > convergence ) || Math.abs(accRatio) > convergence * 0.01){
return false;
}else{
return true;
}
}
}