/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.cloudera.knittingboar.messages; import java.io.DataInput; import java.io.DataInputStream; import java.io.DataOutput; import java.io.DataOutputStream; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import org.apache.mahout.math.DenseMatrix; import org.apache.mahout.math.Matrix; import junit.framework.TestCase; import com.cloudera.knittingboar.messages.iterativereduce.ParameterVector; public class TestParameterVector extends TestCase { public static String msg_file = "/tmp/TestGradientUpdateMessageSerde.msg"; //public static String ip = "255.255.255.1"; public static int pass_count = 8; public void testSerde() throws IOException { int classes = 20; int features = 10000; //GradientBuffer g = new GradientBuffer( classes, features ); Matrix m = new DenseMatrix(classes, features); //m.set(0, 0, 0.1); //m.set(0, 1, 0.3); //g.numFeatures(); for (int c = 0; c < classes - 1; c++) { for (int f = 0; f < features; f++ ) { m.set(c, f, (double)((double)f / 10.0f) ); } } System.out.println( "matrix created..." ); ParameterVector vec_gradient = new ParameterVector(); vec_gradient.SrcWorkerPassCount = pass_count; vec_gradient.parameter_vector = m; vec_gradient.AvgLogLikelihood = -1.368f; vec_gradient.PercentCorrect = 72.68f; vec_gradient.TrainedRecords = 2500; assertEquals( 10000, vec_gradient.numFeatures() ); assertEquals( 10000, vec_gradient.parameter_vector.columnSize() ); assertEquals( 20, vec_gradient.numCategories() ); assertEquals( 20, vec_gradient.parameter_vector.rowSize() ); byte[] buf = vec_gradient.Serialize(); ParameterVector vec_gradient_deserialized = new ParameterVector(); vec_gradient_deserialized.Deserialize(buf); assertEquals( pass_count, vec_gradient_deserialized.SrcWorkerPassCount ); assertEquals( 0.1, vec_gradient_deserialized.parameter_vector.get(0, 1) ); assertEquals( 0.2, vec_gradient_deserialized.parameter_vector.get(0, 2) ); assertEquals( 0.3, vec_gradient_deserialized.parameter_vector.get(0, 3) ); assertEquals( 0.4, vec_gradient_deserialized.parameter_vector.get(0, 4) ); assertEquals( 0.5, vec_gradient_deserialized.parameter_vector.get(0, 5) ); assertEquals( -1.368f, vec_gradient_deserialized.AvgLogLikelihood ); assertEquals( 72.68f, vec_gradient_deserialized.PercentCorrect ); assertEquals( 2500, vec_gradient_deserialized.TrainedRecords ); } }