package org.shanbo.feluca.classification.fmc; import java.util.Arrays; import java.util.Properties; import java.util.Random; import org.shanbo.feluca.classification.common.Evaluator; import org.shanbo.feluca.classification.lr.AbstractSGDLogisticRegression; import org.shanbo.feluca.classification.lr.SGDL2LR; import org.shanbo.feluca.data2.DataEntry; import org.shanbo.feluca.data2.Vector; import org.shanbo.feluca.paddle.common.Utilities; /** * only support 2-degree interaction of features. * Extending AbstractSGDLogisticRegression * * @author lgn * */ public class SGDFactorizeMachine extends SGDL2LR{ protected int dim = 5; protected float fWeightRange = 0.05f; protected float[][] factors ; Random rand = new Random(); protected void init(){ super.init(); factors = new float[dim][]; for(int k = 0 ; k < dim ; k ++){ factors[k] = new float[maxFeatureId + 1]; for(int i = 0 ; i < factors[k].length ; i++){ factors[k][i] = (float)Utilities.randomDouble(-1, 1) * fWeightRange; } } } /** * just for one-hot dataset */ // protected void _train(int fold, int remain) throws Exception{ //// if (true){ //// super._train(fold, remain); //// return; //// } // System.out.println("one hot"); // double avge = 99999.9; // double lastAVGE = Double.MAX_VALUE; // // double corrects = 0; // double lastCorrects = -1; // // // double multi = (biasWeightRound * minSamples + maxSamples)/(minSamples + maxSamples + 0.0); // // for(int l = 0 ; l < Math.max(10, loops) // && (l < Math.min(10, loops) // || (l < loops && (Math.abs(1- avge/ lastAVGE) > convergence ) // || Math.abs(1- corrects/ lastCorrects) > convergence * 0.01)); l++){ // lastAVGE = avge; // lastCorrects = corrects; // dataEntry.reOpen(); //start reading data // // long timeStart = System.currentTimeMillis(); // // int c =1; //for n-fold cv // double error = 0; // double sume = 0; // corrects = 0; // int cc = 0; // // for(Vector sample = dataEntry.getNextVector(); sample != null ; sample = dataEntry.getNextVector()){ // if (c % fold == remain){ // no train // ; // }else{ //train // if ( sample.getIntHeader() == this.biasLabel){ //bias; sequentially compute #(bias - 1) times // for(int bw = 1 ; bw < this.biasWeightRound; bw++){ //bias // this.gradientDescendOneHot(sample); // } // } // error = gradientDescendOneHot(sample); // if (Math.abs(error) < 0.45)//accuracy // if ( sample.getIntHeader() == this.biasLabel) // corrects += this.biasWeightRound; // else // corrects += 1; // cc += 1; // // sume += Math.abs(error); // if (error > 0){ // sume += - Math.log(1-error) / 0.69314718; // }else{ // sume+= - Math.log(1+error)/ 0.69314718; // } // } // c += 1; // } // // avge = sume / cc; // // long timeEnd = System.currentTimeMillis(); // double acc = corrects / (cc * multi) * 100; // // if (corrects < lastCorrects ){ // // if (!alphaSetted){ // this.alpha *= 0.5; // if (alpha < minAlpha) // alpha = minAlpha; // } // if (!lambdaSetted){ // this.lambda *= 0.9; // if (lambda < minLambda) // lambda = minLambda; // } // } // System.out.println(String.format("#%d loop%d\ttime:%d(ms)\tacc: %.3f(approx)\tavg_error:%.6f", cc, l, (timeEnd - timeStart), acc , avge)); // } // } /** * for one-hot */ public final double gradientDescend(Vector sample){ double wTx = w0; int innerLabel = outerLabelInfo[LABELRANGEBASE + sample.getIntHeader()][0]; double newlabel = transform( innerLabel); //{-1, +1} for(int i = 0 ; i < sample.getSize(); i++){//wTx wTx += featureWeights[sample.getFId(i)] * sample.getWeight(i); } double intersectionWeightSum = 0; double[] Sigmav2x2 = new double[dim]; double[] SigmaVX = new double[dim]; for(int f = 0; f < dim ; f ++){ for(int i = 0 ; i < sample.getSize(); i++){ SigmaVX[f] += factors[f][sample.getFId(i)] ; Sigmav2x2[f] += Math.pow(factors[f][sample.getFId(i)], 2) ; } intersectionWeightSum += ((Math.pow(SigmaVX[f], 2) - Sigmav2x2[f])); } wTx += (intersectionWeightSum * 0.5 ); double gradient = - newlabel * ( 1 - 1/(1 + Math.pow(Math.E, - newlabel * wTx))); if (w0Type == 2) w0 -= alpha * (gradient + 2 * lambda * w0); for(int i = 0 ; i < sample.getSize(); i++){ // w <- w + alpha * (error * partial_derivation - lambda * w) featureWeights[sample.getFId(i)] -= alpha * (gradient + 2 * lambda * featureWeights[sample.getFId(i)]) ; for(int f = 0 ; f < dim ; f++){ double step = (SigmaVX[f] - factors[f][sample.getFId(i)]) ; factors[f][sample.getFId(i)] -= alpha * (gradient * step + 2 * lambda * factors[f][sample.getFId(i)] ) ; } } double innerPrediction = 1/ (1+Math.pow(Math.E, - wTx)); return innerPrediction; } @Override public void setProperties(Properties prop) { super.setProperties(prop); this.dim = new Integer(prop.getProperty("dim", "5")); } @Override public Properties getProperties() { // TODO Auto-generated method stub return null; } @Override public void predict(Vector sample, double[] probabilities) throws Exception { double oneDegreeWeightSum = w0; for(int i = 0 ; i < sample.getSize(); i++){ oneDegreeWeightSum += featureWeights[sample.getFId(i)] * sample.getWeight(i); } double intersectionWeightSum = 0; double[] Sigmav2x2 = new double[dim]; double[] SigmaVX = new double[dim]; for(int f = 0; f < dim ; f ++){ for(int i = 0 ; i < sample.getSize(); i++){ SigmaVX[f] += factors[f][sample.getFId(i)] * sample.getWeight(i); Sigmav2x2[f] += Math.pow(factors[f][sample.getFId(i)], 2) * Math.pow(sample.getWeight(i), 2); } intersectionWeightSum += ((Math.pow(SigmaVX[f], 2) - Sigmav2x2[f])); } double probability = 1/(1+Math.pow(Math.E, -(oneDegreeWeightSum + (intersectionWeightSum * 0.5)))); probabilities[0] = 1- probability; probabilities[1] = probability; } @Override public void saveModel(String filePath) throws Exception { // TODO Auto-generated method stub } @Override public void loadModel(String modelPath, Properties statistic) throws Exception { // TODO Auto-generated method stub } @Override public void crossValidation(int fold, Evaluator... evaluators) throws Exception { // TODO Auto-generated method stub } @Override protected void estimateParameter() throws NullPointerException { // TODO Auto-generated method stub } @Override public int estimate(Properties dataStatus, Properties parameters) { // TODO Auto-generated method stub return 0; } }