/*
* Copyright (c) 2011-2016, Peter Abeles. All Rights Reserved.
*
* This file is part of BoofCV (http://boofcv.org).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package boofcv.alg.bow;
import boofcv.struct.learning.ClassificationHistogram;
import boofcv.struct.learning.Confusion;
import java.io.File;
import java.util.*;
/**
* Abstract class which provides a frame work for learning a scene classifier from a set of images.
*
* TODO describe how it provides learning
*
* @author Peter Abeles
*/
public abstract class LearnSceneFromFiles {
protected Random rand;
protected List<String> scenes = new ArrayList<>();
// The minimum number of images in each type of set
int minimumTrain;
int minimumCross;
int minimumTest;
// how to divide the input set up
double fractionTrain;
double fractionCross;
// maps for each set of images
protected Map<String,List<String>> train;
protected Map<String,List<String>> cross;
protected Map<String,List<String>> test;
public Confusion evaluateTest() {
return evaluate(test);
}
/**
* Given a set of images with known classification, predict which scene each one belongs in and compute
* a confusion matrix for the results.
*
* @param set Set of classified images
* @return Confusion matrix
*/
protected Confusion evaluate( Map<String,List<String>> set ) {
ClassificationHistogram histogram = new ClassificationHistogram(scenes.size());
int total = 0;
for (int i = 0; i < scenes.size(); i++) {
total += set.get(scenes.get(i)).size();
}
System.out.println("total images "+total);
for (int i = 0; i < scenes.size(); i++) {
String scene = scenes.get(i);
List<String> images = set.get(scene);
System.out.println(" "+scene+" "+images.size());
for (String image : images) {
int predicted = classify(image);
histogram.increment(i, predicted);
}
}
return histogram.createConfusion();
}
/**
* Given an image compute which scene it belongs to
* @param path Path to input image
* @return integer corresponding to the scene
*/
protected abstract int classify( String path );
public void loadSets( File dirTraining, File dirCross , File dirTest ) {
train = findImages(dirTraining);
if( dirCross != null )
cross = findImages(dirCross);
test = findImages(dirTest);
extractKeys(train);
extractKeys(test);
}
private void extractKeys( Map<String,List<String>> images ) {
Set<String> keys = images.keySet();
for( String key : keys ) {
if( !scenes.contains(key)) {
scenes.add(key);
}
}
}
public void loadThenSplit( File directory ) {
Map<String,List<String>> all = findImages(directory);
train = new HashMap<>();
if( fractionCross != 0 )
cross = new HashMap<>();
test = new HashMap<>();
Set<String> keys = all.keySet();
for( String key : keys ) {
List<String> allImages = all.get(key);
// randomize the ordering to remove bias
Collections.shuffle(allImages,rand);
int numTrain = (int)(allImages.size()*fractionTrain);
numTrain = Math.max(minimumTrain,numTrain);
int numCross = (int)(allImages.size()*fractionCross);
numCross = Math.max(minimumCross,numCross);
int numTest = allImages.size()-numTrain-numCross;
if( numTest < minimumTest )
throw new RuntimeException("Not enough images to create test set. "+key+" total = "+allImages.size());
createSubSet(key, allImages, train, 0, numTrain);
if( cross != null ) {
createSubSet(key, allImages, cross , numTrain, numCross+numTrain);
}
createSubSet(key, allImages, test, numCross+numTrain,allImages.size());
}
scenes.addAll(keys);
}
private void createSubSet(String key, List<String> allImages, Map<String,List<String>> subset ,
int start , int end ) {
List<String> trainImages = new ArrayList<>();
for (int i = start; i < end; i++) {
trainImages.add(allImages.get(i));
}
subset.put(key, trainImages);
}
/**
* Loads the paths to image files contained in subdirectories of the root directory. Each sub directory
* is assumed to be a different category of images.
*/
public static Map<String,List<String>> findImages( File rootDir ) {
File files[] = rootDir.listFiles();
if( files == null )
return null;
List<File> imageDirectories = new ArrayList<>();
for( File f : files ) {
if( f.isDirectory() ) {
imageDirectories.add(f);
}
}
Map<String,List<String>> out = new HashMap<>();
for( File d : imageDirectories ) {
List<String> images = new ArrayList<>();
files = d.listFiles();
if( files == null )
throw new RuntimeException("Should be a directory!");
for( File f : files ) {
if( f.isHidden() || f.isDirectory() || f.getName().endsWith(".txt") ) {
continue;
}
images.add( f.getPath() );
}
String key = d.getName().toLowerCase();
out.put(key,images);
}
return out;
}
public List<String> getScenes() {
return scenes;
}
}