/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.util.normalize.segregate;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import org.encog.util.normalize.DataNormalization;
import org.encog.util.normalize.input.InputField;
/**
* Balance based on an input value. This allows you to make sure that one input
* class does not saturate the training data. To do this, you specify the input
* value to check and the number of occurrences of each integer value of this
* field to allow.
*/
public class IntegerBalanceSegregator implements Segregator {
/**
* The normalization object to use.
*/
private DataNormalization normalization;
/**
* The input field.
*/
private InputField target;
/**
* The count per each of the int values for the input field.
*/
private int count;
/**
* The running totals.
*/
private final Map<Integer, Integer> runningCounts =
new HashMap<Integer, Integer>();
/**
* Construct an integer balance segregator.
* @param target The input field to use.
* @param count The number of each unique integer to allow.
*/
public IntegerBalanceSegregator(final InputField target, final int count) {
this.target = target;
this.count = count;
}
/**
* Default constructor.
*/
public IntegerBalanceSegregator() {
}
/**
* @return A string that contains the counts for each group.
*/
public String dumpCounts() {
final StringBuilder result = new StringBuilder();
for (final Entry<Integer, Integer> entry : this.runningCounts
.entrySet()) {
result.append(entry.getKey());
result.append(" -> ");
result.append(entry.getValue());
result.append(" count\n");
}
return result.toString();
}
/**
* @return The amout of data allowed by this segregator.
*/
public int getCount() {
return this.count;
}
/**
* @return The normalization object used with this segregator.
*/
public DataNormalization getNormalization() {
return this.normalization;
}
/**
* @return The current count for each group.
*/
public Map<Integer, Integer> getRunningCounts() {
return this.runningCounts;
}
/**
* @return The input field being used.
*/
public InputField getTarget() {
return this.target;
}
/**
* Init the segregator with the owning normalization object.
*
* @param normalization
* The data normalization object to use.
*/
public void init(final DataNormalization normalization) {
this.normalization = normalization;
}
/**
* Init for a new pass.
*/
public void passInit() {
this.runningCounts.clear();
}
/**
* Determine of the current row should be included.
*
* @return True if the current row should be included.
*/
public boolean shouldInclude() {
final int key = (int) this.target.getCurrentValue();
int value = 0;
if (this.runningCounts.containsKey(key)) {
value = this.runningCounts.get(key);
}
if (value < this.count) {
value++;
this.runningCounts.put(key, value);
return true;
} else {
return false;
}
}
}