package org.spin.gaitlib.gait;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.random.EmpiricalDistribution;
import org.apache.commons.math3.stat.correlation.PearsonsCorrelation;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.spin.gaitlib.sensor.ThreeAxisSensorReading;
import org.spin.gaitlib.util.SpectralAnalyses;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import java.util.List;
/**
* An implementation of {@link GaitClassifier}, using the algorithm by Oliver Schneider.
*
* @author Oliver
* @author Mike
*/
public class DefaultGaitClassifier extends GaitClassifier {
private static final String DEFAULT_MODEL_LOCATION = "Android/data/org.spin.gaitlib/files/classification_models/personalized_model.model";
private static final String DEFAULT_MODEL_XML_LOCATION = "Android/data/org.spin.gaitlib/files/classification_models/personalized.xml";
private static final int NUMBER_OF_ATTRIBUTES = 112;
private static final String[] axis = {
"x_accel", "y_accel", "z_accel"
};
/**
* A default gait classifier with the default model file - Random Forest with 100 trees.
*
* @param signalListener
*/
public DefaultGaitClassifier() {
this(DEFAULT_MODEL_LOCATION, DEFAULT_MODEL_XML_LOCATION);
}
/**
* A default gait classifier with a custom defined classification model.
*
* @param signalListener
* @param modelFileLocation the path to the model file. E.g., for root/gaitlib/classifier.model,
* use <code>gaitlib/classifier.model</code>
*/
public DefaultGaitClassifier(String modelFileLocation, String modelXMLFileLocation) {
super(modelFileLocation, modelXMLFileLocation);
}
/**
* @exception ModelNotLoadedException if the model for the classifier is not loaded.
* @exception Exception if an error occurred during the prediction.
*/
@Override
public void classifyGait() throws Exception {
if (getClassifier() == null) {
throw new ModelNotLoadedException();
}
if (getAttributes().size() != NUMBER_OF_ATTRIBUTES) {
throw new Exception(
"Error: list of attributes does not match between model file and xml file.");
}
float[] signalX_float, signalY_float, signalZ_float, signalTime_float;
ThreeAxisSensorReading[] accelArr = getSignalListener().getAccelReadingsArray();
signalX_float = new float[accelArr.length];
signalY_float = new float[accelArr.length];
signalZ_float = new float[accelArr.length];
signalTime_float = new float[accelArr.length];
float[][] signal_float = {
signalX_float, signalY_float, signalZ_float
};
for (int i = 0; i < accelArr.length; i++) {
signalX_float[i] = accelArr[i].getX();
signalY_float[i] = accelArr[i].getY();
signalZ_float[i] = accelArr[i].getZ();
signalTime_float[i] = accelArr[i].getTimeSinceStartInS();
}
double[] features = extractFeatures(signal_float, signalTime_float);
Instances instances = new Instances("GaitDataSet", getAttributes(), 1);
instances.setClassIndex(instances.numAttributes() - 1);
Instance thisInstance = new DenseInstance(1.0, features);
instances.add(thisInstance);
thisInstance.setDataset(instances);
int index = (int) getClassifier().classifyInstance(thisInstance);
setCurrentGait(instances.classAttribute().value(index));
}
/**
* @param signal_float a 2D array containing the accelerometer readings in x, y and z
* coordinate.
* @param signalTime_float an array containing the timestamp, with unit of seconds, for each
* accelerometer reading
* @return an array of features
*/
public static double[] extractFeatures(float[][] signal_float, float[] signalTime_float) {
int signalLength = signal_float[0].length;
double[] signalX_double = new double[signalLength];
double[] signalY_double = new double[signalLength];
double[] signalZ_double = new double[signalLength];
double[] signalTime_double = new double[signalLength];
double[][] signal_double = {
signalX_double, signalY_double, signalZ_double
};
float[] zeroes_float = new float[signalLength];
double signal_magnitude_area = 0;
for (int i = 0; i < signalLength; i++) {
signalX_double[i] = signal_float[0][i];
signalY_double[i] = signal_float[1][i];
signalZ_double[i] = signal_float[2][i];
signalTime_double[i] = signalTime_float[i];
zeroes_float[i] = 0.0f;
signal_magnitude_area += Math.abs(signalX_double[i]) + Math.abs(signalY_double[i])
+ Math.abs(signalZ_double[i]);
}
float hifac = (float) 0.25;
float ofac = 4;
double[] features = new double[NUMBER_OF_ATTRIBUTES + 1];
int feature_i = 0;
for (int i = 0; i < axis.length; i++) {
double[] curSignal = signal_double[i];
float[] curSignal_float = signal_float[i];
DescriptiveStatistics curSignalStats = new DescriptiveStatistics(curSignal);
features[feature_i++] = curSignalStats.getMin();
features[feature_i++] = curSignalStats.getMax();
features[feature_i++] = curSignalStats.getMean();
features[feature_i++] = curSignalStats.getVariance();
features[feature_i++] = curSignalStats.getSkewness();
features[feature_i++] = curSignalStats.getKurtosis();
features[feature_i++] = curSignalStats.getPercentile(25);
features[feature_i++] = curSignalStats.getPercentile(50);
features[feature_i++] = curSignalStats.getPercentile(75);
// histogram
EmpiricalDistribution curSignalDistribution = new EmpiricalDistribution(10);
curSignalDistribution.load(curSignal);
List<SummaryStatistics> curSignalBinStats = curSignalDistribution.getBinStats();
double curSignalN = curSignalStats.getN();
for (SummaryStatistics binStat : curSignalBinStats) {
features[feature_i++] = binStat.getN() / curSignalN;
}
float[][] fasper_results = SpectralAnalyses.fasperArray(hifac, ofac, curSignal_float,
zeroes_float, zeroes_float, signalTime_float, signalX_double.length);
// strongest, second strongest and weakest frequencies
float[] minMaxFreq = SpectralAnalyses.fasperResultsMaxMinFreq(fasper_results);
features[feature_i++] = minMaxFreq[0];
features[feature_i++] = minMaxFreq[1];
features[feature_i++] = minMaxFreq[2];
// weighted average frequency
features[feature_i++] = SpectralAnalyses
.fasperResultsWeightedAverageFreq(fasper_results);
// frequency variance
double[] fasperResultsFrequencies = SpectralAnalyses
.fasperResultsFrequenciesAsDoubles(fasper_results);
features[feature_i++] = (new DescriptiveStatistics(fasperResultsFrequencies))
.getVariance();
double[] fasperResultsPowers = SpectralAnalyses
.fasperResultsPowersAsDoubles(fasper_results);
DescriptiveStatistics fasperPowersStats = new
DescriptiveStatistics(fasperResultsPowers);
double fasperPowersSum = fasperPowersStats.getSum();
// spectral entropy
double entropy = 0;
double log2 = Math.log(2);
for (int j = 0; j < fasperResultsPowers.length; j++) {
double probability = fasperResultsPowers[j] / fasperPowersSum;
entropy = entropy + probability * (Math.log(probability) / log2);
}
entropy = -entropy;
features[feature_i++] = entropy;
// spectral histogram
EmpiricalDistribution fasperPowersDistribution = new EmpiricalDistribution(10);
fasperPowersDistribution.load(fasperResultsPowers);
List<SummaryStatistics> fasperPowersBinStats = curSignalDistribution.getBinStats();
double fasperPowersN = fasperPowersStats.getN();
for (SummaryStatistics binStat : fasperPowersBinStats) {
features[feature_i++] = binStat.getN() / fasperPowersN;
}
}
// Pearson correlation values
RealMatrix data_for_correlation = new BlockRealMatrix(signal_double).transpose();
PearsonsCorrelation pearsonCorr = new PearsonsCorrelation(data_for_correlation);
// this may be inefficient, I may have already calculated this above
features[feature_i++] = pearsonCorr.correlation(signalX_double, signalY_double);
features[feature_i++] = pearsonCorr.correlation(signalX_double, signalZ_double);
features[feature_i++] = pearsonCorr.correlation(signalY_double, signalZ_double);
// pearson P values
RealMatrix pvalues = pearsonCorr.getCorrelationPValues();
features[feature_i++] = pvalues.getEntry(0, 1);
features[feature_i++] = pvalues.getEntry(0, 2);
features[feature_i++] = pvalues.getEntry(1, 2);
// signal_magnitude_area
features[feature_i++] = signal_magnitude_area;
features[feature_i++] = 0; // class attribute
return features;
}
}