package fr.unistra.pelican.algorithms.segmentation; import java.awt.Point; import fr.unistra.pelican.Algorithm; import fr.unistra.pelican.AlgorithmException; import fr.unistra.pelican.BooleanImage; import fr.unistra.pelican.ByteImage; import fr.unistra.pelican.Image; import fr.unistra.pelican.IntegerImage; import fr.unistra.pelican.algorithms.arithmetic.Inversion; import fr.unistra.pelican.algorithms.conversion.GrayToPseudoColors; import fr.unistra.pelican.algorithms.segmentation.labels.DrawFrontiersOnImage; import fr.unistra.pelican.algorithms.segmentation.labels.FrontiersFromSegmentation; import fr.unistra.pelican.algorithms.segmentation.labels.LabelsToBinaryMasks; import fr.unistra.pelican.algorithms.spatial.TopographicTransform; import fr.unistra.pelican.algorithms.visualisation.Viewer2D; import fr.unistra.pelican.util.HierarchicalQueue; import fr.unistra.pelican.util.Memory; import fr.unistra.pelican.util.Point4D; /** * This class is a geodesic adaptation of the K-Means algorithm for iterative * image segmentation using distance transforms. * * @author Lefevre */ public class GeodesicKMeans extends Algorithm { /** * Input Image */ public Image inputImage; /** * Number of seeked clusters. */ public int clusters; /** * Output Image */ public Image outputImage; /** * Number of iterations, default 15. */ public int maxIterations = 15; /** * Convergence criterion (distance threshold), default 2 */ public int minDist = 2; /** * Flag to compute hue-based distance */ public boolean hue = false; private boolean DEBUG = true; /** * Constructor */ public GeodesicKMeans() { super.inputs = "inputImage,clusters"; super.options = "maxIterations,minDist,hue"; super.outputs = "outputImage"; } /** * A geodesic adaptation of the K-Means algorithm for iterative image * segmentation using distance transforms. * * @param inputImage * The input image * @param clusters * The number of seeked clusters * @return The output image */ public static Image exec(Image inputImage, int clusters) { return (Image) new GeodesicKMeans().process(inputImage, clusters); } public static Image exec(Image inputImage, int clusters, int iterations) { return (Image) new GeodesicKMeans().process(inputImage, clusters, iterations); } public static Image exec(Image inputImage, int clusters, int iterations, int minDist) { return (Image) new GeodesicKMeans().process(inputImage, clusters, iterations, minDist); } public static Image exec(Image inputImage, int clusters, int iterations, int minDist, boolean hue) { return (Image) new GeodesicKMeans().process(inputImage, clusters, iterations, minDist, hue); } /* * (non-Javadoc) * * @see fr.unistra.pelican.Algorithm#launch() */ public void launch() throws AlgorithmException { outputImage = new IntegerImage(inputImage.getXDim(), inputImage.getYDim(), inputImage.getZDim(), inputImage.getTDim(), 1); // inputImage = (Image) new AdditionConstantChecked().process(inputImage, // 1.0 / 255); Image labels = null, frontiers = null, distances = null, distances2 = null, globaldistances2 = null, draw = null; if (DEBUG) { labels = new IntegerImage(inputImage.getXDim(), inputImage.getYDim(), inputImage.getZDim(), maxIterations, 1); frontiers = new BooleanImage(inputImage.getXDim(), inputImage.getYDim(), inputImage.getZDim(), maxIterations, 1); distances = new ByteImage(inputImage.getXDim(), inputImage.getYDim(), inputImage.getZDim(), maxIterations, 3); distances.setColor(true); // FIXME : works only for 2D images ! distances2 = new ByteImage(inputImage.getXDim(), inputImage.getYDim(), clusters, 1, 3); distances2.setColor(true); globaldistances2 = new ByteImage(inputImage.getXDim(), inputImage .getYDim(), clusters, maxIterations, 3); draw = new ByteImage(inputImage.getXDim(), inputImage.getYDim(), inputImage.getZDim(), maxIterations, inputImage.getBDim()); draw.setColor(true); } int scale = Math.max(inputImage.getXDim(), inputImage.getYDim());// 500; double mem1 = Memory.totalUsedMemoryMB(); HierarchicalQueue queue = new HierarchicalQueue(scale * scale * 255); double mem2 = Memory.totalUsedMemoryMB(); System.out.println("Allocated memory for queue:" + (int) (mem2 - mem1) + " MB"); // Initialise cluster centers Point4D[] centers = new Point4D[clusters]; Point4D[] oldCenters = null, oldCenters2 = null; for (int c = 0; c < clusters; c++) centers[c] = new Point4D((int) (Math.random() * inputImage.getXDim()), (int) (Math.random() * inputImage.getYDim()), (int) Math.random() * inputImage.getZDim(), (int) (Math.random() * inputImage.getTDim())); // centers[0] = new Point4D(229,95, 0, 0); // centers[1] = new Point4D(341,135, 0, 0); // centers[2] = new Point4D(160,301, 0, 0); // centers[3] = new Point4D(121,269, 0, 0); // centers[4] = new Point4D(87,203, 0, 0); // centers[5] = new Point4D(101,65, 0, 0); boolean trueDistance = true; boolean borders = false; // Repeat the process until convergence boolean convergence = false; int i = 0; long t1 = System.currentTimeMillis(); for (i = 0; i < maxIterations && !convergence; i++) { // Set the centers on the input image Image work = inputImage.copyImage(true); /* * for (int c = 0; c < clusters; c++) for (int b = 0; b < * inputImage.getBDim(); b++) work.setPixelXYBByte(centers[c].x, * centers[c].y, b, 0); */ // Perform geodesic computation Image im = GeodesicDistanceBasedWatershed.exec(work, centers, trueDistance, hue, queue); outputImage = im.getImage4D(0, Image.B); if (DEBUG) { distances.setImage4D(DrawFrontiersOnImage.exec(GrayToPseudoColors .exec(im.getImage4D(1, Image.B)), FrontiersFromSegmentation.exec(im .getImage4D(0, Image.B))), i, Image.T); labels.setImage4D(im.getImage4D(0, Image.B), i, Image.T); frontiers.setImage4D((Image) new FrontiersFromSegmentation().process(im .getImage4D(0, Image.B)), i, Image.T); draw.setImage4D(DrawFrontiersOnImage.exec(inputImage, (BooleanImage) frontiers.getImage4D(i, Image.T)), i, Image.T); } // Compute the center of each region // oldCenters = centers.clone(); if (oldCenters != null) oldCenters2 = copy(oldCenters); oldCenters = copy(centers); // for (int cpt=0;cpt<oldCenters.length;cpt++) { // System.err.println(oldCenters[cpt]+" "+centers[cpt]); // } // centre de gravité standard (k-means) // centers = trim((Point[]) new RegionCenter().process(labels)); // recherche du central comme le max de la TD au fond // Image binary=LabelsToBinaryMasks.exec(outputImage); // for (int c=0;c<clusters;c++) { // Image cluster=Inversion.exec(binary.getImage4D(c,Image.B)); // cluster=DistanceTransform.exec(cluster,true); // //Viewer2D.exec(GrayToPseudoColors.exec(cluster)); // int max=0; // for(int x=0;x<cluster.getXDim();x++) // for(int y=0;y<cluster.getYDim();y++) // if (cluster.getPixelXYInt(x, y)>max) { // centers[c].x=x; // centers[c].y=y; // max=cluster.getPixelXYInt(x, y); // } // } Image binary = LabelsToBinaryMasks.exec(outputImage); if (binary.getBDim() < clusters) { clusters = binary.getBDim(); System.err.println("Less clusters : " + clusters); } if (DEBUG) distances2.fill(0); // Viewer2D.exec(binary,"clusters, # "+i); for (int c = 0; c < clusters; c++) { Image cluster = Inversion.exec(binary.getImage4D(c, Image.B)); cluster = TopographicTransform.exec(inputImage, (BooleanImage) cluster, trueDistance, borders, hue, queue); if (DEBUG) distances2.setImage4D(GrayToPseudoColors.exec(cluster), c, Image.Z); // Viewer2D.exec(GrayToPseudoColors.exec(cluster)); int max = 0; for (int x = 0; x < cluster.getXDim(); x++) for (int y = 0; y < cluster.getYDim(); y++) if (cluster.getPixelXYInt(x, y) > max) { centers[c].x = x; centers[c].y = y; max = cluster.getPixelXYInt(x, y); } } // System.err.println("==="); // for (int cpt=0;cpt<oldCenters.length;cpt++) // System.err.println(oldCenters[cpt]+" "+centers[cpt]); if (DEBUG) globaldistances2.setImage4D(distances2, i, Image.T); // Viewer2D.exec(distances2,"distance aux centres, # "+i); if (centers.length != clusters) System.err.println("Probleme :" + centers.length + " != " + clusters); // Evaluate modifications System.out.println(); for (int c = 0; c < clusters; c++) System.out.println(i + "/" + c + ":" + oldCenters[c].x + "," + oldCenters[c].y + " =>" + centers[c].x + "," + centers[c].y + " ==>" + centers[c].distance(oldCenters[c])); int dist = 0; for (int c = 0; c < clusters; c++) dist += centers[c].distance(oldCenters[c]); System.out.println(i + ": convergence ? " + dist + " <= " + (minDist * clusters)); if (dist <= minDist * clusters) convergence = true; // Additional check for oscillation if (oldCenters2 != null) { dist = 0; for (int c = 0; c < clusters; c++) dist += centers[c].distance(oldCenters2[c]); if (dist <= minDist * clusters) convergence = true; } } long t2 = System.currentTimeMillis(); System.out.println("GeodesicKMeans performed in " + (t2 - t1) + " ms (" + (t2 - t1) / (i) + " ms / iteration)"); // queue.clear(); // queue = null; // System.gc(); if (DEBUG) { Viewer2D.exec(distances, "distances"); globaldistances2.setColor(true); Viewer2D.exec(globaldistances2, "distances2"); // Viewer2D.exec(LabelsToRandomColors.exec(labels), "labels"); // Viewer2D.exec(frontiers, "frontiers"); // Viewer2D.exec(draw, "draw"); } } private Point4D[] copy(Point4D[] tab) { Point4D[] res = new Point4D[tab.length]; for (int t = 0; t < tab.length; t++) res[t] = new Point4D(tab[t]); return res; } private Point[] trim(Point[] tab) { int nb = 0; for (int t = 0; t < tab.length; t++) if (tab[t] != null) nb++; Point[] res = new Point[nb]; nb = 0; for (int t = 0; t < tab.length; t++) if (tab[t] != null) res[nb++] = tab[t]; return res; } }