package hu.u_szeged.ml.mallet;
import hu.u_szeged.ml.ClassificationResult;
import hu.u_szeged.ml.DataHandler;
import hu.u_szeged.ml.DataMiningException;
import hu.u_szeged.ml.Model;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.logging.Level;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.MaxEntOptimizableByLabelLikelihood;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
public class MalletDataHandler extends DataHandler implements Serializable {
private static final long serialVersionUID = 7555593484247132956L;
public InstanceList data;
public Map<String, Integer> instanceIds;
protected LabelAlphabet labelAlphabet;
protected FeatureSequence featureAlphabet;
// protected FeatureSequence featureSequence;
protected Map<String, List<String>> nominalValues;
protected ClassifierTrainer<?> trainer = null;
static {
((MalletLogger) MalletProgressMessageLogger.getLogger(MaxEntOptimizableByLabelLikelihood.class.getName())).getRootLogger().setLevel(Level.WARNING);
}
private Instance getInstance(String instanceId) {
if (!instanceIds.containsKey(instanceId)) {
Instance inst = new Instance(new AugmentableFeatureVector(featureAlphabet.getAlphabet(), 0, false), labelAlphabet.lookupLabel("false"), instanceId,
instanceId);
data.add(inst);
inst.unLock();
instanceIds.put(instanceId, data.size() - 1);
}
return data.get(instanceIds.get(instanceId));
}
public FeatureSequence getFeatureSequence() {
return featureAlphabet;
}
public Alphabet getAlphabet(String type) {
if (type.equalsIgnoreCase("feature")) {
return featureAlphabet.getAlphabet();
} else if (type.equals("label")) {
return labelAlphabet;
} else
return null;
}
public AugmentableFeatureVector getInstanceData(String instanceId) {
return (AugmentableFeatureVector) getInstance(instanceId).getData();
}
public double getDoubleValue(String instanceId, String featureName) throws DataMiningException {
if (!featureAlphabet.getAlphabet().contains(featureName))
throw new DataMiningException("getter for unexisting feature: " + featureName);
return getInstanceData(instanceId).value(featureAlphabet.getAlphabet().lookupIndex(featureName));
}
protected void setDoubleValue(String instanceId, String featureName, double value) {
AugmentableFeatureVector fv = getInstanceData(instanceId);
int index = featureAlphabet.getAlphabet().lookupIndex(featureName);
if (index < 0) {
return; // it can occur when featureAlphabet.getStopGrowth()==true and the featureset does not contain the feature
}
int location = fv.location(index);
if (location < 0) {
fv.add(index, value);
featureAlphabet.add(index);
} else {
fv.setValueAtLocation(location, value);
}
}
public void createNewDataset(Map<String, Object> parameters) {
Object params = parameters != null ? parameters.get("useFeatureSet") : null;
if (params != null) {
if (params instanceof MalletDataHandler) {
MalletDataHandler dh = (MalletDataHandler) params;
featureAlphabet = dh.featureAlphabet;
labelAlphabet = dh.labelAlphabet;
} else if (params instanceof List<?>) {
for (Object o : (List<?>) params) {
if (o instanceof LabelAlphabet) {
labelAlphabet = (LabelAlphabet) o;
} else if (o instanceof Alphabet) {
featureAlphabet = new FeatureSequence((Alphabet) o);
}
}
if (featureAlphabet == null || labelAlphabet == null) {
System.err.println("Uninitialized alphabet");
System.exit(1);
}
featureAlphabet.getAlphabet().stopGrowth();
labelAlphabet.stopGrowth();
} else {
System.err.println("Unsupported useFeatureSet parameter (neither List, nor a MalletDataHandler.)");
System.exit(1);
}
} else {
featureAlphabet = new FeatureSequence(new Alphabet());
labelAlphabet = new LabelAlphabet();
featureAlphabet.getAlphabet().startGrowth();
labelAlphabet.startGrowth();
}
if (labelAlphabet.size() == 0) {
@SuppressWarnings("unchecked")
Set<String> classLabels = parameters == null ? null : (Set<String>) parameters.get("classLabels");
Set<String> defaultFeatures = new HashSet<String>(Arrays.asList(new String[] { new Boolean(true).toString(), new Boolean(false).toString() }));
if (classLabels == null) {
classLabels = defaultFeatures;
System.err.println("A binary classifier is being built now (which might take some time).");
}
for (String label : classLabels) {
labelAlphabet.lookupIndex(label);
}
labelAlphabet.stopGrowth();
}
data = new InstanceList(featureAlphabet.getAlphabet(), labelAlphabet);
instanceIds = new HashMap<String, Integer>();
nominalValues = new HashMap<String, List<String>>();
}
public DataHandler createSubset(Collection<String> instancesSelected, Set<String> featuresSelected) throws DataMiningException {
MalletDataHandler dh = new MalletDataHandler();
Map<String, Object> param = new HashMap<String, Object>();
param.put("useFeatureSet", this);
dh.createNewDataset(param);
dh.featureAlphabet = new FeatureSequence(new Alphabet());
dh.featureAlphabet.getAlphabet().startGrowth();
dh.data = new InstanceList(dh.featureAlphabet.getAlphabet(), dh.labelAlphabet);
for (String inst : instancesSelected) {
AugmentableFeatureVector fv = this.getInstanceData(inst);
for (int i = 0; i < fv.numLocations(); ++i) {
String featurename = (String) fv.getAlphabet().lookupObject(fv.getIndices()[i]);
if (featuresSelected.contains(featurename)) {
dh.setDoubleValue(inst, featurename, fv.getValues()[i]);
}
}
dh.setLabel(inst, this.getLabel(inst));
}
return dh;
}
/**
* This feature creates a subset of the InstanceList that the features coded are not pruned (just the instances).
*
* @param instancesSelected
* @return
* @throws DataMiningException
*/
public DataHandler createSubset(Collection<String> instancesSelected) throws DataMiningException {
return createSubset(instancesSelected, getFeatureNames());
}
public void addDataHandler(DataHandler dh) throws DataMiningException {
if (!(dh instanceof MalletDataHandler)) {
throw new DataMiningException("MalletDataHandler can add just MalletDataHandlers");
}
for (String inst : dh.getInstanceIds()) {
AugmentableFeatureVector fv = ((MalletDataHandler) dh).getInstanceData(inst);
for (int i = 0; i < fv.numLocations(); ++i) {
this.setDoubleValue(inst, (String) fv.getAlphabet().lookupObject(fv.getIndices()[i]), fv.getValues()[i]);
}
this.setLabel(inst, dh.getLabel(inst));
}
}
public Boolean getBinaryValue(String instanceId, String featureName) throws DataMiningException {
return getDoubleValue(instanceId, featureName) > 0.0;
}
public int getFeatureCount() {
return data.getAlphabet().size();
}
public Set<String> getFeatureNames() {
Set<String> featurenames = new HashSet<String>();
for (Object o : featureAlphabet.getAlphabet().toArray()) {
featurenames.add((String) o);
}
return featurenames;
}
public List<String> getFeatureValues(String featureName) {
return nominalValues.containsKey(featureName) ? nominalValues.get(featureName) : null;
}
public int getInstanceCount() {
return data.size();
}
public Set<String> getInstanceIds() {
return instanceIds.keySet();
}
@SuppressWarnings("unchecked")
public <T extends Comparable<?>> T getLabel(String instanceId) {
return (T) ((Label) getInstance(instanceId).getTarget()).getEntry();
}
public String getNominalValue(String instanceId, String featureName) throws DataMiningException {
if (!nominalValues.containsKey(featureName))
throw new DataMiningException(featureName + " is not a nominal feature");
return nominalValues.get(featureName).get((int) getDoubleValue(instanceId, featureName));
}
public Double getNumericValue(String instanceId, String featureName) throws DataMiningException {
return getDoubleValue(instanceId, featureName);
}
@SuppressWarnings("unchecked")
public <T extends Comparable<?>> T getValue(String instanceId, String featureName) throws DataMiningException {
return (T) (Double) getDoubleValue(instanceId, featureName);
}
public void initClassifier(Map<String, Object> parameters) throws DataMiningException {
String classifierName = "MaxEntL1";
// String classifierName = "C45";
Double gaussianPrior = null;
if (parameters != null && parameters.containsKey("classifier")) {
classifierName = (String) parameters.get("classifier");
}
if (classifierName.equals("MaxEntL1") && parameters != null && parameters.containsKey("classifier")) {
gaussianPrior = (Double) parameters.get("gaussianPrior");
}
try {
trainer = (ClassifierTrainer<?>) Class.forName("cc.mallet.classify." + classifierName + "Trainer").newInstance();
if (classifierName.equals("MaxEntL1")) {
if (gaussianPrior != null) {
((MaxEntTrainer) trainer).setGaussianPriorVariance(gaussianPrior);
}
}
// ((MaxEntTrainer)trainer).setGaussianPriorVariance(1.0);
/*
* ((C45Trainer)trainer).setMinNumInsts(3); ((C45Trainer)trainer).setDepthLimited(true); ((C45Trainer)trainer).setMaxDepth(2);
* ((C45Trainer)trainer).setDoPruning(true);
*/
} catch (Exception e) {
throw new DataMiningException("unknown classifier: " + classifierName, e);
}
}
public Model trainClassifier() throws DataMiningException {
if (trainer == null) {
initClassifier(null);
}
return new MalletClassifier(trainer.train(data));
}
public ClassificationResult classifyDataset(Model model) throws DataMiningException {
if (!(model instanceof MalletClassifier)) {
throw new DataMiningException("MalletDataHandler can be used only by MALLET classifiers");
}
return new MalletClassificationResult(((MalletClassifier) model).getClassifier().classify(data), this);
}
public void removeFeature(String featureName) throws DataMiningException {
this.removeFeature(new String[] { featureName });
}
public void removeFeature(String... featureNames) throws DataMiningException {
removeFeature(new HashSet<String>(Arrays.asList(featureNames)));
}
public void removeFeature(Set<String> featureNames) throws DataMiningException {
Alphabet alphabet = featureAlphabet.getAlphabet();
FeatureSelection fs = new FeatureSelection(alphabet);
double[] counts = new double[alphabet.size()];
for (int feat = 0; feat < counts.length; ++feat) {
Object featureName = featureAlphabet.getAlphabet().lookupObject(feat);
counts[feat] = featureNames.contains(featureName) ? 0 : 1;
}
Alphabet reducedAlphabet = new Alphabet();
featureAlphabet.prune(counts, reducedAlphabet, 1);
Iterator<?> it = reducedAlphabet.iterator();
while (it.hasNext()) {
fs.add(it.next());
}
data.setFeatureSelection(fs);
}
// public Map<String, Map<String, Integer>> removeFeature(int threshold) throws DataMiningException {
// Map<String, Map<String, Integer>> stats = new TreeMap<String, Map<String, Integer>>();
// stats.put("classLabel", new HashMap<String, Integer>());
// double[] counts = new double[featureAlphabet.getAlphabet().size()];
// for (Entry<String, Integer> instance : instanceIds.entrySet()){
// Instance inst = data.get(instance.getValue());
// AugmentableFeatureVector fv = (AugmentableFeatureVector) inst.getData();
// for (int feat = 0; feat < counts.length; ++feat){
// String featureName = featureAlphabet.getAlphabet().lookupObject(feat).toString();
// Map<String, Integer> vals = stats.get(featureName);
// vals = vals == null ? new HashMap<String, Integer>() : vals;
// stats.put(featureName, vals);
// String val;
// try{
// // in case the feature was nominal one
// val = getNominalValue(instance.getKey(), featureName);
// counts[feat] = threshold;
// }catch (Exception e){
// double value = fv.value(feat);
// if (value > 0.0d)
// counts[feat]++;
// val = Double.toString(value);
// }
// Integer i = vals.get(val);
// if (i == null){
// vals.put(val, 1);
// }else{
// vals.put(val, ++i);
// }
// }
// Map<String, Integer> val = stats.get("classLabel");
// String label = ((Label) inst.getTarget()).getEntry().toString();
// Integer v = val.get(label);
// val.put(label, (v == null ? 1 : ++v));
// stats.put("classLabel", val);
// }
// FeatureSelection fs = new FeatureSelection(featureAlphabet.getAlphabet());
// Alphabet reducedAlphabet = new Alphabet();
// featureAlphabet.prune(counts, reducedAlphabet, threshold);
// Iterator<?> it = reducedAlphabet.iterator();
// while (it.hasNext()){
// fs.add(it.next());
// }
// data.setFeatureSelection(fs);
// return stats;
// }
public void removeInstance(String instanceId) {
Integer number = instanceIds.remove(instanceId);
if (number == null) {
return;
}
data.remove((int) number);
for (Entry<String, Integer> indexPair : instanceIds.entrySet()) {
if (indexPair.getValue() > number) {
indexPair.setValue(indexPair.getValue() - 1);
}
}
// throw new DataMiningException("removeFeature is not implemented yet in MalletDataHandler");
}
public void removeInstances(Collection<String> instanceIdsToRemove) {
Iterator<Instance> instanceIt = data.iterator();
while (instanceIt.hasNext()) {
Instance next = instanceIt.next();
if (instanceIdsToRemove.contains(next.getName())) {
data.remove(next);
instanceIt = data.iterator();
Integer prevVal = instanceIds.remove(next.getName());
if (prevVal != null) {
for (Entry<String, Integer> indexPair : instanceIds.entrySet()) {
if (indexPair.getValue() > prevVal) {
indexPair.setValue(indexPair.getValue() - 1);
}
}
}
}
}
}
public void loadDataset(String source) throws DataMiningException {
try (BufferedReader file = new BufferedReader(new FileReader(source))) {
createNewDataset(null);
String line;
while ((line = file.readLine()) != null) {
String[] tokens = line.split("\\s");
if (!tokens[0].contains("@")) {
throw new DataMiningException("Corrput input format. First token should contain @");
}
String id = tokens[0].split("@")[0];
setLabel(id, tokens[0].split("@")[1]);
for (int i = 1; i < tokens.length; ++i) {
setDoubleValue(id, tokens[i].split(":")[0], Double.parseDouble(tokens[i].split(":")[1]));
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
public void saveDataset(String target) {
if (!target.contains("|") || target.split("\\|")[1].equals("mallet")) {
saveDatasetMallet(target.split("\\|")[0]);
} else if (target.split("\\|")[1].equals("weka")) {
saveDatasetWeka(target.split("\\|")[0]);
} else {
System.err.println("unknow output format " + target.split("\\|")[1]);
}
}
public void saveDatasetMallet(String target) {
try {
PrintWriter out = new PrintWriter(target);
for (String id : instanceIds.keySet()) {
out.print(id + "@" + getLabel(id));
AugmentableFeatureVector fv = getInstanceData(id);
for (int i = 0; i < fv.numLocations(); ++i) {
out.print("\t" + featureAlphabet.getAlphabet().lookupObject(fv.getIndices()[i]) + ":" + fv.getValues()[i]);
}
out.println();
}
out.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
public void saveDatasetWeka(String target) {
try {
if (!target.endsWith(".arff")) {
target += ".arff";
}
PrintWriter out = new PrintWriter(target);
out.println("@relation MalletData");
for (Object f : data.getAlphabet().toArray()) {
String name = f.toString().replaceAll("'", "");
out.println("@attribute '" + name + "' numeric");
}
out.println("@attribute classlabel {0,1}");
out.println("@data");
for (String id : instanceIds.keySet()) {
out.print("{");
AugmentableFeatureVector fv = getInstanceData(id);
for (int i = 0; i < fv.numLocations(); ++i) {
if (i > 0) {
out.print(",");
}
out.print((fv.getIndices()[i]) + " " + fv.getValues()[i]);
}
if ((boolean) getLabel(id)) {
out.print("," + data.getAlphabet().size() + " 1");
}
out.println("}");
}
out.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}
public void setBinaryValue(String instanceId, String featureName, Boolean value) {
setDoubleValue(instanceId, featureName, value ? 1.0 : 0.0);
}
public void setBinaryValue(String instanceId, String featureName, Boolean value, boolean ternal) {
System.err.println("ternal isn't implemented in MalletDataHandler");
}
public void setDefaultFeatureValue(String featureName, String value) throws DataMiningException {
if (!nominalValues.containsKey(featureName)) {
throw new DataMiningException("setDefaultFeatureValue is called for a feature which is not nominal");
}
if (nominalValues.get(featureName).contains(value)) {
nominalValues.get(featureName).remove(value);
}
nominalValues.get(featureName).add(0, value);
}
public <T extends Comparable<?>> void setLabel(String instanceId, T label) {
getInstance(instanceId).setTarget(labelAlphabet.lookupLabel(label.toString()));
}
public void setNominalValue(String instanceId, String featureName, String value) {
if (!nominalValues.containsKey(featureName)) {
nominalValues.put(featureName, new LinkedList<String>());
nominalValues.get(featureName).add("MISSINGVALUE");
}
int pos = nominalValues.get(featureName).indexOf(value);
if (pos < 0) {
nominalValues.get(featureName).add(value);
pos = nominalValues.get(featureName).size() - 1;
}
setDoubleValue(instanceId, featureName, (double) pos);
}
public void setNumericValue(String instanceId, String featureName, double value) {
setDoubleValue(instanceId, featureName, value);
}
protected MalletDataHandler clone() {
MalletDataHandler dh = new MalletDataHandler();
dh.createNewDataset(null);
try {
dh.addDataHandler(this);
} catch (DataMiningException e) {
e.printStackTrace();
}
return dh;
}
public <T extends Comparable<?>> void setValue(String instanceId, String featureName, T value) throws DataMiningException {
if (featureName.startsWith("b_")) {
setBinaryValue(instanceId, featureName, (Boolean) value);
} else if (featureName.startsWith("t_")) {
setBinaryValue(instanceId, featureName, (Boolean) value, true);
} else if (featureName.startsWith("m_")) {
setNominalValue(instanceId, featureName, (String) value);
} else if (featureName.startsWith("m_")) {
setNumericValue(instanceId, featureName, (Double) value);
} else {
throw new DataMiningException("unknown featuretype " + featureName);
}
}
}