package statalign.model.ext.plugins.structalign;
import java.util.List;
import java.util.ArrayList;
import statalign.base.Utils;
import statalign.mcmc.GammaPrior;
import statalign.mcmc.LogNormalPrior;
import statalign.mcmc.McmcMove;
import statalign.mcmc.ParameterInterface;
import statalign.mcmc.PriorDistribution;
import statalign.mcmc.ProposalDistribution;
import statalign.model.ext.plugins.StructAlign;
public class HierarchicalContinuousPositiveStructAlignMove extends ContinuousPositiveStructAlignMove {
private List<ContinuousPositiveStructAlignMove> children = new ArrayList<ContinuousPositiveStructAlignMove>();
public PriorDistribution<Double> hierarchicalPrior;
private boolean onlySampleIfAtLeastTwoChildrenNotFixed = false;
boolean allowSpikeSelection = false;
public void disallowSpikeSelection() { allowSpikeSelection = false; }
public void allowSpikeSelection() { allowSpikeSelection = true; }
double[] spikeParams;
public HierarchicalContinuousPositiveStructAlignMove (StructAlign s,
ParameterInterface p, PriorDistribution<Double> pr,
ProposalDistribution<Double> prop, String n) {
super(s,p,pr,prop,n);
//hierarchicalPrior = new GammaPrior(((StructAlign) owner).nu,((StructAlign) owner).nu / ((StructAlign) owner).sigma2Hier);
hierarchicalPrior = new LogNormalPrior(Math.log(((StructAlign) owner).sigma2Hier),((StructAlign) owner).nu);
//hierarchicalPrior = new GammaPrior(1,((StructAlign) owner).nu / ((StructAlign) owner).sigma2Hier);
// TODO Abstract this somewhat
}
public void addChildMove(ContinuousPositiveStructAlignMove child) {
children.add(child);
//child.addParent(this);
}
public double getLogChildDensity(ContinuousPositiveStructAlignMove child) {
return hierarchicalPrior.logDensity(child.getParam().get());
}
public void onlySampleIfAtLeastTwoChildrenNotFixed() {
onlySampleIfAtLeastTwoChildrenNotFixed = true;
}
public void alwaysSample() {
onlySampleIfAtLeastTwoChildrenNotFixed = false;
}
public void setSpikeParams(double[] params) {
spikeParams = params;
allowSpikeSelection = true;
for (ContinuousPositiveStructAlignMove child : children) {
//System.out.print(child.fixedToParent+" ");
//double fixProb = spikeParams[0]/(spikeParams[0]+spikeParams[1]);
//child.fixedToParent = (Utils.generator.nextDouble() < fixProb);
child.fixedToParent = true;
//System.out.print("("+child.fixedToParent+") ");
}
}
int countFixedToParent() {
int n = 0;
for (ContinuousPositiveStructAlignMove child : children) {
if (child.fixedToParent) n++;
//System.out.print(child.fixedToParent+" ");
}
//System.out.println();
return n;
}
int countChildren() {
return children.size();
}
@Override
public double proposal(Object externalState) {
if (onlySampleIfAtLeastTwoChildrenNotFixed && (countChildren() - countFixedToParent() < 2)){
if (children.get(0).parentPriors.get(0) != this) {
// Then this must be a nuMove
// and in this case we keep nu the same,
// because all (or all but one) of the local sigmas are fixed at
// the global.
autoTune = false;
proposalWidthControlVariable = 0.1;
return 0;
}
}
autoTune = true;
double logProposalDensity = super.proposal(externalState);
for (int i=0; i<children.size(); i++) {
if(i == tree.root.index) {
continue;
}
if (children.get(i).fixedToParent && children.get(i).parentPriors.get(0) == this) {
children.get(i).copyState(externalState);
logProposalDensity -= children.get(i).logPriorDensity(externalState);
}
else logProposalDensity -= hierarchicalPrior.logDensity(children.get(i).getParam().get());
}
//hierarchicalPrior = new GammaPrior(((StructAlign) owner).nu,((StructAlign) owner).nu / ((StructAlign) owner).sigma2Hier);
hierarchicalPrior = new LogNormalPrior(Math.log(((StructAlign) owner).sigma2Hier),((StructAlign) owner).nu);
//hierarchicalPrior = new GammaPrior(1,((StructAlign) owner).nu / ((StructAlign) owner).sigma2Hier);
// TODO Abstract this somewhat
for (int i=0; i<children.size(); i++) {
if(i == tree.root.index) {
continue;
}
if (children.get(i).fixedToParent && children.get(i).parentPriors.get(0) == this) {
children.get(i).setParam(param.get());
logProposalDensity += children.get(i).logPriorDensity(externalState);
}
else logProposalDensity += hierarchicalPrior.logDensity(children.get(i).getParam().get());
//children.get(i).setPlottable();
}
return logProposalDensity;
}
@Override
public void restoreState(Object externalState) {
super.restoreState(externalState);
for (int i=0; i<children.size(); i++) {
if(i == tree.root.index) {
continue;
}
if (children.get(i).fixedToParent && children.get(i).parentPriors.get(0) == this) {
children.get(i).restoreState(externalState);
}
}
//hierarchicalPrior = new GammaPrior(((StructAlign) owner).nu,((StructAlign) owner).nu / ((StructAlign) owner).sigma2Hier);
hierarchicalPrior = new LogNormalPrior(Math.log(((StructAlign) owner).sigma2Hier),((StructAlign) owner).nu);
//hierarchicalPrior = new GammaPrior(1,((StructAlign) owner).nu / ((StructAlign) owner).sigma2Hier);
// TODO Abstract this somewhat
}
}