/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.cloudera.knittingboar.sgd;
import java.util.ArrayList;
import junit.framework.TestCase;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import com.cloudera.knittingboar.utils.Utils;
/**
* Mostly temporary tests used to debug components as we developed the system
*
* @author jpatterson
*
*/
public class TestParallelOnlineLogisticRegression extends TestCase {
public void testCreateLR() {
int categories = 2;
int numFeatures = 5;
double lambda = 1.0e-4;
double learning_rate = 50;
ParallelOnlineLogisticRegression plr = new ParallelOnlineLogisticRegression( categories, numFeatures, new L1())
.lambda(lambda)
.learningRate(learning_rate)
.alpha(1 - 1.0e-3);
assertEquals( plr.getLambda(), 1.0e-4 );
}
public void testTrainMechanics() {
int categories = 2;
int numFeatures = 5;
double lambda = 1.0e-4;
double learning_rate = 10;
ParallelOnlineLogisticRegression plr = new ParallelOnlineLogisticRegression( categories, numFeatures, new L1())
.lambda(lambda)
.learningRate(learning_rate)
.alpha(1 - 1.0e-3);
Vector input = new RandomAccessSparseVector(numFeatures);
for ( int x = 0; x < numFeatures; x++ ) {
input.set(x, x);
}
plr.train(0, input);
plr.train(0, input);
plr.train(0, input);
}
public void testPOLRInternalBuffers() {
System.out.println( "testPOLRInternalBuffers --------------" );
int categories = 2;
int numFeatures = 5;
double lambda = 1.0e-4;
double learning_rate = 10;
ArrayList<Vector> trainingSet_0 = new ArrayList<Vector>();
for ( int s = 0; s < 1; s++ ) {
Vector input = new RandomAccessSparseVector(numFeatures);
for ( int x = 0; x < numFeatures; x++ ) {
input.set(x, x);
}
trainingSet_0.add(input);
} // for
ParallelOnlineLogisticRegression plr_agent_0 = new ParallelOnlineLogisticRegression( categories, numFeatures, new L1())
.lambda(lambda)
.learningRate(learning_rate)
.alpha(1 - 1.0e-3);
System.out.println( "Beta: " );
//Utils.PrintVectorNonZero(plr_agent_0.getBeta().getRow(0));
Utils.PrintVectorNonZero(plr_agent_0.getBeta().viewRow(0));
// System.out.println( "\nGamma: " );
//Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().getRow(0));
// Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0));
plr_agent_0.train(0, trainingSet_0.get(0) );
System.out.println( "Beta: " );
//Utils.PrintVectorNonZero(plr_agent_0.noReallyGetBeta().getRow(0));
Utils.PrintVectorNonZero(plr_agent_0.noReallyGetBeta().viewRow(0));
// System.out.println( "\nGamma: " );
//Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().getRow(0));
// Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0));
}
public void testLocalGradientFlush() {
System.out.println( "\n\n\ntestLocalGradientFlush --------------" );
int categories = 2;
int numFeatures = 5;
double lambda = 1.0e-4;
double learning_rate = 10;
ArrayList<Vector> trainingSet_0 = new ArrayList<Vector>();
for ( int s = 0; s < 1; s++ ) {
Vector input = new RandomAccessSparseVector(numFeatures);
for ( int x = 0; x < numFeatures; x++ ) {
input.set(x, x);
}
trainingSet_0.add(input);
} // for
ParallelOnlineLogisticRegression plr_agent_0 = new ParallelOnlineLogisticRegression( categories, numFeatures, new L1())
.lambda(lambda)
.learningRate(learning_rate)
.alpha(1 - 1.0e-3);
plr_agent_0.train(0, trainingSet_0.get(0) );
// System.out.println( "\nGamma: " );
// Utils.PrintVectorNonZero(plr_agent_0.gamma.getMatrix().viewRow(0));
// plr_agent_0.FlushGamma();
/*
System.out.println( "Flushing Gamma ...... " );
System.out.println( "\nGamma: " );
Utils.PrintVector(plr_agent_0.gamma.getMatrix().viewRow(0));
for ( int x = 0; x < numFeatures; x++ ) {
assertEquals( plr_agent_0.gamma.getMatrix().get(0, x), 0.0 );
}
*/
}
}