/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.cloudera.knittingboar.records;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.mahout.classifier.sgd.CsvRecordFactory;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
import org.apache.mahout.vectorizer.encoders.TextValueEncoder;
import com.google.common.base.CharMatcher;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Splitter;
import com.google.common.collect.Collections2;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
/**
* Based directly on:
* https://builds.apache.org/job/Mahout-Quality/javadoc/org/apache
* /mahout/classifier/sgd/CsvRecordFactory.html
*
* Adapted to implement modified RecordFactory interface
*
* @author jpatterson
*
*/
public class CSVBasedDatasetRecordFactory implements RecordFactory {
private static final String INTERCEPT_TERM = "Intercept Term";
// crude CSV value splitter. This will fail if any double quoted strings have
// commas inside. Also, escaped quotes will not be unescaped. Good enough for
// now.
private static final Splitter COMMA = Splitter.on(',').trimResults(
CharMatcher.is('"'));
private static final Map<String,Class<? extends FeatureVectorEncoder>> TYPE_DICTIONARY = ImmutableMap
.<String,Class<? extends FeatureVectorEncoder>> builder().put(
"continuous", ContinuousValueEncoder.class).put("numeric",
ContinuousValueEncoder.class).put("n", ContinuousValueEncoder.class)
.put("word", StaticWordValueEncoder.class).put("w",
StaticWordValueEncoder.class).put("text", TextValueEncoder.class)
.put("t", TextValueEncoder.class).build();
private final Map<String,Set<Integer>> traceDictionary = Maps.newTreeMap();
private int target;
private final Dictionary targetDictionary;
// Which column is used for identify a CSV file line
private String idName;
private int id = -1;
private List<Integer> predictors;
private Map<Integer,FeatureVectorEncoder> predictorEncoders;
private int maxTargetValue = Integer.MAX_VALUE;
private final String targetName;
private final Map<String,String> typeMap;
private List<String> variableNames;
private boolean includeBiasTerm;
private static final String CANNOT_CONSTRUCT_CONVERTER = "Unable to construct type converter... shouldn't be possible";
/**
* Construct a parser for CSV lines that encodes the parsed data in vector
* form.
*
* @param targetName
* The name of the target variable.
* @param typeMap
* A map describing the types of the predictor variables.
*/
public CSVBasedDatasetRecordFactory(String targetName,
Map<String,String> typeMap) {
this.targetName = targetName;
this.typeMap = typeMap;
targetDictionary = new Dictionary();
}
public CSVBasedDatasetRecordFactory(String targetName, String idName,
Map<String,String> typeMap) {
this(targetName, typeMap);
this.idName = idName;
}
@Override
public String GetClassnameByID(int id) {
// TODO Auto-generated method stub
return null;
}
@Override
public int processLine(String line, Vector featureVector) throws Exception {
// TODO Auto-generated method stub
// List<String> values = Lists.newArrayList(COMMA.split(line));
List<String> values = Lists.newArrayList(line.split(","));
// System.out.println( line + " //values.size(): " + values.size() );
int targetValue = targetDictionary.intern(values.get(target));
if (targetValue >= maxTargetValue) {
targetValue = maxTargetValue - 1;
}
for (Integer predictor : predictors) {
String value;
if (predictor >= 0) {
value = values.get(predictor);
} else {
value = null;
}
predictorEncoders.get(predictor).addToVector(value, featureVector);
}
return targetValue;
}
public void Setup(String PredictorLabelNamesList,
String PredictorVariableTypesList) {
String[] predictor_label_names = PredictorLabelNamesList.split(",");
String[] variable_types = PredictorVariableTypesList.split(",");
// ------ move to CSVFactory ------
List<String> typeList = Lists.newArrayList();
for (int x = 0; x < variable_types.length; x++) {
typeList.add(variable_types[x]);
}
List<String> predictorList = Lists.newArrayList();
for (int x = 0; x < predictor_label_names.length; x++) {
predictorList.add(predictor_label_names[x]);
}
List<String> arTargetCats = Lists.newArrayList();
arTargetCats.add("2");
arTargetCats.add("1");
// polr_modelparams.setTargetCategories(arTargetCats);
// ------ move to CSVFactory ------
}
/**
* Defines the values and thus the encoding of values of the target variables.
* Note that any values of the target variable not present in this list will
* be given the value of the last member of the list.
*
* @param values
* The values the target variable can have.
*/
public void defineTargetCategories(List<String> values) {
Preconditions.checkArgument(values.size() <= maxTargetValue,
"Must have less than or equal to " + maxTargetValue
+ " categories for target variable, but found " + values.size());
if (maxTargetValue == Integer.MAX_VALUE) {
maxTargetValue = values.size();
}
for (String value : values) {
targetDictionary.intern(value);
}
}
/**
* Defines the number of target variable categories, but allows this parser to
* pick encodings for them as they appear.
*
* @param max
* The number of categories that will be excpeted. Once this many
* have been seen, all others will get the encoding max-1.
*/
public CSVBasedDatasetRecordFactory maxTargetValue(int max) {
maxTargetValue = max;
return this;
}
public boolean usesFirstLineAsSchema() {
return true;
}
/**
* Processes the first line of a file (which should contain the variable
* names). The target and predictor column numbers are set from the names on
* this line.
*
* @param line
* Header line for the file.
*/
public void firstLine(String line) {
// System.out.println("> firstline: " + line);
// read variable names, build map of name -> column
final Map<String,Integer> vars = Maps.newHashMap();
variableNames = Lists.newArrayList(COMMA.split(line));
int column = 0;
for (String var : variableNames) {
vars.put(var, column++);
}
// record target column and establish dictionary for decoding target
target = vars.get(targetName);
// record id column
if (idName != null) {
id = vars.get(idName);
}
// create list of predictor column numbers
predictors = Lists.newArrayList(Collections2.transform(typeMap.keySet(),
new Function<String,Integer>() {
@Override
public Integer apply(String from) {
Integer r = vars.get(from);
Preconditions.checkArgument(r != null,
"Can't find variable %s, only know about %s", from, vars);
return r;
}
}));
if (includeBiasTerm) {
predictors.add(-1);
}
Collections.sort(predictors);
// and map from column number to type encoder for each column that is a
// predictor
predictorEncoders = Maps.newHashMap();
for (Integer predictor : predictors) {
String name;
Class<? extends FeatureVectorEncoder> c;
if (predictor == -1) {
name = INTERCEPT_TERM;
c = ConstantValueEncoder.class;
} else {
name = variableNames.get(predictor);
c = TYPE_DICTIONARY.get(typeMap.get(name));
}
try {
Preconditions.checkArgument(c != null,
"Invalid type of variable %s, wanted one of %s",
typeMap.get(name), TYPE_DICTIONARY.keySet());
Constructor<? extends FeatureVectorEncoder> constructor = c
.getConstructor(String.class);
Preconditions.checkArgument(constructor != null,
"Can't find correct constructor for %s", typeMap.get(name));
FeatureVectorEncoder encoder = constructor.newInstance(name);
predictorEncoders.put(predictor, encoder);
encoder.setTraceDictionary(traceDictionary);
} catch (InstantiationException e) {
throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
} catch (IllegalAccessException e) {
throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
} catch (InvocationTargetException e) {
throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
} catch (NoSuchMethodException e) {
throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
}
}
}
/**
* Decodes a single line of csv data and records the target and predictor
* variables in a record. As a side effect, features are added into the
* featureVector. Returns the value of the target variable.
*
* @param line
* The raw data.
* @param featureVector
* Where to fill in the features. Should be zeroed before calling
* processLine.
* @return The value of the target variable.
*/
/*
* @Override public int processLine(String line, Vector featureVector) {
* List<String> values = Lists.newArrayList(COMMA.split(line));
*
* int targetValue = targetDictionary.intern(values.get(target)); if
* (targetValue >= maxTargetValue) { targetValue = maxTargetValue - 1; }
*
* for (Integer predictor : predictors) { String value; if (predictor >= 0) {
* value = values.get(predictor); } else { value = null; }
* predictorEncoders.get(predictor).addToVector(value, featureVector); }
* return targetValue; }
*/
/***
* Decodes a single line of csv data and records the target(if retrunTarget is
* true) and predictor variables in a record. As a side effect, features are
* added into the featureVector. Returns the value of the target variable.
* When used during classify against production data without target value, the
* method will be called with returnTarget = false.
*
* @param line
* The raw data.
* @param featureVector
* Where to fill in the features. Should be zeroed before calling
* processLine.
* @param returnTarget
* whether process and return target value, -1 will be returned if
* false.
* @return The value of the target variable.
*/
public int processLine(CharSequence line, Vector featureVector,
boolean returnTarget) {
List<String> values = Lists.newArrayList(COMMA.split(line));
int targetValue = -1;
if (returnTarget) {
targetValue = targetDictionary.intern(values.get(target));
if (targetValue >= maxTargetValue) {
targetValue = maxTargetValue - 1;
}
}
for (Integer predictor : predictors) {
String value = predictor >= 0 ? values.get(predictor) : null;
predictorEncoders.get(predictor).addToVector(value, featureVector);
}
return targetValue;
}
/***
* Extract the raw target string from a line read from a CSV file.
*
* @param line
* the line of content read from CSV file
* @return the raw target value in the corresponding column of CSV line
*/
public String getTargetString(CharSequence line) {
List<String> values = Lists.newArrayList(COMMA.split(line));
return values.get(target);
}
/***
* Extract the corresponding raw target label according to a code
*
* @param code
* the integer code encoded during training process
* @return the raw target label
*/
public String getTargetLabel(int code) {
for (String key : targetDictionary.values()) {
if (targetDictionary.intern(key) == code) {
return key;
}
}
return null;
}
/***
* Extract the id column value from the CSV record
*
* @param line
* the line of content read from CSV file
* @return the id value of the CSV record
*/
public String getIdString(CharSequence line) {
List<String> values = Lists.newArrayList(COMMA.split(line));
return values.get(id);
}
/**
* Returns a list of the names of the predictor variables.
*
* @return A list of variable names.
*/
public Iterable<String> getPredictors() {
return Lists.transform(predictors, new Function<Integer,String>() {
@Override
public String apply(Integer v) {
if (v >= 0) {
return variableNames.get(v);
} else {
return INTERCEPT_TERM;
}
}
});
}
public Map<String,Set<Integer>> getTraceDictionary() {
return traceDictionary;
}
public CSVBasedDatasetRecordFactory includeBiasTerm(boolean useBias) {
includeBiasTerm = useBias;
return this;
}
public List<String> getTargetCategories() {
List<String> r = targetDictionary.values();
if (r.size() > maxTargetValue) {
r.subList(maxTargetValue, r.size()).clear();
}
return r;
}
public String getIdName() {
return idName;
}
public void setIdName(String idName) {
this.idName = idName;
}
}