/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.cloudera.knittingboar.sgd;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.sgd.AbstractOnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.DefaultGradient;
import org.apache.mahout.classifier.sgd.Gradient;
import org.apache.mahout.classifier.sgd.PolymorphicWritable;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import com.cloudera.knittingboar.utils.Utils;
/**
* Parallel Online Logisitc Regression
*
* Based loosely on Mahout's :
*
* http://svn.apache.org/repos/asf/mahout/trunk/core/src/main/java/org/apache/
* mahout/classifier/sgd/OnlineLogisticRegression.java
*
*
* @author jpatterson
*
*/
public class ParallelOnlineLogisticRegression extends
AbstractOnlineLogisticRegression implements Writable {
public static final int WRITABLE_VERSION = 1;
// these next two control decayFactor^steps exponential type of annealing
// learning rate and decay factor
private double learningRate = 1;
private double decayFactor = 1 - 1.0e-3;
// these next two control 1/steps^forget type annealing
private int stepOffset = 10;
// -1 equals even weighting of all examples, 0 means only use exponential
// annealing
private double forgettingExponent = -0.5;
// controls how per term annealing works
private int perTermAnnealingOffset = 20;
// had to add this because its private in the base class
private Gradient default_gradient = new DefaultGradient();
// ####### This is NEW ######################
// that is (numCategories-1) x numFeatures
//protected MultinomialLogisticRegressionParameterVectors gamma; // this is the saved updated gradient we merge
// at the super step
public ParallelOnlineLogisticRegression() {
// private constructor available for serialization, but not normal use
}
/**
* Main constructor
*
*
*
* @param numCategories
* @param numFeatures
* @param prior
*/
public ParallelOnlineLogisticRegression(int numCategories, int numFeatures,
PriorFunction prior) {
this.numCategories = numCategories;
this.prior = prior;
updateSteps = new DenseVector(numFeatures);
updateCounts = new DenseVector(numFeatures).assign(perTermAnnealingOffset);
beta = new DenseMatrix(numCategories - 1, numFeatures);
// brand new factor for parallelization
// this.gamma = new MultinomialLogisticRegressionParameterVectors(numCategories, numFeatures);
}
/**
* Chainable configuration option.
*
* @param alpha
* New value of decayFactor, the exponential decay rate for the
* learning rate.
* @return This, so other configurations can be chained.
*/
public ParallelOnlineLogisticRegression alpha(double alpha) {
this.decayFactor = alpha;
return this;
}
@Override
public ParallelOnlineLogisticRegression lambda(double lambda) {
// we only over-ride this to provide a more restrictive return type
super.lambda(lambda);
return this;
}
/**
* Chainable configuration option.
*
* @param learningRate
* New value of initial learning rate.
* @return This, so other configurations can be chained.
*/
public ParallelOnlineLogisticRegression learningRate(double learningRate) {
this.learningRate = learningRate;
return this;
}
public ParallelOnlineLogisticRegression stepOffset(int stepOffset) {
this.stepOffset = stepOffset;
return this;
}
public ParallelOnlineLogisticRegression decayExponent(double decayExponent) {
if (decayExponent > 0) {
decayExponent = -decayExponent;
}
this.forgettingExponent = decayExponent;
return this;
}
@Override
public double perTermLearningRate(int j) {
return Math.sqrt(perTermAnnealingOffset / updateCounts.get(j));
}
@Override
public double currentLearningRate() {
return learningRate * Math.pow(decayFactor, getStep())
* Math.pow(getStep() + stepOffset, forgettingExponent);
}
public void copyFrom(ParallelOnlineLogisticRegression other) {
super.copyFrom(other);
learningRate = other.learningRate;
decayFactor = other.decayFactor;
stepOffset = other.stepOffset;
forgettingExponent = other.forgettingExponent;
perTermAnnealingOffset = other.perTermAnnealingOffset;
}
public ParallelOnlineLogisticRegression copy() {
close();
ParallelOnlineLogisticRegression r = new ParallelOnlineLogisticRegression(
numCategories(), numFeatures(), prior);
r.copyFrom(this);
return r;
}
/**
* TODO - add something in to write the gamma to the output stream -- do we
* need to save gamma?
*/
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(WRITABLE_VERSION);
out.writeDouble(learningRate);
out.writeDouble(decayFactor);
out.writeInt(stepOffset);
out.writeInt(step);
out.writeDouble(forgettingExponent);
out.writeInt(perTermAnnealingOffset);
out.writeInt(numCategories);
MatrixWritable.writeMatrix(out, beta);
PolymorphicWritable.write(out, prior);
VectorWritable.writeVector(out, updateCounts);
VectorWritable.writeVector(out, updateSteps);
}
@Override
public void readFields(DataInput in) throws IOException {
int version = in.readInt();
if (version == WRITABLE_VERSION) {
learningRate = in.readDouble();
decayFactor = in.readDouble();
stepOffset = in.readInt();
step = in.readInt();
forgettingExponent = in.readDouble();
perTermAnnealingOffset = in.readInt();
numCategories = in.readInt();
beta = MatrixWritable.readMatrix(in);
prior = PolymorphicWritable.read(in, PriorFunction.class);
updateCounts = VectorWritable.readVector(in);
updateSteps = VectorWritable.readVector(in);
} else {
throw new IOException("Incorrect object version, wanted "
+ WRITABLE_VERSION + " got " + version);
}
}
/**
* Custom training for POLR based around accumulating gradient to send to the
* master process
*
*
*/
@Override
public void train(long trackingKey, String groupKey, int actual,
Vector instance) {
unseal();
double learningRate = currentLearningRate();
// push coefficients back to zero based on the prior
regularize(instance);
// basically this only gets the results for each classification
// update each row of coefficients according to result
Vector gradient = this.default_gradient.apply(groupKey, actual, instance,
this);
for (int i = 0; i < numCategories - 1; i++) {
double gradientBase = gradient.get(i);
// we're only going to look at the non-zero elements of the vector
// then we apply the gradientBase to the resulting element.
Iterator<Vector.Element> nonZeros = instance.iterateNonZero();
while (nonZeros.hasNext()) {
Vector.Element updateLocation = nonZeros.next();
int j = updateLocation.index();
double gradient_to_add = gradientBase * learningRate
* perTermLearningRate(j) * instance.get(j);
// double old_beta = beta.getQuick(i, j);
double newValue = beta.getQuick(i, j) + gradientBase * learningRate
* perTermLearningRate(j) * instance.get(j);
beta.setQuick(i, j, newValue);
// now update gamma --- we only want the gradient since the last time
/*
double old_gamma = gamma.getCell(i, j);
double new_gamma = old_gamma + gradient_to_add; // gradientBase *
// learningRate *
// perTermLearningRate(j)
// * instance.get(j);
gamma.setCell(i, j, new_gamma);
*/
}
}
// remember that these elements got updated
Iterator<Vector.Element> i = instance.iterateNonZero();
while (i.hasNext()) {
Vector.Element element = i.next();
int j = element.index();
updateSteps.setQuick(j, getStep());
updateCounts.setQuick(j, updateCounts.getQuick(j) + 1);
}
nextStep();
}
/**
* get the current parameter vector
*
* @return Matrix
*/
public Matrix noReallyGetBeta() {
return this.beta;
}
public void SetBeta(Matrix beta_mstr_cpy) {
this.beta = beta_mstr_cpy.clone();
}
/**
* Spit out the current values for Gamma (gradient buffer since last flush)
* and Beta (parameter vector)
*
*/
public void Debug_PrintGamma() {
System.out.println("# Debug_PrintGamma > Beta: ");
Utils.PrintVectorSectionNonZero(this.noReallyGetBeta().viewRow(0), 10);
}
/**
* Reset all values in Gamma (gradient buffer) back to zero
*
*/
/* public void FlushGamma() {
this.gamma.Reset();
}
public MultinomialLogisticRegressionParameterVectors getGamma() {
return this.gamma;
}
*/
}