/* * Copyright 2017 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.core.base.accumulators; import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.io.Serializable; /** * An implementation of an accumulator capable of calculating variance. */ public class VarianceAccumulateFunction extends AbstractAccumulateFunction<VarianceAccumulateFunction.VarianceData> { @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { } @Override public void writeExternal(ObjectOutput out) throws IOException { } protected static class VarianceData implements Serializable { protected int count; protected double mean; protected double squaredSum; } @Override public VarianceData createContext() { return new VarianceData(); } @Override public void init(VarianceData data) { data.count = 0; data.mean = 0; data.squaredSum = 0.0; } @Override public void accumulate(VarianceData data, Object value) { double x = ((Number) value).doubleValue(); // Incremental algorithm to calculate variance: // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm data.count++; double lowerDelta = x - data.mean; data.mean += lowerDelta / data.count; double higherDelta = x - data.mean; data.squaredSum += lowerDelta * higherDelta; } @Override public void reverse(VarianceData data, Object value) { double x = ((Number) value).doubleValue(); double higherDelta = x - data.mean; data.mean = data.mean * data.count / (data.count - 1.0) - x / (data.count -1.0); double lowerDelta = x - data.mean; data.count--; data.squaredSum -= lowerDelta * higherDelta; } @Override public Double getResult(VarianceData data) { return data.squaredSum / data.count; } @Override public boolean supportsReverse() { return true; } @Override public Class<?> getResultType() { return Double.class; } }