/** * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.nn; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; import ml.shifu.guagua.io.Combinable; import ml.shifu.guagua.io.HaltBytable; import ml.shifu.shifu.core.dtrain.DTrainUtils; /** * NNParams are used to save NN model info which can also be stored into ZooKeeper. * * <p> * {@link #weights} is used to set model weights which is used to transfer info from master to workers. * * <p> * {@link #gradients} is used to accumulate all workers' gradients together in master and then use the accumulated * gradients to update weights. */ public class NNParams extends HaltBytable implements Combinable<NNParams> { /** * Weights used for NN model */ private double[] weights; /** * Gradients for NN model */ private double[] gradients; /** * Current test error which can be sent to master */ private double testError = 0; /** * Current train error which can be sent to master */ private double trainError = 0; /** * Training size of each worker and master */ private long trainSize = 0; /** * Total size of record */ private long count = 0L; /** * Worker count for such iteration. */ private int wrCount = 1; public double[] getWeights() { return weights; } public void setWeights(double[] weights) { this.weights = weights; } public double getTestError() { return testError; } public void setTestError(double testError) { this.testError = testError; } public double getTrainError() { return trainError; } public void setTrainError(double trainError) { this.trainError = trainError; } public void accumulateGradients(double[] gradients) { if(this.gradients == null) { this.gradients = new double[gradients.length]; Arrays.fill(this.gradients, 0.0); } if(this.weights == null) { this.weights = new double[gradients.length]; DTrainUtils.randomize(gradients.length, this.weights); } for(int i = 0; i < gradients.length; i++) { this.gradients[i] += gradients[i]; } } /** * @return the gradients */ public double[] getGradients() { return gradients; } /** * @param gradients * the gradients to set */ public void setGradients(double[] gradients) { this.gradients = gradients; } public long getTrainSize() { return trainSize; } public void setTrainSize(long trainSize) { this.trainSize = trainSize; } public void accumulateTrainSize(long size) { this.trainSize = this.getTrainSize() + size; } public void reset() { this.setTrainSize(0); if(this.gradients != null) { Arrays.fill(this.gradients, 0.0); } } @Override public void doWrite(DataOutput out) throws IOException { out.writeDouble(getTrainError()); out.writeDouble(getTestError()); out.writeLong(getTrainSize()); out.writeInt(getWeights().length); for(double weight: getWeights()) { out.writeDouble(weight); } out.writeInt(getGradients().length); for(double gradient: getGradients()) { out.writeDouble(gradient); } out.writeLong(count); out.writeInt(this.wrCount); } @Override public void doReadFields(DataInput in) throws IOException { this.trainError = in.readDouble(); this.testError = in.readDouble(); this.trainSize = in.readLong(); int len = in.readInt(); double[] weights = new double[len]; for(int i = 0; i < len; i++) { weights[i] = in.readDouble(); } this.weights = weights; len = in.readInt(); double[] gradients = new double[len]; for(int i = 0; i < len; i++) { gradients[i] = in.readDouble(); } this.gradients = gradients; this.count = in.readLong(); this.wrCount = in.readInt(); } /** * @return the count */ public long getCount() { return count; } /** * @param count * the count to set */ public void setCount(long count) { this.count = count; } /* * (non-Javadoc) * * @see ml.shifu.guagua.io.Combinable#combine(ml.shifu.guagua.io.Bytable) */ @Override public NNParams combine(NNParams from) { assert from != null; this.count += from.count; this.trainSize += from.trainSize; this.trainError += from.trainError; this.testError += from.testError; assert this.gradients != null && from.gradients != null; for(int i = 0; i < this.gradients.length; i++) { this.gradients[i] += from.gradients[i]; } this.setWrCount(this.getWrCount() + from.getWrCount()); return this; } /** * @return the wrCount */ public int getWrCount() { return wrCount; } /** * @param wrCount * the wrCount to set */ public void setWrCount(int wrCount) { this.wrCount = wrCount; } @Override public String toString() { return String.format("NNParams [testError=%s, trainError=%s, trainSize=%s, wrCount=%s, gSize=%s]", this.testError, this.trainError, this.trainSize, this.getWrCount(), this.gradients != null ? this.gradients.length : 0); } }