/** * Copyright (c) 2013 Oculus Info Inc. * http://www.oculusinfo.com/ * * Released under the MIT License. * * Permission is hereby granted, free of charge, to any person obtaining a copy of * this software and associated documentation files (the "Software"), to deal in * the Software without restriction, including without limitation the rights to * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies * of the Software, and to permit persons to whom the Software is furnished to do * so, subject to the following conditions: * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ package spimedb.cluster.tracks; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import spimedb.cluster.DataSet; import spimedb.cluster.Instance; import spimedb.cluster.feature.spatial.TrackFeature; import spimedb.cluster.feature.spatial.centroid.TrackCentroid; import spimedb.cluster.feature.spatial.distance.TrackDistance; import spimedb.cluster.unsupervised.cluster.Cluster; import spimedb.cluster.unsupervised.cluster.ClusterResult; import spimedb.cluster.unsupervised.cluster.kmeans.KMeans; import spimedb.cluster.validation.unsupervised.external.BCubed; import spimedb.util.geom.cartesian.CubicBSpline; import spimedb.util.geom.geodesic.Position; import spimedb.util.geom.geodesic.PositionCalculationParameters; import spimedb.util.geom.geodesic.PositionCalculationType; import spimedb.util.geom.geodesic.Track; import spimedb.util.geom.geodesic.tracks.GeodeticTrack; import spimedb.util.math.linearalgebra.Vector; import java.awt.*; import java.awt.geom.Rectangle2D; import java.util.ArrayList; import java.util.List; import java.util.Random; public class TestTrackCluster { protected static final Logger log = LoggerFactory.getLogger("TestTrackCluster.class"); private static final PositionCalculationParameters GEODETIC_PARAMETERS = new PositionCalculationParameters(PositionCalculationType.Geodetic, 0.0001, 1E-8, false); public void clusterRandomTracks (int N, int T, int P, boolean visible) { TrackFrame frame = new TrackFrame(); DataSet ds = new DataSet(); log.info("Creating {} track splines as bases", N); // Create N track bases Rectangle2D bounds = new Rectangle2D.Double(12.5, 6.172, 1.521, 2.5); frame.setDrawingBounds(bounds); CubicBSpline[] trackBases = new CubicBSpline[N]; for (int i=0; i<N; ++i) { trackBases[i] = randomBoundedSpline(bounds); frame.addSpline(trackBases[i]); } log.info("Creating {} track per track bases", T); // Now create a hundred random tracks from each List<Track> tracks = new ArrayList<Track>(); for (int j=0; j<T; ++j) { for (int i=0; i<N; ++i) { Track track = new GeodeticTrack(GEODETIC_PARAMETERS, randomPoints(trackBases[i], P)); // add the track as an instance to the dataset Instance inst = new Instance(); TrackFeature feature = new TrackFeature("track"); feature.setValue(track); inst.addFeature(feature); inst.setClassLabel( Integer.toString(i)); // set the generating spline as the class label for verification ds.add(inst); tracks.add(track); } } log.info("Clustering {} tracks", ds.size()); // Cluster the tracks in N clusters. KMeans clusterer = new KMeans(N, 10, false); clusterer.registerFeatureType( "track", TrackCentroid.class, new TrackDistance(1.0)); ClusterResult clusterResults = clusterer.doCluster(ds); frame.addClusters(clusterResults); log.info("Validating track clusters"); // validate the clustering BCubed validator = new BCubed(); validator.validate(frame.getClusters()); log.info("Precision: {} ", validator.getPrecision()); log.info("Recall: {} ", validator.getRecall()); log.info("BCubed: {} ", validator.getBCubed()); if (visible) frame.showAndWait(); } private CubicBSpline randomBoundedSpline (Rectangle2D bounds) { double width = bounds.getWidth(); double height = bounds.getHeight(); double minX = bounds.getMinX(); double minY = bounds.getMinY(); int N = (int) Math.round(3 + Math.random() * 3); // Get our times double[] times = new double[N]; for (int i=0; i<N; ++i) { if (0 == i) times[i] = 0.0; else times[i] = times[i-1] + Math.random(); } double totalTime = times[N-1]; List<Vector> points = new ArrayList<Vector>(); for (int i = 0; i < N; ++i) { times[i] = times[i] / totalTime; points.add(new Vector(minX + Math.random() * width, minY + Math.random() * height)); } return CubicBSpline.fit(times, points); } private List<Position> randomPoints (CubicBSpline basis, int N) { N = (int) Math.round(N + N*(Math.random()-Math.random())/2.0); double[] times = new double[N]; for (int i=0; i<N; ++i) { if (0 == i) times[i] = 0; else times[i] = times[i-1]+Math.random(); } double totalTime = times[N-1]; List<Position> points = new ArrayList<Position>(N); for (int i=1; i<N-1; ++i) { double t = times[i]/totalTime; Vector point = basis.getPoint(t); points.add(new Position(point.coord(0), point.coord(1))); } return points; } private class TrackFrame extends TestFrame { private static final long serialVersionUID = 1L; private final List<CubicBSpline> _splines; private final List<Cluster> _clusters; private Rectangle2D _drawingBounds; public TrackFrame () { _splines = new ArrayList<CubicBSpline>(); _clusters = new ArrayList<Cluster>(); _drawingBounds = null; } public void setDrawingBounds (Rectangle2D bounds) { _drawingBounds = bounds; } public void addSpline (CubicBSpline spline) { _splines.add(spline); } public void addClusters (Iterable<Cluster> clusters) { for (Cluster cluster: clusters) { _clusters.add(cluster); } } public List<Cluster> getClusters () { return _clusters; } private Rectangle2D getDrawingBounds () { if (null != _drawingBounds) return _drawingBounds; double minX = Double.MAX_VALUE; double maxX = -Double.MAX_VALUE; double minY = Double.MAX_VALUE; double maxY = -Double.MAX_VALUE; // For now, brute-force it; we'll put in max/min functions later. for (CubicBSpline s: _splines) { for (double t=0; t<=1.0; t += 0.01) { Vector v = s.getPoint(t); if (v.coord(0) < minX) minX = v.coord(0); if (v.coord(0) > maxX) maxX = v.coord(0); if (v.coord(1) < minY) minY = v.coord(1); if (v.coord(1) > maxY) maxY = v.coord(1); } } return new Rectangle2D.Double(minX, minY, maxX-minX, maxY-minY); } @Override public void paint (Graphics g) { Color[] colors = new Color[_splines.size()]; Random rnd = new Random(); // randomly associate a color with each spline for (int i = 0; i < colors.length; i++) { colors[i] = new Color(rnd.nextInt(255), rnd.nextInt(255), rnd.nextInt(255)); } Graphics2D g2 = (Graphics2D) g; Dimension size = getSize(); Rectangle2D bounds = getDrawingBounds(); bounds = new Rectangle2D.Double(bounds.getMinX()-bounds.getWidth()/10.0, bounds.getMinY()-bounds.getHeight()/10.0, bounds.getWidth()*1.2, bounds.getHeight()*1.2); double pixelsPerUnit = Math.min(size.width/bounds.getWidth(), size.height/bounds.getHeight()); double zeroX = bounds.getMinX(); double zeroY = bounds.getMaxY(); g2.setColor(Color.WHITE); g2.fillRect(0, 0, size.width, size.height); // draw the ground truth splines for reference for (int i=0; i<_splines.size(); ++i) { g2.setColor( colors[i] ); g2.setStroke(new BasicStroke(4.0f)); CubicBSpline spline = _splines.get(i); spline.draw(g2, zeroX, zeroY, pixelsPerUnit, 100); } int i = 0; // draw the actual tracks - color by cluster they are assigned to for (Cluster cluster : _clusters) { Color color = colors[i]; i++; for (Instance inst: cluster.getMembers()) { Track track = ((TrackFeature)inst.getFeature("track")).getValue(); g2.setColor(color.darker()); g2.setStroke(new BasicStroke(1.0f)); Point ptLast = null; for (Position p: track.getPoints()) { Point pt = new Point((int) Math.round((p.getLongitude()-zeroX)*pixelsPerUnit), (int) Math.round((zeroY-p.getLatitude())*pixelsPerUnit)); if (null != ptLast) g.drawLine(ptLast.x, ptLast.y, pt.x, pt.y); ptLast = pt; } } } } } public static void main(String[] args) { TestTrackCluster test = new TestTrackCluster(); test.clusterRandomTracks(5, 10, 30, true); } }