package clear.experiment; import clear.dep.DepNode; import clear.dep.DepTree; import clear.reader.DepReader; import clear.util.IOUtil; import clear.util.tuple.JIntDoubleTuple; import com.carrotsearch.hppc.IntOpenHashSet; import com.carrotsearch.hppc.ObjectIntOpenHashMap; import java.io.PrintStream; import java.util.Arrays; import java.util.HashSet; import java.util.Random; public class DepFuzzyCluster { final int MAX_GAP = 5; final int MAX_ITER = 20; final double STOP = 0.1; ObjectIntOpenHashMap<String> m_lexica; int[][] a_vectors; // i_N x * double[][] d_centroids; // i_D x i_K double[] d_scalars; // i_K double[][] d_u; // i_N x i_K double d_m, d_a; int i_K, i_D, i_N; public DepFuzzyCluster(String trainFile, String testFile, int K, double m) { init(trainFile, K, m); train(); split(trainFile, testFile); } private void init(String trainFile, int K, double m) { i_K = K; d_m = m; d_a = 1d / i_K; initLexica(trainFile); initVectors(trainFile); initPriors(); } private void initLexica(String trainFile) { DepReader reader = new DepReader(trainFile, true); DepTree tree; int d = 1; System.out.print("Initializing lexica : "); m_lexica = new ObjectIntOpenHashMap<>(); for (i_N = 0; (tree = reader.nextTree()) != null; i_N++) { for (String key : getLexica(tree, true)) { if (!m_lexica.containsKey(key)) { m_lexica.put(key, d++); } } } reader.close(); System.out.println((i_D = m_lexica.size())); } private void initVectors(String filename) { DepReader reader = new DepReader(filename, true); DepTree tree; System.out.print("Initializing vectors: "); a_vectors = new int[i_N][]; for (int n = 0; (tree = reader.nextTree()) != null; n++) { a_vectors[n] = getVector(tree, true); } reader.close(); System.out.println(i_N); } private void initPriors() { double u = 1d / i_K; d_u = new double[i_N][i_K]; System.out.print("Initializing priors : "); for (int n = 0; n < i_N; n++) { Arrays.fill(d_u[n], u); } System.out.println(i_N + " x " + i_K); } private HashSet<String> getLexica(DepTree tree, boolean isTrain) { HashSet<String> set = new HashSet<>(); addNgramLexica(tree, set); if (isTrain) { addDepTrnLexica(tree, set); } else { addDepTstLexica(tree, set); } return set; } private int[] getVector(DepTree tree, boolean isTrain) { IntOpenHashSet set = new IntOpenHashSet(); int d; for (String key : getLexica(tree, isTrain)) { if ((d = m_lexica.get(key)) > 0) { set.add(d - 1); } } int[] vector = set.toArray(); Arrays.sort(vector); return vector; } private void addNgramLexica(DepTree tree, HashSet<String> set) { DepNode node, prev; for (int i = 1; i < tree.size(); i++) { node = tree.get(i); set.add(node.lemma); set.add(node.pos); if (i > 1) { prev = tree.get(i - 1); add2gramLexica(set, 0, prev, node); } } } private void addDepTrnLexica(DepTree tree, HashSet<String> set) { DepNode node, prev, next; int dist; for (int i = 1; i < tree.size(); i++) { node = tree.get(i); dist = Math.abs(node.id - node.headId); if (dist > MAX_GAP || node.headId == 0) { continue; } if (node.id < node.headId) { prev = node; next = tree.get(node.headId); } else { prev = tree.get(node.headId); next = node; } add2gramLexica(set, dist, prev, next); } } private void addDepTstLexica(DepTree tree, HashSet<String> set) { DepNode prev, next; int i, dist, size = tree.size(); for (i = 1; i < size; i++) { prev = tree.get(i); for (dist = 1; dist <= MAX_GAP && i + dist < size; dist++) { next = tree.get(i + dist); add2gramLexica(set, dist, prev, next); } } } private void add2gramLexica(HashSet<String> set, int dist, DepNode prev, DepNode next) { String prefix = (dist <= 0) ? "" : dist + "_"; set.add(prefix + prev.lemma + "_" + next.lemma); set.add(prefix + prev.lemma + "_" + next.pos); set.add(prefix + prev.pos + "_" + next.lemma); set.add(prefix + prev.pos + "_" + next.pos); } private void train() { double prevScore, currScore = 0; for (int i = 0; i < MAX_ITER; i++) { System.out.println("\nIteration: " + i); prevScore = currScore; System.out.println("- E-step"); expectation(); System.out.println("- M-step"); currScore = maximization(); System.out.println("- Score : " + currScore); if (Math.abs(currScore - prevScore) < STOP) { break; } } } private void expectation() { if (d_centroids == null) { initCentroids(); return; } double[] den = new double[i_K]; double u; int[] vector; int k, n; for (n = 0; n < i_N; n++) { vector = a_vectors[n]; for (k = 0; k < i_K; k++) { u = getPrior(n, k); den[k] += u; for (int d : vector) { d_centroids[d][k] += u; } } } Arrays.fill(d_scalars, 0); for (n = 0; n < i_D; n++) { for (k = 0; k < i_K; k++) { u = d_centroids[n][k] / den[k]; d_centroids[n][k] = u; d_scalars[k] += u * u; } } for (k = 0; k < i_K; k++) { d_scalars[k] = Math.sqrt(d_scalars[k]); } } private void initCentroids() { IntOpenHashSet set = new IntOpenHashSet(); Random rand = new Random(0); int[] vector; int n, k = 0; d_centroids = new double[i_D][i_K]; d_scalars = new double[i_K]; while (set.size() < i_K) { n = rand.nextInt(i_N); if (!set.contains(n)) { vector = a_vectors[n]; set.add(n); for (int d : vector) { d_centroids[d][k] = 1; } d_scalars[k++] = Math.sqrt(vector.length); } } } private double getPrior(int n, int k) { return Math.pow(d_u[n][k], d_m); } private double maximization() { double[] dist; double num, den, sum, m = 1d / (d_m - 1), score = 0; int[] vector; int n, k, i; for (n = 0; n < i_N; n++) { vector = a_vectors[n]; dist = getCosineDistances(vector); for (k = 0; k < i_K; k++) { num = dist[k]; sum = 0; for (i = 0; i < i_K; i++) { if (i == k) { sum += 1; } else { den = dist[i]; sum += Math.pow(num / den, m); } } d_u[n][k] = 1d / sum; score += Math.pow(d_u[n][k], d_m) * dist[k]; } } return score; } private double[] getCosineDistances(int[] vector) { double[] dots = new double[i_K]; double[] centroid; double scalar = Math.sqrt(vector.length); int k; for (int d : vector) { centroid = d_centroids[d]; for (k = 0; k < i_K; k++) { dots[k] += centroid[k]; } } for (k = 0; k < i_K; k++) { dots[k] = 1 - (dots[k] / (d_scalars[k] * scalar)); } return dots; } private void split(String trainFile, String testFile) { splitTrainFile(trainFile); splitTestFile(testFile); } private void splitTrainFile(String trainFile) { PrintStream[] fout = getPrintStreams(trainFile); DepReader reader = new DepReader(trainFile, true); DepTree tree; int[] count = new int[i_K]; int n, k; System.out.println("\nSplitting: " + trainFile); for (n = 0; (tree = reader.nextTree()) != null; n++) { for (k = 0; k < i_K; k++) { if (d_u[n][k] >= d_a) { fout[k].println(tree + "\n"); count[k]++; } } } for (k = 0; k < i_K; k++) { System.out.println(k + ": " + count[k]); } closePrintStreams(fout); } private void splitTestFile(String testFile) { PrintStream[] fout = getPrintStreams(testFile); DepReader reader = new DepReader(testFile, true); DepTree tree; int[] count = new int[i_K]; int n, k; double[] dist; JIntDoubleTuple max = new JIntDoubleTuple(-1, -1); System.out.println("\nSplitting: " + testFile); for (n = 0; (tree = reader.nextTree()) != null; n++) { dist = getCosineDistances(getVector(tree, false)); max.set(-1, Double.MAX_VALUE); for (k = 0; k < i_K; k++) { if (dist[k] < max.d) { max.set(k, dist[k]); } } fout[max.i].println(tree + "\n"); count[max.i]++; } for (k = 0; k < i_K; k++) { System.out.println(k + ": " + count[k]); } closePrintStreams(fout); } private PrintStream[] getPrintStreams(String filename) { PrintStream[] fout = new PrintStream[i_K]; for (int k = 0; k < i_K; k++) { fout[k] = IOUtil.createPrintFileStream(filename + ".k" + i_K + ".m" + d_m + "." + k); } return fout; } private void closePrintStreams(PrintStream[] fout) { for (PrintStream f : fout) { f.close(); } } static public void main(String[] args) { String trainFile = args[0]; String testFile = args[1]; int K = Integer.parseInt(args[2]); double m = Double.parseDouble(args[3]); DepFuzzyCluster depFuzzyCluster = new DepFuzzyCluster(trainFile, testFile, K, m); } }