/* * DensityMap.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.app.tools; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.continuous.MultivariateDiffusionModel; import dr.inference.model.MatrixParameter; import dr.math.matrixAlgebra.Matrix; import dr.util.TIFFWriter; import java.io.DataOutputStream; import java.io.FileOutputStream; /** * @author Marc Suchard * Based on the class CalculateSplitRates in JEBL */ class DensityMap { private final String SEP = "\t"; private final String DBL = "%5.4f"; private int binX; private int binY; private int[][] data; private double[] average; private double[] testStatistic; private int[] counts; private int[] singleTreeCounts; private int count; private double startX; private double endX; private double startY; private double endY; private double scaleX; private double scaleY; private double minX = Double.POSITIVE_INFINITY; private double maxX = Double.NEGATIVE_INFINITY; private double minY = Double.POSITIVE_INFINITY; private double maxY = Double.NEGATIVE_INFINITY; private double upperX; private double lowerX; private double upperY; private double lowerY; private boolean jointDensity = false; private boolean isCalibrated = false; private int slice; public DensityMap(int slice, int binX, int binY, double upperX, double lowerX, double upperY, double lowerY, boolean logScale) { setUp(slice, binX, binY, upperX, lowerX, upperY, lowerY, logScale); } private void setUp(int slice, int binX, int binY, double upperX, double lowerX, double upperY, double lowerY, boolean logScale) { this.slice = slice; this.binX = binX; this.binY = binY; data = new int[binX][binY]; counts = new int[binX]; count = 0; this.upperX = upperX; this.lowerX = lowerX; this.upperY = upperY; this.lowerY = lowerY; this.logScale = logScale; } public DensityMap(int binX, int binY, double[] ptX, double[] ptY) { // double maxX = Double.MIN_VALUE; // double maxY = Double.MIN_VALUE; // double minX = Double.MAX_VALUE; // double minY = Double.MAX_VALUE; if (ptX.length != ptY.length) { throw new RuntimeException("Attempting to construct an unbalanced DensityMap"); } int len = ptX.length; for (int i = 0; i < len; i++) { if (ptX[i] < minX) minX = ptX[i]; if (ptY[i] < minY) minY = ptY[i]; if (ptX[i] > maxX) maxX = ptX[i]; if (ptY[i] > maxY) maxY = ptY[i]; } double scaleX = (maxX - minX) / (double) (binX - 1); // -1 necessary to ensure that maxValue falls in the last box double scaleY = (maxY - minY) / (double) (binY - 1); minX -= 1.5 * scaleX; // Ensures that all boarder cells have zero mass maxX += 0.5 * scaleX; minY -= 1.5 * scaleY; maxY += 0.5 * scaleY; setUp(0, binX, binY, maxX, minX, maxY, minY, false); // this.minX = checkCalibration(); for (int i = 0; i < len; i++) { addPoint(ptX[i], ptY[i]); // System.err.println("Added: "+ptX[i]+":"+ptY[i]); } } public double[] getXMidPoints() { if (!isCalibrated) throw new RuntimeException("Density map is not calibrated"); double[] pts = new double[binX]; pts[0] = startX + 0.5 * scaleX; for (int i = 1; i < binX; i++) pts[i] = pts[i - 1] + scaleX; return pts; } public double[] getYMidPoints() { if (!isCalibrated) throw new RuntimeException("Density map is not calibrated"); double[] pts = new double[binY]; pts[0] = startY + 0.5 * scaleY; for (int i = 1; i < binY; i++) pts[i] = pts[i - 1] + scaleY; return pts; } private boolean logScale = false; public void setLogScale(boolean logscale) { this.logScale = logscale; } private double transform(double d) { if (logScale) return Math.log(d); return d; } public void calibrate(Tree tree, String attributeName) { boolean foundAttribute = false; if (isCalibrated) { throw new RuntimeException("Already calibrated"); } if (jointDensity) { throw new RuntimeException("Already calibrated as a joint density map"); } double height = tree.getNodeHeight(tree.getRoot()); if (height > maxX) { maxX = height; } minX = 0.0; for (int i = 0; i < tree.getNodeCount(); i++) { NodeRef node = tree.getNode(i); if (node != tree.getRoot()) { Double value = (Double) tree.getNodeAttribute(node, attributeName); if (value != null) { value = transform(value); if (value < minY) minY = value; if (value > maxY) maxY = value; foundAttribute = true; } } } if (!foundAttribute) { throw new RuntimeException("Can't find any attributes, " + attributeName + ", in tree " + tree.getId()); } // System.out.println("Calibrated to:"); // System.out.println("\tminX = "+minX); // System.out.println("\tmaxX = "+maxX); // System.out.println("\tminY = "+minY); // System.out.println("\tmaxY = "+maxY); } public void calibrate(Tree tree, String attributeName1, String attributeName2) { boolean foundAttribute1 = false; boolean foundAttribute2 = false; jointDensity = true; if (isCalibrated) { throw new RuntimeException("Already calibrated"); } // double height = tree.getNodeHeight(tree.getRoot()); // if (height > maxX) { // maxX = height; // } for (int i = 0; i < tree.getNodeCount(); i++) { NodeRef node = tree.getNode(i); if (node != tree.getRoot()) { Double value = (Double) tree.getNodeAttribute(node, attributeName1); if (value != null) { value = transform(value); if (value < minX) minX = value; if (value > maxX) maxX = value; foundAttribute1 = true; } value = (Double) tree.getNodeAttribute(node, attributeName2); if (value != null) { value = transform(value); if (value < minY) minY = value; if (value > maxY) maxY = value; foundAttribute2 = true; } } } if (!foundAttribute1) { throw new RuntimeException("Can't find any attributes, " + attributeName1 + ", in tree " + tree.getId()); } if (!foundAttribute2) { throw new RuntimeException("Can't find any attributes, " + attributeName2 + ", in tree " + tree.getId()); } // System.err.printf("Calibrated: minY = %3.2f, maxY = %3.2f, minX = %3.2f, maxX = %3.2f\n",minY,maxY,minX,maxX); // System.exit(-1); } public void addTree(Tree tree, String attributeName) { checkCalibration(); for (int i = 0; i < tree.getNodeCount(); i++) { NodeRef node = tree.getNode(i); if (node != tree.getRoot()) { Double value = (Double) tree.getNodeAttribute(node, attributeName); if (value != null) { value = transform(value); addBranch(tree.getNodeHeight(node), tree.getNodeHeight(tree.getParent(node)), value); } } } } private void checkCalibration() { if (!isCalibrated) { startX = minX; if (lowerX != Double.NEGATIVE_INFINITY) { startX = lowerX; } endX = maxX; if (upperX != Double.POSITIVE_INFINITY) { endX = upperX; } startY = minY; if (lowerY != Double.NEGATIVE_INFINITY) { startY = lowerY; } endY = maxY; if (upperY != Double.POSITIVE_INFINITY) { endY = upperY; } scaleX = (endX - startX) / (double) (binX - 1); // -1 necessary to ensure that maxValue falls in the last box scaleY = (endY - startY) / (double) (binY - 1); isCalibrated = true; } } public void addTree(Tree tree, double sampleTime, String attributeName1, String attributeName2) { checkCalibration(); double[][] variance = null; Object[] obj = (Object[]) tree.getAttribute(MultivariateDiffusionModel.PRECISION_TREE_ATTRIBUTE); if (obj != null) { variance = new Matrix( MatrixParameter.parseFromSymmetricDoubleArray(obj).getParameterAsMatrix() ).inverse().toComponents(); } for (int i = 0; i < tree.getNodeCount(); i++) { NodeRef node = tree.getNode(i); if (node != tree.getRoot()) { NodeRef parent = tree.getParent(node); double t1 = tree.getNodeHeight(node); double t2 = tree.getNodeHeight(parent); if (t1 <= sampleTime && t2 >= sampleTime) { Double valueX1 = transform((Double) tree.getNodeAttribute(node, attributeName1)); Double valueY1 = transform((Double) tree.getNodeAttribute(node, attributeName2)); Double valueX2 = transform((Double) tree.getNodeAttribute(parent, attributeName1)); Double valueY2 = transform((Double) tree.getNodeAttribute(parent, attributeName2)); if (valueX1 != null && valueY1 != null && valueX2 != null && valueY2 != null) { addPoint(sampleTime, t1, t2, valueX1, valueY1, valueX2, valueY2, variance); } } } } } public int[][] getDensityMap() { return data; } private void addBranch(double start, double end, double y) { if (start >= endX || end <= startX) { // branch is outside bounds... return; } if (y > endY || y < startY) { // value is outside bounds... return; } // clip the branch to the bounds if (start < startX) { start = startX; } if (end > endX) { end = endX; } // determine bin for y int Y = (int) ((y - startY) / scaleY); // determine start and end bin for x int START = (int) ((start - startX) / scaleX); int END = (int) ((end - startX) / scaleX); // System.out.println(start+":"+end+" -> "+START+":"+END); for (int i = START; i <= END; i++) { data[i][Y] += 1; counts[i] += 1; // average[i] += y; } } private void addPoint(double x, double y) { if (x > endX || x < startX || y > endY || y < startY) return; int X = (int) ((x - startX) / scaleX); int Y = (int) ((y - startY) / scaleY); data[X][Y] += 1; count += 1; } private void addPoint(double t, double startTime, double endTime, double x0, double y0, double x1, double y1, double[][] variance) { double t0 = t - startTime; double t1 = endTime - t; double x, y; if (t0 == 0) { x = x0; y = y0; } else if (t1 == 0) { x = x1; y = y1; } else { x = ((x0 / t0) + (x1 / t1)) / ((1.0 / t0) + (1.0 / t1)); y = ((y0 / t0) + (y1 / t1)) / ((1.0 / t0) + (1.0 / t1)); if (variance != null) { // todo add stochastic noise } } if (x > endX || x < startX || y > endY || y < startY) { // point is outside bounds... return; } // determine bin for x int X = (int) ((x - startX) / scaleX); // determine bin for y int Y = (int) ((y - startY) / scaleY); data[X][Y] += 1; count += 1; } public String toString() { return toString(true); } public void writeAsTIFF(String fileName) { double[][] matrix = normalize(255); try { DataOutputStream tiffOut = new DataOutputStream(new FileOutputStream(fileName)); TIFFWriter.writeDoubleArray(tiffOut, matrix); tiffOut.close(); } catch (Exception e) { throw new RuntimeException(e.getMessage()); } } public double[][] getNormalizedDensity(double max) { if (!isCalibrated) throw new RuntimeException("Density map is not yet calibrated"); return normalize(max); } private double[][] normalize(double max) { double[][] matrix = new double[binX][binY]; double maxValue = 0; for (int i = 0; i < binY; i++) { for (int j = 0; j < binX; j++) { if (data[j][i] > maxValue) { maxValue = data[j][i]; } } } for (int i = 0; i < binY; i++) { for (int j = 0; j < binX; j++) { matrix[j][i] = ((double) data[j][i] / maxValue) * max; // double dblCount; // if (jointDensity) { // dblCount = (double) count; // } else { // dblCount = (double) counts[j]; // } // if (dblCount > 0) { // matrix[j][i] = (double) data[j][i] / dblCount * max; // } else { // matrix[j][i] = 0.0; // } } } return matrix; } public String toString(boolean printHeaders) { StringBuilder sb = new StringBuilder(); if (printHeaders) { sb.append(String.format("%7.5f", (double) slice)); // todo should return 3rd dimension coordinate for (int i = 0; i < binX; i++) { sb.append(SEP); sb.append(String.format("%7.5f", startX + scaleX * i)); } sb.append("\n"); } double[][] matrix = normalize(1.0); for (int i = 0; i < binY; i++) { if (printHeaders) sb.append(String.format("%7.5f", startY + scaleY * i)); for (int j = 0; j < binX; j++) { if (j > 0 || printHeaders) sb.append(SEP); // double dblCount; // if (jointDensity) { // dblCount = (double) count; // } else { // dblCount = (double) counts[j]; // } // if (dblCount > 0) { // sb.append(String.format(DBL, // (double) data[j][i] / dblCount // )); // } else { // sb.append(String.format(DBL, 0.0)); // } sb.append(String.format(DBL, matrix[j][i])); } sb.append("\n"); } return sb.toString(); } }