package hu.u_szeged.ml.mallet; import hu.u_szeged.ml.ClassificationResult; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.TreeMap; import cc.mallet.classify.Classification; import cc.mallet.types.LabelVector; public class MalletClassificationResult extends ClassificationResult { protected ArrayList<Classification> prediction; protected MalletDataHandler data; public MalletClassificationResult(ArrayList<Classification> pred, MalletDataHandler data) { super(data); this.data = data; prediction = pred; } @SuppressWarnings("unchecked") public <T extends Comparable<?>> T getPredictedLabel(String instanceId) { return (T) (prediction.get(data.instanceIds.get(instanceId)).getLabeling().getBestLabel().getEntry()); } @SuppressWarnings("unchecked") @Override public <T extends Comparable<?>> Map<T, Double> getPredictionProbabilities(String instanceId) { Map<T, Double> res = new TreeMap<T, Double>(); Integer in = data.instanceIds.get(instanceId); LabelVector vec = prediction.get(in).getLabelVector(); for (int i = 0; i < vec.numLocations(); ++i) res.put((T) vec.labelAtLocation(i).getEntry(), vec.valueAtLocation(i)); return res; } @SuppressWarnings("unchecked") @Override public <T extends Comparable<?>> List<T> getPredictions() { List<T> pred = new LinkedList<T>(); for (int i = 0; i < getInstanceCount(); ++i) pred.add((T) prediction.get(i).getLabeling().getBestLabel().getEntry()); return pred; } @Override public void loadPredictions(String source) { } @Override public void savePredictions(String target) { } }