/*
* Copyright 2017 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.optaplanner.examples.tennis.solver.drools.functions;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.kie.api.runtime.rule.AccumulateFunction;
public class LoadBalanceByCountAccumulateFunction implements AccumulateFunction<LoadBalanceByCountAccumulateFunction.LoadBalanceByCountData> {
protected static class LoadBalanceByCountData implements Serializable {
private Map<Object, Long> groupCountMap;
// the sum of squared deviation from zero
private long squaredSum;
}
@Override
public LoadBalanceByCountData createContext() {
return new LoadBalanceByCountData();
}
@Override
public void init(LoadBalanceByCountData data) {
data.groupCountMap = new HashMap<>();
data.squaredSum = 0L;
}
@Override
public void accumulate(LoadBalanceByCountData data, Object groupBy) {
long count = data.groupCountMap.compute(groupBy,
(key, value) -> (value == null) ? 1L : value + 1L);
// squaredZeroDeviation = squaredZeroDeviation - (count - 1)² + count²
// <=> squaredZeroDeviation = squaredZeroDeviation + (2 * count - 1)
data.squaredSum += (2 * count - 1);
}
@Override
public boolean supportsReverse() {
return true;
}
@Override
public void reverse(LoadBalanceByCountData data, Object groupBy) {
Long count = data.groupCountMap.compute(groupBy,
(key, value) -> (value.longValue() == 1L) ? null : value - 1L);
data.squaredSum -= (count == null) ? 1L : (2 * count + 1);
}
@Override
public Class<LoadBalanceByCountResult> getResultType() {
return LoadBalanceByCountResult.class;
}
@Override
public LoadBalanceByCountResult getResult(LoadBalanceByCountData data) {
return new LoadBalanceByCountResult(data.squaredSum);
}
@Override
public void writeExternal(ObjectOutput out) {
}
@Override
public void readExternal(ObjectInput in) {
}
public static class LoadBalanceByCountResult implements Serializable {
private final long squaredSum;
public LoadBalanceByCountResult(long squaredSum) {
this.squaredSum = squaredSum;
}
public long getZeroDeviationSquaredSum() {
return squaredSum;
}
/**
* @return {@link #getZeroDeviationSquaredSumRoot(double)} multiplied by {@literal 1 000}
*/
public long getZeroDeviationSquaredSumRootMillis() {
return getZeroDeviationSquaredSumRoot(1_000.0);
}
/**
* @return {@link #getZeroDeviationSquaredSumRoot(double)} multiplied by {@literal 1 000 000}
*/
public long getZeroDeviationSquaredSumRootMicros() {
return getZeroDeviationSquaredSumRoot(1_000_000.0);
}
/**
* @param scaleMultiplier {@code > 0}
* @return {@code >= 0}, {@code latexmath:[f(n) = \sqrt{\sum_{i=1}^{n} (x_i - 0)^2}]} multiplied by scaleMultiplier
*/
public long getZeroDeviationSquaredSumRoot(double scaleMultiplier) {
return (long) (Math.sqrt((double) squaredSum) * scaleMultiplier);
}
}
}