package gdsc.smlm.filters;
import gdsc.core.utils.FloatEquality;
import gdsc.smlm.TestSettings;
import gdsc.smlm.filters.AverageFilter;
import java.util.ArrayList;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.Test;
import org.junit.internal.ArrayComparisonFailure;
public class AverageFilterTest
{
private gdsc.core.utils.Random rand;
private boolean debug = false;
private int InternalITER3 = 500;
private int InternalITER = 50;
private int ITER3 = 200;
private int ITER = 20;
// TODO - The test data should be representative of the final use case
int[] primes = new int[] { 113, 97, 53, 29 };
//int[] primes = new int[] { 1024 };
int[] boxSizes = new int[] { 15, 9, 5, 3, 2, 1 };
boolean[] checkInternal = new boolean[] { true, false };
/**
* Do a simple and stupid mean filter
*
* @param data
* @param maxx
* @param maxy
* @param boxSize
*/
public static void average(float[] data, int maxx, int maxy, float boxSize)
{
if (boxSize <= 0)
return;
int n = (int) Math.ceil(boxSize);
int size = 2 * n + 1;
float[] weight = new float[size];
Arrays.fill(weight, 1);
if (boxSize != n)
weight[0] = weight[weight.length - 1] = boxSize - (n - 1);
float norm = 0;
for (int yy = 0; yy < size; yy++)
for (int xx = 0; xx < size; xx++)
norm += weight[yy] * weight[xx];
norm = (float) (1.0 / norm);
float[] out = new float[data.length];
for (int y = 0; y < maxy; y++)
{
for (int x = 0; x < maxx; x++)
{
float sum = 0;
for (int yy = 0; yy < size; yy++)
{
int yyy = y + yy - n;
if (yyy < 0)
yyy = 0;
if (yyy >= maxy)
yyy = maxy - 1;
for (int xx = 0; xx < size; xx++)
{
int xxx = x + xx - n;
if (xxx < 0)
xxx = 0;
if (xxx >= maxx)
xxx = maxx - 1;
int index = yyy * maxx + xxx;
sum += data[index] * weight[yy] * weight[xx];
}
}
out[y * maxx + x] = sum * norm;
}
}
System.arraycopy(out, 0, data, 0, out.length);
}
private void floatArrayEquals(String message, float[] data1, float[] data2, int maxx, int maxy, float boxSize)
{
FloatEquality eq = new FloatEquality(1e-5f, 1e-10f);
// Debug: show the images
//gdsc.core.ij.Utils.display("data1", new ij.process.FloatProcessor(maxx, maxy, data1));
//gdsc.core.ij.Utils.display("data2", new ij.process.FloatProcessor(maxx, maxy, data2));
// Ignore the border
int border = (int) Math.ceil(boxSize);
for (int y = border; y < maxy - border - 1; y++)
{
int index = y * maxx + border;
for (int x = border; x < maxx - border - 1; x++, index++)
{
if (!eq.almostEqualRelativeOrAbsolute(data1[index], data2[index]))
{
Assert.assertTrue(String.format("%s [%d,%d] %f != %f", message, x, y, data1[index], data2[index]),
false);
}
}
}
}
/**
* Used to test the filter methods calculate the correct result
*/
private abstract class DataFilter
{
final String name;
final boolean isInterpolated;
public DataFilter(String name, boolean isInterpolated)
{
this.name = name;
this.isInterpolated = isInterpolated;
}
AverageFilter f = new AverageFilter();
public abstract void filter(float[] data, int width, int height, float boxSize);
public abstract void filterInternal(float[] data, int width, int height, float boxSize);
}
private void averageIsCorrect(int width, int height, float boxSize, boolean internal, DataFilter filter)
throws ArrayComparisonFailure
{
rand = new gdsc.core.utils.Random(-30051976);
float[] data1 = createData(width, height);
float[] data2 = data1.clone();
AverageFilterTest.average(data1, width, height, boxSize);
if (internal)
{
filter.filterInternal(data2, width, height, boxSize);
floatArrayEquals(String.format("Internal arrays do not match: [%dx%d] @ %.1f", width, height, boxSize),
data1, data2, width, height, boxSize);
}
else
{
filter.filter(data2, width, height, boxSize);
floatArrayEquals(String.format("Arrays do not match: [%dx%d] @ %.1f", width, height, boxSize), data1,
data2, width, height, 0);
}
}
private void checkIsCorrect(DataFilter filter)
{
for (int width : primes)
for (int height : primes)
for (float boxSize : boxSizes)
for (boolean internal : checkInternal)
{
averageIsCorrect(width, height, boxSize, internal, filter);
if (filter.isInterpolated)
{
averageIsCorrect(width, height, boxSize - 0.3f, internal, filter);
averageIsCorrect(width, height, boxSize - 0.6f, internal, filter);
}
}
}
@Test
public void blockAverageIsCorrect()
{
DataFilter filter = new DataFilter("block", true)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.blockAverage(data, width, height, boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.blockAverageInternal(data, width, height, boxSize);
}
};
checkIsCorrect(filter);
}
@Test
public void stripedBlockAverageIsCorrect()
{
DataFilter filter = new DataFilter("stripedBlock", true)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage(data, width, height, boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageInternal(data, width, height, boxSize);
}
};
checkIsCorrect(filter);
}
@Test
public void rollingBlockAverageIsCorrect()
{
DataFilter filter = new DataFilter("rollingBlock", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.rollingBlockAverage(data, width, height, (int) boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.rollingBlockAverageInternal(data, width, height, (int) boxSize);
}
};
checkIsCorrect(filter);
}
private double speedUpFactor(long slowTotal, long fastTotal)
{
return (1.0 * slowTotal) / fastTotal;
}
private float[] floatClone(float[] data1)
{
float[] data2 = Arrays.copyOf(data1, data1.length);
return data2;
}
private float[] createData(int width, int height)
{
float[] data = new float[width * height];
for (int i = data.length; i-- > 0;)
data[i] = i;
rand.shuffle(data);
return data;
}
private ArrayList<float[]> floatCreateSpeedData(int iter)
{
ArrayList<float[]> dataSet = new ArrayList<float[]>(iter);
for (int i = iter; i-- > 0;)
{
dataSet.add(createData(primes[0], primes[0]));
}
return dataSet;
}
static ArrayList<float[]> dataSet = null;
private ArrayList<float[]> getSpeedData(int iter)
{
if (dataSet == null || dataSet.size() < iter)
{
dataSet = floatCreateSpeedData(iter);
}
ArrayList<float[]> dataSet2 = new ArrayList<float[]>(iter);
for (int i = 0; i < iter; i++)
dataSet2.add(dataSet.get(i).clone());
return dataSet2;
}
private void speedTest(DataFilter fast, DataFilter slow)
{
speedTest(fast, slow, boxSizes);
}
private void speedTest(DataFilter fast, DataFilter slow, int[] testBoxSizes)
{
org.junit.Assume.assumeTrue(TestSettings.RUN_SPEED_TESTS);
rand = new gdsc.core.utils.Random(-300519);
ArrayList<float[]> dataSet = getSpeedData(ITER3);
ArrayList<Long> fastTimes = new ArrayList<Long>();
float[] boxSizes = new float[testBoxSizes.length];
float offset = (fast.isInterpolated && slow.isInterpolated) ? 0.3f : 0;
for (int i = 0; i < boxSizes.length; i++)
boxSizes[i] = testBoxSizes[i] - offset;
// Initialise
for (float boxSize : boxSizes)
{
fast.filter(dataSet.get(0).clone(), primes[0], primes[0], boxSize);
slow.filter(dataSet.get(0).clone(), primes[0], primes[0], boxSize);
}
for (float boxSize : boxSizes)
{
int iter = (boxSize == 1) ? ITER3 : ITER;
for (int width : primes)
for (int height : primes)
{
dataSet = getSpeedData(iter);
long time = System.nanoTime();
for (float[] data : dataSet)
fast.filter(data, width, height, boxSize);
time = System.nanoTime() - time;
fastTimes.add(time);
}
}
long slowTotal = 0, fastTotal = 0;
int index = 0;
for (float boxSize : boxSizes)
{
int iter = (boxSize == 1) ? ITER3 : ITER;
long boxSlowTotal = 0, boxFastTotal = 0;
for (int width : primes)
for (int height : primes)
{
dataSet = getSpeedData(iter);
long time = System.nanoTime();
for (float[] data : dataSet)
slow.filter(data, width, height, boxSize);
time = System.nanoTime() - time;
long fastTime = fastTimes.get(index++);
slowTotal += time;
fastTotal += fastTime;
boxSlowTotal += time;
boxFastTotal += fastTime;
if (debug)
System.out.printf("%s [%dx%d] @ %.1f : %d => %s %d = %.2fx\n", fast.name, width, height,
boxSize, time, slow.name, fastTime, speedUpFactor(time, fastTime));
}
//if (debug)
System.out.printf("%s %.1f : %d => %s %d = %.2fx\n", fast.name, boxSize, boxSlowTotal, slow.name,
boxFastTotal, speedUpFactor(boxSlowTotal, boxFastTotal));
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(
String.format("Not faster: Block %.1f : %d > %d", boxSize, boxFastTotal, boxSlowTotal),
boxFastTotal < boxSlowTotal);
}
System.out.printf("%s %d => %s %d = %.2fx\n", fast.name, slowTotal, slow.name, fastTotal,
speedUpFactor(slowTotal, fastTotal));
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(String.format("Not faster: %d > %d", fastTotal, slowTotal), fastTotal < slowTotal);
}
private void speedTestInternal(DataFilter fast, DataFilter slow)
{
speedTestInternal(fast, slow, boxSizes);
}
private void speedTestInternal(DataFilter fast, DataFilter slow, int[] testBoxSizes)
{
org.junit.Assume.assumeTrue(TestSettings.RUN_SPEED_TESTS);
rand = new gdsc.core.utils.Random(-300519);
ArrayList<float[]> dataSet = getSpeedData(InternalITER3);
ArrayList<Long> fastTimes = new ArrayList<Long>();
float[] boxSizes = new float[testBoxSizes.length];
float offset = (fast.isInterpolated && slow.isInterpolated) ? 0.3f : 0;
for (int i = 0; i < boxSizes.length; i++)
boxSizes[i] = testBoxSizes[i] - offset;
// Initialise
for (float boxSize : boxSizes)
{
fast.filterInternal(floatClone(dataSet.get(0)), primes[0], primes[0], boxSize);
slow.filterInternal(floatClone(dataSet.get(0)), primes[0], primes[0], boxSize);
}
for (float boxSize : boxSizes)
{
int iter = (boxSize == 1) ? InternalITER3 : InternalITER;
for (int width : primes)
for (int height : primes)
{
dataSet = getSpeedData(iter);
long time = System.nanoTime();
for (float[] data : dataSet)
fast.filterInternal(data, width, height, boxSize);
time = System.nanoTime() - time;
fastTimes.add(time);
}
}
long slowTotal = 0, fastTotal = 0;
int index = 0;
for (float boxSize : boxSizes)
{
int iter = (boxSize == 1) ? InternalITER3 : InternalITER;
long boxSlowTotal = 0, boxFastTotal = 0;
for (int width : primes)
for (int height : primes)
{
dataSet = getSpeedData(iter);
long time = System.nanoTime();
for (float[] data : dataSet)
slow.filterInternal(data, width, height, boxSize);
time = System.nanoTime() - time;
long fastTime = fastTimes.get(index++);
slowTotal += time;
fastTotal += fastTime;
boxSlowTotal += time;
boxFastTotal += fastTime;
if (debug)
System.out.printf("Internal %s [%dx%d] @ %.1f : %d => %s %d = %.2fx\n", fast.name, width,
height, boxSize, time, slow.name, fastTime, speedUpFactor(time, fastTime));
}
//if (debug)
System.out.printf("Internal %s %.1f : %d => %s %d = %.2fx\n", fast.name, boxSize, boxSlowTotal, slow.name,
boxFastTotal, speedUpFactor(boxSlowTotal, boxFastTotal));
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(
String.format("Not faster: Block %.1f : %d > %d", boxSize, boxFastTotal, boxSlowTotal),
boxFastTotal < boxSlowTotal);
}
System.out.printf("Internal %s %d => %s %d = %.2fx\n", fast.name, slowTotal, slow.name, fastTotal,
speedUpFactor(slowTotal, fastTotal));
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(String.format("Not faster: %d > %d", fastTotal, slowTotal), fastTotal < slowTotal);
}
@Test
public void stripedBlockIsFasterThanBlock()
{
DataFilter slow = new DataFilter("block", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.blockAverage(data, width, height, (int) boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.blockAverageInternal(data, width, height, (int) boxSize);
}
};
DataFilter fast = new DataFilter("stripedBlock", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage(data, width, height, (int) boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageInternal(data, width, height, (int) boxSize);
}
};
speedTest(fast, slow);
speedTestInternal(fast, slow);
}
@Test
public void interpolatedStripedBlockIsFasterThanBlock()
{
DataFilter slow = new DataFilter("block", true)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.blockAverage(data, width, height, boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.blockAverageInternal(data, width, height, boxSize);
}
};
DataFilter fast = new DataFilter("stripedBlock", true)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage(data, width, height, boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageInternal(data, width, height, boxSize);
}
};
speedTest(fast, slow);
speedTestInternal(fast, slow);
}
@Test
public void rollingBlockIsFasterThanBlock()
{
DataFilter slow = new DataFilter("block", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.blockAverage(data, width, height, (int) boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.blockAverageInternal(data, width, height, (int) boxSize);
}
};
DataFilter fast = new DataFilter("rollingBlock", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.rollingBlockAverage(data, width, height, (int) boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.rollingBlockAverageInternal(data, width, height, (int) boxSize);
}
};
speedTest(fast, slow);
speedTestInternal(fast, slow);
}
@Test
public void rollingBlockIsFasterThanStripedBlock()
{
DataFilter slow = new DataFilter("stripedBlock", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage(data, width, height, (int) boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageInternal(data, width, height, (int) boxSize);
}
};
DataFilter fast = new DataFilter("rollingBlock", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.rollingBlockAverage(data, width, height, (int) boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.rollingBlockAverageInternal(data, width, height, (int) boxSize);
}
};
speedTest(fast, slow);
speedTestInternal(fast, slow);
}
@Test
public void stripedBlock3x3IsFasterThanStripedBlockNxN()
{
DataFilter slow = new DataFilter("stripedBlockNxN", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxN(data, width, height, (int) boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxNInternal(data, width, height, (int) boxSize);
}
};
DataFilter fast = new DataFilter("stripedBlock3x3", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage3x3(data, width, height);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage3x3Internal(data, width, height);
}
};
int[] testBoxSizes = new int[] { 1 };
speedTest(fast, slow, testBoxSizes);
speedTestInternal(fast, slow, testBoxSizes);
}
@Test
public void interpolatedStripedBlock3x3IsFasterThanStripedBlockNxN()
{
DataFilter slow = new DataFilter("stripedBlockNxN", true)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxN(data, width, height, boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxNInternal(data, width, height, boxSize);
}
};
DataFilter fast = new DataFilter("stripedBlock3x3", true)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage3x3(data, width, height, boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage3x3Internal(data, width, height, boxSize);
}
};
int[] testBoxSizes = new int[] { 1 };
speedTest(fast, slow, testBoxSizes);
speedTestInternal(fast, slow, testBoxSizes);
}
@Test
public void stripedBlock5x5IsFasterThanStripedBlockNxN()
{
DataFilter slow = new DataFilter("stripedBlockNxN", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxN(data, width, height, (int) boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxNInternal(data, width, height, (int) boxSize);
}
};
DataFilter fast = new DataFilter("stripedBlock5x5", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage5x5(data, width, height);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage5x5Internal(data, width, height);
}
};
int[] testBoxSizes = new int[] { 2 };
speedTest(fast, slow, testBoxSizes);
speedTestInternal(fast, slow, testBoxSizes);
}
@Test
public void interpolatedStripedBlock5x5IsFasterThanStripedBlockNxN()
{
DataFilter slow = new DataFilter("stripedBlockNxN", true)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxN(data, width, height, boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxNInternal(data, width, height, boxSize);
}
};
DataFilter fast = new DataFilter("stripedBlock5x5", true)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage5x5(data, width, height, boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage5x5Internal(data, width, height, boxSize);
}
};
int[] testBoxSizes = new int[] { 2 };
speedTest(fast, slow, testBoxSizes);
speedTestInternal(fast, slow, testBoxSizes);
}
@Test
public void stripedBlock7x7IsFasterThanStripedBlockNxN()
{
DataFilter slow = new DataFilter("stripedBlockNxN", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxN(data, width, height, (int) boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxNInternal(data, width, height, (int) boxSize);
}
};
DataFilter fast = new DataFilter("stripedBlock7x7", false)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage7x7(data, width, height);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage7x7Internal(data, width, height);
}
};
int[] testBoxSizes = new int[] { 3 };
speedTest(fast, slow, testBoxSizes);
speedTestInternal(fast, slow, testBoxSizes);
}
@Test
public void interpolatedStripedBlock7x7IsFasterThanStripedBlockNxN()
{
DataFilter slow = new DataFilter("stripedBlockNxN", true)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxN(data, width, height, boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverageNxNInternal(data, width, height, boxSize);
}
};
DataFilter fast = new DataFilter("stripedBlock7x7", true)
{
public void filter(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage7x7(data, width, height, boxSize);
}
public void filterInternal(float[] data, int width, int height, float boxSize)
{
f.stripedBlockAverage7x7Internal(data, width, height, boxSize);
}
};
int[] testBoxSizes = new int[] { 3 };
speedTest(fast, slow, testBoxSizes);
speedTestInternal(fast, slow, testBoxSizes);
}}