/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.cloudera.knittingboar.records; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.io.Reader; import java.io.StringReader; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.TreeMap; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.util.Version; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; import org.apache.mahout.vectorizer.encoders.Dictionary; import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; import com.google.common.base.Splitter; import com.google.common.collect.ConcurrentHashMultiset; import com.google.common.collect.HashMultiset; import com.google.common.collect.Iterables; import com.google.common.collect.Multiset; /** * Adapted from: * * https://github.com/tdunning/MiA/blob/master/src/main/java/mia/classifier/ch14 * /TrainNewsGroups.java * * * I've hardcoded the class id's in the dataset record factory, cause, uh, they * don't really change in this dataset * * @author jpatterson * */ public class TwentyNewsgroupsRecordFactory implements RecordFactory { // implements // RecordFactory // { public static final int FEATURES = 10000; Dictionary newsGroups = null; // new Dictionary(); Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31); String class_id_split_string = " "; public TwentyNewsgroupsRecordFactory(String strClassSeperator) { this.newsGroups = new Dictionary(); newsGroups.intern("alt.atheism"); newsGroups.intern("comp.graphics"); newsGroups.intern("comp.os.ms-windows.misc"); newsGroups.intern("comp.sys.ibm.pc.hardware"); newsGroups.intern("comp.sys.mac.hardware"); newsGroups.intern("comp.windows.x"); newsGroups.intern("misc.forsale"); newsGroups.intern("rec.autos"); newsGroups.intern("rec.motorcycles"); newsGroups.intern("rec.sport.baseball"); newsGroups.intern("rec.sport.hockey"); newsGroups.intern("sci.crypt"); newsGroups.intern("sci.electronics"); newsGroups.intern("sci.med"); newsGroups.intern("sci.space"); newsGroups.intern("soc.religion.christian"); newsGroups.intern("talk.politics.guns"); newsGroups.intern("talk.politics.mideast"); newsGroups.intern("talk.politics.misc"); newsGroups.intern("talk.religion.misc"); this.class_id_split_string = strClassSeperator; } @Override public List<String> getTargetCategories() { List<String> out = new ArrayList<String>(); for (int x = 0; x < this.newsGroups.size(); x++) { // System.out.println( x + "" + this.newsGroups.values().get(x) ); out.add(this.newsGroups.values().get(x)); } return out; } public int LookupIDForNewsgroupName(String name) { return this.newsGroups.values().indexOf(name); } public boolean ContainsIDForNewsgroupName(String name) { return this.newsGroups.values().contains(name); } public String GetNewsgroupNameByID(int id) { return this.newsGroups.values().get(id); } @Override public String GetClassnameByID(int id) { return this.newsGroups.values().get(id); } private static void countWords(Analyzer analyzer, Collection<String> words, Reader in) throws IOException { // use the provided analyzer to tokenize the input stream TokenStream ts = analyzer.tokenStream("text", in); ts.addAttribute(CharTermAttribute.class); // for each word in the stream, minus non-word stuff, add word to collection while (ts.incrementToken()) { String s = ts.getAttribute(CharTermAttribute.class).toString(); words.add(s); } } /** * Processes single line of input into: - target variable - Feature vector * * @throws Exception */ public int processLine(String line, Vector v) throws Exception { String[] parts = line.split(this.class_id_split_string); if (parts.length < 2) { throw new Exception("wtf: line not formed well."); } String newsgroup_name = parts[0]; String msg = parts[1]; // p.269 --------------------------------------------------------- Map<String,Set<Integer>> traceDictionary = new TreeMap<String,Set<Integer>>(); // encodes the text content in both the subject and the body of the email FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); encoder.setProbes(2); encoder.setTraceDictionary(traceDictionary); // provides a constant offset that the model can use to encode the average // frequency // of each class FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); bias.setTraceDictionary(traceDictionary); int actual = newsGroups.intern(newsgroup_name); // newsGroups.values().contains(arg0) // System.out.println( "> newsgroup name: " + newsgroup_name ); // System.out.println( "> newsgroup id: " + actual ); Multiset<String> words = ConcurrentHashMultiset.create(); /* * // System.out.println("record: "); for ( int x = 1; x < parts.length; x++ * ) { //String s = ts.getAttribute(CharTermAttribute.class).toString(); // * System.out.print( " " + parts[x] ); String foo = parts[x].trim(); * System.out.print( " " + foo ); words.add( foo ); * * } // System.out.println("\nEOR"); System.out.println( "\nwords found: " + * (parts.length - 1) ); System.out.println( "words in set: " + words.size() * + ", " + words.toString() ); */ StringReader in = new StringReader(msg); countWords(analyzer, words, in); // ----- p.271 ----------- // Vector v = new RandomAccessSparseVector(FEATURES); // original value does nothing in a ContantValueEncoder bias.addToVector("", 1, v); // original value does nothing in a ContantValueEncoder // lines.addToVector("", lineCount / 30, v); // original value does nothing in a ContantValueEncoder // logLines.addToVector("", Math.log(lineCount + 1), v); // now scan through all the words and add them // System.out.println( "############### " + words.toArray().length); for (String word : words.elementSet()) { encoder.addToVector(word, Math.log(1 + words.count(word)), v); // System.out.print( words.count(word) + " " ); } // System.out.println("\nEOL\n"); return actual; } public void Debug() { System.out.println("DictionarySize: " + this.newsGroups.values().size()); } public void DebugDictionary() { for (int x = 0; x < this.newsGroups.size(); x++) { System.out.println(x + "" + this.newsGroups.values().get(x)); } } }