package beast.app.seqgen;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import beast.core.Description;
import beast.core.Input;
import beast.core.Input.Validate;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.Sequence;
import beast.evolution.branchratemodel.BranchRateModel;
import beast.evolution.datatype.DataType;
import beast.evolution.sitemodel.SiteModel;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.util.Randomizer;
import beast.util.XMLProducer;
/**
* @author remco@cs.waikato.ac.nz
*/
@Description("An alignment containing sequences randomly generated using a"
+ "given site model down a given tree.")
public class SimulatedAlignment extends Alignment {
final public Input<Alignment> m_data = new Input<>("data", "alignment data which specifies datatype and taxa of the beast.tree", Validate.REQUIRED);
final public Input<Tree> m_treeInput = new Input<>("tree", "phylogenetic beast.tree with sequence data in the leafs", Validate.REQUIRED);
final public Input<SiteModel.Base> m_pSiteModelInput = new Input<>("siteModel", "site model for leafs in the beast.tree", Validate.REQUIRED);
final public Input<BranchRateModel.Base> m_pBranchRateModelInput = new Input<>("branchRateModel",
"A model describing the rates on the branches of the beast.tree.");
final public Input<Integer> m_sequenceLengthInput = new Input<>("sequencelength", "nr of samples to generate (default 1000).", 1000);
final public Input<String> m_outputFileNameInput = new Input<>(
"outputFileName",
"If provided, simulated alignment is additionally written to this file.");
/**
* nr of samples to generate *
*/
protected int m_sequenceLength;
/**
* tree used for generating samples *
*/
protected Tree m_tree;
/**
* site model used for generating samples *
*/
protected SiteModel.Base m_siteModel;
/**
* branch rate model used for generating samples *
*/
protected BranchRateModel m_branchRateModel;
/**
* nr of categories in site model *
*/
int m_categoryCount;
/**
* nr of states in site model *
*/
int m_stateCount;
/**
* name of output file *
*/
String m_outputFileName;
/**
* an array used to transfer transition probabilities
*/
protected double[][] m_probabilities;
public SimulatedAlignment() {
// Override the sequence input requirement.
sequenceInput.setRule(Validate.OPTIONAL);
}
@Override
public void initAndValidate() {
m_tree = m_treeInput.get();
m_siteModel = m_pSiteModelInput.get();
m_branchRateModel = m_pBranchRateModelInput.get();
m_sequenceLength = m_sequenceLengthInput.get();
m_stateCount = m_data.get().getMaxStateCount();
m_categoryCount = m_siteModel.getCategoryCount();
m_probabilities = new double[m_categoryCount][m_stateCount * m_stateCount];
m_outputFileName = m_outputFileNameInput.get();
sequenceInput.get().clear();
simulate();
// Write simulated alignment to disk if requested:
if (m_outputFileName != null) {
PrintStream pstream;
try {
pstream = new PrintStream(m_outputFileName);
pstream.println(new XMLProducer().toRawXML(this));
pstream.close();
} catch (FileNotFoundException e) {
throw new IllegalArgumentException(e.getMessage());
}
}
super.initAndValidate();
}
/**
* Convert integer representation of sequence into a Sequence
*
* @param seq integer representation of the sequence
* @param node used to determine taxon for sequence
* @return Sequence
*/
Sequence intArray2Sequence(int[] seq, Node node) {
DataType dataType = m_data.get().getDataType();
String seqString = dataType.state2string(seq);
// StringBuilder seq = new StringBuilder();
// String map = m_data.get().getMap();
// if (map != null) {
// for (int i = 0; i < m_sequenceLength; i++) {
// seq.append(map.charAt(seq[i]));
// }
// } else {
// for (int i = 0; i < m_sequenceLength-1; i++) {
// seq.append(seq[i] + ",");
// }
// seq.append(seq[m_sequenceLength-1] + "");
// }
String taxon = m_data.get().getTaxaNames().get(node.getNr());
return new Sequence(taxon, seqString);
} // intArray2Sequence
/**
* perform the actual sequence generation
*
* @return alignment containing randomly generated sequences for the nodes in the
* leaves of the tree
*/
public void simulate() {
Node root = m_tree.getRoot();
double[] categoryProbs = m_siteModel.getCategoryProportions(root);
int[] category = new int[m_sequenceLength];
for (int i = 0; i < m_sequenceLength; i++) {
category[i] = Randomizer.randomChoicePDF(categoryProbs);
}
double[] frequencies = m_siteModel.getSubstitutionModel().getFrequencies();
int[] seq = new int[m_sequenceLength];
for (int i = 0; i < m_sequenceLength; i++) {
seq[i] = Randomizer.randomChoicePDF(frequencies);
}
//alignment.setDataType(m_siteModel.getFrequencyModel().getDataType());
traverse(root, seq, category);
} // simulate
/**
* recursively walk through the tree top down, and add sequence to alignment whenever
* a leave node is reached.
*
* @param node reference to the current node, for which we visit all children
* @param parentSequence randomly generated sequence of the parent node
* @param category array of categories for each of the sites
* @param alignment
*/
void traverse(Node node, int[] parentSequence, int[] category) {
for (int childIndex = 0; childIndex < 2; childIndex++) {
Node child = (childIndex == 0 ? node.getLeft() : node.getRight());
for (int i = 0; i < m_categoryCount; i++) {
getTransitionProbabilities(m_tree, child, i, m_probabilities[i]);
}
int[] seq = new int[m_sequenceLength];
double[] cProb = new double[m_stateCount];
for (int i = 0; i < m_sequenceLength; i++) {
System.arraycopy(m_probabilities[category[i]], parentSequence[i] * m_stateCount, cProb, 0, m_stateCount);
seq[i] = Randomizer.randomChoicePDF(cProb);
}
if (child.isLeaf()) {
sequenceInput.setValue(intArray2Sequence(seq, child), this);
} else {
traverse(child, seq, category);
}
}
} // traverse
/**
* get transition probability matrix for particular rate category *
*/
void getTransitionProbabilities(Tree tree, Node node, int rateCategory, double[] probs) {
Node parent = node.getParent();
double branchRate = (m_branchRateModel == null ? 1.0 : m_branchRateModel.getRateForBranch(node));
branchRate *= m_siteModel.getRateForCategory(rateCategory, node);
// Get the operational time of the branch
//final double branchTime = branchRate * (parent.getHeight() - node.getHeight());
//if (branchTime < 0.0) {
// throw new RuntimeException("Negative branch length: " + branchTime);
//}
//double branchLength = m_siteModel.getRateForCategory(rateCategory) * branchTime;
// // TODO Hack until SiteRateModel issue is resolved
// if (m_siteModel.getSubstitutionModel() instanceof SubstitutionEpochModel) {
// ((SubstitutionEpochModel)m_siteModel.getSubstitutionModel()).getTransitionProbabilities(tree.getNodeHeight(node),
// tree.getNodeHeight(parent),branchLength, probs);
// return;
// }
//m_siteModel.getSubstitutionModel().getTransitionProbabilities(branchLength, probs);
m_siteModel.getSubstitutionModel().getTransitionProbabilities(node, parent.getHeight(), node.getHeight(), branchRate, probs);
} // getTransitionProbabilities
} // class SequenceAlignment