package GeDBIT.index.algorithms; import java.util.List; //import cern.colt.matrix.DoubleMatrix2D; import GeDBIT.index.algorithms.Selection; import GeDBIT.dist.Metric; import GeDBIT.type.IndexObject; //import GeDBIT.util.LargeDenseDoubleMatrix2D; @SuppressWarnings("serial") public class SelectionOnFFT implements PivotSelectionMethod, java.io.Serializable { /* * 这个就是照着PCAOnFFT做的界面。 传的参数要完全符合预先设定的值,包括大小写,否则会出错 * 我感觉这样传参会比较麻烦,先这样吧,再想想有没有什么更好的方法 我给这个类加了3个变量:testKind 用什么标准(现在用的是F-test) * yMethod y怎么求得 selectAlgorithm 用的是forward还是backward * * testKind:在传参的时候,如果用F-test,就给testKind传"ftest" * yMethod:如果用平均值,就给yMethod传"average"。如果是标准差"standard" * selectAlgorithm:如果用forward selection,就传"forward" */ private int FFTScale; private String testKind; private String yMethod; private String selectAlgorithm; public SelectionOnFFT(int scale, String kind, String method, String alog) { FFTScale = scale; testKind = kind; yMethod = method; selectAlgorithm = alog; } public SelectionOnFFT(int scale) // 这个构造函数是作为备用的 { FFTScale = scale; testKind = "ftest"; yMethod = "standard"; selectAlgorithm = "forward"; } public int[] selectPivots(Metric metric, List<? extends IndexObject> data, int numPivots) { final int dataSize = data.size(); // for(int i=0; i<dataSize; i++) // System.out.println(data.get(i)+" "); if (numPivots >= dataSize) { int[] pivots = new int[dataSize]; for (int i = 0; i < dataSize; i++) pivots[i] = i; return IncrementalSelection.removeDuplicate(metric, data, pivots); } // System.out.println("datasize:"+dataSize); // run fft to get a candidate set int[] fftResult = PivotSelectionMethods.FFT.selectPivots(metric, data, numPivots * FFTScale); /* * System.out.println("fft length:"+fftResult.length); * * System.out.println("fft result: "); for(int i=0; i<fftResult.length; * i++) System.out.print(fftResult[i]+" "); System.out.println(); */ // compute the distance matrix if (fftResult.length <= Math.min(dataSize, numPivots)) return fftResult; // get x // x要空出x[0]这一列,为了和后面一致。 // x[0]这一列赋为1的工作就留到各个selection中的init函数来做 double[][] matrix = new double[dataSize][fftResult.length + 1]; // col是fft算出来的列 // row就是行,表示有row个观察值 for (int col = 0; col < fftResult.length; col++) for (int row = 0; row < dataSize; row++) matrix[row][col + 1] = metric.getDistance(data.get(row), data.get(fftResult[col])); /* * if(dataSize == fftResult.length) { * System.out.println("here comes matrix:"); for(int i=0; i<dataSize; * i++) { for(int j=1; j<=fftResult.length; j++) * System.out.print(matrix[i][j]+" "); System.out.println(); } } * * System.out.println("fft scale : "+FFTScale); * System.out.println("fft length : "+fftResult.length); * System.out.println("data size : "+dataSize); */ double[][] y; // get y if (yMethod.equalsIgnoreCase("average")) y = getYAvg(matrix, dataSize, fftResult.length); else if (yMethod.equalsIgnoreCase("standard")) y = getYStand(matrix, dataSize, fftResult.length); else throw new IllegalArgumentException("Invalid option " + yMethod); // X和Y都得到,下面进行计算 Selection select = new Selection(matrix, y, numPivots); // 设置检测方法。当selectAlgorithm为enumerate的时候就不用了 if (testKind.equalsIgnoreCase("ftest")) select.setTestSign(1); // 1 代表F检测 else if (testKind.equalsIgnoreCase("rss")) select.setTestSign(2); // 2 代表用rss else throw new IllegalArgumentException("Invalid option " + testKind); if (selectAlgorithm.equalsIgnoreCase("enumerate")) select.setTestSign(1); int[] result; // 选择使用何种算法来进行selection计算 if (selectAlgorithm.equalsIgnoreCase("forward")) result = select.forwardSelection(); else if (selectAlgorithm.equalsIgnoreCase("enumerate")) result = select.enumerateSelection(); else if (selectAlgorithm.equalsIgnoreCase("backward")) result = select.backwardSelection(); else throw new IllegalArgumentException("Invalid option " + selectAlgorithm); /* * System.out.println("result"); for(int i=0; i<result.length; i++) * System.out.print(result[i]+" "); System.out.println(); */ return result; } double[][] getYAvg(double[][] matrix, int n, int p) { double[][] matrixY = new double[n][1]; // double sum = 0; for (int i = 0; i < n; i++) { double sum = 0; // 计算每行的平均值 for (int j = 1; j <= p; j++) sum += matrix[i][j]; matrixY[i][0] = sum / p; } return matrixY; } double[][] getYStand(double[][] matrix, int n, int p) { double[][] matrixY = new double[n][1]; /* * //double sum = 0, sSum = 0; for(int i=0; i<n; i++){ double sum = 0, * sSum = 0, avg = 0; //计算每行的标准差 for(int j=1; j<=p; j++){ sum += * matrix[i][j]; //和的平方 sSum += matrix[i][j]*matrix[i][j]; } //均值的平方 avg * = sum/p; avg = avg * avg; //平方的均值 sSum = sSum / p; //相减,开根 * matrixY[i][0] = Math.sqrt(sSum - sum); } */ for (int i = 0; i < n; i++) // 计算n行的标准差 { double sum = 0, avg = 0, sSum = 0; for (int j = 1; j <= p; j++) sum += matrix[i][j]; avg = sum / p; for (int j = 1; j <= p; j++) sSum += (matrix[i][j] - avg) * (matrix[i][j] - avg); matrixY[i][0] = Math.sqrt(sSum); } /* * System.out.println("y matrix:"); for(int i=0; i<n; i++) { * System.out.println(matrixY[i][0]); } */ return matrixY; } public int[] selectPivots(Metric metric, List<? extends IndexObject> data, int first, int dataSize, int numPivots) { int[] result = selectPivots(metric, data.subList(first, dataSize), numPivots); for (int i = 0; i < result.length; i++) result[i] += first; return result; } }