/* * GPSkytrackAnalysis.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.coalescent; import dr.evolution.io.Importer; import dr.inference.model.Parameter; //import dr.evolution.io.NexusImporter; //import dr.evolution.io.TreeImporter; //import dr.evolution.tree.Tree; //import dr.inference.operators.CoercionMode; //import dr.inference.trace.AbstractTraceList; import dr.inference.trace.LogFileTraces; import dr.inference.trace.TraceException; import dr.evomodel.coalescent.operators.GaussianProcessSkytrackBlockUpdateOperator; //import dr.inference.trace.TraceFactory; import dr.stats.DiscreteStatistics; //import dr.util.FileHelpers; //import dr.util.HeapSort; import dr.util.TabularData; //import no.uib.cipr.matrix.SymmTridiagEVD; //import no.uib.cipr.matrix.*; import java.io.*; //import java.io.PrintWriter; import java.util.Arrays; import java.util.StringTokenizer; /** * @author Joseph Heled */ public class GPSkytrackAnalysis extends TabularData { // TabularData private final double[] xPoints; private final double[] means; private final double[] medians; private final double[] hpdLower; private final double[] hpdHigh; private final double [][] gValues; private final double [][] tValues; private final double [][] newGvalues; private final double [][] popValues; // final File gvalues = FileHelpers.getFile("gvalues.txt"); // final File locations = FileHelpers.getFile("locations.txt"); // private final double[] HPDLevels; private Parameter numGridPoints; // each bin covers xPoints[-1]/coalBins.length // private int[] coalBins; // private final boolean quantiles; // GaussianProcessSkytrackLikelihood gpLikelihood = (GaussianProcessSkytrackLikelihood) xo.getChild(GaussianProcessSkytrackLikelihood.class); // return new GaussianProcessSkytrackBlockUpdateOperator(gpLikelihood, weight, mode, scaleFactor, // maxIterations, stopValue); // TODO: Error in loadTraces() because a String {..} is being converted to "real/double" // To make my life more miserable I will not use logFileTraces class and do it by hand public GPSkytrackAnalysis(File log, double burnIn, Parameter numGridPoints) throws IOException, Importer.ImportException, TraceException { GaussianProcessSkytrackBlockUpdateOperator GPOperator=new GaussianProcessSkytrackBlockUpdateOperator(); this.numGridPoints=numGridPoints; LogFileTraces ltraces = new LogFileTraces(log.getCanonicalPath(), log); // ltraces.changeTraceType(1, TraceFactory.TraceType.STRING); ltraces.loadTraces(); // System.exit(-1); ltraces.setBurnIn(0); final int runLengthIncludingBurnin = ltraces.getStateCount(); int intBurnIn = (int) Math.floor(burnIn < 1 ? runLengthIncludingBurnin * burnIn : burnIn); final int nStates = runLengthIncludingBurnin - intBurnIn; ltraces.setBurnIn(intBurnIn * ltraces.getStepSize()); assert ltraces.getStateCount() == nStates; xPoints = new double[(int) numGridPoints.getParameterValue(0)+1]; means = new double[(int) numGridPoints.getParameterValue(0)+1]; medians = new double[(int) numGridPoints.getParameterValue(0)+1]; hpdHigh = new double[(int) numGridPoints.getParameterValue(0)+1]; hpdLower = new double[(int) numGridPoints.getParameterValue(0)+1]; int numbPointsColumn = -1; int gvaluesColumn=-1; int xvaluesColumn=-1; int lambdaColumn = -1; int precColumn = -1; int tmrcaColumn=-1; for (int n = 0; n < ltraces.getTraceCount(); ++n) { final String traceName = ltraces.getTraceName(n); System.err.println(traceName); if (traceName.equals("skyride.points")) { numbPointsColumn = n; } else if (traceName.equals("skyride.lambda_bound")) { lambdaColumn = n; } else if (traceName.equals("skyride.precision")) { precColumn = n; } else if (traceName.equals("skyride.tmrca")) { tmrcaColumn = n; } else if (traceName.equals("changePoints")){ xvaluesColumn=n; } else if (traceName.equals("Gvalues")){ gvaluesColumn=n; } } // System.err.println("columns"+tmrcaColumn+" tmrca"+xvaluesColumn+" and"+gvaluesColumn); if (numbPointsColumn < 0 || lambdaColumn < 0 || precColumn<0 || tmrcaColumn<0 || xvaluesColumn<0 || gvaluesColumn<0) { throw new TraceException("incorrect trace column names: unable to find correct columns for summary"); } // TODO: Check if it is ok to define the grid from 0 to max(TMRCA) always double binSize = 0; // double hSum = -0; // System.err.println("states"+nStates); int [] numPoints = new int[nStates]; double[] lambda = new double[nStates]; double[] kappa= new double[nStates]; double tmrca=0; // double binSize=0; double tempTmrca=0.0; int maxpts=0; for (int ns = 0; ns < nStates; ++ns) { lambda[ns]= (Double) ltraces.getTrace(lambdaColumn).getValue(ns); numPoints[ns]=(int)Math.round((Double) ltraces.getTrace(numbPointsColumn).getValue(ns)); kappa[ns]=(Double) ltraces.getTrace(precColumn).getValue(ns); tempTmrca=(Double) ltraces.getTrace(tmrcaColumn).getValue(ns); // System.err.println(tempTmrca); System.exit(-1); if (tempTmrca>tmrca){tmrca=tempTmrca;} if (numPoints[ns]>maxpts) {maxpts=numPoints[ns];} } binSize = tmrca / numGridPoints.getParameterValue(0); xPoints[0]=0.0; for (int np=1;np<xPoints.length;np++){ xPoints[np]=xPoints[np-1]+binSize; } gValues=new double[nStates][]; tValues=new double[nStates][]; newGvalues=new double[nStates][]; popValues=new double[(int) numGridPoints.getParameterValue(0)+1][]; readChain(gValues,"gvalues.txt"); readChain(tValues,"locations.txt"); for (int i=0;i<=numGridPoints.getParameterValue(0);i++){ popValues[i]=new double[nStates-1] ; } // for (int j=0;j<nStates-1;j++){ // newGvalues[j]=new double[numPoints[j]]; newGvalues[j]=GPOperator.getGPvaluesS(tValues[j], gValues[j], xPoints, kappa[j]); // popValues[j]=new double[nStates]; for (int i=0;i<=numGridPoints.getParameterValue(0);i++){ popValues[i][j]=(1+Math.exp(-newGvalues[j][i]))/lambda[j]; } } // //// // hpdLower = new double[HPDLevels.length][]; // hpdHigh = new double[HPDLevels.length][]; // for (int nx = 0; nx < xPoints.length; ++nx) { means[nx] = DiscreteStatistics.mean(popValues[nx]); medians[nx]=DiscreteStatistics.median(popValues[nx]); hpdLower[nx]=DiscreteStatistics.quantile(0.025,popValues[nx]); hpdHigh[nx]=DiscreteStatistics.quantile(0.975,popValues[nx]); } // // for (int i = 0; i < HPDLevels.length; ++i) { // if (quantiles) { // hpdLower[i][nx] = DiscreteStatistics.quantile((1 - HPDLevels[i]) / 2, popValues, indices); // hpdHigh[i][nx] = DiscreteStatistics.quantile((1 + HPDLevels[i]) / 2, popValues, indices); // } else { // final double[] hpd = DiscreteStatistics.HPDInterval(HPDLevels[i], popValues, indices); // hpdLower[i][nx] = hpd[0]; // hpdHigh[i][nx] = hpd[1]; // } // } // medians[nx] = DiscreteStatistics.median(popValues, indices); // } // // if( allDemoWriter != null ) { // for(double xPoint : xPoints) { // allDemoWriter.print(xPoint); // allDemoWriter.append(' '); // } // // for (int ns = 0; ns < nDataPoints; ++ns) { // allDemoWriter.println(); // for(double xPoint : xPoints) { // allDemoWriter.print(allDemog[ns].getDemographic(xPoint)); // allDemoWriter.append(' '); // } // } // allDemoWriter.close(); // } } public void readChain(double [][] current,String fileName){ try { BufferedReader br = new BufferedReader(new FileReader(fileName)); String line=null; int i=0; // System.err.println("will read line1"); while ((line=br.readLine())!=null){ String[] parts=line.split(" "); // System.err.println(i+"with cols:"+parts.length); current[i]=new double[parts.length]; for (int j=0;j<parts.length;j++){ current[i][j]=Double.parseDouble(parts[j]); } i++; } br.close(); } catch(java.io.IOException ioe){ System.err.println("IOException:"+ ioe.getMessage()); } } private final String[] columnNames = {"time", "mean", "median","lower","upper"}; public int nColumns() { return 5; } public String columnName(int nColumn) { // final int fixed = columnNames.length; // if (nColumn < fixed) { return columnNames[nColumn]; } // nColumn -= fixed; // if (nColumn < 2 * HPDLevels.length) { // final double p = HPDLevels[nColumn / 2]; // final String s = (nColumn % 2 == 0) ? "lower" : "upper"; // return (quantiles ? "cpd " : "hpd ") + s + " " + Math.round(p * 100); // } // assert (nColumn - 2 * HPDLevels.length) == 0; // return "bins"; // } public int nRows() { return (int) numGridPoints.getParameterValue(0)+1; } public Object data(int nRow, int nColumn) { switch (nColumn) { case 0: { if (nRow < xPoints.length) { return xPoints[nRow]; } break; } case 1: { if (nRow < means.length) { return means[nRow]; } break; } case 2: { if (nRow < medians.length) { return medians[nRow]; } break; } case 3: { if (nRow < hpdLower.length) { return hpdLower[nRow]; } break; } case 4: { if (nRow < hpdHigh.length) { return hpdHigh[nRow]; } break; } // default: { // final int j = nColumn - columnNames.length; // if (j < 2 * HPDLevels.length) { // if (nRow < xPoints.length) { // final int k = j / 2; // if (0 <= k && k < HPDLevels.length) { // if (j % 2 == 0) { // return hpdLower[k][nRow]; // } else { // return hpdHigh[k][nRow]; // } // } // } // } else { // if (nRow < coalBins.length) { // return coalBins[nRow]; // } // } // break; // } } return ""; } }