package quickml.supervised.tree.regressionTree.valueCounters;
import org.javatuples.Pair;
import quickml.data.instances.RegressionInstance;
import quickml.supervised.tree.summaryStatistics.ValueCounter;
import java.io.Serializable;
public class MeanValueCounter extends ValueCounter<MeanValueCounter> implements Serializable {
private static final long serialVersionUID = -6821237234748044623L;
private double accumulatedValue = 0;
private double accumulatedSquares = 0;
private double accumulatedWeight = 0;
private boolean hasSufficientData = true;
public void setHasSufficientData(boolean hasSufficientData) {
this.hasSufficientData = hasSufficientData;
}
public double getAccumulatedValue() {
return accumulatedValue;
}
public double getAccumulatedSquares() {
return accumulatedSquares;
}
public static MeanValueCounter accumulateAll(final Iterable<? extends RegressionInstance> instances){
final MeanValueCounter result = new MeanValueCounter();
for (RegressionInstance instance : instances) {
result.update(instance.getLabel(), instance.getWeight());
}
return result;
}
public void update(double value, double weight) {
this.accumulatedValue+=value*weight;
this.accumulatedSquares+=value*value*weight;
this.accumulatedWeight +=weight;
}
public boolean hasSufficientData() {
return hasSufficientData;
}
public MeanValueCounter() {}
public MeanValueCounter(Serializable attrVal) {
super(attrVal);
}
public MeanValueCounter(Serializable attrVal, double accumulatedWeight, double accumulatedValue, double accumulatedSquares) {
this(attrVal);
this.accumulatedWeight = accumulatedWeight;
this.accumulatedSquares = accumulatedSquares;
this.accumulatedValue = accumulatedValue;
}
public boolean isEmpty() {
return accumulatedWeight ==0;
}
public MeanValueCounter(MeanValueCounter meanValueCounter) {
super(meanValueCounter.attrVal);
this.accumulatedWeight += meanValueCounter.accumulatedWeight;
this.accumulatedValue += meanValueCounter.accumulatedValue;
this.accumulatedSquares += meanValueCounter.accumulatedSquares;
}
@Override
public MeanValueCounter add(final MeanValueCounter other) {
double weightedNumValues = this.accumulatedWeight + other.accumulatedWeight;
double accumulatedValue = this.accumulatedValue + other.accumulatedValue;
double accumulatedSquares = this.accumulatedSquares + other.accumulatedSquares;
return new MeanValueCounter(this.attrVal, weightedNumValues, accumulatedValue, accumulatedSquares);
}
public MeanValueCounter subtract(final MeanValueCounter other) {
double weightedNumValues = this.accumulatedWeight - other.accumulatedWeight;
double accumulatedValue = this.accumulatedValue - other.accumulatedValue;
double accumulatedSquares = this.accumulatedSquares - other.accumulatedSquares;
return new MeanValueCounter(this.attrVal, weightedNumValues, accumulatedValue, accumulatedSquares); }
@Override
public double getTotal() {
return accumulatedWeight;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
MeanValueCounter that = (MeanValueCounter) o;
if (this.accumulatedValue != that.accumulatedValue || this.accumulatedWeight !=that.accumulatedWeight) return false;
return true;
}
@Override
public int hashCode() {
return new Pair<Double, Double>(accumulatedWeight, accumulatedValue).hashCode();
}
@Override
public String toString() {
return "accumulatedWeight: " + accumulatedWeight + ", accumulatedValue: " + accumulatedValue
+ ", accumulatedSquares: " + accumulatedSquares;
}
}