/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.cloudera.knittingboar.messages.iterativereduce; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInput; import java.io.DataInputStream; import java.io.DataOutput; import java.io.DataOutputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.MatrixWritable; public class ParameterVector { // worker stuff to send out public int SrcWorkerPassCount = 0; public Matrix parameter_vector = null; public int GlobalPassCount = 0; // what pass should the worker dealing with? public int IterationComplete = 0; // 0 = no, 1 = yes public int CurrentIteration = 0; public int TrainedRecords = 0; public float AvgLogLikelihood = 0; public float PercentCorrect = 0; public byte[] Serialize() throws IOException { // DataOutput d ByteArrayOutputStream out = new ByteArrayOutputStream(); DataOutput d = new DataOutputStream(out); // d.writeUTF(src_host); d.writeInt(this.SrcWorkerPassCount); d.writeInt(this.GlobalPassCount); d.writeInt(this.IterationComplete); d.writeInt(this.CurrentIteration); d.writeInt(this.TrainedRecords); d.writeFloat(this.AvgLogLikelihood); d.writeFloat(this.PercentCorrect); // buf.write // MatrixWritable.writeMatrix(d, this.worker_gradient.getMatrix()); MatrixWritable.writeMatrix(d, this.parameter_vector); // MatrixWritable. return out.toByteArray(); } public void Deserialize(byte[] bytes) throws IOException { // DataInput in) throws IOException { ByteArrayInputStream b = new ByteArrayInputStream(bytes); DataInput in = new DataInputStream(b); // this.src_host = in.readUTF(); this.SrcWorkerPassCount = in.readInt(); this.GlobalPassCount = in.readInt(); this.IterationComplete = in.readInt(); this.CurrentIteration = in.readInt(); this.TrainedRecords = in.readInt(); // d.writeInt(this.TrainedRecords); this.AvgLogLikelihood = in.readFloat(); // d.writeFloat(this.AvgLogLikelihood); this.PercentCorrect = in.readFloat(); // d.writeFloat(this.PercentCorrect); this.parameter_vector = MatrixWritable.readMatrix(in); } public int numFeatures() { return this.parameter_vector.numCols(); } public int numCategories() { return this.parameter_vector.numRows(); } /** * TODO: fix loop * * @param other_gamma */ public void AccumulateParameterVector(Matrix other_gamma) { // this.gamma.plus(arg0) for (int row = 0; row < this.parameter_vector.rowSize(); row++) { for (int col = 0; col < this.parameter_vector.columnSize(); col++) { double old_this_val = this.parameter_vector.get(row, col); double other_val = other_gamma.get(row, col); // System.out.println( "Accumulate: " + old_this_val + ", " + other_val // ); this.parameter_vector.set(row, col, old_this_val + other_val); // System.out.println( "new value: " + this.gamma.get(row, col) ); } } // this.AccumulatedGradientsCount++; } /* public void Accumulate(GradientBuffer other_gamma) { for (int row = 0; row < this.gamma.rowSize(); row++) { for (int col = 0; col < this.gamma.columnSize(); col++) { double old_this_val = this.gamma.get(row, col); double other_val = other_gamma.getCell(row, col); this.gamma.set(row, col, old_this_val + other_val); } } this.AccumulatedGradientsCount++; } */ /** * TODO: Need to take a look at built in matrix ops here * */ public void AverageParameterVectors(int denominator) { for (int row = 0; row < this.parameter_vector.rowSize(); row++) { for (int col = 0; col < this.parameter_vector.columnSize(); col++) { double old_this_val = this.parameter_vector.get(row, col); // double other_val = other_gamma.getCell(row, col); this.parameter_vector.set(row, col, old_this_val / denominator); } } } }