/*-
* Copyright 2016 Diamond Light Source Ltd.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*/
package uk.ac.diamond.scisoft.analysis.dataset.function;
import java.util.Arrays;
import java.util.stream.Stream;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.OpenMapRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.eclipse.january.dataset.Dataset;
import org.eclipse.january.dataset.DatasetFactory;
import org.eclipse.january.dataset.DatasetUtils;
import org.eclipse.january.dataset.DoubleDataset;
import org.eclipse.january.dataset.IDataset;
import org.eclipse.january.dataset.LinearAlgebra;
import org.eclipse.january.dataset.Maths;
import org.eclipse.january.dataset.Stats;
/**
* Implementation of the Savitzky Golay filter algorithm.
*
* @see <a href="https://en.wikipedia.org/wiki/Savitzky%E2%80%93Golay_filter">Wikipedia page on the Savitzky Golay filter</a>
*/
public class SavitzkyGolay {
/**
* @param data Dataset that will be filtered
* @param window The length of the filter window (i.e. the number of coefficients). Must be a positive odd integer.
* @param poly The order of the polynomial used to fit the samples. Must be less than window.
* @return Filtered dataset
* @throws Exception
*/
public static Dataset filter(IDataset data, int window, int poly) throws Exception {
return filter(data, window, poly, 0);
}
/**
* @param data Dataset that will be filtered
* @param window The length of the filter window (i.e. the number of coefficients). Must be a positive odd integer.
* @param poly The order of the polynomial used to fit the samples. Must be less than window.
* @param deriv The order of the derivative to compute. Must be a non-negative integer. Use zero if no derivative is required.
* @return Filtered dataset
* @throws Exception
*/
public static Dataset filter(IDataset data, int window, int poly, int deriv) throws Exception {
Dataset rv = null;
int m, n;
//sanity check
if (data.getRank() == 1) {
m = 1;
n = data.getShape()[0];
} else if (data.getRank() == 2) {
m = data.getShape()[0];
n = data.getShape()[1];
} else {
throw new Exception("data must be either one- or two-dimensional");
}
if (window < 3 || window % 2 == 0)
throw new Exception("window must be a positive odd integer");
if (poly < 1 || poly >= window)
throw new Exception("poly must be less than window");
if (deriv < 0)
throw new Exception("deriv must greater than or equal to zero");
int p = (window - 1)/2;
Dataset x1 = LinearAlgebra.outerProduct(DatasetFactory.createRange(-p, p+1, 1, Dataset.FLOAT64), DatasetFactory.ones(new int[]{1, 1 + poly}, Dataset.FLOAT64)).squeeze();
Dataset x2 = LinearAlgebra.outerProduct(DatasetFactory.ones(new int[]{1, window}, Dataset.FLOAT64), DatasetFactory.createRange(poly +1 , Dataset.FLOAT64)).squeeze();
Dataset x3 = Maths.power(x1, x2);
// solveSVD seems to correspond to numpy's lstsq
Dataset weights = LinearAlgebra.solveSVD(x3, DatasetUtils.eye(window, window, 0, Dataset.FLOAT64));
Dataset coeff = null;
if (deriv > 0) {
Dataset coeff1 = LinearAlgebra.outerProduct(DatasetFactory.ones(new int[]{deriv, 1}, Dataset.FLOAT64), DatasetFactory.createRange(1, poly + 2 - deriv, 1, Dataset.FLOAT64));
Dataset coeff2 = LinearAlgebra.outerProduct(DatasetFactory.createRange(deriv, Dataset.FLOAT64), DatasetFactory.ones(new int[]{poly + 1 - deriv}, Dataset.FLOAT64));
coeff = Stats.product(Maths.add(coeff1, coeff2), 0).squeeze();
} else {
coeff = DatasetFactory.ones(new int[]{poly + 1}, Dataset.FLOAT64);
}
Dataset outerTemp = LinearAlgebra.outerProduct(DatasetFactory.ones(new int[]{n, 1}, Dataset.FLOAT64), weights.getSlice(new int[]{deriv, 0}, new int[]{deriv + 1, weights.getShapeRef()[1]}, null)).imultiply(coeff.getDouble(0)).getTransposedView().squeeze().getSlice();
Dataset arangeTemp = DatasetFactory.createRange(p, -p-1, -1, Dataset.INT32);
OpenMapRealMatrix D = new OpenMapRealMatrix(n, n);
//fill up the sparse matrix with our data. Basically do what scipy's spdiags does...
for (int i = 0 ; i < arangeTemp.getSize() ; ++i) {
int diag = arangeTemp.getInt(i);
Dataset diagSlice = outerTemp.getSlice(new int[]{i,0}, new int[]{i+1,outerTemp.getShapeRef()[1]}, null).squeeze();
for (int row = -diag ; row < n ; ++row) {
int column = row + diag;
if (column < 0 || row < 0 || column >= n || row >= n)
continue;
D.setEntry(row, column, diagSlice.getDouble(row));
}
}
Dataset tail = LinearAlgebra.dotProduct(DatasetUtils.diag(coeff, 0), weights.getSlice(new int[]{deriv, 0}, new int[]{poly + 1, weights.getShapeRef()[1]} , null).squeeze());
Dataset dotTemp = LinearAlgebra.dotProduct(x3.getSlice(null, new int[]{p+1, poly - deriv +1}, null).squeeze(), tail).getTransposedView().getSlice();
DoubleDataset dotTempDouble = DatasetUtils.cast(DoubleDataset.class, dotTemp);
double[][] dotTempPrim = convertDoubleDataset2DtoPrimitive(dotTempDouble);
D.setSubMatrix(dotTempPrim, 0, 0);
Dataset dotTemp2 = LinearAlgebra.dotProduct(x3.getSlice(new int[]{p, 0}, new int[]{window, poly - deriv +1}, null).squeeze(), tail).getTransposedView().getSlice();
DoubleDataset dotTempDouble2 = DatasetUtils.cast(DoubleDataset.class, dotTemp2);
D.setSubMatrix(convertDoubleDataset2DtoPrimitive(dotTempDouble2), n-window-1, n-p-2);
RealMatrix dataMatrix = null;
if (data.getRank() == 1) {
dataMatrix = new Array2DRowRealMatrix(DatasetUtils.cast(DoubleDataset.class, data).getData());
dataMatrix = dataMatrix.transpose();
} else {
dataMatrix = new Array2DRowRealMatrix(convertDoubleDataset2DtoPrimitive(DatasetUtils.cast(DoubleDataset.class, data)), false);
}
RealMatrix rvMatrix = D.preMultiply(dataMatrix);
rv = DatasetFactory.createFromObject(Stream.of(rvMatrix.getData()).flatMapToDouble(Arrays::stream).toArray(), data.getShape());
return rv;
}
/*private static void printInfo(Dataset data) {
System.out.println("Size: " + data.getSize());
System.out.println("Shape: " + Arrays.toString(data.getShape()));
//Dataset temp = data.getSlice(null, new int[]{1, data.getShape()[1]} , null).squeeze();
//System.out.println("First value: " + Arrays.toString(((DoubleDataset) data.getSlice(null, new int[]{1, data.getShape()[1]} , null).squeeze()).getData()));
try {
System.out.println("First value: " + data.getSlice(new int[]{0, 0}, new int[]{1, data.getShape()[1]} , null).squeeze().toString(true));
System.out.println("Second value: " + data.getSlice(new int[]{1, 0}, new int[]{2, data.getShape()[1]} , null).squeeze().toString(true));
System.out.println("Last value: " + data.getSlice(new int[]{data.getShape()[0]-1, 0}, new int[]{data.getShape()[0], data.getShape()[1]} , null).squeeze().toString(true));
} catch (Exception e) {
System.out.println("All values: " + data.toString(true));
}
}*/
private static double[][] convertDoubleDataset2DtoPrimitive(DoubleDataset dataset) {
if (dataset.getRank() != 2)
throw new IllegalArgumentException("dataset Shape must be 2D");
double[][] rv = new double[dataset.getShape()[0]][dataset.getShape()[1]];
for (int row = 0 ; row < dataset.getShape()[0] ; row++) {
System.arraycopy(dataset.getData(), row * dataset.getShape()[1], rv[row], 0, dataset.getShape()[1]);
}
return rv;
}
}