package clear.experiment;
import clear.dep.DepNode;
import clear.dep.DepTree;
import clear.reader.DepReader;
import com.carrotsearch.hppc.IntOpenHashSet;
import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
public class DepGMCluster {
final int MAX_ITER = 20;
final double SQRT_2PI = Math.sqrt(2 * Math.PI);
final double STOP = 0.01;
ObjectIntOpenHashMap<String> m_lexica;
int[][] a_vectors; // i_N x *
double[][] d_m; // i_D x i_K
double[] d_s; // i_K
double[] p_k; // i_K
double[][] p_kn; // i_N x i_K
double[] d_mScalars; // i_K
int i_K, i_D, i_N;
public DepGMCluster(String trainFile, int K) {
init(trainFile, K);
cluster();
}
public void init(String trainFile, int K) {
i_K = K;
initLexica(trainFile);
initVectors(trainFile);
initDistributions();
}
private void initLexica(String trainFile) {
DepReader reader = new DepReader(trainFile, true);
DepTree tree;
int d = 1;
System.out.print("Initializing lexica : ");
m_lexica = new ObjectIntOpenHashMap<>();
for (i_N = 0; (tree = reader.nextTree()) != null; i_N++) {
for (String key : getLexica(tree)) {
if (!m_lexica.containsKey(key)) {
m_lexica.put(key, d++);
}
}
}
reader.close();
System.out.println((i_D = m_lexica.size()));
}
private void initVectors(String filename) {
DepReader reader = new DepReader(filename, true);
DepTree tree;
System.out.print("Initializing vectors: ");
a_vectors = new int[i_N][];
for (int n = 0; (tree = reader.nextTree()) != null; n++) {
a_vectors[n] = getVector(tree);
}
reader.close();
System.out.println(i_N);
}
private void initDistributions() {
System.out.print("Initializing priors : ");
initPriors();
System.out.println(i_K + " x " + i_N);
System.out.print("Initializing means : ");
initMeans();
System.out.println(i_K);
System.out.print("Initializing stdevs : ");
initStdevs();
System.out.println(i_K);
}
private void initPriors() {
double p = 1d / i_K;
p_k = new double[i_K];
p_kn = new double[i_N][i_K];
Arrays.fill(p_k, p);
}
private void initMeans() {
IntOpenHashSet set = new IntOpenHashSet();
Random rand = new Random(0);
int[] vector;
int n, k = 0;
d_m = new double[i_D][i_K];
d_mScalars = new double[i_K];
while (set.size() < i_K) {
n = rand.nextInt(i_N);
if (!set.contains(n)) {
vector = a_vectors[n];
set.add(n);
for (int d : vector) {
d_m[d][k] = 1;
}
d_mScalars[k++] = Math.sqrt(vector.length);
}
}
}
private void initStdevs() {
double[] dist;
int n, k;
d_s = new double[i_K];
for (n = 0; n < i_N; n++) {
dist = getCosineDistances(a_vectors[n]);
for (k = 0; k < i_K; k++) {
d_s[k] += dist[k];
}
}
for (k = 0; k < i_K; k++) {
d_s[k] = Math.sqrt(d_s[k] / i_N);
}
}
private HashSet<String> getLexica(DepTree tree) {
HashSet<String> set = new HashSet<>();
addDepLexica(tree, set);
return set;
}
protected void addDepLexica(DepTree tree, HashSet<String> set) {
DepNode node, head;
for (int i = 1; i < tree.size(); i++) {
node = tree.get(i);
if (node.headId < 0) {
continue;
}
addDep1gramLexica(set, node, "<", node.deprel);
if (node.headId == 0) {
continue;
}
head = tree.get(node.headId);
addDep1gramLexica(set, head, ">", node.deprel);
if (node.id < head.id) {
addDep2gramLexica(set, node, head, "<", "");
addDep2gramLexica(set, node, head, "<", node.deprel);
} else {
addDep2gramLexica(set, head, node, ">", "");
addDep2gramLexica(set, head, node, ">", node.deprel);
}
if (head.headId < 0) {
continue;
}
addDep3gramLexica(set, node, head, tree.get(head.headId));
}
}
private void addDep1gramLexica(HashSet<String> set, DepNode node, String dir, String deprel) {
set.add(node.lemma + dir + deprel);
set.add(node.pos + dir + deprel);
}
private void addDep2gramLexica(HashSet<String> set, DepNode prev, DepNode next, String dir, String deprel) {
String label = dir + deprel;
set.add(prev.lemma + "_" + next.lemma + label);
set.add(prev.lemma + "_" + next.pos + label);
set.add(prev.pos + "_" + next.lemma + label);
set.add(prev.pos + "_" + next.pos + label);
}
private void addDep3gramLexica(HashSet<String> set, DepNode node, DepNode head, DepNode grandHead) {
set.add(node.pos + "_" + head.pos + "_" + grandHead.pos);
}
private int[] getVector(DepTree tree) {
IntOpenHashSet set = new IntOpenHashSet();
int d;
for (String key : getLexica(tree)) {
if ((d = m_lexica.get(key)) > 0) {
set.add(d - 1);
}
}
int[] vector = set.toArray();
Arrays.sort(vector);
return vector;
}
public void cluster() {
double prevScore = 0, currScore = expectation();
System.out.println("- Score : " + currScore);
for (int i = 0; i < MAX_ITER && Math.abs(currScore - prevScore) > STOP; i++) {
System.out.println("\nIteration: " + i);
prevScore = currScore;
System.out.println("- M-step");
maximization();
System.out.println("- E-step");
currScore = expectation();
System.out.println("- Score : " + currScore);
}
Random rand = new Random();
for (int i = 0; i < 10; i++) {
System.out.println(Arrays.toString(p_kn[rand.nextInt(i_N)]));
}
}
private double expectation() {
double[] memberships;
double sum, score = 0;
int k, n;
for (n = 0; n < i_N; n++) {
memberships = getMemberships(a_vectors[n]);
sum = 0;
for (k = 0; k < i_K; k++) {
sum += memberships[k];
}
for (k = 0; k < i_K; k++) {
p_kn[n][k] = memberships[k] / sum;
}
score += Math.log(sum);
}
return score;
}
private double[] getMemberships(int[] vector) {
double[] memberships = new double[i_K];
double[] dist = getCosineDistances(vector);
double stdev;
for (int k = 0; k < i_K; k++) {
stdev = d_s[k];
memberships[k] = p_k[k] * (1 / (SQRT_2PI * stdev)) * Math.exp(-0.5 * dist[k] * (stdev * stdev));
}
return memberships;
}
private void maximization() {
maximizePriors();
maximizeMeans();
maximizeStdDevs();
for (int k = 0; k < i_K; k++) {
p_k[k] /= i_N;
}
}
private void maximizePriors() {
int n, k;
Arrays.fill(p_k, 0);
for (n = 0; n < i_N; n++) {
for (k = 0; k < i_K; k++) {
p_k[k] += p_kn[n][k];
}
}
}
private void maximizeMeans() {
int n, k, d;
double m;
for (d = 0; d < i_D; d++) {
Arrays.fill(d_m[d], 0);
}
for (n = 0; n < i_N; n++) {
for (int idx : a_vectors[n]) {
for (k = 0; k < i_K; k++) {
d_m[idx][k] += p_kn[n][k];
}
}
}
Arrays.fill(d_mScalars, 0);
for (d = 0; d < i_D; d++) {
for (k = 0; k < i_K; k++) {
m = d_m[d][k] / p_k[k];
d_m[d][k] = m;
d_mScalars[k] += m * m;
}
}
for (k = 0; k < i_K; k++) {
d_mScalars[k] = Math.sqrt(d_mScalars[k]);
}
}
private void maximizeStdDevs() {
double[] dist;
int n, k;
Arrays.fill(d_s, 0);
for (n = 0; n < i_N; n++) {
dist = getCosineDistances(a_vectors[n]);
for (k = 0; k < i_K; k++) {
d_s[k] += p_kn[n][k] * dist[k];
}
}
for (k = 0; k < i_K; k++) {
d_s[k] = Math.sqrt(d_s[k] / p_k[k]);
}
}
private double[] getCosineDistances(int[] vector) {
double[] dots = new double[i_K];
double[] mean;
double scalar = Math.sqrt(vector.length);
int k;
for (int d : vector) {
mean = d_m[d];
for (k = 0; k < i_K; k++) {
dots[k] += mean[k];
}
}
for (k = 0; k < i_K; k++) {
dots[k] = 1 - (dots[k] / (d_mScalars[k] * scalar));
}
return dots;
}
static public void main(String[] args) {
String trainFile = args[0];
int K = Integer.parseInt(args[1]);
DepGMCluster depGMCluster = new DepGMCluster(trainFile, K);
}
}