package beast.app.beauti;
import java.awt.Color;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import javax.swing.Box;
import javax.swing.JCheckBox;
import javax.swing.JComponent;
import javax.swing.JTextField;
import javax.swing.event.DocumentEvent;
import javax.swing.event.DocumentListener;
import beast.app.draw.BEASTObjectInputEditor;
import beast.app.draw.InputEditor;
import beast.app.draw.IntegerInputEditor;
import beast.app.draw.ParameterInputEditor;
import beast.app.draw.SmallLabel;
import beast.core.BEASTInterface;
import beast.core.Distribution;
import beast.core.Input;
import beast.core.MCMC;
import beast.core.Operator;
import beast.core.parameter.IntegerParameter;
import beast.core.parameter.RealParameter;
import beast.core.util.CompoundDistribution;
import beast.evolution.alignment.Alignment;
import beast.evolution.likelihood.GenericTreeLikelihood;
import beast.evolution.operators.DeltaExchangeOperator;
import beast.evolution.sitemodel.SiteModel;
import beast.evolution.sitemodel.SiteModelInterface;
public class SiteModelInputEditor extends BEASTObjectInputEditor {
private static final long serialVersionUID = 1L;
IntegerInputEditor categoryCountEditor;
JTextField categoryCountEntry;
InputEditor gammaShapeEditor;
ParameterInputEditor inVarEditor;
// vars for dealing with mean-rate delta exchange operator
JCheckBox fixMeanRatesCheckBox;
DeltaExchangeOperator operator;
protected SmallLabel fixMeanRatesValidateLabel;
public SiteModelInputEditor(BeautiDoc doc) {
super(doc);
}
@Override
public Class<?> type() {
return SiteModelInterface.Base.class;
}
@Override
public void init(Input<?> input, BEASTInterface beastObject, int itemNr,
ExpandOption isExpandOption, boolean addButtons) {
fixMeanRatesCheckBox = new JCheckBox("Fix mean substitution rate");
fixMeanRatesCheckBox.setName("FixMeanMutationRate");
fixMeanRatesCheckBox.setEnabled(!doc.autoUpdateFixMeanSubstRate);
super.init(input, beastObject, itemNr, isExpandOption, addButtons);
List<Operator> operators = ((MCMC) doc.mcmc.get()).operatorsInput.get();
fixMeanRatesCheckBox.addActionListener(e -> {
JCheckBox averageRatesBox = (JCheckBox) e.getSource();
doFixMeanRates(averageRatesBox.isSelected());
if (averageRatesBox.isSelected())
// set up relative weights
setUpOperator();
});
operator = (DeltaExchangeOperator) doc.pluginmap.get("FixMeanMutationRatesOperator");
if (operator == null) {
operator = new DeltaExchangeOperator();
try {
operator.setID("FixMeanMutationRatesOperator");
operator.initByName("weight", 2.0, "delta", 0.75);
} catch (Throwable e1) {
// ignore initAndValidate exception
}
doc.addPlugin(operator);
}
fixMeanRatesCheckBox.setSelected(operators.contains(operator));
Box box = Box.createHorizontalBox();
box.add(fixMeanRatesCheckBox);
box.add(Box.createHorizontalGlue());
fixMeanRatesValidateLabel = new SmallLabel("x", Color.GREEN);
fixMeanRatesValidateLabel.setVisible(false);
box.add(fixMeanRatesValidateLabel);
if (doc.alignments.size() >= 1 && operator != null) {
JComponent component = (JComponent) getComponents()[0];
component.add(box);
}
setUpOperator();
}
// @Override
// public Class<?> [] types() {
// Class<?>[] types = {SiteModel.class, SiteModel.Base.class};
// return types;
// }
private void doFixMeanRates(boolean averageRates) {
List<Operator> operators = ((MCMC) doc.mcmc.get()).operatorsInput.get();
if (averageRates) {
// connect DeltaExchangeOperator
if (!operators.contains(operator)) {
operators.add(operator);
}
} else {
operators.remove(operator);
fixMeanRatesValidateLabel.setVisible(false);
repaint();
}
}
public InputEditor createMutationRateEditor() {
SiteModel sitemodel = ((SiteModel) m_input.get());
final Input<?> input = sitemodel.muParameterInput;
ParameterInputEditor mutationRateEditor = new ParameterInputEditor(doc);
mutationRateEditor.init(input, sitemodel, -1, ExpandOption.FALSE, true);
mutationRateEditor.getEntry().setEnabled(!doc.autoUpdateFixMeanSubstRate);
return mutationRateEditor;
}
public InputEditor createGammaCategoryCountEditor() {
SiteModel sitemodel = ((SiteModel) m_input.get());
final Input<?> input = sitemodel.gammaCategoryCount;
categoryCountEditor = new IntegerInputEditor(doc) {
private static final long serialVersionUID = 1L;
@Override
public void validateInput() {
super.validateInput();
SiteModel sitemodel = (SiteModel) m_beastObject;
if (sitemodel.gammaCategoryCount.get() < 2 && sitemodel.shapeParameterInput.get().isEstimatedInput.get()) {
m_validateLabel.m_circleColor = Color.orange;
m_validateLabel.setToolTipText("shape parameter is estimated, but not used");
m_validateLabel.setVisible(true);
}
};
};
categoryCountEditor.init(input, sitemodel, -1, ExpandOption.FALSE, true);
categoryCountEntry = categoryCountEditor.getEntry();
categoryCountEntry.getDocument().addDocumentListener(new DocumentListener() {
@Override
public void removeUpdate(DocumentEvent e) {
processEntry2();
}
@Override
public void insertUpdate(DocumentEvent e) {
processEntry2();
}
@Override
public void changedUpdate(DocumentEvent e) {
processEntry2();
}
});
categoryCountEditor.validateInput();
return categoryCountEditor;
}
void processEntry2() {
String categories = categoryCountEntry.getText();
try {
int categoryCount = Integer.parseInt(categories);
RealParameter shapeParameter = ((SiteModel) m_input.get()).shapeParameterInput.get();
if (!gammaShapeEditor.getComponent().isVisible() && categoryCount >= 2) {
// we are flipping from no gamma to gamma heterogeneity accross sites
// so set the estimate flag on the shape parameter
shapeParameter.isEstimatedInput.setValue(true, shapeParameter);
} else if (gammaShapeEditor.getComponent().isVisible() && categoryCount < 2) {
// we are flipping from with gamma to no gamma heterogeneity accross sites
// so unset the estimate flag on the shape parameter
shapeParameter.isEstimatedInput.setValue(false, shapeParameter);
}
Object o = ((ParameterInputEditor)gammaShapeEditor).getComponent();
if (o instanceof ParameterInputEditor) {
ParameterInputEditor e = (ParameterInputEditor) o;
e.m_isEstimatedBox.setSelected(shapeParameter.isEstimatedInput.get());
}
gammaShapeEditor.getComponent().setVisible(categoryCount >= 2);
repaint();
} catch (java.lang.NumberFormatException e) {
// ignore.
}
}
public InputEditor createShapeEditor() throws NoSuchMethodException, SecurityException, ClassNotFoundException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException {
final Input<?> input = ((SiteModel) m_input.get()).shapeParameterInput;
gammaShapeEditor = doc.getInputEditorFactory().createInputEditor(input, (BEASTInterface) m_input.get(), doc);
gammaShapeEditor.getComponent().setVisible(((SiteModel) m_input.get()).gammaCategoryCount.get() >= 2);
return gammaShapeEditor;
}
public InputEditor createProportionInvariantEditor() {
final Input<?> input = ((SiteModel) m_input.get()).invarParameterInput;
inVarEditor = new ParameterInputEditor(doc) {
private static final long serialVersionUID = 1L;
@Override
public void validateInput() {
RealParameter p = (RealParameter) m_input.get();
if (p.isEstimatedInput.get() && p.valuesInput.get().get(0) <= 0.0) {
m_validateLabel.setVisible(true);
m_validateLabel.setToolTipText("<html><p>Proportion invariant should be non-zero when estimating</p></html>");
return;
}
if (p.valuesInput.get().get(0) < 0.0 || p.valuesInput.get().get(0) >= 1.0) {
m_validateLabel.setVisible(true);
m_validateLabel.setToolTipText("<html><p>Proportion invariant should be from 0 to 1 (exclusive 1)</p></html>");
return;
}
super.validateInput();
}
};
inVarEditor.init(input, (BEASTInterface) m_input.get(), -1, ExpandOption.FALSE, true);
inVarEditor.addValidationListener(this);
return inVarEditor;
}
public static boolean customConnector(BeautiDoc doc) {
try {
DeltaExchangeOperator operator = (DeltaExchangeOperator) doc.pluginmap.get("FixMeanMutationRatesOperator");
if (operator == null) {
return false;
}
List<RealParameter> parameters = operator.parameterInput.get();
parameters.clear();
//String weights = "";
CompoundDistribution likelihood = (CompoundDistribution) doc.pluginmap.get("likelihood");
boolean hasOneEstimatedRate = false;
List<String> rateIDs = new ArrayList<>();
List<Integer> weights = new ArrayList<>();
for (Distribution d : likelihood.pDistributions.get()) {
GenericTreeLikelihood treelikelihood = (GenericTreeLikelihood) d;
Alignment data = treelikelihood.dataInput.get();
int weight = data.getSiteCount();
if (data.isAscertained) {
weight -= data.getExcludedPatternCount();
}
if (treelikelihood.siteModelInput.get() instanceof SiteModel) {
SiteModel siteModel = (SiteModel) treelikelihood.siteModelInput.get();
RealParameter mutationRate = siteModel.muParameterInput.get();
//clockRate.m_bIsEstimated.setValue(true, clockRate);
if (mutationRate.isEstimatedInput.get()) {
hasOneEstimatedRate = true;
if (rateIDs.indexOf(mutationRate.getID()) == -1) {
parameters.add(mutationRate);
weights.add(weight);
rateIDs.add(mutationRate.getID());
} else {
int k = rateIDs.indexOf(mutationRate.getID());
weights.set(k, weights.get(k) + weight);
}
}
}
}
IntegerParameter weightParameter;
if (weights.size() == 0) {
weightParameter = new IntegerParameter();
} else {
String weightString = "";
for (int k : weights) {
weightString += k + " ";
}
weightParameter = new IntegerParameter(weightString);
weightParameter.setID("weightparameter");
}
weightParameter.isEstimatedInput.setValue(false, weightParameter);
operator.parameterWeightsInput.setValue(weightParameter, operator);
return hasOneEstimatedRate;
} catch (Exception e) {
}
return false;
}
/** set up relative weights and parameter input **/
public void setUpOperator() {
boolean isAllClocksAreEqual = true;
try {
boolean hasOneEstimatedRate = customConnector(doc);
if (doc.autoUpdateFixMeanSubstRate) {
fixMeanRatesCheckBox.setSelected(hasOneEstimatedRate);
doFixMeanRates(hasOneEstimatedRate);
}
try {
double commonClockRate = -1;
CompoundDistribution likelihood = (CompoundDistribution) doc.pluginmap.get("likelihood");
for (Distribution d : likelihood.pDistributions.get()) {
GenericTreeLikelihood treelikelihood = (GenericTreeLikelihood) d;
if (treelikelihood.siteModelInput.get() instanceof SiteModel) {
SiteModel siteModel = (SiteModel) treelikelihood.siteModelInput.get();
RealParameter mutationRate = siteModel.muParameterInput.get();
//clockRate.m_bIsEstimated.setValue(true, clockRate);
if (mutationRate.isEstimatedInput.get()) {
if (commonClockRate < 0) {
commonClockRate = mutationRate.valuesInput.get().get(0);
} else {
if (Math.abs(commonClockRate - mutationRate.valuesInput.get().get(0)) > 1e-10) {
isAllClocksAreEqual = false;
}
}
}
}
}
} catch (Exception e) {
}
List<RealParameter> parameters = operator.parameterInput.get();
if (!fixMeanRatesCheckBox.isSelected()) {
fixMeanRatesValidateLabel.setVisible(false);
repaint();
return;
}
if (parameters.size() == 0) {
fixMeanRatesValidateLabel.setVisible(true);
fixMeanRatesValidateLabel.m_circleColor = Color.red;
fixMeanRatesValidateLabel.setToolTipText("The model is invalid: At least one substitution rate should be estimated.");
repaint();
return;
}
if (!isAllClocksAreEqual) {
fixMeanRatesValidateLabel.setVisible(true);
fixMeanRatesValidateLabel.m_circleColor = Color.orange;
fixMeanRatesValidateLabel.setToolTipText("Not all substitution rates are equal. Are you sure this is what you want?");
} else if (parameters.size() == 1) {
fixMeanRatesValidateLabel.setVisible(true);
fixMeanRatesValidateLabel.m_circleColor = Color.orange;
fixMeanRatesValidateLabel.setToolTipText("At least 2 clock models should have their rate estimated");
} else if (parameters.size() < doc.getPartitions("SiteModel").size()) {
fixMeanRatesValidateLabel.setVisible(true);
fixMeanRatesValidateLabel.m_circleColor = Color.orange;
fixMeanRatesValidateLabel.setToolTipText("Not all partitions have their rate estimated");
} else {
fixMeanRatesValidateLabel.setVisible(false);
}
repaint();
} catch (Exception e) {
e.printStackTrace();
}
}
}