/*
* Copyright (c) 2011-2016, Peter Abeles. All Rights Reserved.
*
* This file is part of BoofCV (http://boofcv.org).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package boofcv.alg.tracker.tld;
import boofcv.alg.interpolate.InterpolatePixelS;
import boofcv.alg.misc.ImageMiscOps;
import boofcv.core.image.border.BorderType;
import boofcv.factory.interpolate.FactoryInterpolation;
import boofcv.struct.ImageRectangle;
import boofcv.struct.image.GrayU8;
import georegression.struct.point.Point2D_F32;
import org.junit.Test;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* @author Peter Abeles
*/
public class TestTldFernClassifier {
int width = 60;
int height = 80;
int numFerns = 5;
int numLearnRandom = 7;
Random rand = new Random(234);
GrayU8 input = new GrayU8(width,height);
InterpolatePixelS<GrayU8> interpolate = FactoryInterpolation.bilinearPixelS(
GrayU8.class, BorderType.EXTENDED);
public TestTldFernClassifier() {
ImageMiscOps.fillUniform(input,rand,0,200);
interpolate.setImage(input);
}
@Test
public void learnFern() {
TldFernClassifier<GrayU8> alg = createAlg();
alg.setImage(input);
alg.learnFern(true, new ImageRectangle(10,12,30,45));
for( int i = 0; i < alg.managers.length; i++ ) {
assertEquals(1, countNum(true,alg.managers[i]));
assertEquals(0, countNum(false,alg.managers[i]));
}
assertTrue(alg.getMaxP() > 0 );
assertTrue(alg.getMaxN() == 0 );
alg.learnFern(false, new ImageRectangle(10,12,30,45));
for( int i = 0; i < alg.managers.length; i++ ) {
assertEquals(1, countNum(true,alg.managers[i]));
assertEquals(1, countNum(false,alg.managers[i]));
}
assertTrue(alg.getMaxP() > 0 );
assertTrue(alg.getMaxN() > 0 );
}
@Test
public void learnFernNoise() {
TldFernClassifier<GrayU8> alg = createAlg();
alg.setImage(input);
alg.learnFernNoise(true, new ImageRectangle(10,12,30,45));
for( int i = 0; i < alg.managers.length; i++ ) {
int found = countNum(true,alg.managers[i]);
assertEquals(1+numLearnRandom, found);
assertEquals(0, countNum(false,alg.managers[i]));
}
assertTrue(alg.getMaxP() > 0 );
assertTrue(alg.getMaxN() == 0 );
alg.learnFernNoise(false, new ImageRectangle(10,12,30,45));
for( int i = 0; i < alg.managers.length; i++ ) {
assertEquals(1+numLearnRandom, countNum(true,alg.managers[i]));
assertEquals(1+numLearnRandom, countNum(false,alg.managers[i]));
}
assertTrue(alg.getMaxP() > 0 );
assertTrue(alg.getMaxN() > 0 );
}
@Test
public void computeFernValue() {
TldFernDescription fern = new TldFernDescription(rand,10);
ImageRectangle r = new ImageRectangle(2,20,12,28);
float cx = r.x0 + (r.getWidth()-1)/2.0f;
float cy = r.x0 + (r.getHeight()-1)/2.0f;
float w = r.getWidth()-1;
float h = r.getHeight()-1;
boolean expected[] = new boolean[10];
for( int i = 0; i < 10; i++ ) {
Point2D_F32 a = fern.pairs[i].a;
Point2D_F32 b = fern.pairs[i].b;
float valA = interpolate.get(cx + a.x*w, cy + a.y*h);
float valB = interpolate.get(cx + b.x*w, cy + b.y*h);
expected[9-i] = valA < valB;
}
TldFernClassifier<GrayU8> alg = createAlg();
alg.setImage(input);
int found = alg.computeFernValue(cx,cy,r.getWidth(),r.getHeight(),fern);
for( int i = 0; i < 10; i++ ) {
assertTrue(expected[i] == (((found >> i) & 0x0001) == 1));
}
}
@Test
public void computeFernValueRand() {
TldFernDescription fern = new TldFernDescription(rand,10);
ImageRectangle r = new ImageRectangle(2,20,12,28);
float cx = r.x0 + r.getWidth()/2.0f;
float cy = r.x0 + r.getHeight()/2.0f;
float w = r.getWidth();
float h = r.getHeight();
boolean expected[] = new boolean[10];
for( int i = 0; i < 10; i++ ) {
Point2D_F32 a = fern.pairs[i].a;
Point2D_F32 b = fern.pairs[i].b;
float valA = interpolate.get(cx + a.x*w, cy + a.y*h);
float valB = interpolate.get(cx + b.x*w, cy + b.y*h);
expected[9-i] = valA < valB;
}
TldFernClassifier<GrayU8> alg = createAlg();
alg.setImage(input);
int found = alg.computeFernValueRand(cx,cy,w,h,fern);
int numDiff = 0;
for( int i = 0; i < 10; i++ ) {
if(expected[i] != (((found >> i) & 0x0001) == 1)) {
numDiff++;
}
}
assertTrue(numDiff != 0 );
assertTrue( numDiff < 10 );
}
@Test
public void renormalizeP() {
TldFernClassifier<GrayU8> alg = createAlg();
alg.maxP = 1000;
alg.managers[2].table[1] = new TldFernFeature();
alg.managers[2].table[1].numP = 600;
alg.renormalizeP();
int expected = 600/20;
assertEquals(expected,alg.managers[2].table[1].numP);
}
@Test
public void renormalizeN() {
TldFernClassifier<GrayU8> alg = createAlg();
alg.maxN = 1000;
alg.managers[2].table[1] = new TldFernFeature();
alg.managers[2].table[1].numN = 600;
alg.renormalizeN();
int expected = 600/20;
assertEquals(expected,alg.managers[2].table[1].numN);
}
private TldFernClassifier<GrayU8> createAlg() {
InterpolatePixelS<GrayU8> interpolate = FactoryInterpolation.bilinearPixelS(
GrayU8.class, BorderType.EXTENDED);
return new TldFernClassifier<>(rand,numFerns,8,numLearnRandom,10,interpolate);
}
private int countNum( boolean positive , TldFernManager manager ) {
int total = 0;
for( int i = 0; i < manager.table.length; i++ ) {
TldFernFeature f = manager.table[i];
if( f != null ) {
if( positive )
total += f.numP;
else
total += f.numN;
}
}
return total;
}
}