package statalign.postprocess.plugins.structalign; import java.awt.BorderLayout; import java.awt.Color; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import javax.swing.Icon; import javax.swing.ImageIcon; import javax.swing.JPanel; import org.apache.commons.math3.geometry.euclidean.threed.Rotation; import org.apache.commons.math3.geometry.euclidean.threed.Vector3D; import org.apache.commons.math3.util.Pair; import statalign.base.InputData; import statalign.base.McmcStep; import statalign.base.State; import statalign.base.Utils; import statalign.mcmc.McmcMove; import statalign.model.ext.ModelExtManager; import statalign.model.ext.ModelExtension; import statalign.model.ext.plugins.StructAlign; import statalign.postprocess.Postprocess; import statalign.postprocess.Track; import statalign.postprocess.gui.StructAlignTraceGUI; import statalign.postprocess.plugins.CurrentAlignment; import statalign.postprocess.plugins.MpdAlignment; public class RmsdTrace extends Postprocess { public StructAlign structAlign; /** For adding to the MPD alignment panel in the GUI. */ Track rmsdTrack = new Track(Color.RED, new double[1]); Track bFactorTrack = new Track(Color.GREEN, new double[1]); double maxLikelihood = Double.NEGATIVE_INFINITY; int sampleNumberMLE; String[] alignMLE, alignMLENames; double[] rmsdMLE, bFactorMLE; double epsilonMLE; public double[][] distanceMatrix; /** Determines scaling for RMSD annotation above sequence in GUI */ double SCALE_FACTOR = 2.5; public String[] fullAlign; InputData input; public RmsdTrace() { screenable = false; outputable = true; postprocessable = true; postprocessWrite = false; selected = false; active = false; } @Override public String getFileExtension() { return "rmsd"; } @Override public void setSampling(boolean enabled) { } @Override public void init(ModelExtManager modelExtMan) { for(ModelExtension modExt : modelExtMan.getPluginList()) { if(modExt instanceof StructAlign) { structAlign = (StructAlign) modExt; structAlign.connectRmsdTrace(this); } } // Either this is active already because we're in GUI mode, // or it is not activated yet, and we only activate if printRmsd // is true. active |= structAlign.isActive() & structAlign.printRmsd; postprocessWrite = active & structAlign.printRmsd; } @Override public String[] getDependencies() { return new String[] { "statalign.postprocess.plugins.CurrentAlignment", "statalign.postprocess.plugins.MpdAlignment"}; } CurrentAlignment curAli; MpdAlignment mpdAli; @Override public void refToDependencies(Postprocess[] plugins) { curAli = (CurrentAlignment) plugins[0]; mpdAli = (MpdAlignment) plugins[1]; } @Override public void beforeFirstSample(InputData inputData) { // for(ModelExtension modExt : getModExtPlugins()) { // if(modExt instanceof StructAlign) { // structAlign = (StructAlign) modExt; // } // } if(!active) return; input = inputData; //double[] rad = calcGyration(); if(postprocessWrite) { int leaves = structAlign.coords.length; try { outputFile.write("# Each row contains pairwise mean-square deviations (msd_ij)\n"); outputFile.write("# branch length distances (t_ij), and sequence identities (seqID_ij)\n"); outputFile.write("# for each MCMC sample.\n"); for(int i = 0; i < leaves-1; i++) for(int j = i+1; j < leaves; j++) outputFile.write("msd" + i + "_" + j + "\t"); for(int i = 0; i < leaves-1; i++) for(int j = i+1; j < leaves; j++) outputFile.write("t" + i + "_" + j + "\t"); for(int i = 0; i < leaves-1; i++) for(int j = i+1; j < leaves; j++) outputFile.write("seqID" + i + "_" + j + "\t"); // for(int i = 0; i < rad.length; i++) // outputFile.write(mcmc.tree.names[i] + "\t"); // outputFile.write("\n"); // for(int i = 0; i < rad.length; i++) // outputFile.write(rad[i] + "\t"); outputFile.write("\n"); outputFile.flush(); } catch (IOException e){} } if (show) { //mpdAli.addTrack(rmsdTrack); curAli.addTrack(rmsdTrack); if (structAlign.localEpsilon) curAli.addTrack(bFactorTrack); } } @Override public void newPeek(State state) { if (!active) return; //if (show) { doUpdate(state,0); //} } private void doUpdate(State state, int sampleNumber) { updateTracks(curAli.showFullAlignment ? state.getFullAlign() : state.getLeafAlign()); if (!state.isBurnin && state.logLike > maxLikelihood) { maxLikelihood = state.logLike; sampleNumberMLE = sampleNumber; alignMLE = curAli.showFullAlignment ? state.getFullAlign().clone() : state.getLeafAlign().clone(); alignMLENames = state.name.clone(); rmsdMLE = rmsdTrack.scores.clone(); bFactorMLE = bFactorTrack.scores.clone(); epsilonMLE = structAlign.epsilon; } } @Override public void newSample(State state, int no, int total) { if(!active) return; if(postprocessWrite) { double[][] msd = calcMSD(); double[][] seqID = calcSeqID(); try { for(int i = 0; i < msd.length-1; i++) for(int j = i+1; j < msd.length; j++) outputFile.write(msd[i][j] + "\t"); for(int i = 0; i < msd.length-1; i++) for(int j = i+1; j < msd.length; j++) outputFile.write(distanceMatrix[i][j] + "\t"); for(int i = 0; i < msd.length-1; i++) for(int j = i+1; j < msd.length; j++) outputFile.write(seqID[i][j] + "\t"); outputFile.write("\n"); } catch (IOException e){ e.printStackTrace(); } } //if (show) { doUpdate(state,no); //} } @Override public void afterLastSample() { if(!active) return; if(postprocessWrite) { try { outputFile.close(); } catch (IOException e) { e.printStackTrace(); } try { if (rmsdMLE != null) { FileWriter mle = new FileWriter(getBaseFileName()+"mle." + getFileExtension()); mle.write("# Maximum likelihood = "+maxLikelihood+" at sample "+sampleNumberMLE+"\n"); mle.write("# RMSD\tAverage B-factor\n"); mle.flush(); for (int i=0; i<rmsdMLE.length; i++) { boolean allGap = true; for (int j=0; j<structAlign.rotCoords.length; j++) { // Assume first sequences are non-internals if (alignMLE[j].charAt(i) != '-') allGap = false; } if (!allGap) { mle.write(rmsdMLE[i]+""); if (structAlign.localEpsilon) mle.write("\t"+3*epsilonMLE*bFactorMLE[i]); mle.write("\n"); } } mle.close(); FileWriter aliMLE = new FileWriter(getBaseFileName()+"mle.ali"); // Form an array from the leaves of the alignment String[] aln = Utils.alignmentTransformation(alignMLE, alignMLENames, "Fasta", input); for (int i = 0; i < aln.length; i++) { aliMLE.write(aln[i] + "\n"); } aliMLE.close(); } } catch (IOException e) { e.printStackTrace(); } } } public static void printMatrix(double[][] m) { for(int i = 0; i < m.length; i++) System.out.println(Arrays.toString(m[i])); System.out.println(); } public double[][] calcMSD(){ double[][][] coor = structAlign.rotCoords; String[] align = structAlign.curAlign; int leaves = coor.length; boolean[] hasStructure = new boolean[leaves]; boolean igap, jgap; double[][] msd = new double[leaves][leaves]; for (int i=0; i<leaves; i++) { hasStructure[i] = (structAlign.coords[i]!=null); } for(int i = 0; i < leaves-1; i++){ for(int j = i+1; j < leaves; j++){ int ii = 0, jj = 0, n = 0; for(int k = 0; k < align[0].length(); k++){ igap = align[i].charAt(k) == '-'; jgap = align[j].charAt(k) == '-'; if(!igap & !jgap & hasStructure[i] & hasStructure[j]){ msd[i][j] += sqDistance(coor[i][ii], coor[j][jj]); n++; } ii += igap ? 0 : 1; jj += jgap ? 0 : 1; } msd[i][j] /= n; } } return msd; } public void updateTracks(String[] align){ double[][][] coor = structAlign.rotCoords; int leaves = coor.length; boolean igap, jgap; //int n = leaves * (leaves-1) / 2; // Number of pairwise comparisons int alignmentLength = align[0].length(); int leafAlignmentLength = align[0].length(); boolean[] allGapped = new boolean[alignmentLength]; boolean[] allGappedLeaf = new boolean[alignmentLength]; for(int k = 0; k < align[0].length(); k++){ allGapped[k] = true; for(int i = 0; i < align.length; i++){ allGapped[k] &= (align[i].charAt(k) == '-'); if (i < leaves) allGappedLeaf[k] &= (align[i].charAt(k) == '-'); } if (allGapped[k]) alignmentLength--; if (allGappedLeaf[k]) leafAlignmentLength--; } rmsdTrack.scores = new double[alignmentLength]; rmsdTrack.max = 0.0; rmsdTrack.min = Double.POSITIVE_INFINITY; rmsdTrack.mean = 0.0; if (structAlign.localEpsilon) { bFactorTrack.scores = new double[alignmentLength]; bFactorTrack.max = 0.0; bFactorTrack.min = Double.POSITIVE_INFINITY; bFactorTrack.mean = 0.0; } int[] index = new int[leaves]; for(int k = 0, kk=0; k < align[0].length(); k++){ int n=0; if (allGapped[k]) continue; for(int i = 0; i < leaves-1; i++){ for(int j = i+1; j < leaves; j++){ igap = align[i].charAt(k) == '-'; jgap = align[j].charAt(k) == '-'; if(!igap & !jgap & !(coor[i]==null) & !(coor[j]==null)){ rmsdTrack.scores[kk] += sqDistance(coor[i][index[i]], coor[j][index[j]]); ++n; } } } if (n > 0) rmsdTrack.scores[kk] /= n; if (n==0) rmsdTrack.scores[kk] = Double.NaN; if (!Double.isNaN(rmsdTrack.scores[kk])) { if (rmsdTrack.scores[kk] > 0) rmsdTrack.scores[kk] = Math.sqrt(rmsdTrack.scores[kk]); if (rmsdTrack.scores[kk] > rmsdTrack.max) rmsdTrack.max = rmsdTrack.scores[kk]; if (rmsdTrack.scores[kk] > 0 && rmsdTrack.scores[kk] < rmsdTrack.min) rmsdTrack.min = rmsdTrack.scores[kk]; rmsdTrack.mean += rmsdTrack.scores[kk] / leafAlignmentLength; } int nBfactor = 0; for(int i = 0; i < leaves; i++) { if (align[i].charAt(k) != '-') { if(structAlign.coords[i]==null) continue; if (structAlign.localEpsilon) bFactorTrack.scores[kk] += Math.pow(structAlign.bFactors[i][index[i]],2) * structAlign.epsilon; nBfactor++; index[i]++; } } if (structAlign.localEpsilon) { bFactorTrack.scores[kk] /= nBfactor; bFactorTrack.scores[kk] = Math.sqrt(bFactorTrack.scores[kk]); if (bFactorTrack.scores[kk] == 0) bFactorTrack.scores[kk] = Double.NaN; if (!Double.isNaN(bFactorTrack.scores[kk])) { if (bFactorTrack.scores[kk] > bFactorTrack.max) bFactorTrack.max = bFactorTrack.scores[kk]; if (bFactorTrack.scores[kk] > 0 && bFactorTrack.scores[kk] < bFactorTrack.min) bFactorTrack.min = bFactorTrack.scores[kk]; bFactorTrack.mean += bFactorTrack.scores[kk] / leafAlignmentLength; } } kk++; } } public double sqDistance(double[] x, double[] y){ double d = 0; for(int i = 0; i < x.length; i++) d += Math.pow(x[i] - y[i], 2.0); return d; } public double[] calcGyration(){ double[][][] coor = structAlign.coords; int leaves = coor.length; double[] radii = new double[leaves]; for(int i = 0; i < leaves; i++){ radii[i] = 0; // coordinates are centered in StructAlign.initRun() for(int j = 0; j < coor[i].length; j++) for(int k = 0; k < coor[i][0].length; k++) radii[i] += Math.pow(coor[i][j][k], 2.0); radii[i] /= coor[0].length; radii[i] = Math.pow(radii[i], 0.5); } return radii; } public double[][] calcSeqID(){ String[] align = structAlign.curAlign; int leaves = align.length; boolean igap, jgap; double[][] seqID = new double[leaves][leaves]; for(int i = 0; i < leaves-1; i++){ for(int j = i+1; j < leaves; j++){ double match = 0, id = 0; for(int k = 0; k < align[0].length(); k++){ igap = align[i].charAt(k) == '-'; jgap = align[j].charAt(k) == '-'; if(!igap & !jgap){ id += align[i].charAt(k) == align[j].charAt(k) ? 1 : 0; match++; } } seqID[i][j] = id / match; } } return seqID; } @Override public String getTabName() { // TODO Auto-generated method stub return "RMSD trace"; } @Override public Icon getIcon() { // TODO Auto-generated method stub return null; } @Override public JPanel getJPanel() { // TODO Auto-generated method stub return null; } @Override public String getTip() { // TODO Auto-generated method stub return ""; } }