/*
* Copyright 2016
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.dkpro.core.api.datasets.internal;
import static java.util.Arrays.asList;
import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.TrueFileFilter;
import de.tudarmstadt.ukp.dkpro.core.api.datasets.Dataset;
import de.tudarmstadt.ukp.dkpro.core.api.datasets.DatasetDescription;
import de.tudarmstadt.ukp.dkpro.core.api.datasets.DatasetFactory;
import de.tudarmstadt.ukp.dkpro.core.api.datasets.FileRole;
import de.tudarmstadt.ukp.dkpro.core.api.datasets.Split;
import de.tudarmstadt.ukp.dkpro.core.api.datasets.internal.util.AntFileFilter;
public class LoadedDataset
implements Dataset
{
private DatasetFactory factory;
private DatasetDescription description;
private Split defaultSplit;
public LoadedDataset(DatasetFactory aFactory, DatasetDescription aDescription)
{
super();
factory = aFactory;
description = aDescription;
File[] train = getFiles(FileRole.TRAINING);
File[] dev = getFiles(FileRole.DEVELOPMENT);
File[] test = getFiles(FileRole.TESTING);
if (train.length > 0 || dev.length > 0 || test.length > 0) {
defaultSplit = new SplitImpl(train, test, dev);
}
}
@Override
public String getName()
{
return description.getId();
}
@Override
public String getLanguage()
{
return description.getLanguage();
}
@Override
public String getEncoding()
{
return description.getEncoding();
}
@Override
public File[] getDataFiles()
{
Set<File> all = new HashSet<>();
// Collect all data files
all.addAll(asList(getFiles(FileRole.DATA)));
// If no files are marked as data files, try aggregating over test/dev/train sets
if (all.isEmpty()) {
Split split = getDefaultSplit();
if (split != null) {
all.addAll(asList(split.getTrainingFiles()));
all.addAll(asList(split.getTestFiles()));
all.addAll(asList(split.getDevelopmentFiles()));
}
}
// Sort to ensure stable order
File[] result = all.toArray(all.toArray(new File[all.size()]));
Arrays.sort(result, (a, b) -> { return a.getPath().compareTo(b.getPath()); });
return result;
}
@Override
public File[] getLicenseFiles()
{
return getFiles(FileRole.LICENSE);
}
@Override
public Split getDefaultSplit()
{
return defaultSplit;
}
private File[] getFiles(String aRole)
{
List<File> files = new ArrayList<>();
List<String> patterns = description.getRoles().get(aRole);
if (patterns == null) {
return new File[0];
}
for (String pattern : patterns) {
Path baseDir = factory.resolve(description);
Collection<File> matchedFiles = FileUtils.listFiles(baseDir.toFile(),
new AntFileFilter(baseDir, asList(pattern), null), TrueFileFilter.TRUE);
files.addAll(matchedFiles);
}
File[] all = files.toArray(new File[files.size()]);
Arrays.sort(all, (File a, File b) -> { return a.getName().compareTo(b.getName()); });
return all;
}
}