package syntaxLearner; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.SortedSet; import java.util.TreeMap; import java.util.TreeSet; import syntaxLearner.UI.Console; import syntaxLearner.corpus.*; public class Learner { /* Parameters */ //If a distance between two clusters is below this, use their union final double IDENTITY_EPSILON; //The algorithm halts when this portion has been clustered final float HALTING_RATIO=0.9f; //a word is "rare" if it appears less than this many times in the corpus final int RARE_WORD_THRESHOLD; final int NUMBER_OF_CLUSTERS; /* Class Variables */ private short learnerID = 0; private Corpus corpus; private Map<Short, Cluster> clusters; private Cluster groundCluster; private Set<Short> updatedClusters; private int numOfRareWords; protected short[] parents; //is very large private int iterationCounter = 0; private Recorder rec; //helper iterable set that saves time when calculating distances private Set<ClusterContext> clusterContexts; //Iterator that keeps track of word heirarchy Iterator<Integer> backupIterator; public Learner(int clusters, int threshold, double epsilon){ IDENTITY_EPSILON = (epsilon<=0 ? 0.001 : epsilon); NUMBER_OF_CLUSTERS = (clusters<=1 ? 50 : clusters); RARE_WORD_THRESHOLD = (threshold<=1? 50: threshold); } /* Initializer */ public void setCorpus(Corpus c){ this.corpus = c; clusters = new HashMap<Short,Cluster>(); groundCluster = new Cluster(corpus.getVocabulary(), this, true); updatedClusters = new HashSet<Short>(); clusters.put((short)(-1), groundCluster); clusterContexts = new TreeSet<ClusterContext>(); } /** * Step 1 of the Algorithm */ private void prepareClusters(){ Vocabulary vocab = corpus.getVocabulary(); numOfRareWords = vocab.countWordsBelowThreshold(RARE_WORD_THRESHOLD); Console.line("Words: "+vocab.getNumOfWords()); Console.line("Rare words: "+numOfRareWords); Console.line("Words to scan: "+(vocab.getNumOfWords()-numOfRareWords)); SortedSet<Integer> wordHeirarchy = vocab.getWordHierarchy(); //These two work in unison: Iterator<Integer> iter = wordHeirarchy.iterator(); backupIterator = wordHeirarchy.iterator(); for (int i=0;i<NUMBER_OF_CLUSTERS;i++){ //advance in unison, so backupIterator will remember the next word in the hierarchy int l = iter.next(); backupIterator.next(); //make sure you're not clustering a start / end symbol Word w = vocab.getWord(l); if ((!(w.equals(vocab.START_SYMBOL))) && (!(w.equals(vocab.END_SYMBOL)))) { Cluster c = new Cluster(vocab, this); c.add(l); Console.line(String.format("Created cluster #%1$2s with stem: \"%2$-1s\"", c.ID,w.name)); clusters.put(c.getID(), c); } else { //otherwise, skip a step without changing the placemarker i--; continue; } } while (iter.hasNext()){ int l = iter.next(); groundCluster.add(l); } //Associate START_SYMBOL and END_SYMBOL groundCluster.add(vocab.START_SYMBOL.ID); groundCluster.add(vocab.END_SYMBOL.ID); // Now build a helper set of all possible coordinates getClusterContexts(); } /** * Check to see if any distance between two clusters falls below epsilon * (defined above), then update parenthood, add everything to c1, * and reset c2. */ private boolean unifyCloseClusters() { boolean unionOccured = false; for (Cluster c1 : clusters.values()){ for (Cluster c2 : clusters.values()){ double dist; if ((c1.getID() != c2.getID())&& (c1.getID()!=-1) && (c2.getID()!=-1) && ((dist = distance(c1,c2))<IDENTITY_EPSILON) && (!c2.isNew())){ //NOTE: 'add(cluster)' also transfers parenthood //Check if there is another word to add and it hasn't been added, else return false to halt boolean isNextClustered = true; int next=-1; //fail-fast while (backupIterator.hasNext() && (isNextClustered)){ next = backupIterator.next(); isNextClustered = (parents[next] != -1); } Word w = corpus.getVocabulary().getWord(next); if (w.frequency < RARE_WORD_THRESHOLD) { //the word is rare, halt return false; } else { c1.add(c2); c2.reset(); c2.add(next); c2.setNew(true); groundCluster.remove(next); unionOccured=true; String message = String.format("Cluster #%1$2s merged into #%2$2s at distance [%3$-8g]. " + "Cluster #%1$2s recreated with stem: \"%4$1s\"", c2.ID, c1.ID, dist, w.name); Console.line(message); //updatedClusters.clear(); TODO See if necessary. Now only doing this at the end. //corpus.getVocabulary().purgeUpdatedWords(); } } } //OPTIMIZATION if something was pushed into c1, recalculate now. TODO check if this works if (unionOccured) { updatedClusters.clear(); unionOccured = false; } } for (Cluster c : clusters.values()){ c.setNew(false); } return true; } /** * MAIN ALGORITHM METHOD */ private void clusterCommonWords(){ /* Safety Assertions */ assert (HALTING_RATIO > 0 && HALTING_RATIO < 1); assert (NUMBER_OF_CLUSTERS > 0); /* Step 1 */ corpus.buildDB(); parents = new short[corpus.getVocabulary().getNumOfWords()+1]; prepareClusters(); int size = corpus.getVocabulary().getNumOfWords(); /* * Data structure to save candidate lists for adding to clusters. * Data is organized: <Distance,wordIndex> inside a TreeMap for every Cluster: * <ClusterIndex,TreeMap<K,V>> */ //TODO Change description TreeMap<Integer,TreeMap<Double,Integer>> candidateLists = new TreeMap<Integer,TreeMap<Double,Integer>>(); for (int i : clusters.keySet()){ TreeMap<Double,Integer> closestValues = new TreeMap<Double,Integer>(); candidateLists.put(i, closestValues); } Console.line("Algorithm launched."); rec.recordCorpusData(corpus, this); double percentageTracker; /* Here be iterations */ mainLoop: while ((percentageTracker= (1.0*size-groundCluster.wordCount) /(size-numOfRareWords)) < HALTING_RATIO){ iterationCounter++; //unification done before iterations begin, not after they end. if (iterationCounter>0) { Console.line("Unifying clusters."); if(!unifyCloseClusters()){ //if there are no words left, we're done break mainLoop; } } //Automatically deletes the update list if necessary updatedClusters.clear(); corpus.getVocabulary().purgeUpdatedWords(); //List sortable Words TODO use this list for the algorithm itself, instead of just for the recording StringBuilder sb = new StringBuilder(); sb.append("[ "); for (int index: groundCluster.words){ if (corpus.getVocabulary().getWord(index).frequency >= RARE_WORD_THRESHOLD){ sb.append(Integer.toString(index)+", "); } } sb.deleteCharAt(sb.length()-2); sb.append("]"); //update cluster records rec.recordNewIteration(sb.toString()); for (Cluster c : clusters.values()){ if (c.ID>=0){ rec.recordClusterInfo(c); } } Console.line("Calculating distances"); //int displayCounter = 0; for (int index: groundCluster.words){ Word w = corpus.getVocabulary().getWord(index); //Data structure: <distance,cluster Index > TreeMap <Double, Short> distances = new TreeMap<Double, Short>(); if (w.frequency >= RARE_WORD_THRESHOLD){ for (Cluster c : clusters.values()){ //if not ground cluster, calculate distance and store if (c.getID()!=-1) { distances.put(distance(w,c),c.getID()); } } } if (!distances.isEmpty()){ //Build Vector StringBuilder distributionVector = new StringBuilder(); distributionVector.append("["); Iterator<ClusterContext> iter = clusterContexts.iterator(); while (iter.hasNext()){ distributionVector.append(String.format("%1$.5f", w.clusterDistribution(iter.next()))); if (iter.hasNext()){ distributionVector.append(", "); } } distributionVector.append("]"); rec.recordWordInfo(w, distances, distributionVector); int closestCluster = distances.firstEntry().getValue(); double distanceToClosestCluster = distances.firstEntry().getKey(); candidateLists.get(closestCluster).put(distanceToClosestCluster,w.ID); distances.remove(distances.firstEntry().getKey()); double distanceGap = distances.firstEntry().getKey()- distanceToClosestCluster; String message = String.format("\"%1$-15s\" -> %2$-2s [%3$-8g] Count: %7$-8s ; Next: %4$2s [%5$-8g] ; Gap: [%6$-8g]", w.name, closestCluster, distanceToClosestCluster, distances.firstEntry().getValue(), distances.firstEntry().getKey(), distanceGap, w.frequency); Console.line(message); //Console.text((++displayCounter)+" Calculated\r"); } } Console.line(" "); //Throttles the insertion //TODO find some insertion strategy better than this int insertionLimit = iterationCounter<10? iterationCounter : 10; boolean insertionOccured=false; addToClusterByDistance: for (Map.Entry<Integer, TreeMap<Double,Integer>> candidateList : candidateLists.entrySet()){ Iterator<Integer> iter = candidateList.getValue().values().iterator(); for (int i=0;i<(5+iterationCounter*2);i++){ if (!iter.hasNext()) { continue addToClusterByDistance; } else { int next = iter.next(); Word w = corpus.getVocabulary().getWord(next); Cluster parent = w.getParent(); parent.remove((int)w.ID); int clusterKey = candidateList.getKey(); clusters.get((short)clusterKey).add(w.ID); insertionOccured=true; } } } Console.line(" "); //TODO Complete Algorithm //Report Vocabulary vocab = corpus.getVocabulary(); for (Cluster c : clusters.values()){ if (c.ID!=-1){ Console.line("Cluster "+c.ID+" :\n********************"); for(int l : c.words){ Console.text(vocab.getWord(l).name+", "); } Console.line(""); } } double newPercentage = 1.0*(size-groundCluster.wordCount)/(size-numOfRareWords); Console.text("Iteration "+iterationCounter+" complete. "); Console.text(Double.toString(100.0*newPercentage)); Console.text("% clustered."); //if no change, halt if (!insertionOccured){ break mainLoop;} } } public short newClusterID() { return learnerID++; } protected boolean isClusterUpdated(Cluster c){ return updatedClusters.contains(c.getID()); } /** * Updates are wiped every iteration, so updating these is important. * @param c */ protected void registerClusterUpdate(Cluster c){ updatedClusters.add(c.getID()); } /** * * @param a A cluster * @param b Another cluster * @return the KLD: D(a||b) */ private double distance (Cluster a, Cluster b){ double sum = 0; for (ClusterContext cc : clusterContexts){ double aDist = a.clusterDistribution()[(int) cc.type1+1][(int) cc.type2+1]; double bDist = b.clusterDistribution()[(int) cc.type1+1][(int) cc.type2+1]; sum += aDist * Math.log( aDist / bDist ); } sum = Math.abs(sum); return sum; } /** * @param w A word * @param c A cluster * @return The KLD: D(w||c) */ private double distance (Word w, Cluster c){ double sum = 0; for (ClusterContext cc : clusterContexts){ double wDist = w.clusterDistribution(cc); double cDist = c.clusterDistribution()[(int) cc.type1+1][(int) cc.type2+1]; sum +=wDist * Math.log( wDist / cDist ); } sum = Math.abs(sum); return sum; } public int getNumOfClusters() { return NUMBER_OF_CLUSTERS; } public Cluster getCluster(short index){ if (index < 0){ return groundCluster; } else { return clusters.get(index); } } public Set<ClusterContext> getClusterContexts() { if (clusterContexts.isEmpty()){ for (short i : clusters.keySet()) { for (short j : clusters.keySet()){ clusterContexts.add(new ClusterContext(i, j)); } } } return clusterContexts; } public void learn(){ clusterCommonWords(); } public void setParent(int i, short id) { this.parents[i]=id; } public short getParent(int i){ return this.parents[i]; } public int getIterationCount(){ return iterationCounter; } public void setRecorder (Recorder rec){ this.rec = rec; } }