/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.cloudera.knittingboar.sgd.iterativereduce; import java.io.DataOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.classifier.sgd.L1; import org.apache.mahout.classifier.sgd.UniformPrior; import org.apache.mahout.math.DenseMatrix; //import com.cloudera.knittingboar.messages.GlobalParameterVectorUpdateMessage; //import com.cloudera.knittingboar.messages.GradientUpdateMessage; import com.cloudera.knittingboar.messages.iterativereduce.ParameterVector; import com.cloudera.knittingboar.messages.iterativereduce.ParameterVectorUpdatable; import com.cloudera.knittingboar.records.CSVBasedDatasetRecordFactory; import com.cloudera.knittingboar.records.RCV1RecordFactory; import com.cloudera.knittingboar.records.RecordFactory; import com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory; //import com.cloudera.knittingboar.sgd.MultinomialLogisticRegressionParameterVectors; import com.cloudera.knittingboar.sgd.POLRModelParameters; import com.cloudera.knittingboar.sgd.ParallelOnlineLogisticRegression; //import com.cloudera.knittingboar.yarn.appmaster.ApplicationMaster; //import com.cloudera.knittingboar.yarn.appmaster.ComputableMaster; import com.cloudera.iterativereduce.yarn.appmaster.ApplicationMaster; import com.cloudera.iterativereduce.ComputableMaster; //import com.cloudera.iterativereduce.yarn import com.google.common.collect.Lists; /** * Master node for the IR-KnittingBoar YARN process - coordinates the parallel * SGD process amongst many workers - gets the parameter vector updates from * many workers and averages them together, sending them back to the workers * * * @author jpatterson * */ public class POLRMasterNode extends POLRNodeBase implements ComputableMaster<ParameterVectorUpdatable> { private static final Log LOG = LogFactory.getLog(POLRMasterNode.class); ParameterVector global_parameter_vector = null; private int GlobalMaxPassCount = 0; private int Global_Min_IterationCount = 0; // these are only used for saving the model public ParallelOnlineLogisticRegression polr = null; public POLRModelParameters polr_modelparams; private RecordFactory VectorFactory = null; @Override public ParameterVectorUpdatable compute( Collection<ParameterVectorUpdatable> workerUpdates, Collection<ParameterVectorUpdatable> masterUpdates) { System.out.println("\nMaster Compute: SuperStep - Worker Info ----- "); int x = 0; // reset //this.Global_Min_IterationCount = this.NumberPasses; boolean iterationComplete = true; this.global_parameter_vector.parameter_vector = new DenseMatrix(this.num_categories - 1, this.FeatureVectorSize); for (ParameterVectorUpdatable i : workerUpdates) { // not sure we still need this --------------- if (i.get().SrcWorkerPassCount > this.GlobalMaxPassCount) { this.GlobalMaxPassCount = i.get().SrcWorkerPassCount; } // if any worker is not done with hte iteration, trip the flag if (i.get().IterationComplete == 0 ) { //this.Global_Min_IterationCount = i.get().IterationCount; iterationComplete = false; } System.out.println("[Master] WorkerReport[" + x + "]: I: " + i.get().CurrentIteration + ", IC: " + i.get().IterationComplete + " Trained Recs: " + i.get().TrainedRecords + " AvgLogLikelihood: " + i.get().AvgLogLikelihood + " PercentCorrect: " + i.get().PercentCorrect); if ( i.get().IterationComplete == 1) { System.out.println( "> worker " + x + " is done with current iteration" ); } x++; // accumulate gradient of parameter vectors //this.global_parameter_vector.AccumulateGradient(i.get().parameter_vector); this.global_parameter_vector.AccumulateParameterVector(i.get().parameter_vector); } // now average the parameter vectors together //this.global_parameter_vector.AverageAccumulations(workerUpdates.size()); this.global_parameter_vector.AverageParameterVectors(workerUpdates.size()); LOG.debug("Master node accumulating and averaging " + workerUpdates.size() + " worker updates."); ParameterVector gradient_msg = new ParameterVector(); gradient_msg.GlobalPassCount = this.GlobalMaxPassCount; /* if (iterationComplete) { gradient_msg.IterationComplete = 1; System.out.println( "> Master says: Iteration Complete" ); } else { gradient_msg.IterationComplete = 0; } */ gradient_msg.parameter_vector = this.global_parameter_vector.parameter_vector.clone(); ParameterVectorUpdatable return_msg = new ParameterVectorUpdatable(); return_msg.set(gradient_msg); // set the master copy! this.polr.SetBeta(this.global_parameter_vector.parameter_vector.clone()); // THIS NEEDS TO BE DONE, probably automated! workerUpdates.clear(); return return_msg; } @Override public ParameterVectorUpdatable getResults() { System.out.println(">>> getResults() - null!!!"); return null; } @Override public void setup(Configuration c) { this.conf = c; try { // this is hard set with LR to 2 classes this.num_categories = this.conf.getInt( "com.cloudera.knittingboar.setup.numCategories", 2); // feature vector size this.FeatureVectorSize = LoadIntConfVarOrException( "com.cloudera.knittingboar.setup.FeatureVectorSize", "Error loading config: could not load feature vector size"); // feature vector size // this.BatchSize = this.conf.getInt( // "com.cloudera.knittingboar.setup.BatchSize", 200); // this.NumberPasses = this.conf.getInt( // "com.cloudera.knittingboar.setup.NumberPasses", 1); this.NumberIterations = this.conf.getInt("app.iteration.count", 1); // protected double Lambda = 1.0e-4; this.Lambda = Double.parseDouble(this.conf.get( "com.cloudera.knittingboar.setup.Lambda", "1.0e-4")); // protected double LearningRate = 50; this.LearningRate = Double.parseDouble(this.conf.get( "com.cloudera.knittingboar.setup.LearningRate", "10")); // local input split path // this.LocalInputSplitPath = LoadStringConfVarOrException( // "com.cloudera.knittingboar.setup.LocalInputSplitPath", // "Error loading config: could not load local input split path"); // System.out.println("LoadConfig()"); // maps to either CSV, 20newsgroups, or RCV1 this.RecordFactoryClassname = LoadStringConfVarOrException( "com.cloudera.knittingboar.setup.RecordFactoryClassname", "Error loading config: could not load RecordFactory classname"); if (this.RecordFactoryClassname.equals(RecordFactory.CSV_RECORDFACTORY)) { // so load the CSV specific stuff ---------- System.out .println("----- Loading CSV RecordFactory Specific Stuff -------"); // predictor label names this.PredictorLabelNames = LoadStringConfVarOrException( "com.cloudera.knittingboar.setup.PredictorLabelNames", "Error loading config: could not load predictor label names"); // predictor var types this.PredictorVariableTypes = LoadStringConfVarOrException( "com.cloudera.knittingboar.setup.PredictorVariableTypes", "Error loading config: could not load predictor variable types"); // target variables this.TargetVariableName = LoadStringConfVarOrException( "com.cloudera.knittingboar.setup.TargetVariableName", "Error loading config: Target Variable Name"); // column header names this.ColumnHeaderNames = LoadStringConfVarOrException( "com.cloudera.knittingboar.setup.ColumnHeaderNames", "Error loading config: Column Header Names"); // System.out.println("LoadConfig(): " + this.ColumnHeaderNames); } } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); System.out.println(">> Error loading conf!"); } System.out.println( "-----------------------------------------" ); System.out.println( "# Master Conf #" ); System.out.println( "Number Iterations: " + this.NumberIterations ); System.out.println( "-----------------------------------------\n\n" ); this.SetupPOLR(); } // setup() public void SetupPOLR() { System.err.println("SetupOLR: " + this.num_categories + ", " + this.FeatureVectorSize); LOG.debug("SetupOLR: " + this.num_categories + ", " + this.FeatureVectorSize); this.global_parameter_vector = new ParameterVector(); //this.num_categories, //this.FeatureVectorSize); String[] predictor_label_names = this.PredictorLabelNames.split(","); String[] variable_types = this.PredictorVariableTypes.split(","); polr_modelparams = new POLRModelParameters(); polr_modelparams.setTargetVariable(this.TargetVariableName); // getStringArgument(cmdLine, // target)); polr_modelparams.setNumFeatures(this.FeatureVectorSize); polr_modelparams.setUseBias(true); // !getBooleanArgument(cmdLine, noBias)); List<String> typeList = Lists.newArrayList(); for (int x = 0; x < variable_types.length; x++) { typeList.add(variable_types[x]); } List<String> predictorList = Lists.newArrayList(); for (int x = 0; x < predictor_label_names.length; x++) { predictorList.add(predictor_label_names[x]); } polr_modelparams.setTypeMap(predictorList, typeList); polr_modelparams.setLambda(this.Lambda); // based on defaults - match // command line polr_modelparams.setLearningRate(this.LearningRate); // based on defaults - // match command line // setup record factory stuff here --------- if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY .equals(this.RecordFactoryClassname)) { this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); } else if (RecordFactory.RCV1_RECORDFACTORY .equals(this.RecordFactoryClassname)) { this.VectorFactory = new RCV1RecordFactory(); } else { // need to rethink this this.VectorFactory = new CSVBasedDatasetRecordFactory( this.TargetVariableName, polr_modelparams.getTypeMap()); ((CSVBasedDatasetRecordFactory) this.VectorFactory) .firstLine(this.ColumnHeaderNames); } polr_modelparams.setTargetCategories(this.VectorFactory .getTargetCategories()); // ----- this normally is generated from the POLRModelParams ------ this.polr = new ParallelOnlineLogisticRegression(this.num_categories, this.FeatureVectorSize, new UniformPrior()).alpha(1).stepOffset(1000) .decayExponent(0.9).lambda(this.Lambda).learningRate(this.LearningRate); polr_modelparams.setPOLR(polr); // this.bSetup = true; } @Override public void complete(DataOutputStream out) throws IOException { // TODO Auto-generated method stub System.out.println("master::complete "); System.out.println("complete-ms:" + System.currentTimeMillis()); LOG.debug("Master complete, saving model."); try { this.polr_modelparams.saveTo(out); } catch (Exception ex) { throw new IOException("Unable to save model", ex); } } public static void main(String[] args) throws Exception { POLRMasterNode pmn = new POLRMasterNode(); ApplicationMaster<ParameterVectorUpdatable> am = new ApplicationMaster<ParameterVectorUpdatable>( pmn, ParameterVectorUpdatable.class); ToolRunner.run(am, args); } }