package dr.evomodel.sitemodel; import dr.inference.model.*; import dr.inference.distribution.ParametricDistributionModel; /** * @author Chieh-Hsi Wu * * This class models the relative loci rates using the given discritized parametric distribution. */ public class DiscretizedLociRates extends AbstractModel { private CompoundParameter lociRates; private Parameter rateCategoryParameter; private ParametricDistributionModel distrModel; private double normalizeRateTo; private double[] rates; private boolean normalize; private int categoryCount; private double scaleFactor; private boolean completeSetup; public DiscretizedLociRates( CompoundParameter lociRates, Parameter rateCategoryParameter, ParametricDistributionModel model, boolean normalize, double normalizeLociRateTo, int categoryCount) { super("DiscretizedLociRatesModel"); this.lociRates = lociRates; this.rateCategoryParameter = rateCategoryParameter; //Force the boundaries of rateCategoryParameter to match the category count Parameter.DefaultBounds bound = new Parameter.DefaultBounds(categoryCount - 1, 0, rateCategoryParameter.getDimension()); this.rateCategoryParameter.addBounds(bound); this.distrModel = model; this.normalizeRateTo = normalizeLociRateTo; this.normalize = normalize; this.categoryCount = categoryCount; rates = new double[categoryCount]; completeSetup = true; setupRates(); addModel(distrModel); addVariable(this.rateCategoryParameter); } private void setupRates(){ if(completeSetup){ double categoryIntervalSize = 1.0/categoryCount; for(int i = 0; i < categoryCount; i++){ rates[i]= distrModel.quantile((i+0.5)*categoryIntervalSize); } } if(normalize){ computeFactor(); } completeSetup = false; int lociCount = rateCategoryParameter.getDimension(); for(int i = 0; i < lociCount; i ++){ lociRates.setParameterValue(i,rates[(int)rateCategoryParameter.getParameterValue(i)]*scaleFactor); } } public void handleModelChangedEvent(Model model, Object object, int index) { if (model == distrModel) { completeSetup = true; setupRates(); //System.out.println("speed investigation 1"); fireModelChanged(); }else if (model == rateCategoryParameter) { //System.out.println("speed investigation 2"); setupRates(); fireModelChanged(null, index); } } protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { //System.out.println("speed investigation 3"); setupRates(); fireModelChanged(null, index); } protected void storeState() { } protected void acceptState() { } protected void restoreState() { //setupRates(); } private void computeFactor(){ double sumRates = 0.0; int lociCount = rateCategoryParameter.getDimension(); for(int i = 0; i < lociCount; i++){ sumRates += rates[(int)rateCategoryParameter.getParameterValue(i)]; } scaleFactor = normalizeRateTo/(sumRates/lociCount); } }