package org.shanbo.feluca.classification.lr;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;
import java.util.Properties;
import java.util.Map.Entry;
import org.shanbo.feluca.classification.common.Classifier;
import org.shanbo.feluca.classification.common.Evaluator;
import org.shanbo.feluca.data2.DataEntry;
import org.shanbo.feluca.data2.Vector;
import org.shanbo.feluca.data2.DataStatistic;
import org.shanbo.feluca.paddle.common.MemoryEstimater;
import org.shanbo.feluca.paddle.common.Utilities;
/**
* minimize least square loss
* @author lgn
*
*/
public abstract class AbstractSGDLogisticRegression implements Classifier, MemoryEstimater{
protected final static double initWeight = 0;
protected final static int LABELRANGEBASE = 32768;
public final static double DEFAULT_STOP = 0.001;
public final static int DEFAULT_LOOPS = 30;
protected double w0;
protected int w0Type = 0;// 0 for no use; 1 for stay ; 2 for gradient
public double[] featureWeights = null;
protected DataEntry dataEntry= null;
protected int[][] outerLabelInfo = null; //outer label -> info
protected int[][] innerLabelInfo = null; //inner label -> info
protected Double alpha = null; // learning speed
protected Double lambda = null;// regularization
protected Double convergence = null;
protected boolean alphaSetted = false;
protected boolean lambdaSetted = false;
protected double minAlpha = 0.001;
protected double minLambda = 0.01;
protected Integer loops = DEFAULT_LOOPS;
protected int fold = 5;
protected int samples = 0;
protected int maxFeatureId = -1;
protected int biasLabel = 0; // original label
protected int biasWeightRound = -1;
// for accuracy stop
protected int minSamples = 0; // #
protected int maxSamples = 0; // #
public void loadData(DataEntry data) throws Exception {
dataEntry = data;
if (this.dataEntry == null){
throw new RuntimeException("dataEntry must be set!");
}else{
maxFeatureId = Utilities.getIntFromProperties(dataEntry.getDataStatistic(), DataStatistic.MAX_FEATURE_ID);
String tmpInfo = Utilities.getStrFromProperties(dataEntry.getDataStatistic(), DataStatistic.LABEL_INFO);
this.outerLabelInfo = new int[LABELRANGEBASE * 2][];
if (tmpInfo.split(" ").length > 2){
throw new RuntimeException("Data Set contains more than 2 classes");
}
_loadDataInfo(tmpInfo);
}
}
private void _loadDataInfo(String infoString){
String[] ll = infoString.split("\\s+");
String[] classInfo1 = ll[0].split(":"); // orginal_label:converted_label:#num
String[] classInfo2 = ll[1].split(":");
int[] classInfo1Ints = new int[]{Integer.parseInt(classInfo1[0]), Integer.parseInt(classInfo1[1]), Integer.parseInt(classInfo1[2])};
int[] classInfo2Ints = new int[]{Integer.parseInt(classInfo2[0]), Integer.parseInt(classInfo2[1]), Integer.parseInt(classInfo2[2])};
this.outerLabelInfo = new int[LABELRANGEBASE * 2][]; // original_LABEL -> innerLabel, #sample
this.outerLabelInfo[LABELRANGEBASE + classInfo1Ints[0]] = new int[]{classInfo1Ints[1], classInfo1Ints[2]};
this.outerLabelInfo[LABELRANGEBASE + classInfo2Ints[0]] = new int[]{classInfo2Ints[1], classInfo2Ints[2]};
this.innerLabelInfo = new int[2][]; //innerLabel -> original_LABEL, #sample
this.innerLabelInfo[classInfo1Ints[1]] = new int[]{classInfo1Ints[0], classInfo1Ints[2]};
this.innerLabelInfo[classInfo2Ints[1]] = new int[]{classInfo2Ints[0], classInfo2Ints[2]};
// set bias automatically
float ratio = classInfo2Ints[2] /(classInfo1Ints[2] + 0.0f);
this.biasLabel = classInfo1Ints[0];
this.minSamples = classInfo1Ints[2];
this.maxSamples = classInfo2Ints[2];
if (classInfo1Ints[2] > classInfo2Ints[2]){ // #(label 0) > #(label 1)
this.biasLabel = classInfo2Ints[0];
ratio = classInfo1Ints[2] /(classInfo2Ints[2] + 0.0f);
this.minSamples = classInfo2Ints[2];
this.maxSamples = classInfo1Ints[2];
}else{ //default
;
}
if (biasWeightRound == -1){
this.biasWeightRound = Math.round(ratio);
}
float cc = this.biasWeightRound * this.minSamples + this.maxSamples + 0.0f;
w0 = - Math.log(cc/(this.biasWeightRound * this.minSamples) -1 );
if (w0Type == 0 || w0Type == 2)
w0 = 0;
System.out.println(w0);
try{
this.estimateParameter();
if (this.convergence == null){
convergence = DEFAULT_STOP;
}
}catch(NullPointerException e){
// loading model
System.out.println( " a model loading~");
}
}
abstract protected void estimateParameter() throws NullPointerException;
public void train() throws Exception {
this.init();
this._train(Integer.MAX_VALUE, 0); // all samples for training
}
protected void init(){
featureWeights = new double[maxFeatureId + 1];
Arrays.fill(featureWeights, initWeight);
}
protected double logloss(int y, double yy) {
if (y == 0){
return - Math.log( 1- yy) / 0.69314718;
}else{
return - Math.log(yy) / 0.69314718;
}
}
protected boolean acc(int y , double yy) {
if (y == 1 && yy > 0.5){
return true;
}else if (y == 0 && yy < 0.5){
return true;
}
return false;
}
public void crossValidation(int fold, Evaluator... evaluators) throws Exception {
for(int i = 0 ; i < fold; i++){
this.init();
System.out.println("----cross validation loop " + i);
this._train(fold, i);
//-------------test-------
this.dataEntry.reOpen();
int c = 1;
double[] resultProbs = new double[2];
System.out.println("testing");
for(Vector sample = dataEntry.getNextVector(); sample != null ; sample = dataEntry.getNextVector()){
if (c % fold == i){
if (sample.getSize() == 0)
continue;
this.predict(sample, resultProbs);
for(Evaluator e : evaluators){
e.collect(outerLabelInfo[LABELRANGEBASE + sample.getIntHeader()][0], resultProbs);
}
}
}
}
}
public void setProperties(Properties prop) {
if (prop.getProperty("loops") != null){
loops = Utilities.getIntFromProperties(prop, "loops");
}
if (prop.getProperty("alpha") != null){
alpha = Utilities.getDoubleFromProperties(prop, "alpha");
alphaSetted = true;
}
if (prop.getProperty("lambda") != null){
this.lambda = Utilities.getDoubleFromProperties(prop,"lambda");
lambdaSetted = true;
}
if (prop.getProperty("convergence") != null){
convergence = Utilities.getDoubleFromProperties(prop,"convergence");
}
if (prop.getProperty("w0type") != null){
setW0Type(Utilities.getIntFromProperties(prop,"w0type"));
}
for(Entry<Object, Object> entry : prop.entrySet()){
String key= entry.getKey().toString();
if (key.startsWith("-w")){
biasLabel = Integer.parseInt(key.substring(2));
biasWeightRound = Integer.parseInt(entry.getValue().toString());
}
}
}
public void saveModel(String filePath) throws Exception {
BufferedWriter bw = new BufferedWriter(new FileWriter(filePath));
// bw.write(Utilities.getStrFromProperties(dataEntry.getDataStatistic(), DataStatistic.MAX_FEATURE_ID) + "\n");
// bw.write(Utilities.getStrFromProperties(dataEntry.getDataStatistic(), DataStatistic.LABEL_INFO) + "\n");
bw.write(w0 + "\n");
for(int i = 0 ; i < this.featureWeights.length; i++){
if (this.featureWeights[i] != initWeight)
bw.write(String.format("%d\t%.6f\n", i, this.featureWeights[i]));
}
bw.close();
}
public void loadModel(String modelPath, Properties statistic) throws Exception {
BufferedReader br = new BufferedReader(new FileReader(modelPath));
this.featureWeights = new double[Integer.parseInt(statistic.getProperty(DataStatistic.MAX_FEATURE_ID)) + 1];
this._loadDataInfo(statistic.getProperty(DataStatistic.LABEL_INFO));
w0 = Double.parseDouble(br.readLine());
for(String line = br.readLine(); line != null; line = br.readLine()){
String[] fidWeight = line.split("\t");
this.featureWeights[Integer.parseInt(fidWeight[0])] = Double.parseDouble(fidWeight[1]);
}
br.close();
}
public void predict(Vector sample, double[] probabilities) throws Exception{
double weigtSum = w0 ;
for(int i = 0 ; i < sample.getSize(); i++){
weigtSum += this.featureWeights[sample.getFId(i)] * sample.getWeight(i);
}
double probability = 1/(1+Math.pow(Math.E, -weigtSum));
probabilities[0] = 1- probability;
probabilities[1] = probability;
}
/**
* predict probability for data; the predict_Label will accord with training data;
* Otherwise use {@link #predict(String, String, Evaluator...)} instead.
*/
public void predict(DataEntry data, String resultPath, Evaluator... evaluators) throws Exception {
if (this.featureWeights == null)
throw new IOException("!Model haven't been initialized yet! :(");
BufferedWriter bw = new BufferedWriter(new FileWriter(resultPath));
bw.write("testLabel\tpredictLabel\tprobability(here means confidence)\n");
double[] resultProbs = new double[2];
int innerLabel = -1;
data.reOpen();
System.out.println(w0);
for(Vector sample = data.getNextVector(); sample != null ; sample = data.getNextVector()){
if (sample.getSize() == 0){ //how to predict without any features? A default probability = 0.5 should be moderate;
bw.write(String.format("%d\t%d\t%.4f\n" , sample.getIntHeader(), innerLabelInfo[0][0], 0.5f));
}
this.predict(sample, resultProbs);
if (evaluators != null){
for(Evaluator e : evaluators){
int testLabel = outerLabelInfo[LABELRANGEBASE + sample.getIntHeader()][0]; //innerLabel = [0 or 1]
e.collect(testLabel, resultProbs);
}
}
innerLabel = resultProbs[0] > resultProbs[1] ? 0 : 1; // predict inner Label with probabilities;
//output original label;
bw.write(String.format("%d\t%d\t%.4f\n" , sample.getIntHeader(), innerLabelInfo[innerLabel][0], resultProbs[innerLabel]));
}
bw.close();
for(Evaluator e : evaluators){
System.out.println(e.resultString());
}
}
public String toString(){
return String.format("alpha:%.6f, lambda:%.9f, loops: %d, bias:%d on %d times",
this.alpha, this.lambda, this.loops, biasLabel, biasWeightRound);
}
public Properties getProperties() {
Properties p = new Properties();
p.put("alpha", this.alpha);
p.put("loops", this.loops);
p.put("lambda", this.lambda);
return p;
}
public void setW0Type(int type){
if (type > 2 || type < 0){
throw new RuntimeException("0 for no use; 1 for stay ; 2 for gradient");
}
this.w0Type = type;
}
protected void _train(int fold, int remain) throws Exception{
double avge = 99999.9;
double lastAVGE = Double.MAX_VALUE;
double corrects = 0;
double lastCorrects = -1;
double multi = (biasWeightRound * minSamples + maxSamples)/(minSamples + maxSamples + 0.0);
for(int l = 0 ; l < Math.max(5, loops)
&& (l < Math.min(5, loops)
|| (l < loops && (Math.abs(1- avge/ lastAVGE) > convergence )
|| Math.abs(1- corrects/ lastCorrects) > convergence * 0.01)); l++){
lastAVGE = avge;
lastCorrects = corrects;
dataEntry.reOpen(); //start reading data
long timeStart = System.currentTimeMillis();
int c =1; //for n-fold cv
double innerPredict = 0; //0~1
double sume = 0;
corrects = 0;
int cc = 0;
int pp = 10;
for(Vector sample = dataEntry.getNextVector(); sample != null ; sample = dataEntry.getNextVector()){
if (c % fold == remain){ // no train
;
}else{ //train
//bias; sequentially compute #(bias - 1) times
innerPredict = this.gradientDescend(sample);
if (acc(outerLabelInfo[LABELRANGEBASE + sample.getIntHeader()][0], innerPredict)){
corrects += 1;
}
cc += 1;
sume += logloss(outerLabelInfo[LABELRANGEBASE + sample.getIntHeader()][0], innerPredict);
if ( sample.getIntHeader() == this.biasLabel){
for(int bw = 0 ; bw < this.biasWeightRound; bw++){ //bias
innerPredict = this.gradientDescend(sample);
}
}
}
c += 1;
if (c% pp == 0){
System.out.print(String.format("[%.4f:%.1f]", sume / cc, corrects * 100 / cc));
pp *= 2;
}
}
avge = sume / cc;
long timeEnd = System.currentTimeMillis();
double acc = corrects / (cc ) * 100;
if (corrects < lastCorrects ){ //
if (!alphaSetted){
this.alpha *= 0.5;
if (alpha < minAlpha)
alpha = minAlpha;
}
if (!lambdaSetted){
this.lambda *= 0.9;
if (lambda < minLambda)
lambda = minLambda;
}
}
System.out.println(w0);
System.out.println(String.format("#%d loop%d\ttime:%d(ms)\tacc: %.3f(approx)\tavg_error:%.6f", cc, l, (timeEnd - timeStart), acc , avge));
}
}
abstract protected double gradientDescend(Vector sample) ;
public abstract int estimate(Properties dataStatus, Properties parameters) ;
}