/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.dtrain.dt;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* {@link DTEarlyStopDecider} monitor the train error and validation error in the training process. When it
* identified if the training is over fit or the effort is not worth more training loop, method {@link #add} will
* return true.
*
* <p>
* 1, Filter Algorithm After I tried several well-known filter algorithm, it all not as good as I want. I tried plot
* these data by python, and found these training error data varies in loop by tree depth. On the other hand, we have a
* tendency to get the minimal training error, So I divided these error data into different window by the size of tree
* depth. And for each window picks the minimum value as the value represent value.
*
* <p>
* Basic on the upper sample values, I adopt a further more filter algorithm: Recursive filtering. Use the average of
* the last queue size value as the value for the new value and insert into the queue for further operation.
*
* <p>
* 2, Iteration Gain Iteration gain is the reduce value of the error for each loop.
*
* <p>
* {@link #canStop()} 1, Identify the training is over fit When the validation iteration gain is negative for continue 3
* times, we consider this algorithm is over fitted.
*
* <p>
* 3, Identify not worth more training iteration If the iteration gain is less than one tenth of the max gain value for
* continue 3 times, we consider it worth no more iteration.
*
* @author haifwu
*/
class DTEarlyStopDecider {
static final Logger LOG = LoggerFactory.getLogger(DTEarlyStopDecider.class);
/**
* if 3 times continue reach the stop requirements, decider will make stop decision
*/
private static final int MAGIC_NUMBER = 3;
/**
* Threshold value to stop iteration
*/
private static final double NEARLY_ZERO = 0.000001;
/**
* Make decision when iteration, return a positive or negative sign whether training is over fitted.
*/
private MinAverageDecider validationErrorDecider;
/**
* Continue count of positive sign over fit
*/
private int validationGainContinueNearZeroCount;
/**
* count restart times
*/
private int restartCount;
/**
* Average queue to return the average value of the latest 10 or 20 evaluation errors
*/
private AverageQueue averageQueue;
DTEarlyStopDecider(int treeDepth) {
if(treeDepth <= 0) {
throw new IllegalArgumentException("Tree num should not be less or equal than zero!");
}
this.validationErrorDecider = new MinAverageDecider(treeDepth * MAGIC_NUMBER, treeDepth);
this.averageQueue = new AverageQueue(treeDepth);
this.restartCount = 0;
}
/**
* Add new iteration's train error and validation error into the decider.
*
* @param validationError
* validation error
* @return true if no more iteration needed, else false
*/
public boolean add(double validationError) {
boolean validationDecideReady = this.validationErrorDecider.add(validationError);
if(validationDecideReady) {
if(this.validationErrorDecider.getDecide()) {
this.validationGainContinueNearZeroCount += 1;
LOG.warn("Continue {} positive sign for not worth more iteration!",
this.validationGainContinueNearZeroCount);
if(this.validationGainContinueNearZeroCount >= MAGIC_NUMBER) {
this.validationErrorDecider.restart();
this.restartCount += 1;
this.validationGainContinueNearZeroCount = 0;
LOG.warn("Restart! Total restart times {}", this.restartCount);
}
} else {
this.validationGainContinueNearZeroCount = 0;
}
}
// average queue for compute latest 10 or 20 iterations average value
this.averageQueue.add(validationError);
return canStop();
}
/**
* Get current average evaluation error of last 10 or 20 iterations
*
* @return average evaluation error
*/
double getCurrentAverageValue() {
return this.averageQueue.getAverage();
}
/**
* Get current status is ready to stop or not
*
* @return True if now ready to stop, else False
*/
boolean canStop() {
return this.restartCount >= MAGIC_NUMBER;
}
static class MinAverageDecider {
/**
* minQueue to get the minimal value of a queue size values
*/
private final MinQueue minQueue;
/**
* averageQueue, insert with recursive average value into the queue, and get iteration gain
*/
private final AverageQueue averageQueue;
/**
* Current gain
*/
private double gain;
MinAverageDecider(int minQueueNum, int averageQueueNum) {
this.minQueue = new MinQueue(minQueueNum);
this.averageQueue = new AverageQueue(averageQueueNum);
}
/**
* Add a value into the decider
*
* @param element
* the value to insert to the decide
* @return true if new gain generated, and decide is ready to get
*/
public boolean add(double element) {
if(!this.minQueue.add(element)) {
return false;
}
double minValue = this.minQueue.getQueueMin();
LOG.debug("MinQueue is full, get min value: {}", minValue);
if(!this.averageQueue.add(minValue)) {
return false;
}
this.gain = this.averageQueue.getGain();
LOG.debug("Average Queue is full, get gain value: {}", this.gain);
return true;
}
boolean getDecide() {
return this.gain < NEARLY_ZERO;
}
void restart() {
this.minQueue.restart();
this.averageQueue.restart();
}
}
/**
* Generate minimal value of each {@link #capacity} values.
*/
private static class MinQueue {
/**
* total element in current queue
*/
private int size;
/**
* min value in current queue
*/
private double min;
/**
* max capacity of the queue
*/
private int capacity;
MinQueue(int capacity) {
this.capacity = capacity;
this.restart();
}
void restart() {
this.min = Double.MAX_VALUE;
this.size = -1;
}
/**
* Add an element to the queue
*
* @param element
* the value to add
* @return true if the queue is full, and ready to generate the minimal value of this window
*/
public boolean add(double element) {
if(element < this.min) {
this.min = element;
}
this.size += 1;
return this.size >= this.capacity;
}
/**
* Get the minimal value in the queue. Should be called when the queue is full.
*
* @return the value of the minimal value of the queue
*/
double getQueueMin() {
double queueMin = this.min;
this.restart();
return queueMin;
}
}
/**
* Generate recursive average value gain.
*/
private static class AverageQueue {
/**
* The max capacity of this queue
*/
private int capacity;
/**
* Array to store values in the queue
*/
private double[] queueArray;
/**
* Total count of value have into queue
*/
private long totalCount;
/**
* Total sum of value in queue
*/
private double sum;
AverageQueue(int capacity) {
this.capacity = capacity;
this.queueArray = new double[this.capacity];
this.restart();
}
void restart() {
this.totalCount = 0;
this.sum = 0;
}
/**
* Add element into the queue
*
* @param element
* the element inset into the queue
* @return false if the queue is not reach {@link #capacity} yet, else true
*/
public boolean add(double element) {
int index = (int) this.totalCount % this.capacity;
this.totalCount += 1;
if(this.totalCount <= this.capacity) {
// Before queue full, we calculate the sum value by add new value
this.sum += element;
this.queueArray[index] = this.sum / this.totalCount;
return false;
} else {
// After queue full, we calculate the sum value by add new value minus old value
this.sum += element - this.queueArray[index];
this.queueArray[index] = this.sum / this.capacity;
return true;
}
}
/**
* Get the iteration gain of current insert value
*
* @return the iteration gain of the current value
*/
double getGain() {
int curIndex = (int) (this.totalCount - 1) % this.capacity;
int lastIndex = (int) (this.totalCount - 2) % this.capacity;
return this.queueArray[lastIndex] - this.queueArray[curIndex];
}
/**
* get the latest average value in the queue
*
* @return average value
*/
public double getAverage() {
int curIndex = (int) (this.totalCount - 1) % this.capacity;
return this.queueArray[curIndex];
}
}
}