package test.dr.evomodel.substmodel;
import dr.evolution.datatype.DataType;
import dr.oldevomodel.substmodel.ComplexSubstitutionModel;
import dr.oldevomodel.substmodel.FrequencyModel;
import dr.oldevomodel.substmodel.SVSComplexSubstitutionModel;
import dr.inference.model.Parameter;
import junit.framework.TestCase;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Random;
/**
*
*/
public class TimeIrreversibleTest extends TestCase {
private static final double time = 0.01;
private static NumberFormat formatter = new DecimalFormat("###0.000000");
private static ArrayList<Double> ratioSummary = new ArrayList<Double> ();
static class Original {
public double[] getRates() {
return new double[]{
38.505, 23.573, 2.708, 3.35, 11.641, 0.189, 0.127, 0.511, 0.272, 1.788E-2,
0.214, 1.322E-2, 3.015E-2, 0.449, 0.177, 0.305, 1.517E-2, 3.924E-2, 0.18, 0.14,
1.273E-2, 0.265, 1.422E-2, 1.474E-2, 0.911, 0.17, 0.217, 4.078, 0.206, 3.309E-2,
0.657, 1.874E-2, 3.141E-2, 0.403, 2.003E-2, 0.582, 0.732, 0.106, 9.147E-2, 0.248,
1.516E-2, 0.524, 0.1, 1.986, 0.819, 0.146, 7.519E-2, 1.35, 0.166, 0.204,
1.753E-2, 0.59, 0.691, 3.308, 0.377, 1.785E-2
}; // 56
}
public double[] getIndicators() {
return new double[]{
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0
}; // 56
}
public double[] getFrequencies() {
// return new double[]{0.141, 7.525E-2, 0.117, 0.283, 7.743E-2, 0.156, 1.555E-2, 0.135}; // sum to 1.00023
return new double[]{0.141, 0.075, 0.117, 0.283, 0.077, 0.156, 0.016, 0.135};
}
public DataType getDataType() {
return new DataType() {
public String getDescription() {
return null;
}
public int getType() {
return 0; // NUCLEOTIDES = 0;
}
@Override
public char[] getValidChars() {
return null;
}
public int getStateCount() {
return getFrequencies().length; // = frequency
}
};
}
public String toString() {
return "original data test";
}
}
class Test extends Original {
private final double x;
public Test(double x) {
this.x = x;
}
public double[] getRates(int id) {
double[] originalRates = super.getRates();
System.out.println("original rates:");
printRateMatrix(originalRates, getDataType().getStateCount());
double[] newRates = new double[originalRates.length];
double[] uniform = new double[originalRates.length];
for (int r = 0; r < originalRates.length; r++) {
if (r == id) {
uniform[r] = (new Random()).nextDouble() * ((1 / x) - x) + x;
newRates[r] = originalRates[r] * uniform[r];
} else {
newRates[r] = originalRates[r];
}
}
System.out.println("random ratio:");
printRateMatrix(uniform, getDataType().getStateCount());
System.out.println("new rates:");
printRateMatrix(newRates, getDataType().getStateCount());
return newRates;
}
public String toString() {
return "test using random number : " + x;
}
}
public void tests() {
Original originalTest = new Original();
double[] csm_orig = testComplexSubstitutionModel(originalTest, originalTest.getRates());
double[] svs_orig = testSVSComplexSubstitutionModel(originalTest, originalTest.getRates());
Test test = new Test(0.8);
for (int r = 0; r < test.getRates().length; r++) {
System.out.println("==================== changing index = " + r + " (start from 0) ====================");
double[] newRate = test.getRates(r);
double[] csm_test = testComplexSubstitutionModel(test, newRate);
reportMatrix(csm_orig, csm_test);
double[] svs_test = testSVSComplexSubstitutionModel(test, newRate);
reportMatrix(svs_orig, svs_test);
}
System.out.println("==================== Biggest Ratio Summary ====================\n");
int i = 1;
double bigget = 0;
int biggetId = 0;
for (Double r : ratioSummary) {
if (i % 2 != 0) {
System.out.print(i/2 + " ");
}
System.out.print(formatter.format(r) + ", ");
if (bigget < r) {
bigget = r;
biggetId = i;
}
if (i % 2 == 0) {
System.out.println("");
}
i++;
}
System.out.println("bigget = " + formatter.format(bigget) + ", where index is " + biggetId/2);
}
private double[] testComplexSubstitutionModel(Original test, double[] rates) {
System.out.println("\n*** Complex Substitution Model Test: " + test + " ***");
Parameter ratesP = new Parameter.Default(rates);
DataType dataType = test.getDataType();
FrequencyModel freqModel = new FrequencyModel(dataType, new Parameter.Default(test.getFrequencies()));
ComplexSubstitutionModel substModel = new ComplexSubstitutionModel("Complex Substitution Model Test",
dataType,
freqModel,
ratesP);
double logL = substModel.getLogLikelihood();
System.out.println("Prior = " + logL);
double[] finiteTimeProbs = null;
if (!Double.isInfinite(logL)) {
finiteTimeProbs = new double[substModel.getDataType().getStateCount() * substModel.getDataType().getStateCount()];
substModel.getTransitionProbabilities(time, finiteTimeProbs);
System.out.println("Probs = ");
printRateMatrix(finiteTimeProbs, substModel.getDataType().getStateCount());
}
// assertEquals(1, 1, 1e-10);
return finiteTimeProbs;
}
private double[] testSVSComplexSubstitutionModel(Original test, double[] rates) {
System.out.println("\n*** SVS Complex Substitution Model Test: " + test + " ***");
double[] indicators = test.getIndicators();
Parameter ratesP = new Parameter.Default(rates);
Parameter indicatorsP = new Parameter.Default(indicators);
DataType dataType = test.getDataType();
FrequencyModel freqModel = new FrequencyModel(dataType, new Parameter.Default(test.getFrequencies()));
SVSComplexSubstitutionModel substModel = new SVSComplexSubstitutionModel("SVS Complex Substitution Model Test",
dataType,
freqModel,
ratesP, indicatorsP);
double logL = substModel.getLogLikelihood();
System.out.println("Prior = " + logL);
double[] finiteTimeProbs = null;
if (!Double.isInfinite(logL)) {
finiteTimeProbs = new double[substModel.getDataType().getStateCount() * substModel.getDataType().getStateCount()];
substModel.getTransitionProbabilities(time, finiteTimeProbs);
System.out.println("Probs = ");
printRateMatrix(finiteTimeProbs, substModel.getDataType().getStateCount());
}
// assertEquals(1, 1, 1e-10);
return finiteTimeProbs;
}
public static void printRateMatrix(double[] m, int a) {
int id = 0;
for (int i = 0; i < a; i++) {
if (i == 0) {
System.out.print("/ ");
} else if (i == a - 1) {
System.out.print("\\ ");
} else {
System.out.print("| ");
}
for (int j = 0; j < a; j++) {
if (i == j) {
System.out.print("null");
for (int n = 0; n < formatter.getMaximumFractionDigits(); n++) {
System.out.print(" ");
}
} else {
System.out.print(formatter.format(m[id]) + " ");
id++;
}
}
if (i == 0) {
System.out.print("\\");
} else if (i == a - 1) {
System.out.print("/");
} else {
System.out.print("|");
}
System.out.println();
}
System.out.println("\n");
}
public static void reportMatrix(double[] orig, double[] test) {
double bigRatio = 0;
double ratio;
int index = -1;
if (orig.length != test.length)
System.err.println("Error : 2 matrix should have same length ! " + orig.length + " " + test.length);
for (int i = 0; i < orig.length; i++) {
ratio = Math.abs(orig[i] / test[i]);
if (bigRatio < ratio) {
bigRatio = ratio;
index = i;
}
}
ratioSummary.add(bigRatio);
System.out.println("Biggest Ratio = " + formatter.format(bigRatio) + ", between " + formatter.format(orig[index])
+ " and " + formatter.format(test[index]));
System.out.println("index = " + index + " (start from 0)");
System.out.println("\n");
}
}