package clear.experiment; import clear.dep.DepNode; import clear.dep.DepTree; import clear.reader.DepReader; import clear.util.IOUtil; import com.carrotsearch.hppc.IntOpenHashSet; import com.carrotsearch.hppc.ObjectIntOpenHashMap; import java.io.PrintStream; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; public class DepAdapt { final int MAX_GAP = 5; final int MAX_ITER = 20; final double STOP = 0.1; ObjectIntOpenHashMap<String> m_lexica; ArrayList<IntOpenHashSet> a_vectors; public DepAdapt(String trainFile, String testFile, String outputFile, double m) { initLexica(trainFile); initVectors(testFile); select(trainFile, outputFile, m); } private void initLexica(String trainFile) { DepReader reader = new DepReader(trainFile, true); DepTree tree; int d = 1; System.out.print("Initializing lexica : "); m_lexica = new ObjectIntOpenHashMap<>(); while ((tree = reader.nextTree()) != null) { for (String key : getLexica(tree, true)) { if (!m_lexica.containsKey(key)) { m_lexica.put(key, d++); } } } reader.close(); System.out.println(m_lexica.size()); } private void initVectors(String testFile) { DepReader reader = new DepReader(testFile, true); DepTree tree; int n; System.out.print("Initializing vectors: "); a_vectors = new ArrayList<>(); for (n = 0; (tree = reader.nextTree()) != null; n++) { a_vectors.add(getVectorSet(tree, false)); } reader.close(); a_vectors.trimToSize(); System.out.println(n); } private IntOpenHashSet getVectorSet(DepTree tree, boolean isTrain) { IntOpenHashSet set = new IntOpenHashSet(); int d; for (String key : getLexica(tree, isTrain)) { if ((d = m_lexica.get(key)) > 0) { set.add(d - 1); } } return set; } private int[] getVectorArray(DepTree tree, boolean isTrain) { IntOpenHashSet set = getVectorSet(tree, isTrain); int[] vector = set.toArray(); Arrays.sort(vector); return vector; } private HashSet<String> getLexica(DepTree tree, boolean isTrain) { HashSet<String> set = new HashSet<>(); addNgramLexica(tree, set); if (isTrain) { addDepTrnLexica(tree, set); } else { addDepTstLexica(tree, set); } return set; } private void addNgramLexica(DepTree tree, HashSet<String> set) { DepNode node, prev; for (int i = 1; i < tree.size(); i++) { node = tree.get(i); set.add(node.lemma); set.add(node.pos); if (i > 1) { prev = tree.get(i - 1); add2gramLexica(set, 0, prev, node); } } } private void addDepTrnLexica(DepTree tree, HashSet<String> set) { DepNode node, prev, next; int dist; for (int i = 1; i < tree.size(); i++) { node = tree.get(i); dist = Math.abs(node.id - node.headId); if (dist > MAX_GAP || node.headId == 0) { continue; } if (node.id < node.headId) { prev = node; next = tree.get(node.headId); } else { prev = tree.get(node.headId); next = node; } add2gramLexica(set, dist, prev, next); } } private void addDepTstLexica(DepTree tree, HashSet<String> set) { DepNode prev, next; int i, dist, size = tree.size(); for (i = 1; i < size; i++) { prev = tree.get(i); for (dist = 1; dist <= MAX_GAP && i + dist < size; dist++) { next = tree.get(i + dist); add2gramLexica(set, dist, prev, next); } } } private void add2gramLexica(HashSet<String> set, int dist, DepNode prev, DepNode next) { String prefix = (dist <= 0) ? "" : dist + "_"; set.add(prefix + prev.lemma + "_" + next.lemma); set.add(prefix + prev.lemma + "_" + next.pos); set.add(prefix + prev.pos + "_" + next.lemma); set.add(prefix + prev.pos + "_" + next.pos); } private void select(String trainFile, String outputFile, double m) { int count = 0; try (PrintStream fout = IOUtil.createPrintFileStream(outputFile)) { DepReader reader = new DepReader(trainFile, true); DepTree tree; int[] vector1; int size = a_vectors.size(); double sim; System.out.print("Selecting trees: "); for (int i = 1; (tree = reader.nextTree()) != null; i++) { vector1 = getVectorArray(tree, true); sim = 0; for (IntOpenHashSet vector2 : a_vectors) { sim += getCosineSimilarity(vector1, vector2); } if (sim / size > m) { fout.println(tree + "\n"); count++; } if (i % 1000 == 0) { System.out.print("."); } } reader.close(); } System.out.println(); System.out.println(count); } private double getCosineSimilarity(int[] vector1, IntOpenHashSet vector2) { double dot = 0; for (int idx : vector1) { if (vector2.contains(idx)) { dot++; } } return dot / (Math.sqrt(vector1.length) * Math.sqrt(vector2.size())); } static public void main(String[] args) { String trainFile = args[0]; String testFile = args[1]; String outputFile = args[2]; double m = Double.parseDouble(args[3]); DepAdapt depAdapt = new DepAdapt(trainFile, testFile, outputFile, m); } }