package org.shanbo.feluca.classification.lr;
import java.util.Arrays;
import java.util.Properties;
import org.shanbo.feluca.data2.Vector;
import org.shanbo.feluca.data2.DataStatistic;
import org.shanbo.feluca.paddle.common.Utilities;
public final class SGDL1LR extends SGDL2LR{
private double[] qWeights = null;
private double u = 0.0;
protected void init(){
qWeights = new double[maxFeatureId + 1];
Arrays.fill(qWeights, initWeight);
super.init();
}
public final double gradientDescend(Vector sample){
double wTx = w0;
int innerLabel = outerLabelInfo[LABELRANGEBASE + sample.getIntHeader()][0];
double newlabel = transform( innerLabel); //{-1, +1}
for(int i = 0 ; i < sample.getSize(); i++){//wTx
wTx += featureWeights[sample.getFId(i)] * sample.getWeight(i);
}
double gradient = - newlabel * ( 1 - 1/(1 + Math.pow(Math.E, - newlabel * wTx)));
w0 -= alpha * (gradient + 2 * lambda * w0);
for(int i = 0 ; i < sample.getSize(); i++){
// w <- w + alpha * (error * partial_derivation - lambda * w)
featureWeights[sample.getFId(i)] -=
alpha * (gradient * sample.getWeight(i) + 2 * lambda * featureWeights[sample.getFId(i)]) ;
applyPenalty(sample.getFId(i));
}
double innerPrediction = 1/ (1+Math.pow(Math.E, - wTx));
return innerPrediction;
}
private void applyPenalty(int fid){
double z = featureWeights[fid];
//w[i]
if (featureWeights[fid] > 0){
featureWeights[fid] = Math.max(0, featureWeights[fid] - (u + qWeights[fid]));
}else if (featureWeights[fid] < 0){
featureWeights[fid] = Math.min(0, featureWeights[fid] + (u - qWeights[fid]));
}
qWeights[fid] = qWeights[fid] + (featureWeights[fid] - z);
}
protected void estimateParameter() throws NullPointerException{
this.samples = Utilities.getIntFromProperties(dataEntry.getDataStatistic(), DataStatistic.NUM_VECTORS);
double rate = Math.log(2 + samples /((1 + biasWeightRound)/(biasWeightRound * 2.0)) /( this.maxFeatureId + 0.0));
if (rate < 0.5)
rate = 0.5;
if (alpha == null){
alpha = 0.5 / rate;
minAlpha = alpha / Math.pow(1 + rate, 1.8);
}
if (this.lambda == null){
lambda = 0.5 / rate;
// minLambda = lambda / Math.pow(1 + rate, 1.8);
minLambda = 0.1;
}
}
@Override
public int estimate(Properties dataStatus, Properties parameters) {
// TODO Auto-generated method stub
int maxFeatureId = Utilities.getIntFromProperties(dataStatus, DataStatistic.MAX_FEATURE_ID);
int maxVectorSize = Utilities.getIntFromProperties(dataStatus, DataStatistic.MAX_VECTORSIZE);
int numberLines = Utilities.getIntFromProperties(dataStatus, DataStatistic.NUM_VECTORS);
int numberFeatures = Utilities.getIntFromProperties(dataStatus, DataStatistic.TOTAL_FEATURES);
int modelSize = maxFeatureId * 4 / 1024 ;
modelSize += maxFeatureId * 4 / 1024 ;
int dataSetKb = 0;
if (parameters.containsKey("inRam")){
// use file data
dataSetKb += 30 * 1024; // VectorStorage.FileStorage approximately cost
// dataSetKb += VectorPool.RAMEstimate( maxVectorSize);
}else{
// dataSetKb += VectorStorage.RAMCompactStorage.RAMEstimate(numberLines, numberFeatures, vectorStatusPara);
}
return dataSetKb + modelSize;
}
}