package statalign.postprocess.plugins.structalign;
import java.awt.BorderLayout;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.swing.Icon;
import javax.swing.ImageIcon;
import javax.swing.JPanel;
import org.apache.commons.math3.geometry.euclidean.threed.Rotation;
import org.apache.commons.math3.geometry.euclidean.threed.Vector3D;
import statalign.base.InputData;
import statalign.base.McmcStep;
import statalign.base.State;
import statalign.base.Utils;
import statalign.io.input.plugins.PDBReader;
import statalign.mcmc.McmcMove;
import statalign.model.ext.ModelExtManager;
import statalign.model.ext.ModelExtension;
import statalign.model.ext.plugins.StructAlign;
import statalign.postprocess.Postprocess;
import statalign.postprocess.gui.StructAlignTraceGUI;
public class StructTrace extends Postprocess {
List<StructAlignTraceParameters> parameterHistory;
public StructAlign structAlign;
public int burninLength;
public int MAX_HISTORY_SIZE = 1000;
public int refreshRate;
double maxLikelihood = Double.NEGATIVE_INFINITY;
int sampleNumberMLE;
double[][][] coorMLE;
String[] seqs, names;
public List<StructAlignTraceParameters> getParameterHistory() {
return parameterHistory;
}
JPanel pan;
int current;
private int count;
public int getCount() {
return count;
}
private StructAlignTraceGUI gui;
public StructTrace() {
screenable = true;
outputable = true;
postprocessable = true;
postprocessWrite = false;
selected = false;
active = false;
}
@Override
public void init(ModelExtManager modelExtMan) {
for(ModelExtension modExt : modelExtMan.getPluginList()) {
if(modExt instanceof StructAlign) {
structAlign = (StructAlign) modExt;
structAlign.connectStructTrace(this);
}
}
active = structAlign.isActive();
postprocessWrite = active;
}
@Override
public String getTabName() {
return "Structural parameters";
}
@Override
public double getTabOrder() {
return 5.0d;
}
@Override
public Icon getIcon() {
return new ImageIcon(ClassLoader.getSystemResource("icons/loglikelihood1.gif"));
}
@Override
public JPanel getJPanel() {
pan = new JPanel(new BorderLayout());
return pan;
}
@Override
public String getTip() {
return "StructAlign parameter values";
}
@Override
public String getFileExtension() {
return "struct.params";
}
@Override
public void setSampling(boolean enabled) {
}
@Override
public void beforeFirstSample(InputData inputData) {
// for(ModelExtension modExt : getModExtPlugins()) {
// if(modExt instanceof StructAlign) {
// structAlign = (StructAlign) modExt;
// }
// }
if(!active)
return;
try {
for (McmcMove mcmcMove : structAlign.getMcmcMoves()) {
if (mcmcMove.getParam() != null) {
outputFile.write(mcmcMove.name+"\t");
//outputFile.write(mcmcMove.name+" (Proposed)\t");
}
}
outputFile.write("\n");
} catch (IOException e) {
}
if(show) {
pan.removeAll();
gui = new StructAlignTraceGUI(pan, this);
pan.add(gui);
pan.getParent().getParent().getParent().validate();
}
parameterHistory = new ArrayList<StructAlignTraceParameters>();
burninLength = inputData.pars.burnIn;
MAX_HISTORY_SIZE = Math.min(burninLength, MAX_HISTORY_SIZE);
current = 0;
//refreshRate = inputData.pars.burnIn / (2*MAX_HISTORY_SIZE);
refreshRate = burninLength / (MAX_HISTORY_SIZE);
// Means we will have the whole burnin in one window, but then it
// will start to shift.
count = 0;
}
private void doUpdate(State state, int sampleNumber) {
if (!state.isBurnin && state.logLike > maxLikelihood) {
maxLikelihood = state.logLike;
sampleNumberMLE = sampleNumber;
coorMLE = structAlign.rotCoords.clone();
if (seqs == null) {
seqs = state.seq;
names = state.name;
}
}
}
@Override
public void newPeek(State state) {
if (!active) return;
//if (show) {
doUpdate(state,0);
//}
}
@Override
public void newSample(State state, int no, int total) {
if(!active)
return;
if(postprocessWrite) {
try {
for (McmcMove mcmcMove : structAlign.getMcmcMoves()) {
if (mcmcMove.getParam() != null) {
outputFile.write(mcmcMove.getParam().get()+"\t");
// if (mcmcMove.moveProposed) {
// outputFile.write(mcmcMove.getParam().get()+"\t");
// }
// else {
// outputFile.write(-1+"\t");
// }
}
}
outputFile.write("\n");
} catch (IOException e) {
e.printStackTrace();
}
//structAlign.setAllMovesNotProposed();
}
doUpdate(state,no);
}
@Override
public void afterLastSample() {
if(!active)
return;
try {
outputFile.close();
} catch (IOException e) {
e.printStackTrace();
}
if (coorMLE != null) {
try {
FileWriter mle = new FileWriter(getBaseFileName()+"mle.super.pdb");
PDBReader.writePDB(coorMLE, seqs, names, mle);
mle.close();
}catch (IOException e) {
e.printStackTrace();
}
}
if(Utils.DEBUG) {
System.out.println("final rotation matrices:");
for(int i = 1; i < structAlign.xlats.length; i++) {
Rotation rot = new Rotation(new Vector3D(structAlign.axes[i]), structAlign.angles[i]);
printMatrix(rot.getMatrix());
}
System.out.println("final translations:");
for(int i = 0; i < structAlign.xlats.length; i++) {
System.out.println(Arrays.toString(structAlign.xlats[i]));
}
System.out.println();
System.out.println("Acceptance rates:");
for (McmcMove mcmcMove : structAlign.getMcmcMoves()) {
System.out.println(mcmcMove.name+"\t"+mcmcMove.acceptanceRate());
}
}
}
public static void printMatrix(double[][] m) {
for(int i = 0; i < m.length; i++)
System.out.println(Arrays.toString(m[i]));
System.out.println();
}
@Override
public void newStep(McmcStep mcmcStep) {
if(!active)
return;
if (screenable && (count % refreshRate == 0)) {
StructAlignTraceParameters currentParameters =
new StructAlignTraceParameters(this,mcmcStep.burnIn);
currentParameters.globalSigma = structAlign.globalSigma;
if (count > 0) {
currentParameters.setProposalFlags(parameterHistory.get(parameterHistory.size()-1));
}
if(parameterHistory.size() <= MAX_HISTORY_SIZE){
parameterHistory.add(currentParameters);
} else {
parameterHistory.remove(0);
parameterHistory.add(currentParameters);
}
if(show) {
gui.repaint();
}
}
++count;
}
}