package ml.humaning.algorithm; import java.io.BufferedWriter; import java.io.FileWriter; import java.io.IOException; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Random; import ml.humaning.util.Point; import ml.humaning.util.Reader; public class KNN { private Point [] allData; public KNN(String trainFile) throws IOException { allData = Reader.readPoints(trainFile); BufferedWriter wr = new BufferedWriter(new FileWriter("train_binary.dat")); for(int i = 0;i < allData.length;i++){ wr.write(allData[i].toBinaryString()+"\n"); } wr.close(); allData = Reader.readPoints("train_binary.dat"); } private HashSet <Integer> getValidationPoints(int numberOfValidationPoints) { HashSet<Integer> validationPoints = new HashSet<Integer>(); Random generator = new Random(System.currentTimeMillis()); while (validationPoints.size() < numberOfValidationPoints) validationPoints.add(generator.nextInt(allData.length)); return validationPoints; } private int classify(Point [] trainData, int k) { HashMap <Integer, Integer> zodiacToFrequency = new HashMap <Integer, Integer>(); for(int i = 0;i < trainData.length && i < k;i++){ if(zodiacToFrequency.get(trainData[i].getZodiac() ) == null){ zodiacToFrequency.put(trainData[i].getZodiac(), 1); }else { zodiacToFrequency.put(trainData[i].getZodiac(), zodiacToFrequency.get(trainData[i].getZodiac())+1); } } int maxFrequency = 0; int maxZodiac = 0; for(Integer zodiac : zodiacToFrequency.keySet()){ if(zodiacToFrequency.get(zodiac) > maxFrequency){ maxFrequency = zodiacToFrequency.get(zodiac); maxZodiac = zodiac; } } return maxZodiac; } public double getCVError(int k, int numberOfFold){ if(numberOfFold < 2)return 0.0; double crossValidationError = 0.0; for(int i = 1;i <= numberOfFold;i++){ crossValidationError += getValidationError(k, numberOfFold, i); } return crossValidationError/numberOfFold; } public void predict(int k, String testFile, String outputFile) throws IOException{ Point [] testData = Reader.readPoints(testFile); BufferedWriter writer = new BufferedWriter(new FileWriter(outputFile)); for(Point testP : testData){ int emptyRegion = testP.getEmptyRegion(); testP.setMaskRegion(emptyRegion); for(int trainIndex = 0; trainIndex < allData.length;trainIndex++){ allData[trainIndex].setMaskRegion(emptyRegion); allData[trainIndex].setDistanceToReference(-1*testP.cosineSimilarity(allData[trainIndex])); //allData[trainIndex].setDistanceToReference(testP.distance(allData[trainIndex], true)); } Arrays.sort(allData); writer.write(classify(allData, k)+"\n"); // System.out.println(classify(allData, k)+""); } writer.close(); } // fold is 1-based, [1, numberOfFold] private double getValidationError(int k, int numberOfFold, int fold){ if(numberOfFold == 0)return 0.0; int interval = allData.length/numberOfFold; int numberOfValidationPoints = (fold == numberOfFold)? allData.length - (fold-1)*interval : interval; HashSet <Integer> validationPoints = getValidationPoints(numberOfValidationPoints); Point [] trainData = new Point[allData.length - numberOfValidationPoints]; double error = 0.0; for (Integer validationPointIndex : validationPoints) { int trainIndex = 0; for (int i = 0;i < allData.length;i++){ if (!validationPoints.contains(i)) {// it's a train point trainData[trainIndex] = allData[i]; //trainData[trainIndex].setDistanceToReference(allData[validationPointIndex].distance(allData[i]) ); //trainData[trainIndex].setDistanceToReference(-1*allData[validationPointIndex].innerProduct(allData[i]) ); trainData[trainIndex].setDistanceToReference(-1*allData[validationPointIndex].cosineSimilarity(allData[i])); trainIndex++; } } Arrays.sort(trainData); if (classify(trainData, k) != allData[validationPointIndex].getZodiac()) error++; } return error/numberOfValidationPoints; } }