/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.cloudera.knittingboar.sgd.olr; import junit.framework.TestCase; import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.FileReader; import java.io.IOException; import java.io.OutputStream; import java.io.Reader; import java.io.StringReader; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; import org.apache.hadoop.fs.Path; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.util.Version; import org.apache.mahout.classifier.sgd.L1; import org.apache.mahout.classifier.sgd.ModelSerializer; import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; import org.apache.mahout.vectorizer.encoders.Dictionary; import com.cloudera.knittingboar.utils.DataUtils; import com.cloudera.knittingboar.utils.Utils; import com.google.common.base.Splitter; import com.google.common.collect.ConcurrentHashMultiset; import com.google.common.collect.HashMultiset; import com.google.common.collect.Iterables; import com.google.common.collect.Multiset; import com.google.common.io.Closeables; /** * Run basic OLR on the 20 newsgroups dataset * - collect stats on run * -- accuracy * -- wall clock time * * Credit: adapted from: * * https://github.com/tdunning/MiA/blob/master/src/main/java/mia/classifier/ch14/TrainNewsGroups.java * * @author jpatterson * */ public class TestBaseOLR_Train20Newsgroups extends TestCase { private static final int FEATURES = 10000; private static Multiset<String> overallCounts; private static Path workDir20NewsLocal = new Path(new Path("/tmp"), "Dataset20Newsgroups"); private static File unzipDir = new File( workDir20NewsLocal + "/20news-bydate"); /** * * Counts words * * @param analyzer * @param words * @param in * @throws IOException */ private static void countWords(Analyzer analyzer, Collection<String> words, Reader in) throws IOException { //System.out.println( "> ----- countWords ------" ); // use the provided analyzer to tokenize the input stream TokenStream ts = analyzer.tokenStream("text", in); ts.addAttribute(CharTermAttribute.class); // for each word in the stream, minus non-word stuff, add word to collection while (ts.incrementToken()) { String s = ts.getAttribute(CharTermAttribute.class).toString(); //System.out.print( " " + s ); words.add(s); } //System.out.println( "\n<" ); /*overallCounts.addAll(words);*/ } public void testTrainNewsGroups() throws IOException { File file20News = DataUtils.getTwentyNewsGroupDir(); File base = new File( file20News + "/20news-bydate-train/" ); //"/Users/jpatterson/Downloads/datasets/20news-bydate/20news-bydate-train/"); overallCounts = HashMultiset.create(); long startTime = System.currentTimeMillis(); // p.269 --------------------------------------------------------- Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>(); // encodes the text content in both the subject and the body of the email FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); encoder.setProbes(2); encoder.setTraceDictionary(traceDictionary); // provides a constant offset that the model can use to encode the average frequency // of each class FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); bias.setTraceDictionary(traceDictionary); // used to encode the number of lines in a message FeatureVectorEncoder lines = new ConstantValueEncoder("Lines"); lines.setTraceDictionary(traceDictionary); FeatureVectorEncoder logLines = new ConstantValueEncoder("LogLines"); logLines.setTraceDictionary(traceDictionary); Dictionary newsGroups = new Dictionary(); // matches the OLR setup on p.269 --------------- // stepOffset, decay, and alpha --- describe how the learning rate decreases // lambda: amount of regularization // learningRate: amount of initial learning rate OnlineLogisticRegression learningAlgorithm = new OnlineLogisticRegression( 20, FEATURES, new L1()) .alpha(1).stepOffset(1000) .decayExponent(0.9) .lambda(3.0e-5) .learningRate(20); // bottom of p.269 ------------------------------ // because OLR expects to get integer class IDs for the target variable during training // we need a dictionary to convert the target variable (the newsgroup name) // to an integer, which is the newsGroup object List<File> files = new ArrayList<File>(); for (File newsgroup : base.listFiles()) { newsGroups.intern(newsgroup.getName()); files.addAll(Arrays.asList(newsgroup.listFiles())); } // mix up the files, helps training in OLR Collections.shuffle(files); System.out.printf("%d training files\n", files.size()); // p.270 ----- metrics to track lucene's parsing mechanics, progress, performance of OLR ------------ double averageLL = 0.0; double averageCorrect = 0.0; double averageLineCount = 0.0; int k = 0; double step = 0.0; int[] bumps = new int[]{1, 2, 5}; double lineCount = 0; // last line on p.269 Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31); Splitter onColon = Splitter.on(":").trimResults(); int input_file_count = 0; // ----- p.270 ------------ "reading and tokenzing the data" --------- for (File file : files) { BufferedReader reader = new BufferedReader(new FileReader(file)); input_file_count++; // identify newsgroup ---------------- // convert newsgroup name to unique id // ----------------------------------- String ng = file.getParentFile().getName(); int actual = newsGroups.intern(ng); Multiset<String> words = ConcurrentHashMultiset.create(); // check for line count header ------- String line = reader.readLine(); while (line != null && line.length() > 0) { // if this is a line that has a line count, let's pull that value out ------ if (line.startsWith("Lines:")) { String count = Iterables.get(onColon.split(line), 1); try { lineCount = Integer.parseInt(count); averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000); } catch (NumberFormatException e) { // if anything goes wrong in parse: just use the avg count lineCount = averageLineCount; } } boolean countHeader = ( line.startsWith("From:") || line.startsWith("Subject:")|| line.startsWith("Keywords:")|| line.startsWith("Summary:")); // loop through the lines in the file, while the line starts with: " " do { // get a reader for this specific string ------ StringReader in = new StringReader(line); // ---- count words in header --------- if (countHeader) { countWords(analyzer, words, in); } // iterate to the next string ---- line = reader.readLine(); } while (line.startsWith(" ")); } // while (lines in header) { // -------- count words in body ---------- countWords(analyzer, words, reader); reader.close(); // ----- p.271 ----------- Vector v = new RandomAccessSparseVector(FEATURES); // original value does nothing in a ContantValueEncoder bias.addToVector("", 1, v); // original value does nothing in a ContantValueEncoder lines.addToVector("", lineCount / 30, v); // original value does nothing in a ContantValueEncoder logLines.addToVector("", Math.log(lineCount + 1), v); // now scan through all the words and add them for (String word : words.elementSet()) { encoder.addToVector(word, Math.log(1 + words.count(word)), v); } //Utils.PrintVectorNonZero(v); // calc stats --------- double mu = Math.min(k + 1, 200); double ll = learningAlgorithm.logLikelihood(actual, v); averageLL = averageLL + (ll - averageLL) / mu; Vector p = new DenseVector(20); learningAlgorithm.classifyFull(p, v); int estimated = p.maxValueIndex(); int correct = (estimated == actual? 1 : 0); averageCorrect = averageCorrect + (correct - averageCorrect) / mu; learningAlgorithm.train(actual, v); k++; int bump = bumps[(int) Math.floor(step) % bumps.length]; int scale = (int) Math.pow(10, Math.floor(step / bumps.length)); if (k % (bump * scale) == 0) { step += 0.25; System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n", k, ll, averageLL, averageCorrect * 100, ng, newsGroups.values().get(estimated)); } learningAlgorithm.close(); /* if (k>4) { break; } */ } Utils.PrintVectorSection( learningAlgorithm.getBeta().viewRow(0), 3 ); long endTime = System.currentTimeMillis(); //System.out.println("That took " + (endTime - startTime) + " milliseconds"); long duration = (endTime - startTime); System.out.println( "Processed Input Files: " + input_file_count +", time: " + duration + "ms" ); ModelSerializer.writeBinary("/tmp/olr-news-group.model", learningAlgorithm); // learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); } }