/*
* File Frequencies.java
*
* Copyright (C) 2010 Remco Bouckaert remco@cs.auckland.ac.nz
*
* This file is part of BEAST2.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package beast.evolution.substitutionmodel;
import java.util.Arrays;
import beast.core.CalculationNode;
import beast.core.Description;
import beast.core.Input;
import beast.core.Input.Validate;
import beast.core.parameter.RealParameter;
import beast.core.util.Log;
import beast.evolution.alignment.Alignment;
import beast.evolution.datatype.DataType;
// RRB: TODO: make this an interface?
@Description("Represents character frequencies typically used as distribution of the root of the tree. " +
"Calculates empirical frequencies of characters in sequence data, or simply assumes a uniform " +
"distribution if the estimate flag is set to false.")
public class Frequencies extends CalculationNode {
final public Input<Alignment> dataInput = new Input<>("data", "Sequence data for which frequencies are calculated");
final public Input<Boolean> estimateInput = new Input<>("estimate", "Whether to estimate the frequencies from data (true=default) or assume a uniform distribution over characters (false)", true);
final public Input<RealParameter> frequenciesInput = new Input<>("frequencies", "A set of frequencies specified as space separated values summing to 1", Validate.XOR, dataInput);
/**
* contains frequency distribution *
*/
protected double[] freqs;
/**
* flag to indicate m_fFreqs is up to date *
*/
boolean needsUpdate;
@Override
public void initAndValidate() {
update();
double sum = getSumOfFrequencies(getFreqs());
// sanity check
if (Math.abs(sum - 1.0) > 1e-6) {
throw new IllegalArgumentException("Frequencies do not add up to 1");
}
}
/**
* return up to date frequencies *
*/
public double[] getFreqs() {
synchronized (this) {
if (needsUpdate) {
update();
}
}
return freqs.clone();
}
/**
* recalculate frequencies, unless it is fixed *
*/
void update() {
if (frequenciesInput.get() != null) {
// if user specified, parse frequencies from space delimited string
freqs = new double[frequenciesInput.get().getDimension()];
for (int i = 0; i < freqs.length; i++) {
freqs[i] = frequenciesInput.get().getValue(i);
}
} else if (estimateInput.get()) { // if not user specified, either estimate from data or set as fixed
// estimate
estimateFrequencies();
checkFrequencies();
} else {
// uniformly distributed
int states = dataInput.get().getMaxStateCount();
freqs = new double[states];
for (int i = 0; i < states; i++) {
freqs[i] = 1.0 / states;
}
}
needsUpdate = false;
} // update
/**
* Estimate from sequence alignment.
* This version matches the implementation in Beast 1 & PAUP *
*/
void estimateFrequencies() {
Alignment alignment = dataInput.get();
DataType dataType = alignment.getDataType();
int stateCount = alignment.getMaxStateCount();
freqs = new double[stateCount];
Arrays.fill(freqs, 1.0 / stateCount);
int attempts = 0;
double difference;
do {
double[] tmpFreq = new double[stateCount];
double total = 0.0;
for (int i = 0; i < alignment.getPatternCount(); i++) {
int[] pattern = alignment.getPattern(i);
double weight = alignment.getPatternWeight(i);
for (int value : pattern) {
int[] codes = dataType.getStatesForCode(value);
double sum = 0.0;
for (int codeIndex : codes) {
sum += freqs[codeIndex];
}
for (int codeIndex : codes) {
double tmp = (freqs[codeIndex] * weight) / sum;
tmpFreq[codeIndex] += tmp;
total += tmp;
}
}
}
difference = 0.0;
for (int i = 0; i < stateCount; i++) {
difference += Math.abs((tmpFreq[i] / total) - freqs[i]);
freqs[i] = tmpFreq[i] / total;
}
attempts++;
} while (difference > 1E-8 && attempts < 1000);
// Alignment alignment = m_data.get();
// m_fFreqs = new double[alignment.getMaxStateCount()];
// for (int i = 0; i < alignment.getPatternCount(); i++) {
// int[] pattern = alignment.getPattern(i);
// double weight = alignment.getPatternWeight(i);
// DataType dataType = alignment.getDataType();
// for (int value : pattern) {
// if (value < 4) {
// int [] codes = dataType.getStatesForCode(value);
// for (int codeIndex : codes) {
// m_fFreqs[codeIndex] += weight / codes.length;
// }
// }
//// if (value < m_fFreqs.length) { // ignore unknowns
//// m_fFreqs[value] += weight;
//// }
// }
// }
// // normalize
// double sum = 0;
// for (double f : m_fFreqs) {
// sum += f;
// }
// for (int i = 0; i < m_fFreqs.length; i++) {
// m_fFreqs[i] /= sum;
// }
Log.info.println("Starting frequencies: " + Arrays.toString(freqs));
} // calcFrequencies
/**
* Ensures that frequencies are not smaller than MINFREQ and
* that two frequencies differ by at least 2*MINFDIFF.
* This avoids potential problems later when eigenvalues
* are computed.
*/
private void checkFrequencies() {
// required frequency difference
double MINFDIFF = 1.0E-10;
// lower limit on frequency
double MINFREQ = 1.0E-10;
int maxi = 0;
double sum = 0.0;
double maxfreq = 0.0;
for (int i = 0; i < freqs.length; i++) {
double freq = freqs[i];
if (freq < MINFREQ) freqs[i] = MINFREQ;
if (freq > maxfreq) {
maxfreq = freq;
maxi = i;
}
sum += freqs[i];
}
double diff = 1.0 - sum;
freqs[maxi] += diff;
for (int i = 0; i < freqs.length - 1; i++) {
for (int j = i + 1; j < freqs.length; j++) {
if (freqs[i] == freqs[j]) {
freqs[i] += MINFDIFF;
freqs[j] -= MINFDIFF;
}
}
}
} // checkFrequencies
/**
* CalculationNode implementation *
*/
@Override
protected boolean requiresRecalculation() {
boolean recalculates = false;
if (frequenciesInput.get().somethingIsDirty()) {
needsUpdate = true;
recalculates = true;
}
return recalculates;
}
/**
* @param frequencies the frequencies
* @return return the sum of frequencies
*/
private double getSumOfFrequencies(double[] frequencies) {
double total = 0.0;
for (double frequency : frequencies) {
total += frequency;
}
return total;
}
@Override
public void restore() {
needsUpdate = true;
super.restore();
}
} // class Frequencies