/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.cloudera.knittingboar.sgd.olr; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.FileInputFormat; import org.apache.hadoop.mapred.InputSplit; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.TextInputFormat; import org.apache.mahout.classifier.sgd.ModelSerializer; import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import com.cloudera.knittingboar.io.InputRecordsSplit; import com.cloudera.knittingboar.metrics.POLRMetrics; import com.cloudera.knittingboar.metrics.POLRModelTester; import com.cloudera.knittingboar.records.RecordFactory; import com.cloudera.knittingboar.records.TwentyNewsgroupsRecordFactory; import com.cloudera.knittingboar.utils.DataUtils; import com.cloudera.knittingboar.utils.DatasetConverter; import junit.framework.TestCase; /** * Mainly just a demo to show how we'd test the 20Newsgroups model generated * with OLR * * @author jpatterson * */ public class TestBaseOLRTest20Newsgroups extends TestCase { //private static Path testData20News = new Path(System.getProperty("test.build.data", "/Users/jpatterson/Downloads/datasets/20news-kboar/test/kboar-shard-0.txt")); //private static Path model20News = new Path( "/Users/jpatterson/Downloads/datasets/20news-kboar/models/model_10_31pm.model" ); private static Path model20News = new Path( "/tmp/olr-news-group.model" ); //private static Path testData20News = new Path(System.getProperty("test.build.data", "/Users/jpatterson/Downloads/datasets/20news-kboar/test/")); private static final int FEATURES = 10000; private static JobConf defaultConf = new JobConf(); private static FileSystem localFs = null; static { try { defaultConf.set("fs.defaultFS", "file:///"); localFs = FileSystem.getLocal(defaultConf); } catch (IOException e) { throw new RuntimeException("init failure", e); } } POLRMetrics metrics = new POLRMetrics(); //double averageLL = 0.0; //double averageCorrect = 0.0; double averageLineCount = 0.0; int k = 0; double step = 0.0; int[] bumps = new int[]{1, 2, 5}; double lineCount = 0; private static Path workDir20NewsLocal = new Path(new Path("/tmp"), "Dataset20Newsgroups"); private static File unzipDir = new File( workDir20NewsLocal + "/20news-bydate"); private static String strKBoarTestDirInput = "" + unzipDir.toString() + "/KBoar-test/"; public Configuration generateDebugConfigurationObject() { Configuration c = new Configuration(); // feature vector size c.setInt( "com.cloudera.knittingboar.setup.FeatureVectorSize", 10000 ); c.setInt( "com.cloudera.knittingboar.setup.numCategories", 20); // setup 20newsgroups c.set( "com.cloudera.knittingboar.setup.RecordFactoryClassname", RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY); return c; } public InputSplit[] generateDebugSplits( Path input_path, JobConf job ) { long block_size = localFs.getDefaultBlockSize(); System.out.println("default block size: " + (block_size / 1024 / 1024) + "MB"); // ---- set where we'll read the input files from ------------- //FileInputFormat.setInputPaths(job, workDir); FileInputFormat.setInputPaths(job, input_path); // try splitting the file in a variety of sizes TextInputFormat format = new TextInputFormat(); format.configure(job); //LongWritable key = new LongWritable(); //Text value = new Text(); int numSplits = 1; InputSplit[] splits = null; try { splits = format.getSplits(job, numSplits); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } return splits; } public void testResults() throws Exception { File file20News = DataUtils.getTwentyNewsGroupDir(); DatasetConverter.ConvertNewsgroupsFromSingleFiles( DataUtils.get20NewsgroupsLocalDataLocation() + "/20news-bydate-test/", strKBoarTestDirInput, 6000); // File base = new File( file20News + "/20news-bydate-train/" ); System.out.println( "Testing on: " + strKBoarTestDirInput ); OnlineLogisticRegression classifier = ModelSerializer.readBinary(new FileInputStream(model20News.toString()), OnlineLogisticRegression.class); Text value = new Text(); long batch_vec_factory_time = 0; int k = 0; int num_correct = 0; // ---- this all needs to be done in JobConf job = new JobConf(defaultConf); // TODO: work on this, splits are generating for everything in dir // InputSplit[] splits = generateDebugSplits(inputDir, job); Path strKBoarTestDirInputPath = new Path( strKBoarTestDirInput ); //fullRCV1Dir InputSplit[] splits = generateDebugSplits(strKBoarTestDirInputPath, job); System.out.println( "split count: " + splits.length ); InputRecordsSplit custom_reader_0 = new InputRecordsSplit(job, splits[0]); TwentyNewsgroupsRecordFactory VectorFactory = new TwentyNewsgroupsRecordFactory("\t"); for (int x = 0; x < 8000; x++ ) { if ( custom_reader_0.next(value)) { long startTime = System.currentTimeMillis(); Vector v = new RandomAccessSparseVector(FEATURES); int actual = VectorFactory.processLine(value.toString(), v); long endTime = System.currentTimeMillis(); //System.out.println("That took " + (endTime - startTime) + " milliseconds"); batch_vec_factory_time += (endTime - startTime); String ng = VectorFactory.GetClassnameByID(actual); //.GetNewsgroupNameByID( actual ); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = classifier.logLikelihood(actual, v); //averageLL = averageLL + (ll - averageLL) / mu; metrics.AvgLogLikelihood = metrics.AvgLogLikelihood + (ll - metrics.AvgLogLikelihood) / mu; Vector p = new DenseVector(20); classifier.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual? 1 : 0); if (estimated == actual) { num_correct++; } //averageCorrect = averageCorrect + (correct - averageCorrect) / mu; metrics.AvgCorrect = metrics.AvgCorrect + (correct - metrics.AvgCorrect) / mu; //this.polr.train(actual, v); k++; // if (x == this.BatchSize - 1) { int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; System.out.printf("Worker %s:\t Tested Recs: %10d, numCorrect: %d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n", "OLR-standard-test", k, num_correct, metrics.AvgLogLikelihood, metrics.AvgCorrect * 100, batch_vec_factory_time); } classifier.close(); } else { // nothing else to process in split! break; } // if } // for the number of passes in the run } }