package org.signalml.app;
import static org.signalml.SignalMLAssert.assertArrayEquals;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import org.signalml.app.method.ep.EvokedPotentialApplicationData;
import org.signalml.app.method.ep.view.tags.TagStyleGroup;
import org.signalml.domain.signal.space.AbstractTagSegmentedTest;
import org.signalml.domain.tag.StyledTagSet;
import org.signalml.method.ep.EvokedPotentialMethod;
import org.signalml.method.ep.EvokedPotentialParameters;
import org.signalml.method.ep.EvokedPotentialResult;
import org.signalml.plugin.export.signal.Tag;
/**
* Checks if the {@link EvokedPotentialMethod} works as expected.
*
* @author Piotr Szachewicz
*/
public class EvokedPotentialMethodTest extends AbstractTagSegmentedTest {
EvokedPotentialApplicationData data = new EvokedPotentialApplicationData();
public EvokedPotentialMethodTest() throws Exception {
data.setSignalDocument(getSignalDocument());
data.setTagDocument(getTagDocument());
data.getParameters().setAveragingStartTime(-1.0F);
data.getParameters().setAveragingTimeLength(2.0F);
data.getParameters().setBaselineTimeStart(-2.0F);
data.getParameters().setBaselineTimeLength(1.0F);
data.getParameters().setFilteringEnabled(false);
List<TagStyleGroup> group = new ArrayList<TagStyleGroup>();
group.add(new TagStyleGroup(tagStyles[0].getName()));
data.getParameters().setAveragedTagStyles(group);
}
@Test
public void testZeroTags() throws Exception {
data.getParameters().setBaselineCorrectionEnabled(false);
performTest();
}
@Test
public void testOneTag() throws Exception {
StyledTagSet tagSet = data.getTagDocument().getTagSet();
tagSet.addTag(new Tag(averagedTagStyle, 1.0, 0.0));
data.getParameters().setBaselineCorrectionEnabled(false);
performTest();
}
@Test
public void testThreeTags() throws Exception {
StyledTagSet tagSet = data.getTagDocument().getTagSet();
tagSet.addTag(new Tag(averagedTagStyle, 2.0, 0.0));
tagSet.addTag(new Tag(averagedTagStyle, 4.0, 0.0));
tagSet.addTag(new Tag(averagedTagStyle, 5.0, 0.0));
data.getParameters().setBaselineCorrectionEnabled(false);
performTest();
}
@Test
public void testThreeTagsWithBaseline() throws Exception {
StyledTagSet tagSet = data.getTagDocument().getTagSet();
tagSet.addTag(new Tag(averagedTagStyle, 3.0, 0.0));
tagSet.addTag(new Tag(averagedTagStyle, 5.0, 0.0));
tagSet.addTag(new Tag(averagedTagStyle, 8.0, 0.0));
data.getParameters().setBaselineCorrectionEnabled(true);
performTest();
}
@Test
public void testTagOutsideTheSignal() throws Exception {
StyledTagSet tagSet = data.getTagDocument().getTagSet();
tagSet.addTag(new Tag(averagedTagStyle, -1.0, 0.0));
data.getParameters().setBaselineCorrectionEnabled(true);
performTest();
}
@Test
public void testTagOutsideTheSignalAndTheOtherInside() throws Exception {
StyledTagSet tagSet = data.getTagDocument().getTagSet();
tagSet.addTag(new Tag(averagedTagStyle, 10000.0, 0.0));
tagSet.addTag(new Tag(averagedTagStyle, 2.0, 0.0));
data.getParameters().setBaselineCorrectionEnabled(true);
performTest();
}
@Test
public void testOneTagBaselineOutside() throws Exception {
StyledTagSet tagSet = data.getTagDocument().getTagSet();
tagSet.addTag(new Tag(averagedTagStyle, 1.0, 0.0));
data.getParameters().setBaselineCorrectionEnabled(true);
performTest();
}
@Test
public void testOneTagBaselineOutside2() throws Exception {
StyledTagSet tagSet = data.getTagDocument().getTagSet();
tagSet.addTag(new Tag(averagedTagStyle, 1000.0, 0.0));
data.getParameters().setBaselineCorrectionEnabled(true);
performTest();
}
@Test
public void testALotOfTags() throws Exception {
data.getParameters().getAveragedTagStyles().add(new TagStyleGroup(AVERAGED_TAG_NAME_2));
StyledTagSet tagSet = data.getTagDocument().getTagSet();
tagSet.addTag(new Tag(averagedTagStyle, 10000.0, 0.0));
tagSet.addTag(new Tag(averagedTagStyle, 2.0, 0.0));
tagSet.addTag(new Tag(averagedTagStyle2, 30.0, 0.0));
tagSet.addTag(new Tag(otherTagStyle, 4.0, 0.0));
tagSet.addTag(new Tag(otherTagStyle, 7.0, 0.0));
tagSet.addTag(new Tag(otherTagStyle, -1.0, 0.0));
tagSet.addTag(new Tag(artifactTagStyle, 11.0, 0.0));
data.getParameters().setBaselineCorrectionEnabled(true);
performTest();
}
public void performTest() throws Exception {
List<Double> tagPositions = new ArrayList<Double>();
for (Tag tag: data.getTagDocument().getTagSet().getTags()) {
if ((tag.getStyle() == averagedTagStyle || tag.getStyle() == averagedTagStyle2)
&& tag.getPosition() >= 0.0 &&
tag.getPosition() * samplingFrequency <= data.getSignalDocument().getSampleSource().getSampleCount(0))
tagPositions.add(tag.getPosition());
}
data.calculate();
EvokedPotentialMethod method = new EvokedPotentialMethod();
EvokedPotentialResult result = (EvokedPotentialResult) method.doComputation(data, new DummyMethodExecutionTracker());
double[][] averagedSamples = result.getAverageSamples().get(0);
int avgLength = getAveragedSamples(0, 1.0).length;
for (int channel = 0; channel < CHANNEL_COUNT; channel++) {
double[][] samplesTag = new double[tagPositions.size()][avgLength];
for (int i = 0; i < tagPositions.size(); i++) {
samplesTag[i] = getAveragedSamples(channel, tagPositions.get(i));
}
double[] expectedAveragedSamples = new double[avgLength];
for (int i = 0; i < expectedAveragedSamples.length; i++) {
double sum = 0.0;
for (int j = 0; j < samplesTag.length; j++)
sum += samplesTag[j][i];
if (samplesTag.length > 0)
expectedAveragedSamples[i] = sum / samplesTag.length;
}
if (data.getParameters().isBaselineCorrectionEnabled()) {
performBaselineCorrection(channel, tagPositions, expectedAveragedSamples);
}
assertArrayEquals(averagedSamples[channel], expectedAveragedSamples, 1e-3);
}
}
protected void performBaselineCorrection(int channel, List<Double> tagPositions, double[] samples) {
double baseline = 0.0;
int number = 0;
for (int i = 0; i < tagPositions.size(); i++) {
double[] baselineSamples = getBaselineSamples(channel, tagPositions.get(i));
for (double sample: baselineSamples) {
baseline += sample;
number++;
}
}
if (number > 0)
baseline /= number;
for (int i = 0; i < samples.length; i++) {
samples[i] -= baseline;
}
}
public double[] getSamples(int channel, double markerPosition, double startTime, double lengthInSeconds) {
int startSample = (int) ((markerPosition + startTime) * samplingFrequency);
int numberOfSamples = (int) (lengthInSeconds * samplingFrequency);
if (startSample < 0 || startSample + numberOfSamples > samples[0].length)
numberOfSamples = 0;
double[] sampleChunk = new double[numberOfSamples];
for (int i = 0; i < sampleChunk.length; i++) {
sampleChunk[i] = samples[channel][startSample + i];
}
return sampleChunk;
}
public double[] getAveragedSamples(int channel, double markerPosition) {
EvokedPotentialParameters parameters = data.getParameters();
return getSamples(channel, markerPosition, parameters.getAveragingStartTime(), parameters.getAveragingTimeLength());
}
public double[] getBaselineSamples(int channel, double startPosition) {
EvokedPotentialParameters parameters = data.getParameters();
return getSamples(channel, startPosition, parameters.getBaselineTimeStart(), parameters.getBaselineTimeLength());
}
}