/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.cloudera.knittingboar.metrics;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.TextInputFormat;
import com.cloudera.iterativereduce.irunit.IRUnitDriver;
import com.cloudera.knittingboar.io.InputRecordsSplit;
import com.cloudera.knittingboar.sgd.iterativereduce.POLRMasterNode;
import com.cloudera.knittingboar.utils.DataUtils;
import com.cloudera.knittingboar.utils.DatasetConverter;
import junit.framework.TestCase;
/**
* This unit test tests running a 4 worker simulated parallel SGD process, saving the model,
* loading the model, and then checking to see that the parameters of the model deserialized correctly
*
* @author jpatterson
*
*/
public class TestSaveLoadModel extends TestCase {
private static JobConf defaultConf = new JobConf();
private static FileSystem localFs = null;
static {
try {
defaultConf.set("fs.defaultFS", "file:///");
localFs = FileSystem.getLocal(defaultConf);
} catch (IOException e) {
throw new RuntimeException("init failure", e);
}
}
private static Path workDir20NewsLocal = new Path(new Path("/tmp"), "TestSaveLoadModel");
private static File unzipDir = new File( workDir20NewsLocal + "/20news-bydate");
private static String strKBoarTrainDirInput = "" + unzipDir.toString() + "/KBoar-train/";
//private static String strKBoarTestDirInput = "" + unzipDir.toString() + "/KBoar-test/";
// location of N splits of KBoar converted data ---
private static Path workDir = new Path( strKBoarTrainDirInput ); //DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-train/" );
public Configuration generateDebugConfigurationObject() {
Configuration c = new Configuration();
// feature vector size
c.setInt( "com.cloudera.knittingboar.setup.FeatureVectorSize", 10000 );
c.setInt( "com.cloudera.knittingboar.setup.numCategories", 20);
// setup 20newsgroups
c.set( "com.cloudera.knittingboar.setup.RecordFactoryClassname", "com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory");
return c;
}
public void testRunMasterAndFourWorkers() throws Exception {
DataUtils.getTwentyNewsGroupDir();
// convert the training data into 4 shards
DatasetConverter.ConvertNewsgroupsFromSingleFiles( DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-train/", strKBoarTrainDirInput, 3000);
// convert the test data into 1 shard
// DatasetConverter.ConvertNewsgroupsFromSingleFiles( DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-test/", strKBoarTestDirInput, 12000);
int num_passes = 15;
String[] props = {
"app.iteration.count",
"com.cloudera.knittingboar.setup.FeatureVectorSize",
"com.cloudera.knittingboar.setup.numCategories",
"com.cloudera.knittingboar.setup.RecordFactoryClassname"
};
IRUnitDriver polr_driver = new IRUnitDriver("src/test/resources/app_unit_test.properties", props );
polr_driver.SetProperty("app.input.path", strKBoarTrainDirInput);
polr_driver.Setup();
polr_driver.SimulateRun();
System.out.println("\n\nComplete...");
POLRMasterNode IR_Master = (POLRMasterNode)polr_driver.getMaster();
Path out = new Path("/tmp/TestSaveLoadModel.model");
FileSystem fs = out.getFileSystem(defaultConf);
FSDataOutputStream fos = fs.create(out);
IR_Master.complete(fos);
fos.flush();
fos.close();
System.out.println("\n\nModel Saved: /tmp/TestSaveLoadModel.model" );
System.out.println( "\n\n> Loading Model for tests..." );
POLRModelTester tester = new POLRModelTester();
// ------------------
// generate the debug conf ---- normally setup by YARN stuff
tester.setConf(this.generateDebugConfigurationObject());
// now load the conf stuff into locally used vars
try {
tester.LoadConfigVarsLocally();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
System.out.println( "Conf load fail: shutting down." );
assertEquals( 0, 1 );
}
// now construct any needed machine learning data structures based on config
tester.Setup();
tester.Load( "/tmp/TestSaveLoadModel.model" );
assertEquals( 1.0e-4, tester.polr.getLambda() );
assertEquals( 20, tester.polr.numCategories() );
}
}