/******************************************************************************* * Copyright (c) 2010 Haifeng Li * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package smile.imputation; import smile.sort.QuickSort; /** * Missing value imputation by k-nearest neighbors. The KNN-based method * selects instances similar to the instance of interest to impute * missing values. If we consider instance A that has one missing value on * attribute i, this method would find K other instances, which have a value * present on attribute 1, with values most similar (in term of some distance, * e.g. Euclidean distance) to A on other attributes without missing values. * The average of values on attribute i from the K nearest * neighbors is then used as an estimate for the missing value in instance A. * In the weighted average, the contribution of each instance is weighted by * similarity between it and instance A. * * @author Haifeng Li */ public class KNNImputation implements MissingValueImputation { /** * The number of neighbors used for imputation. */ private int k; /** * Constructor. * @param k the number of neighbors used for imputation. */ public KNNImputation(int k) { if (k < 1) { throw new IllegalArgumentException("Invalid number of nearest neighbors for imputation: " + k); } this.k = k; } @Override public void impute(double[][] data) throws MissingValueImputationException { int[] count = new int[data[0].length]; for (int i = 0; i < data.length; i++) { int n = 0; for (int j = 0; j < data[i].length; j++) { if (Double.isNaN(data[i][j])) { n++; count[j]++; } } if (n == data[i].length) { throw new MissingValueImputationException("The whole row " + i + " is missing"); } } for (int i = 0; i < data[0].length; i++) { if (count[i] == data.length) { throw new MissingValueImputationException("The whole column " + i + " is missing"); } } double[] dist = new double[data.length]; for (int i = 0; i < data.length; i++) { double[] x = data[i]; int missing = 0; for (int j = 0; j < x.length; j++) { if (Double.isNaN(x[j])) { missing++; } } if (missing == 0) continue; for (int j = 0; j < data.length; j++) { double[] y = data[j]; int n = 0; dist[j] = 0; for (int m = 0; m < x.length; m++) { if (!Double.isNaN(x[m]) && !Double.isNaN(y[m])) { n++; double d = x[m] - y[m]; dist[j] += d * d; } } if (n > (x.length-missing) / 2) { dist[j] = x.length * dist[j] / n; } else { dist[j] = Double.MAX_VALUE; } } double[][] dat = new double[data.length][]; System.arraycopy(data, 0, dat, 0, data.length); QuickSort.sort(dist, dat); for (int j = 0; j < data[i].length; j++) { if (Double.isNaN(x[j])) { x[j] = 0; int n = 0; for (int m = 0; n < k && m < dat.length; m++) { if (!Double.isNaN(dat[m][j])) { x[j] += dat[m][j]; n++; } } x[j] /= n; } } } } }