package fr.unistra.pelican.algorithms.segmentation.weka; import java.util.Random; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.functions.MultilayerPerceptron; import weka.core.FastVector; import weka.core.Instance; import weka.core.Instances; import weka.filters.Filter; import weka.filters.supervised.instance.Resample; import weka.filters.unsupervised.instance.Randomize; import fr.unistra.pelican.Algorithm; import fr.unistra.pelican.AlgorithmException; import fr.unistra.pelican.Image; import fr.unistra.pelican.IntegerImage; import fr.unistra.pelican.algorithms.io.ImageLoader; import fr.unistra.pelican.algorithms.io.SamplesLoader; import fr.unistra.pelican.algorithms.segmentation.labels.LabelsToColorByMeanValue; import fr.unistra.pelican.algorithms.visualisation.Viewer2D; /** * Perform a classification using a Weka algorithm. Each band from input * represents a attribute. Each band from samples represent a class exemples. * @author Sᅵbastien Derivaux */ public class WekaClassification extends Algorithm { // Inputs parameters public Image inputImage; public Classifier classifier; public Image samples; public boolean stats=false; // Outputs parameters public Image outputImage; /** * Constructor * */ public WekaClassification() { super.inputs = "inputImage,classifier,samples"; super.options="stats"; super.outputs = "outputImage"; } public static Image exec(Image inputImage, Classifier classifier,Image samples) { return (Image) new WekaClassification().process(inputImage, classifier,samples); } public static Image exec(Image inputImage, Classifier classifier,Image samples,boolean stats) { return (Image) new WekaClassification().process(inputImage, classifier,samples,stats); } /* * (non-Javadoc) * * @see fr.unistra.pelican.Algorithm#launch() */ public void launch() throws AlgorithmException { outputImage = new IntegerImage(inputImage.getXDim(), inputImage .getYDim(), 1, 1, 1); int xDim = inputImage.getXDim(); int yDim = inputImage.getYDim(); int bDim = inputImage.getBDim(); // Creation of the datas for Weka. // Create attributes. FastVector attributes = new FastVector(bDim); for (int i = 0; i < bDim; i++) attributes.addElement(new weka.core.Attribute("bande" + i)); // Add class attribute. FastVector classValues = new FastVector(10); for (int b = 0; b < samples.getBDim(); b++) classValues.addElement("class" + b); attributes.addElement(new weka.core.Attribute("label", classValues)); Instances dataset = new Instances("dataset", attributes, 0); dataset.setClassIndex(attributes.size() - 1); // Put learning samples in the dataset for (int x = 0; x < xDim; x++) for (int y = 0; y < yDim; y++) for (int c = 0; c < samples.getBDim(); c++) if (samples.getPixelXYBBoolean(x, y, c) == true) { Instance instance = new Instance(dataset .numAttributes()); for (int b = 0; b < bDim; b++) instance.setValue(b, inputImage.getPixelXYBDouble( x, y, b)); instance.setDataset(dataset); instance.setClassValue((double) c); dataset.add(instance); } // Filter a little the dataset try { // Radomise presentation Filter filter = new Randomize(); filter.setInputFormat(dataset); dataset = Filter.useFilter(dataset, filter); // Resample or not resample to uniformize class distribution?? Resample resample = new Resample(); resample.setBiasToUniformClass(1.0); resample.setSampleSizePercent(100.0); resample.setInputFormat(dataset); resample.setRandomSeed(123); dataset = Filter.useFilter(dataset, resample); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } // Learn the classification try { Evaluation eval = new Evaluation(dataset); eval.crossValidateModel(classifier, dataset, 10, new Random()); if(stats) { System.out.println(eval.toMatrixString()); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString()); } classifier.buildClassifier(dataset); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } for (int x = 0; x < xDim; x++) for (int y = 0; y < yDim; y++) { Instance instance = new Instance(dataset.numAttributes()); for (int b = 0; b < bDim; b++) instance.setValue(b, inputImage.getPixelXYBDouble(x, y, b)); instance.setDataset(dataset); int label = -1; try { label = (int) classifier.classifyInstance(instance); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } outputImage.setPixelXYInt(x, y, label); } } public static void main(String[] args) { String file = "samples/remotesensing1.png"; String samplesPath = "samples/remotesensing1"; if (args.length > 0) file = args[0]; Image source = (Image) new ImageLoader().process(file); new Viewer2D().process(source, "Image " + file); Image samples = (Image) new SamplesLoader().process(samplesPath); new Viewer2D().process(samples, "Samples of" + file); MultilayerPerceptron classifier = new MultilayerPerceptron(); Image work = (Image) new WekaClassification() .process(source, classifier, samples); // View It! new Viewer2D().process(new LabelsToColorByMeanValue().process(work, source), "Classification for " + file); } }