package fr.unistra.pelican.algorithms.segmentation.weka; import weka.clusterers.Clusterer; import weka.clusterers.EM; import weka.core.FastVector; import weka.core.Instance; import weka.core.Instances; import weka.filters.Filter; import weka.filters.unsupervised.instance.Resample; import fr.unistra.pelican.Algorithm; import fr.unistra.pelican.AlgorithmException; import fr.unistra.pelican.DoubleImage; import fr.unistra.pelican.Image; import fr.unistra.pelican.InvalidNumberOfParametersException; import fr.unistra.pelican.InvalidTypeOfParameterException; import fr.unistra.pelican.algorithms.histogram.ContrastStretch; import fr.unistra.pelican.algorithms.io.ImageLoader; import fr.unistra.pelican.algorithms.visualisation.Viewer2D; /** * Perform a soft segmentation using a Weka algorithm. Each band represents a * attribute. * @author Sébastien Derivaux */ public class WekaSoftSegmentation extends Algorithm { // Inputs parameters public Image inputImage; public Clusterer clusterer; // Outputs parameters public Image outputImage; public final int MAX_LEARNING = 5000; /** * Constructor * */ public WekaSoftSegmentation() { super(); super.inputs = "inputImage,clusterer"; super.outputs = "outputImage"; } /* * (non-Javadoc) * * @see fr.unistra.pelican.Algorithm#launch() */ public void launch() throws AlgorithmException { int xDim = inputImage.getXDim(); int yDim = inputImage.getYDim(); int bDim = inputImage.getBDim(); // Creation of the datas for Wek. // Create attributes. FastVector attributes = new FastVector(bDim); for (int i = 0; i < bDim; i++) attributes.addElement(new weka.core.Attribute("bande" + i)); Instances dataset = new Instances("dataset", attributes, 0); for (int x = 0; x < xDim; x++) for (int y = 0; y < yDim; y++) { Instance instance = new Instance(dataset.numAttributes()); for (int b = 0; b < bDim; b++) instance.setValue(b, inputImage.getPixelXYBDouble(x, y, b)); instance.setDataset(dataset); dataset.add(instance); } // Learn the classification try { Instances learningSet = dataset; if (learningSet.numInstances() > MAX_LEARNING) { Resample filter = new Resample(); filter.setRandomSeed((int) System.currentTimeMillis()); filter.setSampleSizePercent((double) MAX_LEARNING * 100.0 / (double) learningSet.numInstances()); filter.setInputFormat(learningSet); learningSet = Filter.useFilter(learningSet, filter); System.out .println("INFO : WekaSoftSegmentation : numInstances = " + learningSet.numInstances()); } clusterer.buildClusterer(learningSet); outputImage = new DoubleImage(inputImage.getXDim(), inputImage .getYDim(), 1, 1, clusterer.numberOfClusters()); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } for (int x = 0; x < xDim; x++) for (int y = 0; y < yDim; y++) { Instance instance = new Instance(dataset.numAttributes()); for (int b = 0; b < bDim; b++) instance.setValue(b, inputImage.getPixelXYBDouble(x, y, b)); instance.setDataset(dataset); double distrib[] = null; try { distrib = clusterer.distributionForInstance(instance); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } for (int b = 0; b < distrib.length; b++) outputImage.setPixelXYBDouble(x, y, b, distrib[b]); } } public static void main(String[] args) { String file = "samples/remotesensing1.png"; if (args.length > 0) file = args[0]; try { // Load the image Image source = (Image) new ImageLoader().process(file); new Viewer2D().process(source, "Image " + file); EM clusterer = new EM(); try { clusterer.setNumClusters(3); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } clusterer.setSeed((int) System.currentTimeMillis()); Image work = (Image) new WekaSoftSegmentation().process(source, clusterer); // View it new Viewer2D().process(new ContrastStretch().process(work), "Soft clusters from " + file); } catch (InvalidTypeOfParameterException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (AlgorithmException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (InvalidNumberOfParametersException e) { // TODO Auto-generated catch block e.printStackTrace(); } } }