package beast.core.util;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
import beast.core.BEASTObject;
import beast.core.CalculationNode;
import beast.core.Description;
import beast.core.Function;
import beast.core.Input;
import beast.core.Input.Validate;
import beast.core.Loggable;
import beast.core.parameter.BooleanParameter;
import beast.core.parameter.IntegerParameter;
@Description("calculates sum of a valuable")
public class Sum extends CalculationNode implements Function, Loggable {
final public Input<List<Function>> functionInput = new Input<>("arg", "argument to be summed", new ArrayList<>(), Validate.REQUIRED);
enum Mode {integer_mode, double_mode}
Mode mode;
boolean needsRecompute = true;
double sum = 0;
double storedSum = 0;
@Override
public void initAndValidate() {
List<Function> valuable = functionInput.get();
mode = Mode.integer_mode;
for (Function v : valuable) {
if (!(v instanceof IntegerParameter || v instanceof BooleanParameter)) {
mode = Mode.double_mode;
}
}
}
@Override
public int getDimension() {
return 1;
}
@Override
public double getArrayValue() {
if (needsRecompute) {
compute();
}
return sum;
}
/**
* do the actual work, and reset flag *
*/
void compute() {
sum = 0;
for (Function v : functionInput.get()) {
for (int i = 0; i < v.getDimension(); i++) {
sum += v.getArrayValue(i);
}
}
needsRecompute = false;
}
@Override
public double getArrayValue(int dim) {
if (dim == 0) {
return getArrayValue();
}
return Double.NaN;
}
/**
* CalculationNode methods *
*/
@Override
public void store() {
storedSum = sum;
super.store();
}
@Override
public void restore() {
sum = storedSum;
super.restore();
}
@Override
public boolean requiresRecalculation() {
needsRecompute = true;
return true;
}
/**
* Loggable interface implementation follows
*/
@Override
public void init(PrintStream out) {
out.print("sum(" + ((BEASTObject) functionInput.get().get(0)).getID() + ")\t");
}
@Override
public void log(int sampleNr, PrintStream out) {
double sum = 0;
for (Function v : functionInput.get()) {
for (int i = 0; i < v.getDimension(); i++) {
sum += v.getArrayValue(i);
}
}
if (mode == Mode.integer_mode) {
out.print((int) sum + "\t");
} else {
out.print(sum + "\t");
}
}
@Override
public void close(PrintStream out) {
// nothing to do
}
} // class Sum