package clear.experiment;
import clear.dep.DepNode;
import clear.dep.DepTree;
import clear.reader.SRLReader;
import clear.util.IOUtil;
import clear.util.cluster.Kmeans;
import clear.util.tuple.JIntDoubleTuple;
import com.carrotsearch.hppc.IntDoubleOpenHashMap;
import com.carrotsearch.hppc.IntOpenHashSet;
import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
public class DepSelect {
private ObjectIntOpenHashMap<String> m_ftr;
private JIntDoubleTuple[][] a_unit;
private int N;
/**
* @param cutoff feature cutoff (inclusive).
*/
public DepSelect(String inputFile, String outputFile, int cutoff, int K, double threshold, float portion) {
addLexica(inputFile);
configureMap(cutoff);
generateUnitClusters(inputFile);
printSubset(inputFile, outputFile, getSubsetIndices(K, threshold, portion));
}
public void addLexica(String inputFile) {
SRLReader reader = new SRLReader(inputFile, true);
DepTree tree;
String vb = "# of trees : ";
System.out.print(vb);
m_ftr = new ObjectIntOpenHashMap<>();
for (N = 0; (tree = reader.nextTree()) != null; N++) {
for (String key : getFeatures(tree)) {
m_ftr.put(key, m_ftr.get(key) + 1);
}
if (N % 10000 == 0) {
System.out.print(".");
}
}
reader.close();
System.out.println("\r" + vb + N);
}
public void configureMap(int cutoff) {
ObjectIntOpenHashMap<String> map = new ObjectIntOpenHashMap<>();
int value, count = 1;
System.out.print("# of features: ");
for (ObjectCursor<String> cur : m_ftr.keys()) {
value = m_ftr.get(cur.value);
if (value >= cutoff) {
map.put(cur.value, count++);
}
}
m_ftr.clear();
m_ftr.putAll(map);
System.out.println(map.size());
}
public void generateUnitClusters(String inputFile) {
SRLReader reader = new SRLReader(inputFile, true);
DepTree tree;
int i;
System.out.print("Generating unit vectors: ");
a_unit = new JIntDoubleTuple[N][];
for (i = 0; (tree = reader.nextTree()) != null; i++) {
a_unit[i] = generateUnitCluster(tree);
if (i % 10000 == 0) {
System.out.print(".");
}
}
reader.close();
System.out.println();
}
private JIntDoubleTuple[] generateUnitCluster(DepTree tree) {
IntDoubleOpenHashMap map = new IntDoubleOpenHashMap();
ArrayList<String> lsFtr = getFeatures(tree);
int index;
for (String key : lsFtr) {
if ((index = m_ftr.get(key) - 1) >= 0) {
map.put(index, map.get(index) + 1);
}
}
int[] indices = map.keys().toArray();
// int size = tree.size();
JIntDoubleTuple[] tup = new JIntDoubleTuple[map.size()];
Arrays.sort(indices);
index = 0;
for (int i : indices) {
tup[index++] = new JIntDoubleTuple(i, map.get(i));
}
return tup;
}
public IntOpenHashSet getSubsetIndices(int K, double threshold, float portion) {
Kmeans km = new Kmeans(a_unit, m_ftr.size());
ArrayList<ArrayList<JIntDoubleTuple>> cluster = km.cluster(K, threshold);
IntOpenHashSet sAll = new IntOpenHashSet(), sSub;
ArrayList<JIntDoubleTuple> ck;
Random rand = new Random(0);
int k, nk, nSub;
for (k = 0; k < K; k++) {
ck = cluster.get(k);
nk = ck.size();
nSub = Math.round(portion * nk);
sSub = new IntOpenHashSet(nSub);
// Collections.sort(ck);
while (sSub.size() < nSub) {
sSub.add(ck.get(rand.nextInt(nk)).i);
}
sAll.addAll(sSub);
}
return sAll;
}
public void printSubset(String inputFile, String outputFile, IntOpenHashSet set) {
SRLReader reader = new SRLReader(inputFile, true);
DepTree tree;
System.out.print("Printing: ");
PrintStream fout = IOUtil.createPrintFileStream(outputFile);
for (int i = 0; (tree = reader.nextTree()) != null; i++) {
if (set.contains(i)) {
fout.println(tree + "\n");
}
if (i % 10000 == 0) {
System.out.print(".");
}
}
reader.close();
fout.close();
System.out.println();
}
ArrayList<String> getFeatures(DepTree tree) {
ArrayList<String> lsFtr = new ArrayList<>();
int[] iFtr = new int[1];
DepNode curr;
tree.setSubcat();
for (int i = 1; i < tree.size(); i++) {
iFtr[0] = 0;
curr = tree.get(i);
getNgramFeatures(lsFtr, iFtr, tree, curr);
getDepFeatures(lsFtr, iFtr, tree, curr);
}
return lsFtr;
}
void getNgramFeatures(ArrayList<String> lsFtr, int[] iFtr, DepTree tree, DepNode curr) {
// 1-gram
lsFtr.add(getFeature(iFtr, curr.form));
lsFtr.add(getFeature(iFtr, curr.pos, curr.lemma));
// 2-gram
DepNode prev1 = null;
if (curr.id - 1 > 0) {
prev1 = tree.get(curr.id - 1);
lsFtr.add(getFeature(iFtr, prev1.pos, curr.pos));
lsFtr.add(getFeature(iFtr, prev1.lemma, curr.pos));
lsFtr.add(getFeature(iFtr, prev1.pos, curr.lemma));
lsFtr.add(getFeature(iFtr, prev1.lemma, curr.lemma));
} else {
iFtr[0] += 4;
}
// 3-gram
DepNode prev2;
if (curr.id - 2 > 0) {
prev2 = tree.get(curr.id - 2);
lsFtr.add(getFeature(iFtr, prev2.pos, prev1.pos, curr.pos));
} else {
iFtr[0] += 1;
}
}
void getDepFeatures(ArrayList<String> lsFtr, int[] iFtr, DepTree tree, DepNode curr) {
DepNode head = tree.get(curr.headId);
String dir = (curr.id < head.id) ? "<" : ">";
lsFtr.add(getFeature(iFtr, dir, curr.pos, head.pos));
lsFtr.add(getFeature(iFtr, dir, curr.lemma, head.pos));
lsFtr.add(getFeature(iFtr, dir, curr.pos, head.lemma));
lsFtr.add(getFeature(iFtr, dir, curr.lemma, head.lemma));
lsFtr.add(getFeature(iFtr, curr.deprel, curr.pos, head.pos));
lsFtr.add(getFeature(iFtr, curr.deprel, curr.lemma, head.pos));
lsFtr.add(getFeature(iFtr, curr.deprel, curr.pos, head.lemma));
lsFtr.add(getFeature(iFtr, curr.deprel, curr.lemma, head.lemma));
}
String getFeature(int[] iFtr, String... ftr) {
StringBuilder build = new StringBuilder();
build.append(iFtr[0]++);
for (String s : ftr) {
build.append("_");
build.append(s);
}
return build.toString();
}
static public void main(String[] args) {
String inputFile = args[0];
String outputFile = args[1];
int cutoff = Integer.parseInt(args[2]);
int K = Integer.parseInt(args[3]);
double threshold = Double.parseDouble(args[4]);
float portion = Float.parseFloat(args[5]);
DepSelect depSelect =
new DepSelect(inputFile, outputFile, cutoff, K, threshold, portion);
}
}