/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.cloudera.knittingboar.metrics; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.util.List; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.Text; import org.apache.mahout.classifier.sgd.L1; import org.apache.mahout.classifier.sgd.ModelDissector; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import com.cloudera.knittingboar.io.InputRecordsSplit; import com.cloudera.knittingboar.records.CSVBasedDatasetRecordFactory; import com.cloudera.knittingboar.records.RCV1RecordFactory; import com.cloudera.knittingboar.records.RecordFactory; import com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory; //import com.cloudera.knittingboar.sgd.POLRBaseDriver; import com.cloudera.knittingboar.sgd.POLRModelParameters; import com.cloudera.knittingboar.sgd.ParallelOnlineLogisticRegression; import com.google.common.collect.Lists; import com.google.common.io.Closeables; public class POLRModelTester { boolean bConfLoaded = false; boolean bSetup = false; boolean bRunning = false; private Configuration conf = null; public ParallelOnlineLogisticRegression polr = null; // lmp.createRegression(); public POLRModelParameters polr_modelparams; // private static int passes; private static boolean scores = false; public String internalID = "TEST"; private RecordFactory VectorFactory = null; InputRecordsSplit input_split = null; // TODO: dissect, use this ModelDissector md = new ModelDissector(); // basic stats tracking POLRMetrics metrics = new POLRMetrics(); // double averageLL = 0.0; // double averageCorrect = 0.0; double averageLineCount = 0.0; int k = 0; double step = 0.0; int[] bumps = new int[] {1, 2, 5}; double lineCount = 0; protected int num_categories = 2; protected int FeatureVectorSize = -1; // protected int BatchSize = 200; protected double Lambda = 1.0e-4; protected double LearningRate = 10; String LocalInputSplitPath = ""; String PredictorLabelNames = ""; String PredictorVariableTypes = ""; protected String TargetVariableName = ""; protected String ColumnHeaderNames = ""; protected int NumberPasses = 1; // protected int LocalPassCount = 0; // protected int GlobalPassCount = 0; protected String RecordFactoryClassname = ""; /** * used with unit tests to pre-set a Configuration * * * @param c */ public void setConf(Configuration c) { this.conf = c; } public Configuration getConf() { return this.conf; } /** * Loads the config from [HDFS / JobConf] * * * NOTES - the mechanics of Configuration may be different in this context - * where does Configuration typically get its info from onload? * * @throws Exception * */ public void LoadConfigVarsLocally() throws Exception { // figure out how many features we need this.bConfLoaded = false; // this is hard set with LR to 2 classes this.num_categories = this.conf.getInt( "com.cloudera.knittingboar.setup.numCategories", 2); // feature vector size this.FeatureVectorSize = LoadIntConfVarOrException( "com.cloudera.knittingboar.setup.FeatureVectorSize", "Error loading config: could not load feature vector size"); // feature vector size // this.BatchSize = this.conf.getInt( // "com.cloudera.knittingboar.setup.BatchSize", 200); this.NumberPasses = this.conf.getInt( "com.cloudera.knittingboar.setup.NumberPasses", 1); // protected double Lambda = 1.0e-4; this.Lambda = Double.parseDouble(this.conf.get( "com.cloudera.knittingboar.setup.Lambda", "1.0e-4")); // protected double LearningRate = 50; this.LearningRate = Double.parseDouble(this.conf.get( "com.cloudera.knittingboar.setup.LearningRate", "10")); // local input split path // this.LocalInputSplitPath = LoadStringConfVarOrException( // "com.cloudera.knittingboar.setup.LocalInputSplitPath", // "Error loading config: could not load local input split path"); // System.out.println("LoadConfig()"); // maps to either CSV, 20newsgroups, or RCV1 this.RecordFactoryClassname = LoadStringConfVarOrException( "com.cloudera.knittingboar.setup.RecordFactoryClassname", "Error loading config: could not load RecordFactory classname"); if (this.RecordFactoryClassname.equals(RecordFactory.CSV_RECORDFACTORY)) { // so load the CSV specific stuff ---------- // predictor label names this.PredictorLabelNames = LoadStringConfVarOrException( "com.cloudera.knittingboar.setup.PredictorLabelNames", "Error loading config: could not load predictor label names"); // predictor var types this.PredictorVariableTypes = LoadStringConfVarOrException( "com.cloudera.knittingboar.setup.PredictorVariableTypes", "Error loading config: could not load predictor variable types"); // target variables this.TargetVariableName = LoadStringConfVarOrException( "com.cloudera.knittingboar.setup.TargetVariableName", "Error loading config: Target Variable Name"); // column header names this.ColumnHeaderNames = LoadStringConfVarOrException( "com.cloudera.knittingboar.setup.ColumnHeaderNames", "Error loading config: Column Header Names"); // System.out.println("LoadConfig(): " + this.ColumnHeaderNames); } this.bConfLoaded = true; } /* public int GetCurrentLocalPassCount() { return this.LocalPassCount; } public void IncGlobalPassCount() { this.GlobalPassCount++; } */ private String LoadStringConfVarOrException(String ConfVarName, String ExcepMsg) throws Exception { if (null == this.conf.get(ConfVarName)) { throw new Exception(ExcepMsg); } else { return this.conf.get(ConfVarName); } } private int LoadIntConfVarOrException(String ConfVarName, String ExcepMsg) throws Exception { if (null == this.conf.get(ConfVarName)) { throw new Exception(ExcepMsg); } else { return this.conf.getInt(ConfVarName, 0); } } public void SetCore(ParallelOnlineLogisticRegression plr, POLRModelParameters params, RecordFactory fac) { this.polr = plr; this.polr_modelparams = params; this.VectorFactory = fac; this.num_categories = this.polr_modelparams.getMaxTargetCategories(); this.FeatureVectorSize = this.polr_modelparams.getNumFeatures(); //this.BatchSize = 10000; } public void Setup() { // do splitting strings into arrays here... this.num_categories = -1; // polr_modelparams.getMaxTargetCategories(); this.FeatureVectorSize = -1; // polr_modelparams.getNumFeatures(); // this.BatchSize = 10000; // setup record factory stuff here --------- if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY .equals(this.RecordFactoryClassname)) { this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); // this.VectorFactory.setClassSplitString("\t"); // System.out.println( // "POLRModelTester: TwentyNewsgroupsRecordFactory\n\n" ); } else if (RecordFactory.RCV1_RECORDFACTORY .equals(this.RecordFactoryClassname)) { this.VectorFactory = new RCV1RecordFactory(); } else { System.out.println("POLRModelTester: CSV is broken!!\n\n\n"); this.VectorFactory = new CSVBasedDatasetRecordFactory( this.TargetVariableName, polr_modelparams.getTypeMap()); ((CSVBasedDatasetRecordFactory) this.VectorFactory) .firstLine(this.ColumnHeaderNames); } // this.bSetup = true; } /** * Runs the next training batch to prep the gamma buffer to send to the * mstr_node * * TODO: need to provide stats, group measurements into struct * * @throws Exception * @throws IOException */ public void RunThroughTestRecords() throws IOException, Exception { Text value = new Text(); long batch_vec_factory_time = 0; k = 0; int num_correct = 0; // for (int x = 0; x < this.BatchSize; x++) { while (true) { if (this.input_split.next(value)) { long startTime = System.currentTimeMillis(); Vector v = new RandomAccessSparseVector(this.FeatureVectorSize); int actual = this.VectorFactory.processLine(value.toString(), v); long endTime = System.currentTimeMillis(); // System.out.println("That took " + (endTime - startTime) + // " milliseconds"); batch_vec_factory_time += (endTime - startTime); String ng = this.VectorFactory.GetClassnameByID(actual); // .GetNewsgroupNameByID( // actual ); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = this.polr.logLikelihood(actual, v); if (Double.isNaN(ll)) { /* * System.out.println(" --------- NaN -----------"); * * System.out.println( "k: " + k ); System.out.println( "ll: " + ll ); * System.out.println( "mu: " + mu ); */ // return; } else { metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu; } Vector p = new DenseVector(20); this.polr.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual ? 1 : 0); if (estimated == actual) { num_correct++; } // averageCorrect = averageCorrect + (correct - averageCorrect) / mu; metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu; // this.polr.train(actual, v); k++; // if (x == this.BatchSize - 1) { int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; System.out .printf( "Worker %s:\t Trained Recs: %10d, numCorrect: %d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", this.internalID, k, num_correct, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); } this.polr.close(); } else { // nothing else to process in split! break; } // if } // for the number of passes in the run } /** * NOTE: This should only be used for durability purposes in checkpointing the * workers * * @param path * @throws IOException */ public void Load(String path) throws IOException { InputStream in = new FileInputStream(path); try { polr_modelparams = POLRModelParameters.loadFrom(in); System.out.println("> tester: model loaded"); } finally { Closeables.closeQuietly(in); } // System.out.println( "POLRModelTester > num categories is hardcoded to 2" // ); this.num_categories = polr_modelparams.getMaxTargetCategories(); this.FeatureVectorSize = polr_modelparams.getNumFeatures(); /* * this.polr = new ParallelOnlineLogisticRegression(this.num_categories, * this.FeatureVectorSize, new L1()) .alpha(1).stepOffset(1000) * .decayExponent(0.9) .lambda(3.0e-5) .learningRate(20); */ this.polr = polr_modelparams.getPOLR(); // System.out.println(")))))))))) Learning rate: " + this.Lambda); } public void setupInputSplit(InputRecordsSplit split) { this.input_split = split; } public void Debug() throws IOException { System.out.println("POLRModelTester --------------------------- "); System.out.println("> Num Categories: " + this.num_categories); System.out.println("> FeatureVecSize: " + this.FeatureVectorSize); this.polr_modelparams.Debug(); } }