package se.kodapan.osm.city; import se.kodapan.osm.city.*; import weka.core.Instances; import java.io.*; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; /** * @author kalle * @since 2015-01-12 05:31 */ public class WekaClassifier extends Classifier { private weka.classifiers.Classifier classifier; private weka.core.Instances wekaTrainingData; private File trainingDataArffFile; @Override public void build() throws Exception { File file = trainingDataArffFile != null ? trainingDataArffFile : File.createTempFile("training_data", "arff"); try { Writer arff = new OutputStreamWriter(new FileOutputStream(file), "UTF8"); arff.append("@Relation color_distributions\n"); for (Color color : getPalette().getColors()) { String name = String.valueOf((char) ('a' + color.getAttributeIndex())); arff.append("@Attribute ").append(String.valueOf((char) ('a' + color.getAttributeIndex()))).append(" numeric\n"); } arff.append("@Attribute class {"); for (Iterator<String> iterator = new HashSet<String>(getTrainingData().values()).iterator(); iterator.hasNext(); ) { String classification = iterator.next(); arff.append("\"").append(classification).append("\""); if (iterator.hasNext()) { arff.append(", "); } } arff.append("}\n"); arff.append("@Data\n"); for (Map.Entry<Instance, String> instance : getTrainingData().entrySet()) { for (double value : instance.getKey().getHistogramPercent()) { arff.append(String.valueOf(value)).append(","); } arff.append("\"").append(instance.getValue()).append("\""); arff.append("\n"); } arff.close(); wekaTrainingData = new Instances(new InputStreamReader(new FileInputStream(file), "UTF8")); wekaTrainingData.setClass(wekaTrainingData.attribute("class")); classifier.buildClassifier(wekaTrainingData); } finally { //file.delete(); } } @Override public String classify(Instance instance) throws Exception { weka.core.Instance wekaInstance = new weka.core.Instance(wekaTrainingData.numAttributes()); wekaInstance.setDataset(wekaTrainingData); double[] histogramPercent = instance.getHistogramPercent(); for (int i = 0; i < histogramPercent.length; i++) { wekaInstance.setValue(i, histogramPercent[i]); } wekaInstance.setMissing(wekaTrainingData.attribute("class")); double wekaClassification = classifier.classifyInstance(wekaInstance); String classification = wekaTrainingData.attribute("class").value((int)wekaClassification); return classification; } public weka.classifiers.Classifier getClassifier() { return classifier; } public void setClassifier(weka.classifiers.Classifier classifier) { this.classifier = classifier; } public File getTrainingDataArffFile() { return trainingDataArffFile; } public void setTrainingDataArffFile(File trainingDataArffFile) { this.trainingDataArffFile = trainingDataArffFile; } }