package org.shanbo.feluca.classification.lr; import java.util.Properties; import org.shanbo.feluca.data2.Vector; import org.shanbo.feluca.data2.DataStatistic; import org.shanbo.feluca.paddle.common.Utilities; public class SGDL2LR extends AbstractSGDLogisticRegression{ /** * 1->1 ; 0-> -1; 0.5->0 * @param y * @return */ protected double transform(double y ){ return y * 2 - 1 ; } protected double transform(int y ){ return (y * 2) - 1 ; } protected double gradientDescend(Vector sample){ double wTx = w0; int innerLabel = outerLabelInfo[LABELRANGEBASE + sample.getIntHeader()][0]; double newlabel = transform( innerLabel); //{-1, +1} for(int i = 0 ; i < sample.getSize(); i++){//wTx wTx += featureWeights[sample.getFId(i)] * sample.getWeight(i); } double gradient = - newlabel * ( 1 - 1/(1 + Math.pow(Math.E, - newlabel * wTx))); if (w0Type == 2) w0 -= alpha * (gradient + 2 * lambda * w0); for(int i = 0 ; i < sample.getSize(); i++){ // w <- w + alpha * (error * partial_derivation - lambda * w) featureWeights[sample.getFId(i)] -= alpha * (gradient * sample.getWeight(i) + 2 * lambda * featureWeights[sample.getFId(i)]) ; } double innerPrediction = 1/ (1+Math.pow(Math.E, - wTx)); return innerPrediction; } @Override public int estimate(Properties dataStatus, Properties parameters) { // TODO Auto-generated method stub int maxFeatureId = Utilities.getIntFromProperties(dataStatus, DataStatistic.MAX_FEATURE_ID); int maxVectorSize = Utilities.getIntFromProperties(dataStatus, DataStatistic.MAX_VECTORSIZE); int numberLines = Utilities.getIntFromProperties(dataStatus, DataStatistic.NUM_VECTORS); int numberFeatures = Utilities.getIntFromProperties(dataStatus, DataStatistic.TOTAL_FEATURES); int modelSize = maxFeatureId * 4 / 1024 ; modelSize += maxFeatureId * 4 / 1024 ; int dataSetKb = 0; if (parameters.containsKey("inRam")){ // use file data dataSetKb += 30 * 1024; // VectorStorage.FileStorage approximately cost // dataSetKb += VectorPool.RAMEstimate( maxVectorSize); }else{ // dataSetKb += VectorStorage.RAMCompactStorage.RAMEstimate(numberLines, numberFeatures, vectorStatusPara); } return dataSetKb + modelSize; } @Override protected void estimateParameter(){ this.samples = Utilities.getIntFromProperties(dataEntry.getDataStatistic(), DataStatistic.NUM_VECTORS); double rate = Math.log(2 + samples /((1 + biasWeightRound)/(biasWeightRound * 2.0)) /( this.maxFeatureId + 0.0)); if (rate < 0.5) rate = 0.5; if (alpha == null){ alpha = 0.5 / rate; minAlpha = alpha / Math.pow(1 + rate, 1.8); System.out.println("guessing alpha:" + alpha); } if (this.lambda == null){ lambda = 0.02 / rate; // minLambda = lambda / Math.pow(1 + rate, 1.8); minLambda = 0.001; System.out.println("guessing lambda:" + lambda); } } }