/*
* Copyright 2016
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.dkpro.core.api.datasets;
import java.io.File;
import java.util.Arrays;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import de.tudarmstadt.ukp.dkpro.core.api.datasets.internal.SplitImpl;
public interface Dataset
{
String getName();
String getLanguage();
String getEncoding();
File[] getDataFiles();
File[] getLicenseFiles();
Split getDefaultSplit();
default Split getSplit(double aTrainRatio)
{
return getSplit(aTrainRatio, 1.0 - aTrainRatio);
}
default Split getSplit(double aTrainRatio, double aTestRatio)
{
Log LOG = LogFactory.getLog(getClass());
File[] all = getDataFiles();
Arrays.sort(all, (File a, File b) -> { return a.getName().compareTo(b.getName()); });
LOG.info("Found " + all.length + " files");
int trainPivot = (int) Math.round(all.length * aTrainRatio);
int testPivot = (int) Math.round(all.length * aTestRatio) + trainPivot;
File[] train = (File[]) ArrayUtils.subarray(all, 0, trainPivot);
File[] test = (File[]) ArrayUtils.subarray(all, trainPivot, testPivot);
LOG.debug("Assigned " + train.length + " files to training set");
LOG.debug("Assigned " + test.length + " files to test set");
if (testPivot != all.length) {
LOG.info("Files missing from split: [" + (all.length - testPivot) + "]");
}
return new SplitImpl(train, test, null);
}
}