/* EvokedPotentialMethod.java created 2008-01-12 * */ package org.signalml.method.ep; import static org.signalml.app.util.i18n.SvarogI18n._; import org.apache.log4j.Logger; import org.signalml.domain.montage.filter.TimeDomainSampleFilter; import org.signalml.domain.montage.system.IChannelFunction; import org.signalml.domain.signal.filter.iir.OfflineIIRSinglechannelSampleFilter; import org.signalml.domain.signal.samplesource.ChannelSelectorSampleSource; import org.signalml.domain.signal.samplesource.DoubleArraySampleSource; import org.signalml.domain.signal.samplesource.MultichannelSegmentedSampleSource; import org.signalml.domain.signal.space.MarkerSegmentedSampleSource; import org.signalml.math.iirdesigner.BadFilterParametersException; import org.signalml.math.iirdesigner.FilterCoefficients; import org.signalml.math.iirdesigner.IIRDesigner; import org.signalml.method.AbstractMethod; import org.signalml.method.ComputationException; import org.signalml.method.MethodExecutionTracker; import org.signalml.method.TrackableMethod; import org.signalml.plugin.export.SignalMLException; import org.signalml.plugin.export.method.BaseMethodData; import org.springframework.validation.Errors; /** EvokedPotentialMethod * * * @author Michal Dobaczewski © 2007-2008 CC Otwarte Systemy Komputerowe Sp. z o.o. */ public class EvokedPotentialMethod extends AbstractMethod implements TrackableMethod { protected static final Logger logger = Logger.getLogger(EvokedPotentialMethod.class); private static final String UID = "561691f8-bd14-486f-989b-a09c0bd57455"; private static final String NAME = "evokedPotential"; private static final int[] VERSION = new int[] {1,0}; public EvokedPotentialMethod() throws SignalMLException { super(); } @Override public Object doComputation(Object dataObj, final MethodExecutionTracker tracker) throws ComputationException { logger.debug("Beginning computation of EP"); EvokedPotentialData data = (EvokedPotentialData) dataObj; tracker.setMessage(_("Preparing")); MarkerSegmentedSampleSource sampleSource = data.getSampleSources().get(0); int sampleCount = sampleSource.getSegmentLengthInSamples(); int segmentCount = sampleSource.getSegmentCount(); int channelCount = sampleSource.getChannelCount(); float samplingFrequency = sampleSource.getSamplingFrequency(); String[] labels = new String[channelCount]; for (int segment=0; segment<channelCount; segment++) { labels[segment] = sampleSource.getLabel(segment); IChannelFunction channelFunction = sampleSource.getChannelFunction(segment); if (channelFunction != null) { String unit = channelFunction.getUnitOfMeasurementSymbol(); labels[segment] += " [" + unit + "]"; } } EvokedPotentialResult result = new EvokedPotentialResult(data); EvokedPotentialParameters parameters = data.getParameters(); result.setStartTime(parameters.getAveragingStartTime()); result.setSegmentLength(parameters.getAveragingTimeLength()); tracker.setMessage(_("Summing")); tracker.setTickerLimit(0, segmentCount); for (MultichannelSegmentedSampleSource segmentedSampleSource: data.getSampleSources()) { double[][] averageSamples = average(segmentedSampleSource, tracker); if (averageSamples == null) return null; result.addAverageSamples(averageSamples); } if (data.getParameters().isBaselineCorrectionEnabled()) performBaselineCorrection(result, data); if (data.getParameters().isFilteringEnabled()) try { performLowPassFiltering(result, data); } catch (BadFilterParametersException exception) { logger.error("", exception); throw new ComputationException(_("An error occured while designing the signal filter.")); } result.setLabels(labels); result.setSampleCount(sampleCount); result.setChannelCount(channelCount); result.setSamplingFrequency(samplingFrequency); for (MarkerSegmentedSampleSource segmentedSampleSource: data.getSampleSources()) { result.getAveragedSegmentsCount().add(segmentedSampleSource.getSegmentCount()); result.getUnusableSegmentsCount().add(segmentedSampleSource.getUnusableSegmentCount()); result.getArtifactRejectedSegmentsCount().add(segmentedSampleSource.getArtifactRejectedSegmentsCount()); } tracker.setMessage(_("Finished")); return result; } protected void performBaselineCorrection(EvokedPotentialResult result, EvokedPotentialData data) { int sampleSourceNumber = 0; for (MultichannelSegmentedSampleSource segmentedSampleSource: data.getBaselineSampleSources()) { double[] baselineSamples = new double[segmentedSampleSource.getSegmentLengthInSamples()]; for (int channel = 0; channel < segmentedSampleSource.getChannelCount(); channel++) { double sum = 0.0; for (int segment = 0; segment < segmentedSampleSource.getSegmentCount(); segment++) { segmentedSampleSource.getSegmentSamples(channel, baselineSamples, segment); for (double sample: baselineSamples) { sum += sample; } } if (segmentedSampleSource.getSegmentCount() > 0) { double baseline = sum / (segmentedSampleSource.getSegmentCount() * segmentedSampleSource.getSegmentLengthInSamples()); double[] samples = result.getAverageSamples().get(sampleSourceNumber)[channel]; for (int i = 0; i < samples.length; i++) samples[i] = samples[i] - baseline; } } sampleSourceNumber++; } } protected void performLowPassFiltering(EvokedPotentialResult result, EvokedPotentialData data) throws BadFilterParametersException { TimeDomainSampleFilter filter = data.getParameters().getTimeDomainSampleFilter(); FilterCoefficients filterCoefficients = IIRDesigner.designDigitalFilter(filter); for (double[][] samples: result.getAverageSamples()) { DoubleArraySampleSource multichannelSampleSource = new DoubleArraySampleSource(samples); for (int channel = 0; channel < samples.length; channel++) { ChannelSelectorSampleSource channelSampleSource = new ChannelSelectorSampleSource(multichannelSampleSource, channel); OfflineIIRSinglechannelSampleFilter filterEngine = new OfflineIIRSinglechannelSampleFilter(channelSampleSource, filterCoefficients); filterEngine.setFiltfiltEnabled(true); filterEngine.getSamples(samples[channel], 0, samples[channel].length, 0); } } } protected double[][] average(MultichannelSegmentedSampleSource sampleSource, MethodExecutionTracker tracker) { int sampleCount = sampleSource.getSegmentLengthInSamples(); int segmentCount = sampleSource.getSegmentCount(); int channelCount = sampleSource.getChannelCount(); double[] samples = new double[sampleCount]; double[][] averageSamples = new double[channelCount][sampleCount]; for (int segment=0; segment<segmentCount; segment++) { for (int channel=0; channel<channelCount; channel++) { if (tracker.isRequestingAbort()) { return null; } sampleSource.getSegmentSamples(channel, samples, segment); for (int j=0; j<sampleCount; j++) { averageSamples[channel][j] += samples[j]; } } if (segment % 10 == 0) { tracker.tick(0, 10); } } tracker.setMessage(_("Averaging")); tracker.setTicker(0, 0); tracker.setTickerLimit(0, channelCount); // markers have been summed, now divide to get the average if (segmentCount >0 ) { for (int channel=0; channel<channelCount; channel++) { for (int j=0; j<sampleCount; j++) { averageSamples[channel][j] /= segmentCount; } tracker.tick(0); } } return averageSamples; } @Override public void validate(Object dataObj, Errors errors) { super.validate(dataObj, errors); if (!errors.hasErrors()) { EvokedPotentialData data = (EvokedPotentialData) dataObj; data.validate(errors); } } @Override public int getTickerCount() { return 1; } @Override public String getTickerLabel(int ticker) { if (0 == ticker) return _("Processing markers"); else throw new IndexOutOfBoundsException(); } @Override public String getUID() { return UID; } @Override public String getName() { return NAME; } @Override public int[] getVersion() { return VERSION; } @Override public BaseMethodData createData() { return new EvokedPotentialData(); } @Override public Class<?> getResultClass() { return EvokedPotentialResult.class; } @Override public boolean supportsDataClass(Class<?> clazz) { return EvokedPotentialData.class.isAssignableFrom(clazz); } }