package statalign.model.ext.plugins;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import javax.swing.ImageIcon;
import javax.swing.JComponent;
import javax.swing.JToggleButton;
import org.apache.commons.math3.geometry.euclidean.threed.Rotation;
import org.apache.commons.math3.geometry.euclidean.threed.Vector3D;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.MathArrays;
import statalign.base.InputData;
import statalign.base.Tree;
import statalign.base.Utils;
import statalign.base.Vertex;
import statalign.base.hmm.Hmm;
import statalign.io.DataType;
import statalign.io.ProteinSkeletons;
import statalign.io.RawSequences;
import statalign.mcmc.GammaPrior;
import statalign.mcmc.GammaProposal;
import statalign.mcmc.GaussianProposal;
import statalign.mcmc.HyperbolicPrior;
import statalign.mcmc.InverseGammaPrior;
import statalign.mcmc.McmcCombinationMove;
import statalign.mcmc.McmcMove;
import statalign.mcmc.ParameterInterface;
import statalign.mcmc.PriorDistribution;
import statalign.mcmc.UniformPrior;
import statalign.model.ext.ModelExtension;
import statalign.model.ext.plugins.structalign.AlignmentMove;
import statalign.model.ext.plugins.structalign.ContinuousPositiveStructAlignMove;
import statalign.model.ext.plugins.structalign.Funcs;
import statalign.model.ext.plugins.structalign.HierarchicalContinuousPositiveStructAlignMove;
import statalign.model.ext.plugins.structalign.LibraryMove;
import statalign.model.ext.plugins.structalign.LinearLink;
import statalign.model.ext.plugins.structalign.MultiNormCholesky;
import statalign.model.ext.plugins.structalign.QuadraticLink;
import statalign.model.ext.plugins.structalign.RotationMove;
import statalign.model.ext.plugins.structalign.RotationProposal;
import statalign.model.ext.plugins.structalign.StructAlignParameterInterface;
import statalign.model.ext.plugins.structalign.TranslationMove;
import statalign.model.subst.SubstitutionModel;
import statalign.postprocess.plugins.TreeNode;
import statalign.postprocess.plugins.structalign.RmsdTrace;
import statalign.postprocess.plugins.structalign.StructTrace;
import statalign.postprocess.plugins.structalign.StructTreeVisualizer;
import statalign.utils.LinkFunction;
public class StructAlign extends ModelExtension implements ActionListener {
/** The command line identifier of this plugin */
//private static final String CMD_LINE_PLUGIN_ID = "structal";
private final String pluginID = "structal";
@Override
public String getPluginID() {
return pluginID;
}
JToggleButton myButton;
public boolean globalSigma = true;
public boolean useLibrary = false;
public boolean fixedEpsilon = false;
public boolean fixedSigma2 = false;
/**
* If globalSigma = false then this switches on a spike prior at sigma2Hier.
* This can also be switched on via a command-line option.
* */
public boolean globalSigmaSpike = false;
double[] globalSigmaSpikeParams = {1.35,1.1};
public boolean localEpsilon = false;
double structTemp = 1;
private boolean USE_IN_ALIGNMENT_PROPOSALS = true;
@Override
public boolean useInAlignmentProposals() {
return USE_IN_ALIGNMENT_PROPOSALS;
}
/** Alpha-C atomic coordinate for each sequence and each residue */
public double[][][] coords;
/** Crystallographic temperature factors, for weighting epsilon. */
public double[][] bFactors;
/** Alpha-C atomic coordinates under the current set of rotations/translations */
public double[][][] rotCoords;
/** Axis of rotation for each sequence */
public double[][] axes;
/** Rotation angle for each protein along the rotation axis */
public double[] angles;
/** Translation vector for each protein */
public double[][] xlats;
/** The structure used as the reference for rotations. */
private int refIndex = 0;
/** Parameters of structural drift */
public double[] sigma2;
public double sigma2Hier;
public double nu;
public double tau;
public double epsilon;
// TODO Allow starting values to be specified at command line/GUI
/** Pairwise distances implied by current tree topology */
public double[][] distanceMatrix;
/** Covariance matrix implied by current tree topology */
public double[][] fullCovar;
/** Current alignment between all leaf sequences */
public String[] curAlign;
public double[][] oldCovar;
public double[][] oldDist;
public String[] oldAlign;
public double oldLogLi;
/** For caching purposes */
public HashMap<Integer, MultiNormCholesky> multiNorms;
public HashMap<Column, MultiNormCholesky> multiNormsLocal;
private HashMap<Integer, MultiNormCholesky> oldMultiNorms;
public HashMap<Column, MultiNormCholesky> oldMultiNormsLocal;
/** Relates the structural and sequence evolutionary timescales */
private LinkFunction<Double> linkFunction;
public String linkType = "linear";
// TODO change the above public variables to package visible and put
// StructAlign.java in statalign.model.ext.plugins.structalign ?
/* Priors */
private double sigma2PriorShape = 0.001;
private double sigma2PriorRate = 0.001;
public PriorDistribution<Double> sigma2Prior;
boolean sigma2PriorInitialised = false;
// sigma2Prior will either be InverseGamma or Hyperbolic, depending
// on whether globalSigma is switched on. It is defined inside the initRun()
// method.
private double epsilonPriorShape = 2;//10; //1; //2;
private double epsilonPriorRate = 2;//50; //5; //2;
public PriorDistribution<Double> epsilonPrior;
boolean epsilonPriorInitialised = false;
/** independence rotation proposal distribution */
public RotationProposal rotProp;
private double tauPriorShape = 0.001;
private double tauPriorRate = 0.001;
public InverseGammaPrior tauPrior = new InverseGammaPrior(tauPriorShape,tauPriorRate);
private double sigma2HPriorShape = 1;
private double sigma2HPriorRate = 1;
// public InverseGammaPrior sigma2HPrior = new InverseGammaPrior(sigma2HPriorShape,sigma2HPriorRate);
HierarchicalContinuousPositiveStructAlignMove sigma2HMove = null;
public GammaPrior sigma2HPrior = new GammaPrior(sigma2HPriorShape,sigma2HPriorRate);
// public HyperbolicPrior sigma2HPrior = new HyperbolicPrior();
private double nuPriorShape = 1;
private double nuPriorRate = 6;
public GammaPrior nuPrior = new GammaPrior(nuPriorShape,nuPriorRate);
HierarchicalContinuousPositiveStructAlignMove nuMove = null;
// priors for rotation and translation are uniform
// so do not need to be included in M-H ratio
/** Default proposal weights in this order:
* align, topology, edge, indel param, subst param, modelext param
* { 35, 20, 25, 15, 10, 0 };
*/
private final int pluginProposalWeight = 50;
//int sigma2Weight = 5; //15;
int sigma2Weight = 18; //
int tauWeight = 10;
int sigma2HierWeight = 10; // ORIGINAL
//int sigma2HierWeight = 0;
int nuWeight = 0; // ORIGINAL
int nuWeightIncrement = 10; // ORIGINAL
//int nuWeight = 0;
//int epsilonWeight = 2;//10;
int epsilonWeight = 13; //
int rotationWeight = 2;
int translationWeight = 2;
int libraryWeight = 2;
int alignmentWeight = 2;
int alignmentWeightIncrement = 0;
/* Weights for combination moves */
int alignmentRotationWeight = 8; // ORIGINAL
int alignmentTranslationWeight = 6; // ORIGINAL
//int alignmentRotationWeight = 15;
//int alignmentTranslationWeight = 10;
int alignmentLibraryWeight = 6;
int sigmaEpsilonWeight = 4; //
// This is reallocated to sigma2Weight if epsilon is being fixed
/** Starting value for 1 / rotation proposal tuning parameter. */
public double angleP = 1000;
/** Starting value for translation proposal tuning parameter. */
public double xlatP = .1;
/** Value to fix sigma at if we're not estimating it. */
public double fixedSigma2Value = 0.0;
/** Minimum value for epsilon, to prevent numerical errors. */
public double MIN_EPSILON = 0.01;
/** Value to fix epsilon at if we're not estimating it. */
public double fixedEpsilonValue = 0.0;
// reference to the postprocessing plugin
private StructTrace structTrace;
private RmsdTrace rmsdTrace;
private StructTreeVisualizer structTree;
public boolean printRmsd = false;
public boolean showStructTree = false;
public StructAlign() {
// By default this plugin is unselectable unless we read in
// coordinate data.
selectable = false;
}
@Override
public List<JComponent> getToolBarItems() {
myButton = new JToggleButton(new ImageIcon(ClassLoader.getSystemResource("icons/protein.png")));
myButton.setToolTipText("Structural alignment mode");
myButton.addActionListener(this);
myButton.setEnabled(false);
myButton.setSelected(false);
return Arrays.asList((JComponent)myButton);
}
@Override
public void actionPerformed(ActionEvent e) {
setActive(myButton.isSelected());
activateAssociatedPlugins();
}
@Override
public String getUsageInfo() {
StringBuilder usage = new StringBuilder();
usage.append("___________________________\n\n");
usage.append(" StructAlign version 1.1\n\n");
usage.append("^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n");
usage.append("java -jar statalign.jar -plugin:structal[OPTION1,OPTION2,...]\n");
usage.append("OPTIONS: \n");
usage.append("\tprintRmsd=true\t\t(Prints RMSD and sequence identity for each sample.)\n");
usage.append("\tprintTree=true\t\t(Prints tree with edge lengths replaced by structural diffusivity.)\n");
usage.append("\tsigma2=X\t\t(Fixes sigma2 at X)\n");
usage.append("\tepsilon=X\t\t(Fixes epsilon at X)\n");
usage.append("\tminEpsilon=X\t\t(Sets minimum value for epsilon to X) [default 0.01]\n");
usage.append("\tlocalEpsilon\t\t(Uses B-factor information [if available] to scale epsilon per site.)\n");
usage.append("\tlocalSigma\t\t(Allows each branch to have its own sigma parameter)\n");
usage.append("\tuseLibrary\t\t(Allows rotation library moves to be used)\n");
usage.append("\tsigmaSpike\t(Activates spike mixture prior for local sigmas)\n");
usage.append("\tsigmaSpike={a_b}\t(Specifies parameters for the Beta prior on spike probability) [default (3.1,1.1) ]\n");
usage.append("\tsigma2Prior=PRIOR\t(Sets the prior and hyperparameters for sigma2)\n");
usage.append("\tepsilonPrior=PRIOR\t(Sets the prior and hyperparameters for epsilon)\n");
usage.append("\tPRIOR can be one of:\n");
usage.append("\t\thyp\t\tUses a hyperbolic prior (default)\n");
usage.append("\t\tg{a_b}\t\tUses a Gamma(a,b) prior\n");
usage.append("\t\tinvg{a_b}\tUses an InverseGamma(a,b) prior\n");
usage.append("\t\tunif{a_b}\tUses a Uniform(a,b) prior\n");
usage.append("\tlink=LINK_FUNCTION\tSets link function between sequence and structure time\n");
usage.append("\tLINK_FUNCTION can be one of:\n");
usage.append("\t\tlinear (default)\n");
usage.append("\t\tquadratic\n");
usage.append("\nNote that the above syntax is designed to work in bash shells. " +
"Other shells such as csh may require square brackets to be preceded by a backslash.");
return usage.toString();
}
@Override
public void setActive(boolean active) {
boolean oldActive = this.active;
if (oldActive != active) {
super.setActive(active);
System.out.println("StructAlign plugin is now "+(active?"enabled":"disabled"));
}
}
@Override
public void setParam(String paramName, String paramValue) {
if (paramName.equals("epsilon")) {
fixedEpsilon = true;
fixedEpsilonValue = Double.parseDouble(paramValue);
addToFilenameExtension("eps_"+fixedEpsilonValue);
System.out.println("Fixing epsilon to "+fixedEpsilonValue+".");
}
else if (paramName.equals("minEpsilon")) {
MIN_EPSILON = Double.parseDouble(paramValue);
addToFilenameExtension("minEps_"+MIN_EPSILON);
System.out.println("Minimum value for epsilon is now "+MIN_EPSILON+".");
}
else if (paramName.equals("sigma2")) {
fixedSigma2 = true;
globalSigma = true;
fixedSigma2Value = Double.parseDouble(paramValue);
addToFilenameExtension("sigma2_"+fixedSigma2Value);
System.out.println("Fixing sigma2 to "+fixedSigma2Value+".");
}
else if (paramName.equals("sigmaSpike")) {
globalSigmaSpike = true;
globalSigma = false;
addToFilenameExtension("spike");
String[] argString = paramValue.split("\\{",2);
if (argString[1].endsWith("}")) {
String [] args = argString[1].substring(0,argString[1].length()-1).split("_",2);
if (args.length == 2) {
System.out.println("Using Gamma("+Double.parseDouble(args[0])+","+Double.parseDouble(args[1])+
") prior for "+paramName+".");
globalSigmaSpikeParams = new double[] {Double.parseDouble(args[0]),Double.parseDouble(args[1])};
}
else {
throw new IllegalArgumentException(
"Spike parameters must be specifed in the form\n-plugin:structal[sigmaSpike={a_b}]\n");
}
}
addToFilenameExtension(globalSigmaSpikeParams[0]+"_"+globalSigmaSpikeParams[1]);
System.out.println("Activating spike prior for local sigmas.");
}
else if (paramName.equals("sigma2Prior")) {
sigma2Prior = setPrior(paramName,paramValue);
sigma2PriorInitialised = (sigma2Prior != null);
}
else if (paramName.equals("epsilonPrior")) {
epsilonPrior = setPrior(paramName,paramValue);
epsilonPriorInitialised = (epsilonPrior != null);
}
else if (paramName.equals("link")) {
linkType = paramValue;
addToFilenameExtension("link_"+paramValue);
System.out.println("Setting link function to "+paramValue+".");
}
else {
super.setParam(paramName,paramValue);
}
}
@Override
public void setParam(String paramName, Number paramValue) {
if (paramName.equals("epsilon")) {
fixedEpsilon = true;
fixedEpsilonValue = (Double) paramValue;
addToFilenameExtension("eps_"+fixedEpsilonValue);
System.out.println("Fixing epsilon to "+fixedEpsilonValue+".");
}
else {
super.setParam(paramName,paramValue);
}
}
@Override
public void setParam(String paramName, boolean paramValue) {
if (paramName.equals("localSigma")) {
globalSigma = false;
}
else if (paramName.equals("localEpsilon")) {
localEpsilon = true;
System.out.println("Using B-factor information to scale epsilon.");
}
else if (paramName.equals("printRmsd")) {
printRmsd = true;
}
else if (paramName.equals("printTree")) {
showStructTree = true;
}
else if (paramName.equals("useLibrary")) {
useLibrary = true;
}
else if (paramName.equals("sigmaSpike")) {
globalSigmaSpike = true;
globalSigma = false;
addToFilenameExtension("spike");
addToFilenameExtension(globalSigmaSpikeParams[0]+"_"+globalSigmaSpikeParams[1]);
System.out.println("Activating spike prior for local sigmas.");
}
else {
super.setParam(paramName,paramValue);
}
}
private PriorDistribution<Double> setPrior(String paramName,
String paramValue) {
if (paramValue.startsWith("hyp")) {
addToFilenameExtension(paramName+"_hyp");
System.out.println("Using hyperbolic prior for "+paramName+".");
return new HyperbolicPrior();
}
else if (paramValue.startsWith("unif{")) {
String[] argString = paramValue.split("\\{",2);
if (argString[1].endsWith("}")) {
String [] args = argString[1].substring(0,argString[1].length()-1).split("_",2);
if (args.length == 2) {
addToFilenameExtension(paramName+"_u_"+args[0]+"_"+args[1]);
System.out.println("Using Unif("+Double.parseDouble(args[0])+","+Double.parseDouble(args[1])+
") prior for "+paramName+".");
return new UniformPrior(Double.parseDouble(args[0]),Double.parseDouble(args[1]));
}
else {
throw new IllegalArgumentException(
"Prior parameters must be specifed in the form\n-plugin:structal[sigma2Prior=unif{a_b}]\n");
}
}
else {
throw new IllegalArgumentException(
"Prior parameters must be specifed in the form\n-plugin:structal[sigma2Prior=unif{a_b}]\n");
}
}
else if (paramValue.startsWith("g{")) {
String[] argString = paramValue.split("\\{",2);
if (argString[1].endsWith("}")) {
String [] args = argString[1].substring(0,argString[1].length()-1).split("_",2);
if (args.length == 2) {
addToFilenameExtension(paramName+"_g_"+args[0]+"_"+args[1]);
System.out.println("Using Gamma("+Double.parseDouble(args[0])+","+Double.parseDouble(args[1])+
") prior for "+paramName+".");
return new GammaPrior(Double.parseDouble(args[0]),Double.parseDouble(args[1]));
}
else {
throw new IllegalArgumentException(
"Prior parameters must be specifed in the form\n-plugin:structal[sigma2Prior=g{a_b}]\n");
}
}
else {
throw new IllegalArgumentException(
"Prior parameters must be specifed in the form\n-plugin:structal[sigma2Prior=g{a_b}]\n");
}
}
else if (paramValue.startsWith("invg{")) {
String[] argString = paramValue.split("\\{",2);
if (argString[1].endsWith("}")) {
String [] args = argString[1].substring(0,argString[1].length()-1).split("_",2);
if (args.length == 2) {
addToFilenameExtension(paramName+"_invg_"+args[0]+"_"+args[1]);
System.out.println("Using InvGamma("+Double.parseDouble(args[0])+","+Double.parseDouble(args[1])+
") prior for "+paramName+".");
return new InverseGammaPrior(Double.parseDouble(args[0]),Double.parseDouble(args[1]));
}
else {
throw new IllegalArgumentException(
"Prior parameters must be specifed in the form\n-plugin:structal[sigma2Prior=invg{a_b}]\n");
}
}
else {
throw new IllegalArgumentException(
"Prior parameters must be specifed in the form\n-plugin:structal[sigma2Prior=invg{a_b}]\n");
}
}
else {
throw new IllegalArgumentException("Unrecognised prior specification "+paramName+"="+paramValue+".");
}
}
@Override
public void init() {
}
@Override
protected void resetData() {
super.resetData();
distanceMatrix = null;
fullCovar = null;
curAlign = null;
oldCovar = null;
oldDist = null;
oldAlign = null;
oldLogLi = Double.NEGATIVE_INFINITY;
}
// @Override
// public ModelExtension reset() {
// return new StructAlign();
// }
@Override
public void initRun(InputData inputData) throws IllegalArgumentException {
super.initRun(inputData);
HashMap<String, Integer> seqMap = new HashMap<String, Integer>();
int i = 0;
for(String name : inputData.seqs.getSeqnames())
seqMap.put(name.toUpperCase(), i++);
coords = new double[inputData.seqs.size()][][];
if (localEpsilon) bFactors = new double[inputData.seqs.size()][];
for(DataType data : inputData.auxData) {
if(!(data instanceof ProteinSkeletons))
continue;
ProteinSkeletons ps = (ProteinSkeletons) data;
if (localEpsilon && ps.bFactors.size() == 0) {
throw new RuntimeException("No B-factor data available: cannot use localEpsilon mode.");
}
for(i = 0; i < ps.names.size(); i++) {
String name = ps.names.get(i).toUpperCase();
if(!seqMap.containsKey(name))
throw new IllegalArgumentException("structalign: missing sequence or duplicate structure for "+name);
int ind = seqMap.get(name);
int len = inputData.seqs.getSequence(ind).replaceAll("-", "").length();
List<double[]> cl = ps.coords.get(i);
List<Double> bF = ps.bFactors.get(i);
if (localEpsilon && bF.size() == 0) {
localEpsilon = false;
System.out.println("No B-factor data available for "+name+": cannot use localEpsilon mode.");
}
if(len != cl.size())
throw new IllegalArgumentException("structalign: sequence length mismatch with structure file for seq "+name);
coords[ind] = new double[len][];
if (localEpsilon) bFactors[ind] = new double[len];
double bFactorMean = 0;
for(int j = 0; j < len; j++) {
coords[ind][j] = Utils.copyOf(cl.get(j));
if (localEpsilon) {
bFactors[ind][j] = bF.get(j);
bFactorMean += bFactors[ind][j]/len;
}
}
if (localEpsilon) {
for(int j = 0; j < len; j++) {
if (bFactorMean == 0) {
bFactors[ind][j] = 1;
}
else {
bFactors[ind][j] /= bFactorMean;
}
}
}
// center all coordinates to mean zero so that rotations are around center of gravity
RealMatrix temp = new Array2DRowRealMatrix(coords[ind]);
RealVector mean = Funcs.meanVector(temp);
for(int j = 0; j < len; j++)
coords[ind][j]= temp.getRowVector(j).subtract(mean).toArray();
seqMap.remove(name);
}
}
while (coords[refIndex]==null) ++refIndex;
int nStructures=0;
for (int ii=0; ii<coords.length; ii++) {
if (coords[ii]!=null) ++nStructures;
}
if (nStructures < 2) {
throw new RuntimeException("Cannot run StructAlign with fewer than two structures.");
}
// if(seqMap.size() > 0)
// throw new IllegalArgumentException("structalign: missing structure for sequence "+seqMap.keySet().iterator().next());
if (useLibrary) {
rotProp = new RotationProposal(this);
}
if (globalSigma && showStructTree) {
System.out.println("structTree option can only be used in conjunction with localSigma. ");
}
rotCoords = new double[coords.length][][];
axes = new double[coords.length][];
angles = new double[coords.length];
xlats = new double[coords.length][];
axes[refIndex] = new double[] { 1, 0, 0 };
angles[refIndex] = 0;
xlats[refIndex] = new double[] { 0, 0, 0 };
multiNorms = new HashMap<Integer, MultiNormCholesky>();
multiNormsLocal = new HashMap<Column, MultiNormCholesky>();
oldMultiNorms = new HashMap<Integer, MultiNormCholesky>();
oldMultiNormsLocal = new HashMap<Column, MultiNormCholesky>();
// MCMC parameters
sigma2Hier = 1;
nu = 1;
tau = 50;
if (fixedEpsilon) {
epsilon = fixedEpsilonValue;
sigma2Weight += sigmaEpsilonWeight;
}
else {
//epsilon = 100;
epsilon = 50;
//epsilon = 20;
}
// if (fixedSigma2) {
// globalSigma = true;
// }
// number of branches in the tree is 2*leaves - 1
if (globalSigma) {
sigma2 = new double[1];
}
else {
sigma2 = new double[2*coords.length - 1];
}
for(i = 0; i < sigma2.length; i++) {
sigma2[i] = 1;
}
if (fixedSigma2) {
sigma2[0] = fixedSigma2Value;
epsilonWeight += sigmaEpsilonWeight;
}
if (!sigma2PriorInitialised && !fixedSigma2) {
if (globalSigma) {
if (fixedEpsilon) {
sigma2Prior = new GammaPrior(2,2);
}
else {
sigma2Prior = new GammaPrior(1,1);
//sigma2Prior = new HyperbolicPrior();
}
}
else {
//sigma2Prior = new InverseGammaPrior(sigma2PriorShape,sigma2PriorRate);
//sigma2Prior = new LinearPrior();
//sigma2Prior = new UniformPrior();
sigma2Prior = new GammaPrior(1,1);
}
sigma2PriorInitialised = true;
}
if (!epsilonPriorInitialised && !fixedEpsilon) {
epsilonPrior = new GammaPrior(epsilonPriorShape,epsilonPriorRate);
epsilonPriorInitialised = true;
}
if(linkType.equals("linear")) {
linkFunction = new LinearLink();
}
else if (linkType.equals("quadratic")) {
linkFunction = new QuadraticLink();
}
else {
throw new IllegalArgumentException("Invalid link function selected.");
}
// for now, don't allow hierarchical sigma unless link function is linear
if(!linkType.equals("linear")){
globalSigma = true;
}
}
@Override
protected void initMcmc(InputData inputData) {
/* Add alignment and rotation/translation moves */
RotationMove rotationMove = new RotationMove(this,"rotation");
addMcmcMove(rotationMove,rotationWeight);
TranslationMove translationMove = new TranslationMove(this,"translation");
addMcmcMove(translationMove,translationWeight);
LibraryMove libraryMove = null;
if (useLibrary) {
libraryMove = new LibraryMove(this,"library");
addMcmcMove(libraryMove,libraryWeight);
}
if (!inputData.pars.fixAlign) {
AlignmentMove alignmentMove = new AlignmentMove(this,"alignment");
addMcmcMove(alignmentMove,alignmentWeight,alignmentWeightIncrement);
/* Combination moves */
ArrayList<McmcMove> alignmentRotation = new ArrayList<McmcMove>();
alignmentRotation.add(alignmentMove);
rotationMove.shareSubtreeRoot(alignmentMove);
alignmentRotation.add(rotationMove);
McmcCombinationMove alignmentRotationMove =
new McmcCombinationMove(alignmentRotation);
alignmentRotationMove.autoTune = false;
addMcmcMove(alignmentRotationMove,alignmentRotationWeight);
ArrayList<McmcMove> alignmentTranslation = new ArrayList<McmcMove>();
alignmentTranslation.add(alignmentMove);
translationMove.shareSubtreeRoot(alignmentMove);
alignmentTranslation.add(translationMove);
McmcCombinationMove alignmentTranslationMove =
new McmcCombinationMove(alignmentTranslation);
alignmentTranslationMove.autoTune = false;
addMcmcMove(alignmentTranslationMove,alignmentTranslationWeight);
if (useLibrary) {
ArrayList<McmcMove> alignmentLibrary = new ArrayList<McmcMove>();
alignmentLibrary.add(alignmentMove);
alignmentLibrary.add(libraryMove);
McmcCombinationMove alignmentLibraryMove =
new McmcCombinationMove(alignmentLibrary);
addMcmcMove(alignmentLibraryMove,alignmentLibraryWeight);
}
}
/** Add moves for scalar parameters */
StructAlignParameterInterface paramInterfaceGenerator = new StructAlignParameterInterface(this);
ParameterInterface tauInterface = paramInterfaceGenerator.new TauInterface();
ContinuousPositiveStructAlignMove tauMove =
new ContinuousPositiveStructAlignMove(this,tauInterface,tauPrior,new GammaProposal(0.001,0.001),"tau");
tauMove.moveParams.setPlottable();
tauMove.moveParams.setPlotSide(1);
addMcmcMove(tauMove,tauWeight);
ContinuousPositiveStructAlignMove epsilonMove = null;
if (!fixedEpsilon) {
ParameterInterface epsilonInterface = paramInterfaceGenerator.new EpsilonInterface();
epsilonMove =
new ContinuousPositiveStructAlignMove(this,epsilonInterface,epsilonPrior,new GaussianProposal(),"eps");
//new ContinuousPositiveStructAlignMove(this,epsilonInterface,epsilonPrior,new GammaProposal(0.001,0.001),"eps");
epsilonMove.setMinValue(MIN_EPSILON);
epsilonMove.moveParams.setPlottable();
epsilonMove.moveParams.setPlotSide(1);
addMcmcMove(epsilonMove,epsilonWeight);
}
if (!fixedSigma2) {
if (!globalSigma) {
ParameterInterface sigma2HInterface = paramInterfaceGenerator.new Sigma2HInterface();
sigma2HMove = new HierarchicalContinuousPositiveStructAlignMove(this,sigma2HInterface,sigma2HPrior,new GammaProposal(0.001,0.001),"s2_g");
sigma2HMove.moveParams.setPlottable();
sigma2HMove.moveParams.setPlotSide(0);
addMcmcMove(sigma2HMove,sigma2HierWeight);
ParameterInterface nuInterface = paramInterfaceGenerator.new NuInterface();
nuMove = new HierarchicalContinuousPositiveStructAlignMove(this,nuInterface,nuPrior,new GammaProposal(0.001,0.001),"nu");
nuMove.moveParams.setPlottable();
nuMove.moveParams.setPlotSide(1);
nuMove.onlySampleIfAtLeastTwoChildrenNotFixed();
nuMove.setMinValue(0.1);
if (!globalSigmaSpike) {
nuWeight += nuWeightIncrement;
nuWeightIncrement = 0;
}
addMcmcMove(nuMove,nuWeight,nuWeightIncrement);
}
for (int j=0; j<sigma2.length; j++) {
if (j<=sigma2.length/2 && coords[j]==null) continue;
String sigmaName;
if (sigma2.length == 1) {
sigmaName = "s2";
}
else {
sigmaName = "s2_"+j;
}
ParameterInterface sigma2Interface = paramInterfaceGenerator.new Sigma2Interface(j);
// ProposalDistribution prop = null;
// if (globalSigma) prop = new GaussianProposal();
// else prop = new MultiplicativeProposal();
ContinuousPositiveStructAlignMove m = new ContinuousPositiveStructAlignMove(
this,sigma2Interface,
sigma2Prior,new GaussianProposal(),sigmaName);
//sigma2Prior,new GammaProposal(0.001,0.001),sigmaName);
if (!globalSigma && j == sigma2.length - 1) {
continue;
// i.e. don't add the last one if we have
// more than one
}
if (j<=sigma2.length/2) {
// plot only for tips
m.moveParams.setPlottable();
m.moveParams.setPlotSide(0);
}
addMcmcMove(m,sigma2Weight);
if (!globalSigma) {
sigma2HMove.addChildMove(m);
if (globalSigmaSpike) {
sigma2HMove.setSpikeParams(globalSigmaSpikeParams);
sigma2HMove.disallowSpikeSelection();
}
m.addParent(sigma2HMove);
nuMove.addChildMove(m);
// Don't add nuMove as a parent, because otherwise
// we'll double count the prior.
}
if (sigma2.length == 1 && !fixedEpsilon) {
ArrayList<McmcMove> sigmaEpsilon = new ArrayList<McmcMove>();
sigmaEpsilon.add(m);
sigmaEpsilon.add(epsilonMove);
McmcCombinationMove sigmaEpsilonMove =
new McmcCombinationMove(sigmaEpsilon);
addMcmcMove(sigmaEpsilonMove,sigmaEpsilonWeight);
}
}
}
}
@Override
public void beforeSampling(Tree tree) {
/* check for protein with no residues aligned to reference protein, resample alignment if any found */
boolean stop = false;
while(!stop){
String[] align = tree.getState().getLeafAlign();
String ref = align[refIndex];
proteinLoop:
for(int i = 0; i < align.length; i++){
if (i==refIndex) continue;
String other = align[i];
int countAligned = 0;
for(int j = 0; j < align[refIndex].length(); j++)
countAligned += (ref.charAt(j) != '-' & other.charAt(j) != '-') ? 1 : 0;
if(countAligned == 0){
tree.root.selectAndResampleAlignment();
break proteinLoop;
} else if (i == align.length - 1) {stop = true;}
}
}
Funcs.initLSRotations(tree,coords,xlats,axes,angles);
//calcAllRotations();
computeLogLikeFactor(tree);
}
@Override
public void afterFirstHalfBurnin() {
if (!globalSigma && globalSigmaSpike) {
sigma2HMove.allowSpikeSelection();
zeroAllMoveCounts();
}
}
public void afterBurnin() {
if (nuMove != null) {
nuMove.alwaysSample();
nuMove.setMinValue(0.001);
}
}
public double computeLogLikeFactor(Tree tree) {
String[] align = tree.getState().getLeafAlign();
checkConsAlign(align);
curAlign = align;
double[][] covar = calcFullCovar(tree);
checkConsCovar(covar);
fullCovar = covar;
if(!checkConsRots() && rotCoords[refIndex] == null)
calcAllRotations();
double logli = calcAllColumnContrib();
checkConsLogLike(logli);
curLogLike = logli;
return curLogLike;
}
@Override
public double logLikeFactor(Tree tree) {
// Compute log likelihood if not yet computed
if (curLogLike==0) return computeLogLikeFactor(tree);
if (Utils.DEBUG) {
double oldLogLike = curLogLike;
if (Math.abs(oldLogLike - computeLogLikeFactor(tree)) > 1e-8) {
throw new RuntimeException("Inconsistency in logLikeFactor: "+
oldLogLike +" != "+curLogLike);
}
}
// If it's non-zero, then we return its current value
//return computeLogLikeFactor(tree);
return curLogLike;
}
public double calcAllColumnContrib() {
String[] align = curAlign;
double logli = 0;
int[] inds = new int[align.length]; // current char indices
int[] col = new int[align.length];
for(int i = 0; i < align[refIndex].length(); i++) {
for(int j = 0; j < align.length; j++)
col[j] = align[j].charAt(i) == '-' ? -1 : inds[j]++;
double ll = columnContrib(col);
logli += ll;
}
return structTemp * logli;
}
// TODO Change visibility of this to package, after moving
// StructAlign.java to statalign.model.ext.plugins.structalign
private boolean checkConsAlign(String[] align) {
if(!Utils.DEBUG || curAlign == null)
return false;
if(align.length != curAlign.length)
throw new Error("Inconsistency in StructAlign, alignment length: "+align.length+", "+curAlign.length);
for(int i = 0; i < align.length; i++)
if(!align[i].equals(curAlign[i]))
throw new Error("Inconsistency in StructAlign, alignment: "+align[i]+", "+curAlign[i]);
return true;
}
private boolean checkConsCovar(double[][] covar) {
if(!Utils.DEBUG || fullCovar == null)
return false;
if(covar.length != fullCovar.length)
throw new Error("Inconsistency in StructAlign, covar matrix length: "+covar.length+", "+fullCovar.length);
for(int i = 0; i < covar.length; i++) {
if(covar[i].length != fullCovar[i].length)
throw new Error("Inconsistency in StructAlign, covar matrix "+i+" length: "+covar[i].length+", "+fullCovar[i].length);
for(int j = 0; j < covar[i].length; j++)
if(Math.abs(covar[i][j]-fullCovar[i][j]) > 1e-5)
throw new Error("Inconsistency in StructAlign, covar matrix "+i+","+j+" value: "+covar[i][j]+", "+fullCovar[i][j]+", "+tau+", "+epsilon);
}
return true;
}
private boolean checkConsRots() {
if(!Utils.DEBUG || rotCoords[0] == null)
return false;
double[][][] rots = new double[rotCoords.length][][];
for(int i = 0; i < rots.length; i++) {
if (rotCoords[i]==null) continue;
rots[i] = new double[rotCoords[i].length][];
for(int j = 0; j < rots[i].length; j++)
rots[i][j] = MathArrays.copyOf(rotCoords[i][j]);
}
calcAllRotations();
for(int i = 0; i < rots.length; i++) {
if (rots[i]==null) continue;
for(int j = 0; j < rots[i].length; j++)
for(int k = 0; k < rots[i][j].length; k++)
if(Math.abs(rots[i][j][k]-rotCoords[i][j][k]) > 1e-5)
throw new Error("Inconsistency in StructAlign, rotation "+i+","+j+","+k+": "+rots[i][j][k]+" vs "+rotCoords[i][j][k]);
}
return true;
}
private boolean checkConsLogLike(double logli) {
if(!Utils.DEBUG || curLogLike == 0)
return false;
if(Math.abs(logli-curLogLike) > 1e-5)
throw new Error("Inconsistency in StructAlign, log-likelihood "+logli+" vs "+curLogLike);
return true;
}
/**
* Calculates the structural likelihood contribution of a single alignment column
* @param col the column, id of the residue for each sequence (or -1 if gapped in column)
* @return the likelihood contribution
*/
public double columnContrib(int[] _col) {
// count the number of ungapped positions in the column
int numMatch = 0;
int[] col = _col.clone();
for(int i = 0; i < col.length; i++){
if (coords[i]==null) col[i] = -1;
if(col[i]!=-1)
numMatch++;
}
if(numMatch == 0)
return 0;
// collect indices of ungapped positions
int[] notgap = new int[numMatch];
int columnCode = 0;
int j = 0;
for(int i = 0; i < col.length; i++) {
if(col[i]!=-1) {
notgap[j++] = i;
columnCode |= (1 << i);
}
}
/*
* Under localEpsilon mode, the covariance depends on the column,
* not just the indel pattern of the column, but we can still
* cache the Cholesky decompositions to be re-used for columns
* that do not change (since most of the alignment columns do
* not change during an alignment move, this could still yield
* a significant speedup).
*/
MultiNormCholesky multiNorm = null;
if (localEpsilon) multiNorm = multiNormsLocal.get(new Column(col));
else multiNorm = multiNorms.get(columnCode);
MultiNormCholesky multiNorm2 = null;
if (Utils.DEBUG){
double[][] subCovar = Funcs.getSubMatrix(fullCovar, notgap, notgap);
// create normal distribution with mean 0 and covariance subCovar
if (localEpsilon) addLocalEpsilonToDiagonal(subCovar,notgap,col);
multiNorm2 = new MultiNormCholesky(new double[numMatch], subCovar);
}
if (multiNorm == null) {
// extract covariance corresponding to ungapped positions
double[][] subCovar = Funcs.getSubMatrix(fullCovar, notgap, notgap);
if (localEpsilon) addLocalEpsilonToDiagonal(subCovar,notgap,col);
// create normal distribution with mean 0 and covariance subCovar
multiNorm = new MultiNormCholesky(new double[numMatch], subCovar);
if (localEpsilon) multiNormsLocal.put(new Column(col), multiNorm);
else multiNorms.put(columnCode, multiNorm);
}
double logli = 0;
double[] vals = new double[numMatch];
// loop over all 3 coordinates
for(j = 0; j < 3; j++){
for(int i = 0; i < numMatch; i++)
vals[i] = rotCoords[notgap[i]][col[notgap[i]]][j];
if (Utils.DEBUG && multiNorm.logDensity(vals) != multiNorm2.logDensity(vals)) {
System.out.print("col = [");
for (int k=0; k<col.length; k++) System.out.print(col[k]+",");
System.out.println("] ("+columnCode+")");
for (int key : multiNorms.keySet()) {
if (multiNorms.get(key).getMeans().length==vals.length) {
System.out.println(key+" "+multiNorms.get(key).logDensity(vals));
}
}
throw new RuntimeException(
"Inconsistency: "+multiNorm.logDensity(vals)+" != "
+multiNorm2.logDensity(vals));
}
logli += multiNorm.logDensity(vals);
}
return logli;
}
private void addLocalEpsilonToDiagonal(double[][] subCovar, int[] notgap, int[] col) {
for (int i=0; i<notgap.length; i++) {
subCovar[i][i] += Math.pow((bFactors[notgap[i]][col[notgap[i]]]),2) * epsilon / notgap.length;
}
}
private class Column {
public int[] col;
Column(int[] x) {
col = x.clone();
}
@Override
public boolean equals(Object o) {
if (o == this) return true;
if (o == null || o.getClass() != this.getClass()) return false;
Column x = (Column) o;
if (x.col.length != col.length) return false;
for (int i=0; i<x.col.length; i++) {
if (x.col[i] != col[i]) return false;
}
return true;
}
@Override
public int hashCode() {
return Arrays.hashCode(col);
}
}
/**
* extracts the specified rows and columns of a 2d array
* @param matrix, 2d array from which to extract; rows, rows to extract; cols, columns to extract
* @return submatrix
*/
private void calcAllRotations() {
for(int i = 0; i < coords.length; i++) {
if (coords[i]==null) continue;
calcRotation(i);
}
}
public void calcRotation(int ind) {
double[][] ci = coords[ind], rci = rotCoords[ind];
if(rci == null)
rci = rotCoords[ind] = new double[ci.length][];
Rotation rot = new Rotation(new Vector3D(axes[ind]), angles[ind]);
for(int i = 0; i < ci.length; i++) {
rci[i] = rot.applyTo(new Vector3D(ci[i])).add(new Vector3D(xlats[ind])).toArray();
}
}
// TODO Change visibility of this to package, after moving
// StructAlign.java into statalign.model.ext.plugins.structalign.
/**
* return the full covariance matrix for the tree topology and branch lengths
*/
public double[][] calcFullCovar(Tree tree) {
// tree.names.length is equal to the number of vertices
distanceMatrix = new double[tree.names.length][tree.names.length];
double[][] covar = new double[tree.names.length][tree.names.length];
calcDistanceMatrix(tree.root, distanceMatrix);
if (printRmsd) {
rmsdTrace.distanceMatrix = new double[tree.names.length][tree.names.length];
calcUnweightedDistanceMatrix(tree.root,rmsdTrace.distanceMatrix);
}
// for hierarchical sigma, distance matrix calculation already incorporates multiplication
// by theta_i = sigma_i^2 / (2 tau)
if(globalSigma){
for(int i = 0; i < tree.names.length; i++)
for(int j = i; j < tree.names.length; j++)
covar[j][i] = covar[i][j] = tau * Math.exp(-linkFunction.f(distanceMatrix[i][j]) * sigma2[0] / (2*tau));
} else{
for(int i = 0; i < tree.names.length; i++)
for(int j = i; j < tree.names.length; j++)
covar[j][i] = covar[i][j] = tau * Math.exp(-distanceMatrix[i][j]);
}
for(int i = 0; i < tree.names.length; i++) {
if (!localEpsilon) covar[i][i] += epsilon;
}
if (localEpsilon) multiNormsLocal = new HashMap<Column, MultiNormCholesky>();
else multiNorms = new HashMap<Integer, MultiNormCholesky>();
return covar;
}
public void printTree(Vertex v, String vname){
System.out.println(vname +"-" + v.name + ": " + v.edgeLength);
if(v.left!=null){
printTree(v.left, vname + "l");
printTree(v.right, vname + "r");
}
}
public int[] calcUnweightedDistanceMatrix(Vertex vertex, double[][] distMat){
return calcDistanceMatrix(vertex,distMat, true);
}
public int[] calcDistanceMatrix(Vertex vertex, double[][] distMat){
return calcDistanceMatrix(vertex,distMat, false);
}
/**
* recursive algorithm to traverse tree and calculate distance matrix between leaves
*/
public int[] calcDistanceMatrix(Vertex vertex, double[][] distMat, boolean unweighted){
int[] subTree = new int[distMat.length + 1];
// either both left and right are null or neither is
if(vertex.left != null){
int[] subLeft = calcDistanceMatrix(vertex.left, distMat, unweighted);
int[] subRight = calcDistanceMatrix(vertex.right, distMat, unweighted);
int i = 0;
while(subLeft[i] > -1){
subTree[i] = subLeft[i];
i++;
}
for(int j = 0; i+j < subTree.length; j++)
subTree[i+j] = subRight[j];
}
else{
subTree[0] = vertex.index;
for(int j = 1; j < subTree.length; j++)
subTree[j] = -1;
}
if (globalSigma || unweighted) {
addEdgeLength(distMat, subTree, vertex.edgeLength);
}
else {
addEdgeLength(distMat, subTree, vertex.edgeLength * sigma2[vertex.index] / (2*tau));
}
return subTree;
}
// adds the length of the current edge to the distance between all leaves
// of a subtree to all other leaves
// 'rows' contains the indices of vertices in the subtree
public void addEdgeLength(double[][] distMat, int[] subTree, double edgeLength){
int i = 0;
while(subTree[i] > -1){
for(int j = 0; j < distMat.length; j++){
distMat[subTree[i]][j] += edgeLength;
distMat[j][subTree[i]] += edgeLength;
}
i++;
}
// edge length should not be added to distance between vertices in the subtree
// subtract the value from these entries of the distance matrix
i = 0;
while(subTree[i] > -1){
int j = 0;
while(subTree[j] > -1){
distMat[subTree[i]][subTree[j]] -= edgeLength;
distMat[subTree[j]][subTree[i]] -= edgeLength;
j++;
}
i++;
}
}
@Override
public int getParamChangeWeight() {
// TODO test converge and tune value
return pluginProposalWeight;
}
@Override
public double logLikeModExtParamChange(Tree tree, ModelExtension ext) {
// current log-likelihood always precomputed (regardless of whether ext == this)
return curLogLike;
}
@Override
public void beforeAlignChange(Tree tree, Vertex selectRoot) {
oldAlign = curAlign;
oldLogLi = curLogLike;
}
@Override
public double logLikeAlignChange(Tree tree, Vertex selectRoot) {
curAlign = tree.getState().getLeafAlign();
curLogLike = calcAllColumnContrib();
return curLogLike;
}
@Override
public void afterAlignChange(Tree tree, Vertex selectRoot, boolean accepted) {
if(accepted) // accepted, do nothing
return;
// rejected, restore
curAlign = oldAlign;
curLogLike = oldLogLi;
}
@Override
public void beforeTreeChange(Tree tree, Vertex nephew) {
oldDist = distanceMatrix;
oldCovar = fullCovar;
oldAlign = curAlign;
oldLogLi = curLogLike;
if (localEpsilon) oldMultiNormsLocal = multiNormsLocal;
else oldMultiNorms = multiNorms;
}
@Override
public double logLikeTreeChange(Tree tree, Vertex nephew) {
fullCovar = calcFullCovar(tree);
curAlign = tree.getState().getLeafAlign();
curLogLike = calcAllColumnContrib();
return curLogLike;
}
@Override
public void afterTreeChange(Tree tree, Vertex nephew, boolean accepted) {
if (showStructTree) {
structTree.updateStructTree(tree.root,sigma2);
}
if(accepted) // accepted, do nothing
return;
// rejected, restore
distanceMatrix = oldDist;
fullCovar = oldCovar;
curAlign = oldAlign;
curLogLike = oldLogLi;
if (localEpsilon) multiNormsLocal = oldMultiNormsLocal;
else multiNorms = oldMultiNorms;
}
public void beforeContinuousParamChange(Tree tree) {
//oldCovar = fullCovar;
//oldLogLi = curLogLike;
if (localEpsilon) oldMultiNormsLocal = multiNormsLocal;
else oldMultiNorms = multiNorms;
}
public double logLikeContinuousParamChange(Tree tree) {
fullCovar = calcFullCovar(tree);
curLogLike = calcAllColumnContrib();
return curLogLike;
}
public void afterContinuousParamChange(Tree tree, boolean accepted) {
if(accepted) // accepted, do nothing
return;
// rejected, restore
//fullCovar = oldCovar;
//curLogLike = oldLogLi;
if (localEpsilon) multiNormsLocal = oldMultiNormsLocal;
else multiNorms = oldMultiNorms;
}
@Override
public double logLikeEdgeLenChange(Tree tree, Vertex vertex) {
// do exactly the same as for topology change
return logLikeTreeChange(tree, vertex);
}
@Override
public void beforeEdgeLenChange(Tree tree, Vertex vertex) {
// do exactly the same as for topology change
beforeTreeChange(tree, vertex);
}
@Override
public void afterEdgeLenChange(Tree tree, Vertex vertex, boolean accepted) {
// do exactly the same as for topology change
afterTreeChange(tree, vertex, accepted);
}
@Override
public double logLikeIndelParamChange(Tree tree, Hmm hmm, McmcMove m) {
// does not affect log-likelihood
return curLogLike;
}
@Override
public double logLikeSubstParamChange(Tree tree, SubstitutionModel model,
int ind) {
// does not affect log-likelihood
return curLogLike;
}
@Override
public double calcLogEm(int[] aligned) {
return columnContrib(aligned);
}
public void connectStructTrace(StructTrace structTrace) {
this.structTrace = structTrace;
}
public void connectStructTree(StructTreeVisualizer structTree) {
this.structTree = structTree;
}
public void connectRmsdTrace(RmsdTrace _rmsdTrace) {
rmsdTrace = _rmsdTrace;
rmsdTrace.active = printRmsd;
}
/**
* This method is called when structures are added in the GUI.
*/
@Override
public void dataAdded(File file, DataType data) {
if(data instanceof ProteinSkeletons) {
myButton.setEnabled(true);
myButton.setSelected(true);
setActive(true);
selectable = true;
// By default we'll switch on localSigma mode in GUI
globalSigma = false;
// /// For hardcoding fixed-sigma in GUI mode:
// fixedSigma2 = true;
// globalSigma = true;
// fixedSigma2Value = 0.0;
// addToFilenameExtension("sigma2_"+fixedSigma2Value);
// ////////////////
activateAssociatedPlugins();
// If there is B-factor information in the .coor file
// then we will also switch on localEpsilon by default
if ((((ProteinSkeletons) data).bFactors) != null) {
localEpsilon = true;
}
}
}
/**
* Called when this plugin is activated in GUI mode, either
* by toggling the StructAlign button, or by reading in
* protein structure(s) from file.
*/
private void activateAssociatedPlugins() {
structTrace.active = active;
structTrace.postprocessWrite = active;
structTrace.setSelected(active);
// StructTree is activated by default in GUI mode, where possible
showStructTree = !globalSigma;
structTree.active = active & showStructTree;
structTree.setSelected(active & showStructTree);
structTree.postprocessWrite = active & showStructTree;
// RMSDTrace is activated by default in GUI mode, but does not print to file
rmsdTrace.active = active;
rmsdTrace.setSelected(active);
//rmsdTrace.postprocessWrite = active;
rmsdTrace.postprocessWrite = false;
}
// </StructAlign>
}