/*
* Copyright (c) 2011-2016, Peter Abeles. All Rights Reserved.
*
* This file is part of BoofCV (http://boofcv.org).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package boofcv.alg.tracker.tld;
import boofcv.struct.ImageRectangle;
import boofcv.struct.image.GrayU8;
import georegression.struct.shapes.Rectangle2D_F64;
import org.ddogleg.struct.FastQueue;
import org.junit.Test;
import java.util.Random;
import static org.junit.Assert.assertEquals;
/**
* @author Peter Abeles
*/
public class TestTldLearning {
Random rand = new Random(234);
TldParameters config = new TldParameters();
@Test
public void initialLearning() {
DummyVariance variance = new DummyVariance();
DummyFern fern = new DummyFern();
DummyTemplate template = new DummyTemplate();
DummyDetection detection = new DummyDetection();
TldLearning alg = new TldLearning(rand,config,template,variance,fern,detection);
FastQueue<ImageRectangle> regions = new FastQueue<>(ImageRectangle.class, true);
regions.grow();
regions.grow();
regions.grow();
alg.initialLearning(new Rectangle2D_F64(10,20,30,40),regions);
// Check to see if the variance threshold was set
assertEquals(1,variance.calledSelect);
// There should be a positive example
assertEquals(1,fern.calledP);
assertEquals(1,template.calledP);
// several negative examples too
assertEquals(3, fern.calledN);
// only negative template for ambiguous, which there are none since I'm being lazy
assertEquals(0, template.calledN);
assertEquals(1,detection.calledDetection);
}
@Test
public void updateLearning() {
DummyVariance variance = new DummyVariance();
DummyFern fern = new DummyFern();
DummyTemplate template = new DummyTemplate();
DummyDetection detection = new DummyDetection();
TldLearning alg = new TldLearning(rand,config,template,variance,fern,detection);
alg.updateLearning(new Rectangle2D_F64(10, 20, 30, 40));
// Check to see if the variance threshold was set
assertEquals(0,variance.calledSelect);
// There should be a positive example
assertEquals(1,fern.calledP);
assertEquals(1,template.calledP);
// several negative examples too
assertEquals(10, fern.calledN);
// only negative template for ambiguous, which there are none since I'm being lazy
assertEquals(0, template.calledN);
assertEquals(0,detection.calledDetection);
}
protected static class DummyFern extends TldFernClassifier {
int calledP = 0;
int calledN = 0;
public DummyFern() {
Random rand = new Random(234);
int numFerns = 5;
ferns = new TldFernDescription[numFerns];
managers = new TldFernManager[numFerns];
// create random ferns
for( int i = 0; i < numFerns; i++ ) {
ferns[i] = new TldFernDescription(rand,10);
managers[i] = new TldFernManager(10);
}
}
@Override
public void learnFern(boolean positive, ImageRectangle r) {
if( positive )
calledP++;
else
calledN++;
}
@Override
public void learnFernNoise(boolean positive, ImageRectangle r) {
if( positive )
calledP++;
else
calledN++;
}
}
protected static class DummyVariance extends TldVarianceFilter {
int calledSelect = 0;
@Override
public void selectThreshold( ImageRectangle r ) {
calledSelect++;
}
@Override
public boolean checkVariance( ImageRectangle r ) {
return true;
}
}
protected static class DummyTemplate extends TldTemplateMatching {
int calledP = 0;
int calledN = 0;
@Override
public void addDescriptor( boolean positive , float x0 , float y0 , float x1 , float y1 ) {
if( positive )
calledP++;
else
calledN++;
}
}
protected static class DummyDetection extends TldDetection<GrayU8> {
int calledDetection = 0;
public DummyDetection() {
ambiguous = false;
for( int i = 0; i < 10; i++ ) {
fernInfo.grow();
fernInfo.get(i).r = new ImageRectangle();
}
}
@Override
protected void detectionCascade( FastQueue<ImageRectangle> cascadeRegions ) {
calledDetection++;
}
}
}