/*******************************************************************************
* Copyright 2012 University of Southern California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* This code was developed by the Information Integration Group as part
* of the Karma project at the Information Sciences Institute of the
* University of Southern California. For more information, publications,
* and related projects, please see: http://www.isi.edu/integration
******************************************************************************/
package edu.isi.karma.cleaning.features;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Vector;
import org.apache.mahout.classifier.sgd.CsvRecordFactory;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.RecordFactory;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import edu.isi.karma.cleaning.PartitionClassifierType;
public class RecordClassifier2 implements PartitionClassifierType {
HashMap<String, Vector<String>> trainData = new HashMap<String, Vector<String>>();
RecordFeatureSet rf = new RecordFeatureSet();
OnlineLogisticRegression cf;
List<String> labels = new ArrayList<String>();
LogisticModelParameters lmp;
public RecordClassifier2() {
}
public OnlineLogisticRegression train(
HashMap<String, Vector<String>> traindata) throws Exception {
String csvTrainFile = "./target/tmp/csvtrain.csv";
Data2Features.Traindata2CSV(traindata, csvTrainFile, rf);
lmp = new LogisticModelParameters();
lmp.setTargetVariable("label");
lmp.setMaxTargetCategories(rf.labels.size());
lmp.setNumFeatures(rf.getFeatureNames().size());
List<String> typeList = Lists.newArrayList();
typeList.add("numeric");
List<String> predictorList = Lists.newArrayList();
for (String attr : rf.getFeatureNames()) {
if (attr.compareTo("lable") != 0) {
predictorList.add(attr);
}
}
lmp.setTypeMap(predictorList, typeList);
// lmp.setUseBias(!getBooleanArgument(cmdLine, noBias));
// lmp.setTypeMap(predictorList, typeList);
lmp.setLambda(1e-4);
lmp.setLearningRate(50);
int passes = 100;
CsvRecordFactory csv = lmp.getCsvRecordFactory();
OnlineLogisticRegression lr = lmp.createRegression();
for (int pass = 0; pass < passes; pass++) {
BufferedReader in = new BufferedReader(new FileReader(new File(
csvTrainFile)));
;
try {
// read variable names
csv.firstLine(in.readLine());
String line = in.readLine();
while (line != null) {
// for each new line, get target and predictors
RandomAccessSparseVector input = new RandomAccessSparseVector(
lmp.getNumFeatures());
int targetValue = csv.processLine(line, input);
String label = csv.getTargetCategories().get(lr.classifyFull(input).maxValueIndex());
// now update model
lr.train(targetValue, input);
line = in.readLine();
}
} finally {
Closeables.closeQuietly(in);
}
}
labels = csv.getTargetCategories();
return lr;
}
private static double predictorWeight(OnlineLogisticRegression lr, int row,
RecordFactory csv, String predictor) {
double weight = 0;
for (Integer column : csv.getTraceDictionary().get(predictor)) {
weight += lr.getBeta().get(row, column);
}
return weight;
}
public String Classify(String instance) {
Collection<Feature> cfeat = rf.computeFeatures(instance, "");
Feature[] x = cfeat.toArray(new Feature[cfeat.size()]);
// row.add(f.getName());
RandomAccessSparseVector row = new RandomAccessSparseVector(x.length);
String line = "";
for (int k = 0; k < cfeat.size(); k++) {
line += x[k].getScore()+",";
}
line +="label"; // dummy class label for testing
CsvRecordFactory csv = lmp.getCsvRecordFactory();
csv.processLine(line, row);
DenseVector dvec = (DenseVector) this.cf.classifyFull(row);
String label = labels.get(dvec.maxValueIndex());
return label;
}
@Override
public void addTrainingData(String value, String label) {
if (trainData.containsKey(label)) {
trainData.get(label).add(value);
} else {
Vector<String> vsStrings = new Vector<String>();
vsStrings.add(value);
trainData.put(label, vsStrings);
}
}
@Override
public String learnClassifer() {
try {
this.cf = this.train(trainData);
} catch (Exception e) {
System.out.println("" );
}
return this.cf.toString();
}
@Override
public String getLabel(String value) {
try {
String label = this.Classify(value);
if (label.length() > 0)
return label;
else {
return "null_in_classification";
}
} catch (Exception e) {
return "null_in_classification";
// TODO: handle exception
}
}
public static void main(String[] args) {
try {
HashMap<String, Vector<String>> trainData = new HashMap<String, Vector<String>>();
Vector<String> test = new Vector<String>();
Vector<String> par1 = new Vector<String>();
par1.add("1286 adams blvd");
par1.add("3711 catalina st");
// par1.add("11 w 37th pl, los angeles");
Vector<String> par2 = new Vector<String>();
par2.add("1142 37st");
// par2.add("1 jefferson st");
Vector<String> par3 = new Vector<String>();
par3.add("710 27");
trainData.put("c1", par1);
trainData.put("c2", par2);
trainData.put("c3", par3);
test.add("2353 portland st");
RecordClassifier2 rc = new RecordClassifier2();
for (String key : trainData.keySet()) {
for (String value : trainData.get(key)) {
rc.addTrainingData(value, key);
}
}
rc.learnClassifer();
System.out.println(rc.Classify(test.get(0)));
} catch (Exception ex) {
ex.printStackTrace();
}
}
}