/* * Copyright [2013-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.guagua.example.lnr; import java.util.Arrays; import java.util.Random; import ml.shifu.guagua.master.AbstractMasterComputable; import ml.shifu.guagua.master.MasterContext; import ml.shifu.guagua.util.NumberFormatUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * {@link LinearRegressionMaster} defines logic to update global linear regression</a> model. * * <p> * At first iteration, master builds a random model then send to all workers to start computing. This is to make all * workers use the same model at the starting time. * * <p> * At other iterations, master works: * <ul> * <li>1. Accumulate all gradients from workers.</li> * <li>2. Update global models by using accumulated gradients.</li> * <li>3. Send new global model to workers by returning model parameters.</li> * </ul> */ public class LinearRegressionMaster extends AbstractMasterComputable<LinearRegressionParams, LinearRegressionParams> { private static final Logger LOG = LoggerFactory.getLogger(LinearRegressionMaster.class); private static final Random RANDOM = new Random(); private int inputNum; private double[] weights; private double learnRate; @Override public void init(MasterContext<LinearRegressionParams, LinearRegressionParams> context) { this.inputNum = NumberFormatUtils.getInt(LinearRegressionContants.LR_INPUT_NUM, LinearRegressionContants.LR_INPUT_DEFAULT_NUM); this.learnRate = NumberFormatUtils.getDouble(LinearRegressionContants.LR_LEARNING_RATE, LinearRegressionContants.LR_LEARNING_DEFAULT_RATE); // not initialized and not first iteration, should be fault tolerence, recover state in LogisticRegressionMaster if(!context.isFirstIteration()) { LinearRegressionParams lastMasterResult = context.getMasterResult(); if(lastMasterResult != null && lastMasterResult.getParameters() != null) { // recover state in current master computable and return to workers this.weights = lastMasterResult.getParameters(); } else { // no weights, restarted from the very beginning, this may not happen initWeights(); } } } @Override public LinearRegressionParams doCompute(MasterContext<LinearRegressionParams, LinearRegressionParams> context) { if(context.isFirstIteration()) { initWeights(); } else { double[] gradients = new double[this.inputNum + 1]; double sumError = 0.0d; int size = 0; for(LinearRegressionParams param: context.getWorkerResults()) { if(param != null) { for(int i = 0; i < gradients.length; i++) { gradients[i] += param.getParameters()[i]; } sumError += param.getError(); } size++; } for(int i = 0; i < weights.length; i++) { weights[i] -= learnRate * gradients[i]; } LOG.info("DEBUG: Weights: {}", Arrays.toString(this.weights)); LOG.info("Iteration {} with error {}", context.getCurrentIteration(), sumError / size); } return new LinearRegressionParams(weights); } /** * */ private void initWeights() { weights = new double[this.inputNum + 1]; for(int i = 0; i < weights.length; i++) { weights[i] = RANDOM.nextDouble(); } } }