package statalign.postprocess.plugins; import java.awt.BorderLayout; import java.awt.Color; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import javax.swing.Icon; import javax.swing.ImageIcon; import javax.swing.JPanel; import javax.swing.JScrollPane; import org.apache.commons.math3.util.Pair; import statalign.base.CircularArray; import statalign.base.InputData; import statalign.base.State; import statalign.base.Utils; import statalign.postprocess.Postprocess; import statalign.postprocess.Track; import statalign.postprocess.gui.AlignmentGUI; public class MpdAlignment extends statalign.postprocess.Postprocess { public String title; public int frequency = 5; JPanel pan; AlignmentGUI gui; //private boolean sampling = true; CurrentAlignment curAlig; ColumnNetwork network; Column firstVector, lastVector; int sizeOfAlignments; int[] firstDescriptor; //String t[][]; String[] sequences; String[] viterbialignment; double[] decoding; String[] alignment; String[] sequenceNames; InputData input; public MpdAlignment(){ screenable = true; outputable = true; postprocessable = true; postprocessWrite = true; rnaAssociated = false; } @Override public JPanel getJPanel() { pan = new JPanel(new BorderLayout()); return pan; } @Override public Icon getIcon() { // return new ImageIcon("icons/MPD.gif"); return new ImageIcon(ClassLoader.getSystemResource("icons/MPD.gif")); } @Override public String getTabName() { return "Summary alignment"; } @Override public String getTip() { return "Summary (consensus) alignment"; } @Override public double getTabOrder() { return 6.0d; } @Override public String getFileExtension() { return "mpd.ali"; } @Override public ArrayList<String> getAdditionalFileExtensions() { ArrayList<String> result = new ArrayList<String>(); result.add("mpd.scores"); return result; } @Override public String[] getDependencies() { return new String[] { "statalign.postprocess.plugins.CurrentAlignment" }; } @Override public boolean createsMultipleOutputFiles() { return true; } @Override public void refToDependencies(Postprocess[] plugins) { curAlig = (CurrentAlignment) plugins[0]; curAlig.mpdAli = this; } static Comparator<String[]> compStringArr = new Comparator<String[]>() { @Override public int compare(String[] a1, String[] a2) { return a1[0].compareTo(a2[0]); }}; @Override public void beforeFirstSample(InputData input) { if(show) { pan.removeAll(); title = input.title; JScrollPane scroll = new JScrollPane(); scroll.setViewportView(gui = new AlignmentGUI(title,input.model,this));//, mcmc.tree.printedAlignment())); pan.add(scroll, BorderLayout.CENTER); //System.out.println("Mpd Alignment parent: " + pan.getParent()); pan.getParent().validate(); } this.input = input; sizeOfAlignments = input.seqs.size(); alignment = new String[sizeOfAlignments]; sequenceNames = new String[sizeOfAlignments]; if(show) { gui.alignment = alignment; gui.sequenceNames = sequenceNames; } //t = new String[sizeOfAlignments][]; sequences = null; viterbialignment = new String[sizeOfAlignments]; network = new ColumnNetwork(); firstDescriptor = new int[sizeOfAlignments]; Arrays.fill(firstDescriptor, -1); firstVector = network.add(firstDescriptor); lastVector = null; } @Override public void newSample(State state, int no, int total) { //System.out.println(curAlig); //System.out.println(curAlig.leafAlignment); if (state.isBurnin) return; for(int i = 0; i < curAlig.leafAlignment.length; i++){ if (curAlig == null || curAlig.leafAlignment[i] == null) { System.out.println(); } //t[i] = curAlig.leafAlignment[i].split("\t"); } //Arrays.sort(t, compStringArr); int[] previousDescriptor = firstDescriptor; int i, j, len = curAlig.leafAlignment[0].length(); for(j = 0; j < len; j++){ int[] nextDescriptor = new int[sizeOfAlignments]; boolean allGap = true; for(int k = 0; k < sizeOfAlignments; k++){ if(curAlig.leafAlignment[k].charAt(j) == '-') nextDescriptor[k] = ColumnKey.colNext(previousDescriptor[k]); else { nextDescriptor[k] = ColumnKey.colNext(previousDescriptor[k])+1; allGap = false; } } if(!allGap) network.add(nextDescriptor);//[j]); previousDescriptor = nextDescriptor; }//j (length of alignments) if(no == 0) { // add last vector once only int[] lastDescriptor = new int[sizeOfAlignments]; for(j = 0; j < sizeOfAlignments; j++){ lastDescriptor[j] = ColumnKey.colNext(previousDescriptor[j])+1; } lastVector = network.add(lastDescriptor); } if(no == 0 || (total-1-no) % frequency == 0) { network.updateViterbi(no+1); //System.out.println("sequences first: "+sequences); if(sequences == null) { sequences = new String[sizeOfAlignments]; for(i = 0; i < sizeOfAlignments; i++){ sequences[i] = ""; for(j = 0; j < len; j++){ if(curAlig.leafAlignment[i].charAt(j) != '-'){ sequences[i] += curAlig.leafAlignment[i].charAt(j); } } } } for(i = 0; i < sizeOfAlignments; i++) viterbialignment[i] = ""; Column actualVector = lastVector.viterbi; ArrayList<Integer> posteriorList = new ArrayList<Integer>(); while(!actualVector.equals(firstVector)){ int[] desc = actualVector.key.desc; posteriorList.add(new Integer(actualVector.count)); for(i = 0; i < desc.length; i++){ if((desc[i] & 1) == 0){ viterbialignment[i] = "-"+viterbialignment[i]; } else{ viterbialignment[i] = sequences[i].charAt(desc[i] >> 1) + viterbialignment[i]; } } actualVector = actualVector.viterbi; } decoding = new double[posteriorList.size()]; for(i = 0; i < decoding.length; i++){ decoding[i] = (double)(posteriorList.get(posteriorList.size() - i - 1)).intValue()/(no+1); } for(i = 0; i < viterbialignment.length; i++){ //alignment[i] = t[i][0]+"\t"+viterbialignment[i]; alignment[i] = viterbialignment[i]; sequenceNames[i] = curAlig.seqNames[i]; } // sort alignment lexicographically // TODO sort oder is parameter (alternatives: original, tree, lexico) //Arrays.sort(alignment); if(show) { gui.decoding = decoding; gui.alignment = alignment; gui.sequenceNames = sequenceNames; gui.repaint(); } } if(sampling){ try { String[] aln = Utils.alignmentTransformation(alignment, sequenceNames, alignmentType, input); for(i = 0; i < aln.length; i++){ file.write("Sample "+no+"\tMPD alignment:\t"+aln[i]+"\n"); } if(decoding != null){ for(i = 0; i < decoding.length; i++){ file.write("Sample "+no+"\tMPD alignment probabilities:\t"+decoding[i]+"\n"); } } else{ file.write("Sample "+no+"\tMPD alignment:\tNo posterior values so far\n"); } } catch (IOException e) { e.printStackTrace(); } } /*if(no == 0) { String [] [] inputAlignment = new String[input.seqs.sequences.size()][2]; for(int k = 0 ; k < inputAlignment.length ; k++) { inputAlignment[k][0] = input.seqs.seqNames.get(k); inputAlignment[k][1] = input.seqs.sequences.get(k); } Arrays.sort(inputAlignment, compStringArr); appendAlignment("reference", inputAlignment, new File(input.title+".samples"), false); appendAlignment(no+"", t, new File(input.title+".samples"), true); } else { appendAlignment(no+"", t, new File(input.title+".samples"), true); }*/ } @Override public void afterLastSample() { if (postprocessWrite && sequences != null) { try { String[] aln = Utils.alignmentTransformation(alignment, sequenceNames, alignmentType, input); for (int i = 0; i < aln.length; i++) { outputFile.write(aln[i] + "\n"); } outputFile.close(); //additionalOutputFiles.get(0).write("\n#scores\n\n"); if (decoding != null) { for (int i = 0; i < decoding.length; i++) { additionalOutputFiles.get(0).write(decoding[i]+""); for (Track track : tracks) { additionalOutputFiles.get(0).write("\t"+track.scores[i]); } additionalOutputFiles.get(0).write("\n"); } } else { additionalOutputFiles.get(0).write("No posterior values so far\n"); } additionalOutputFiles.get(0).close(); } catch (IOException e) { e.printStackTrace(); } /* appendAlignment("mpd", alignment, new File(input.title+".samples"), true); PPFold.saveToFile(Utils.alignmentTransformation(alignment, "Fasta", input), new File(input.title+".dat.res.mpd")); try { BufferedWriter buffer = new BufferedWriter(new FileWriter(new File(input.title+".samples"), true)); buffer.write("%posteriors\n"); for(int i = 0 ; i < decoding.length ; i++) { buffer.write(decoding[i]+"\n"); } buffer.close(); } catch(IOException ex) { ex.printStackTrace(); }*/ } } /* (non-Javadoc) * @see statalign.postprocess.Postprocess#setSampling(boolean) */ @Override public void setSampling(boolean enabled) { sampling = enabled; } double[] getPosteriorSplus() { return getPosterior(true); } double[] getPosterior() { return getPosterior(false); } double[] getPosterior(boolean useSplus) { int[] previousDescriptor = firstDescriptor; int j, len = curAlig.leafAlignment[0].length(); double[] scores = new double[len]; double max = 0.0; for(j = 0; j < len; j++){ int[] nextDescriptor = new int[sizeOfAlignments]; boolean allGap = true; for(int k = 0; k < sizeOfAlignments; k++){ if(curAlig.leafAlignment[k].charAt(j) == '-') nextDescriptor[k] = ColumnKey.colNext(previousDescriptor[k]); else { nextDescriptor[k] = ColumnKey.colNext(previousDescriptor[k])+1; allGap = false; } } scores[j] = 0; if(!allGap) { Column c = null; if (useSplus) c=network.splusMap.get(ColumnNetwork.splusKey(new ColumnKey(nextDescriptor))); else c=network.contMap.get(new ColumnKey(nextDescriptor)); if (c != null) scores[j] = (double) c.count; } if (scores[j] > max) max = scores[j]; previousDescriptor = nextDescriptor; } for (j=0; j<len; j++) scores[j] /= (double) max; return scores; } } class ColumnNetwork { HashMap<ColumnKey,Column> contMap = new HashMap<ColumnKey,Column>(); HashMap<ColumnKey,Column> splusMap = new HashMap<ColumnKey,Column>(); HashMap<ColumnKey,ArrayList<Column>> preMap = new HashMap<ColumnKey,ArrayList<Column>>(); HashMap<ColumnKey,ArrayList<Column>> postMap = new HashMap<ColumnKey,ArrayList<Column>>(); int numberOfEdges = 0; int numberOfNodes = 0; Column first; /** * Adds a new alignment column into the network. If already in the network, MyVector.count is incremented. * @param descriptor Alignment column represented by an array of signed integers * @return MyVector of alignment column or null if it was in network before */ Column add(int[] descriptor) { Column val; ColumnKey key = new ColumnKey(descriptor); ColumnKey splus = splusKey(key); if((val=splusMap.get(splus)) != null) { val.logcnt = Math.log(++val.count); } if((val=contMap.get(key)) != null) { val.logcnt = Math.log(++val.count); return null; } val = new Column(key); contMap.put(key, val); splusMap.put(splus,new Column(splus)); if(numberOfNodes == 0) first = val; numberOfNodes++; ArrayList<Column> arr; key = new ColumnKey(ColumnKey.pre(descriptor)); if((arr = preMap.get(key)) == null) preMap.put(key, arr = new ArrayList<Column>()); arr.add(val); if((arr = postMap.get(key)) != null) { for(Column postCol : arr) { postCol.inNum++; val.outgoing.add(postCol); numberOfEdges++; } } key = new ColumnKey(ColumnKey.post(descriptor)); if((arr = postMap.get(key)) == null) postMap.put(key, arr = new ArrayList<Column>()); arr.add(val); if((arr = preMap.get(key)) != null) { for(Column preCol : arr) { preCol.outgoing.add(val); val.inNum++; numberOfEdges++; } } return val; } public static ColumnKey splusKey(ColumnKey key) { int len = key.desc.length, d; int[] desc = key.desc, sdesc = new int[len]; for(int i = 0; i < len; i++) { d = desc[i]; sdesc[i] = ((d&1)==1)?d:0; } return new ColumnKey(sdesc); } void updateViterbi(int n) { double logN = Math.log(n); CircularArray<Column> calculable = new CircularArray<Column>(); calculable.push(first); first.score = logN; Column act; while((act = calculable.shift()) != null) { double myScore = act.score+act.logcnt-logN; for(Column outGoing : act.outgoing) { if(outGoing.score < myScore) { outGoing.score = myScore; outGoing.viterbi = act; } if(++outGoing.inReady == outGoing.inNum) calculable.push(outGoing); } act.inReady = 0; act.score = -1e300; } } } class Column { ArrayList<Column> outgoing = new ArrayList<Column>(); ColumnKey key; int inNum; int inReady; int count = 1; double logcnt = 0; double score = -1e300; Column viterbi; Column(ColumnKey _key) { key = _key; } } class ColumnKey { public int[] desc; ColumnKey(int[] arr) { desc = arr; } @Override public boolean equals(Object o) { return (o instanceof ColumnKey) && Arrays.equals(desc, ((ColumnKey)o).desc); } @Override public int hashCode() { return Arrays.hashCode(desc); } static int[] pre(int[] desc) { int[] ret = new int[desc.length]; for(int i = 0; i < desc.length; i++) ret[i] = (desc[i]+1) >> 1; return ret; } static int[] post(int[] desc) { int[] ret = new int[desc.length]; for(int i = 0; i < desc.length; i++) ret[i] = desc[i] >> 1; return ret; } static int colNext(int n) { return n + (n & 1); } }