/*******************************************************************************
* Copyright 2012
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package org.dkpro.lab.task.impl;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.dkpro.lab.task.Dimension;
public class FoldDimensionBundle<T> extends DimensionBundle<Collection<T>> implements DynamicDimension
{
private Dimension<T> foldedDimension;
private List<T>[] buckets;
private int validationBucket = -1;
private int folds;
private Comparator<T> comparator;
public FoldDimensionBundle(String aName, Dimension<T> aFoldedDimension, int aFolds, Comparator<T> aComparator)
{
this(aName, aFoldedDimension, aFolds);
comparator = aComparator;
}
public FoldDimensionBundle(String aName, Dimension<T> aFoldedDimension, int aFolds)
{
super(aName, new Object[0] );
foldedDimension = aFoldedDimension;
folds = aFolds;
comparator = null;
}
private void init()
{
buckets = new List[folds];
for(int bucket=0;bucket<buckets.length;bucket++){
buckets[bucket] = new ArrayList<T>();
}
// Capture all data from the dimension into buckets, one per fold
foldedDimension.rewind();
//User controls instances across folds
if(comparator != null){
while (foldedDimension.hasNext()) {
T newItem = foldedDimension.next();
// Check every bucket if the current object belongs there
boolean found = false;
for(int bucket=0;bucket<buckets.length;bucket++){
for (int j=0;j<buckets[bucket].size();j++) {
T item = buckets[bucket].get(j);
if (comparator.compare(item, newItem) == 0) {
// has to go into this bucket!
found = true;
addToBucket(newItem, bucket);
break;
}
}
if(found == true){
break;
}
}
// There is no bucket where the current item has to go into, just use the next one.
if (!found) {
//put it in the smallest bucket
int smallestBucket = 0;
int smallestBucketSize = buckets[smallestBucket].size();
for(int bucket=0;bucket<buckets.length;bucket++){
if(buckets[bucket].size() < smallestBucketSize){
smallestBucket = bucket;
smallestBucketSize = buckets[smallestBucket].size();
}
}
addToBucket(newItem, smallestBucket);
}
}
//Default instance division across folds
}else{
int i = 0;
while (foldedDimension.hasNext()) {
int bucket = i % folds;
if (buckets[bucket] == null) {
buckets[bucket] = new ArrayList<T>();
}
buckets[bucket].add(foldedDimension.next());
i++;
}
if (i < folds) {
throw new IllegalStateException("Requested [" + folds + "] folds, but only got [" + i
+ "] values. There must be at least as many values as folds.");
}
}
String foldsAndSizes = "";
for(int bucket=0;bucket<buckets.length;bucket++){
foldsAndSizes = foldsAndSizes + " fold " + bucket + ": size " + buckets[bucket].size() + ". ";
if(buckets[bucket].size() == 0){
throw new IllegalStateException("Detected an empty fold: " + bucket + ". " +
"Maybe your fold control is causing all of your instances to be put in very few buckets? " +
"Previous folds and buckets: " + foldsAndSizes);
}
}
}
private void addToBucket(T newItem, int bucket){
if (buckets[bucket] == null) {
buckets[bucket] = new ArrayList<T>();
}
buckets[bucket].add(newItem);
}
@Override
public boolean hasNext()
{
return validationBucket < buckets.length-1;
}
@Override
public void rewind()
{
init();
validationBucket = -1;
}
@Override
public Map<String, Collection<T>> next()
{
validationBucket++;
return current();
}
@Override
public Map<String, Collection<T>> current()
{
List<T> trainingData = new ArrayList<T>();
for (int i = 0; i < buckets.length; i++) {
if (i != validationBucket) {
trainingData.addAll(buckets[i]);
}
}
Map<String, Collection<T>> data = new HashMap<String, Collection<T>>();
data.put(getName()+"_training", trainingData);
data.put(getName()+"_validation", buckets[validationBucket]);
return data;
}
@Override
public void setConfiguration(Map<String, Object> aConfig)
{
if (foldedDimension instanceof DynamicDimension) {
((DynamicDimension) foldedDimension).setConfiguration(aConfig);
}
}
}