/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.cloudera.knittingboar.utils;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.Version;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
/**
* Tool to convert 20newsgroups to the format for Knitting Boar - need to
* convert multiple dirs of small files into larger splits containing multiples
* types of records per line
*
*
*
* 1. Download the canonical dataset from:
* http://people.csail.mit.edu/jrennie/20Newsgroups/20news-bydate.tar.gz 2.
* Extract the dataset locally 3. Run the DatasetConverter process to merge the
* smaller files into larger input files 4. edit "workDir" on line 44 in
* com.cloudera.knittingboar.sgd.TestRunPOLRMasterAndNWorkers to reflect
* location of input training data 5. Run unit test:
* com.cloudera.knittingboar.sgd.TestRunPOLRMasterAndNWorkers
*
*
* @author jpatterson
*
*/
public class DatasetConverter {
private static void countWords(Analyzer analyzer, Collection<String> words,
Reader in) throws IOException {
// use the provided analyzer to tokenize the input stream
TokenStream ts = analyzer.tokenStream("text", in);
ts.addAttribute(CharTermAttribute.class);
// for each word in the stream, minus non-word stuff, add word to collection
while (ts.incrementToken()) {
String s = ts.getAttribute(CharTermAttribute.class).toString();
// System.out.print( " " + s );
words.add(s);
}
}
public static String ReadFullFile(Analyzer analyzer, String newsgroup_name,
String file) throws IOException {
String out = newsgroup_name + "\t";
BufferedReader reader = null;
// Collection<String> words
Multiset<String> words = ConcurrentHashMultiset.create();
try {
reader = new BufferedReader(new FileReader(file));
TokenStream ts = analyzer.tokenStream("text", reader);
ts.addAttribute(CharTermAttribute.class);
// for each word in the stream, minus non-word stuff, add word to
// collection
while (ts.incrementToken()) {
String s = ts.getAttribute(CharTermAttribute.class).toString();
out += s + " ";
}
} finally {
if(reader != null) {
reader.close();
}
}
return out + "\n";
}
/**
* Function to convert the 20Newsgroups from the standard 20,000 files in 20
* directories to N files more appropriate for Knitting Boar
*
* 1. Download the 20Newsgroups dataset from:
* http://people.csail.mit.edu/jrennie/20Newsgroups/20news-bydate.tar.gz 2.
* Extract the dataset to a local dir 3. Run the DatasetConverter process to
* merge the smaller files into larger input files check out
* "TestConvert20NewsTestDataset" in the unit tests
*
*
* @param inputBaseDir
* @param outputBaseDir
* @throws IOException
*/
public static int ConvertNewsgroupsFromSingleFiles(String inputBaseDir,
String outputBaseDir, int records_per_shard) throws IOException {
Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);
File base = new File(inputBaseDir);
// because OLR expects to get integer class IDs for the target variable
// during training
// we need a dictionary to convert the target variable (the newsgroup name)
// to an integer, which is the newsGroup object
List<File> files = new ArrayList<File>();
for (File newsgroup : base.listFiles()) {
files.addAll(Arrays.asList(newsgroup.listFiles()));
}
// mix up the files, helps training in OLR
Collections.shuffle(files);
System.out.printf("%d training files\n", files.size());
double step = 0.0;
int[] bumps = new int[] {1, 2, 5};
BufferedWriter shard_writer = null;
int shard_count = 0;
int input_file_count = 0;
Map<Integer,Integer> current_shard_rec_count = new HashMap<Integer,Integer>();
try {
File base_dir = new File(outputBaseDir);
if (!base_dir.exists()) {
base_dir.mkdirs();
}
File shard_file_0 = new File(outputBaseDir + "kboar-shard-" + shard_count
+ ".txt");
if (shard_file_0.exists()) {
shard_file_0.delete();
}
shard_file_0.createNewFile();
shard_writer = new BufferedWriter(new FileWriter(shard_file_0));
System.out.println("Starting: " + shard_file_0.toString());
// ----- "reading and tokenzing the data" ---------
for (File file : files) {
input_file_count++;
// identify newsgroup ----------------
// convert newsgroup name to unique id
// -----------------------------------
String ng = file.getParentFile().getName();
String file_contents = ReadFullFile(analyzer, file.getParentFile()
.getName(), inputBaseDir + file.getParentFile().getName() + "/"
+ file.getName());
shard_writer.write(file_contents);
if (false == current_shard_rec_count.containsKey(shard_count)) {
System.out.println(".");
current_shard_rec_count.put(shard_count, 1);
} else {
int c = current_shard_rec_count.get(shard_count);
current_shard_rec_count.put(shard_count, ++c);
}
if (current_shard_rec_count.get(shard_count) >= records_per_shard) {
shard_writer.flush();
shard_writer.close();
shard_count++;
shard_file_0 = new File(outputBaseDir + "kboar-shard-" + shard_count
+ ".txt");
System.out.println("Starting shard: " + "kboar-shard-" + shard_count
+ ".txt");
if (shard_file_0.exists()) {
shard_file_0.delete();
}
shard_file_0.createNewFile();
shard_writer = new BufferedWriter(new FileWriter(shard_file_0));
}
int bump = bumps[(int) Math.floor(step) % bumps.length];
int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
if (input_file_count % (bump * scale) == 0) {
step += 0.25;
System.out.printf("Files Converted: %10d , %10d \n",
input_file_count, current_shard_rec_count.get(shard_count));
}
} // for
} finally {
if(shard_writer != null) {
shard_writer.flush();
shard_writer.close();
}
}
for (int x = 0; x < current_shard_rec_count.size(); x++) {
System.out.println("> Shard " + x + " record count: "
+ current_shard_rec_count.get(x));
}
System.out.printf("> Total Files Converted: %10d \n", input_file_count);
return input_file_count;
}
/**
* Conversion Tool to break up the RCV1 dataset into smaller chunks for
* various tests.
*
* RCV1 Dataset:
*
* https://github.com/JohnLangford/vowpal_wabbit/wiki/Rcv1-example
*
* @param input_file
* @param outputBaseDir
* @param total_recs_to_extract
* @param records_per_shard
* @return
* @throws IOException
*/
public static int ExtractSubsetofRCV1V2ForTraining(String input_file,
String outputBaseDir, int total_recs_to_extract, int records_per_shard)
throws IOException {
double step = 0.0;
int[] bumps = new int[] {1, 2, 5};
System.out.println("> ExtractSubsetofRCV1V2ForTraining: " + input_file);
BufferedWriter shard_writer = null;
int shard_count = 0;
int input_file_count = 0;
int line_count = 0;
Map<Integer,Integer> current_shard_rec_count = new HashMap<Integer,Integer>();
try {
File base_dir = new File(outputBaseDir);
if (!base_dir.exists()) {
base_dir.mkdirs();
}
System.out.println(outputBaseDir + "rcv1-shard-" + shard_count + ".txt");
File shard_file_0 = new File(outputBaseDir + "rcv1-shard-" + shard_count
+ ".txt");
if (shard_file_0.exists()) {
shard_file_0.delete();
} else {
System.out.println("no output file, creating...");
}
boolean bCreate = shard_file_0.createNewFile();
System.out.println("file created: " + bCreate);
shard_writer = new BufferedWriter(new FileWriter(shard_file_0));
input_file_count++;
BufferedReader reader = null;
try {
System.out.println("opening file for reading: input_file");
reader = new BufferedReader(new FileReader(input_file));
String line = reader.readLine();
while (line != null && line.length() > 0) {
shard_writer.write(line + "\n");
line = reader.readLine();
if (false == current_shard_rec_count.containsKey(shard_count)) {
current_shard_rec_count.put(shard_count, 1);
} else {
int c = current_shard_rec_count.get(shard_count);
current_shard_rec_count.put(shard_count, ++c);
}
line_count++;
if (total_recs_to_extract <= line_count) {
break;
}
if (current_shard_rec_count.get(shard_count) >= records_per_shard) {
shard_writer.flush();
shard_writer.close();
shard_count++;
shard_file_0 = new File(outputBaseDir + "rcv1-shard-" + shard_count
+ ".txt");
System.out.println("Starting shard: " + "rcv1-shard-" + shard_count
+ ".txt");
if (shard_file_0.exists()) {
shard_file_0.delete();
}
shard_file_0.createNewFile();
shard_writer = new BufferedWriter(new FileWriter(shard_file_0));
}
int bump = bumps[(int) Math.floor(step) % bumps.length];
int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
if (input_file_count % (bump * scale) == 0) {
step += 0.25;
System.out.printf("Files Converted: %10d , %10d \n",
input_file_count, current_shard_rec_count.get(shard_count));
}
}
} finally {
reader.close();
}
} catch (Exception e) {
System.out.println(e);
} finally {
shard_writer.flush();
shard_writer.close();
}
for (int x = 0; x < current_shard_rec_count.size(); x++) {
System.out.println("> Shard " + x + " record count: "
+ current_shard_rec_count.get(x));
}
System.out.printf("> Total Files Converted: %10d \n", input_file_count);
return input_file_count;
}
}