package beast.app.seqgen;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
import beast.core.BEASTInterface;
import beast.core.Description;
import beast.core.Input;
import beast.core.Input.Validate;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.Sequence;
import beast.evolution.branchratemodel.BranchRateModel;
import beast.evolution.datatype.DataType;
import beast.evolution.likelihood.TreeLikelihood;
import beast.evolution.sitemodel.SiteModel;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.util.Randomizer;
import beast.util.XMLParser;
import beast.util.XMLParserException;
import beast.util.XMLProducer;
/**
* @author remco@cs.waikato.ac.nz
*/
@Description("Performs random sequence generation for a given site model. " +
"Sequences for the leave nodes in the tree are returned as an alignment.")
public class SequenceSimulator extends beast.core.Runnable {
final public Input<Alignment> m_data = new Input<>("data", "alignment data which specifies datatype and taxa of the beast.tree", Validate.REQUIRED);
final public Input<Tree> m_treeInput = new Input<>("tree", "phylogenetic beast.tree with sequence data in the leafs", Validate.REQUIRED);
final public Input<SiteModel.Base> m_pSiteModelInput = new Input<>("siteModel", "site model for leafs in the beast.tree", Validate.REQUIRED);
final public Input<BranchRateModel.Base> m_pBranchRateModelInput = new Input<>("branchRateModel",
"A model describing the rates on the branches of the beast.tree.");
final public Input<Integer> m_sequenceLengthInput = new Input<>("sequencelength", "nr of samples to generate (default 1000).", 1000);
final public Input<String> m_outputFileNameInput = new Input<>(
"outputFileName",
"If provided, simulated alignment is written to this file rather "
+ "than to standard out.");
final public Input<List<MergeDataWith>> mergeListInput = new Input<>("merge", "specifies template used to merge the generated alignment with", new ArrayList<>());
final public Input<Integer> iterationsInput = new Input<>("iterations","number of times the data is generated", 1);
/**
* nr of samples to generate *
*/
protected int m_sequenceLength;
/**
* tree used for generating samples *
*/
protected Tree m_tree;
/**
* site model used for generating samples *
*/
protected SiteModel.Base m_siteModel;
/**
* branch rate model used for generating samples *
*/
protected BranchRateModel m_branchRateModel;
/**
* nr of categories in site model *
*/
int m_categoryCount;
/**
* nr of states in site model *
*/
int m_stateCount;
/**
* name of output file *
*/
String m_outputFileName;
/**
* an array used to transfer transition probabilities
*/
protected double[][] m_probabilities;
@Override
public void initAndValidate() {
m_tree = m_treeInput.get();
m_siteModel = m_pSiteModelInput.get();
m_branchRateModel = m_pBranchRateModelInput.get();
m_sequenceLength = m_sequenceLengthInput.get();
m_stateCount = m_data.get().getMaxStateCount();
m_categoryCount = m_siteModel.getCategoryCount();
m_probabilities = new double[m_categoryCount][m_stateCount * m_stateCount];
m_outputFileName = m_outputFileNameInput.get();
}
@Override
public void run() throws IllegalArgumentException, IllegalAccessException, IOException, XMLParserException {
for (int i = 0; i < iterationsInput.get(); i++) {
Alignment alignment = simulate();
// Write output to stdout or file
PrintStream pstream;
if (m_outputFileName == null)
pstream = System.out;
else
pstream = new PrintStream(m_outputFileName);
pstream.println(new XMLProducer().toRawXML(alignment));
for (MergeDataWith merge : mergeListInput.get()) {
merge.process(alignment, i);
}
}
}
/**
* Convert integer representation of sequence into a Sequence
*
* @param seq integer representation of the sequence
* @param node used to determine taxon for sequence
* @return Sequence
* @
*/
Sequence intArray2Sequence(int[] seq, Node node) {
DataType dataType = m_data.get().getDataType();
String seqString = dataType.state2string(seq);
// StringBuilder seq = new StringBuilder();
// String map = m_data.get().getMap();
// if (map != null) {
// for (int i = 0; i < m_sequenceLength; i++) {
// seq.append(map.charAt(seq[i]));
// }
// } else {
// for (int i = 0; i < m_sequenceLength-1; i++) {
// seq.append(seq[i] + ",");
// }
// seq.append(seq[m_sequenceLength-1] + "");
// }
List<Sequence> taxa = m_data.get().sequenceInput.get();
String taxon = taxa.get(node.getNr()).taxonInput.get();
return new Sequence(taxon, seqString);
} // intArray2Sequence
/**
* perform the actual sequence generation
*
* @return alignment containing randomly generated sequences for the nodes in the
* leaves of the tree
* @
*/
public Alignment simulate() {
Node root = m_tree.getRoot();
double[] categoryProbs = m_siteModel.getCategoryProportions(root);
int[] category = new int[m_sequenceLength];
for (int i = 0; i < m_sequenceLength; i++) {
category[i] = Randomizer.randomChoicePDF(categoryProbs);
}
double[] frequencies = m_siteModel.getSubstitutionModel().getFrequencies();
int[] seq = new int[m_sequenceLength];
for (int i = 0; i < m_sequenceLength; i++) {
seq[i] = Randomizer.randomChoicePDF(frequencies);
}
Alignment alignment = new Alignment();
alignment.userDataTypeInput.setValue(m_data.get().getDataType(), alignment);
alignment.setID("SequenceSimulator");
traverse(root, seq, category, alignment);
return alignment;
} // simulate
/**
* recursively walk through the tree top down, and add sequence to alignment whenever
* a leave node is reached.
*
* @param node reference to the current node, for which we visit all children
* @param parentSequence randomly generated sequence of the parent node
* @param category array of categories for each of the sites
* @param alignment
* @
*/
void traverse(Node node, int[] parentSequence, int[] category, Alignment alignment) {
for (int childIndex = 0; childIndex < 2; childIndex++) {
Node child = (childIndex == 0 ? node.getLeft() : node.getRight());
for (int i = 0; i < m_categoryCount; i++) {
getTransitionProbabilities(m_tree, child, i, m_probabilities[i]);
}
int[] seq = new int[m_sequenceLength];
double[] cProb = new double[m_stateCount];
for (int i = 0; i < m_sequenceLength; i++) {
System.arraycopy(m_probabilities[category[i]], parentSequence[i] * m_stateCount, cProb, 0, m_stateCount);
seq[i] = Randomizer.randomChoicePDF(cProb);
}
if (child.isLeaf()) {
alignment.sequenceInput.setValue(intArray2Sequence(seq, child), alignment);
} else {
traverse(child, seq, category, alignment);
}
}
} // traverse
/**
* get transition probability matrix for particular rate category *
*/
void getTransitionProbabilities(Tree tree, Node node, int rateCategory, double[] probs) {
Node parent = node.getParent();
double branchRate = (m_branchRateModel == null ? 1.0 : m_branchRateModel.getRateForBranch(node));
branchRate *= m_siteModel.getRateForCategory(rateCategory, node);
// Get the operational time of the branch
//final double branchTime = branchRate * (parent.getHeight() - node.getHeight());
//if (branchTime < 0.0) {
// throw new RuntimeException("Negative branch length: " + branchTime);
//}
//double branchLength = m_siteModel.getRateForCategory(rateCategory) * branchTime;
// // TODO Hack until SiteRateModel issue is resolved
// if (m_siteModel.getSubstitutionModel() instanceof SubstitutionEpochModel) {
// ((SubstitutionEpochModel)m_siteModel.getSubstitutionModel()).getTransitionProbabilities(tree.getNodeHeight(node),
// tree.getNodeHeight(parent),branchLength, probs);
// return;
// }
//m_siteModel.getSubstitutionModel().getTransitionProbabilities(branchLength, probs);
m_siteModel.getSubstitutionModel().getTransitionProbabilities(node, parent.getHeight(), node.getHeight(), branchRate, probs);
} // getTransitionProbabilities
/**
* find a treelikelihood object among the plug-ins by recursively inspecting plug-ins *
*/
static TreeLikelihood getTreeLikelihood(BEASTInterface beastObject) {
for (BEASTInterface beastObject2 : beastObject.listActiveBEASTObjects()) {
if (beastObject2 instanceof TreeLikelihood) {
return (TreeLikelihood) beastObject2;
} else {
TreeLikelihood likelihood = getTreeLikelihood(beastObject2);
if (likelihood != null) {
return likelihood;
}
}
}
return null;
}
/**
* helper method *
*/
public static void printUsageAndExit() {
System.out.println("Usage: java " + SequenceSimulator.class.getName() + " <beast file> <nr of instantiations> [<output file>]");
System.out.println("simulates from a treelikelihood specified in the beast file.");
System.out.println("<beast file> is name of the path beast file containing the treelikelihood.");
System.out.println("<nr of instantiations> is the number of instantiations to be replicated.");
System.out.println("<output file> optional name of the file to write the sequence to. By default, the sequence is written to standard output.");
System.exit(0);
} // printUsageAndExit
@SuppressWarnings("unchecked")
public static void main(String[] args) {
try {
// parse arguments
if (args.length < 2) {
printUsageAndExit();
}
String fileName = args[0];
int replications = Integer.parseInt(args[1]);
PrintStream out = System.out;
if (args.length == 3) {
File file = new File(args[2]);
out = new PrintStream(file);
}
// grab the file
String xml = "";
BufferedReader fin = new BufferedReader(new FileReader(fileName));
while (fin.ready()) {
xml += fin.readLine();
}
fin.close();
// parse the xml
XMLParser parser = new XMLParser();
BEASTInterface beastObject = parser.parseFragment(xml, true);
// find relevant objects from the model
TreeLikelihood treeLikelihood = getTreeLikelihood(beastObject);
if (treeLikelihood == null) {
throw new IllegalArgumentException("No treelikelihood found in file. Giving up now.");
}
Alignment data = ((Input<Alignment>) treeLikelihood.getInput("data")).get();
Tree tree = ((Input<Tree>) treeLikelihood.getInput("tree")).get();
SiteModel pSiteModel = ((Input<SiteModel>) treeLikelihood.getInput("siteModel")).get();
BranchRateModel pBranchRateModel = ((Input<BranchRateModel>) treeLikelihood.getInput("branchRateModel")).get();
// feed to sequence simulator and generate leaves
SequenceSimulator treeSimulator = new SequenceSimulator();
treeSimulator.init(data, tree, pSiteModel, pBranchRateModel, replications);
XMLProducer producer = new XMLProducer();
Alignment alignment = treeSimulator.simulate();
xml = producer.toRawXML(alignment);
out.println("<beast version='2.0'>");
out.println(xml);
out.println("</beast>");
} catch (Exception e) {
e.printStackTrace();
}
} // main
} // class SequenceSimulator