package com.spbsu.direct.gen; import com.spbsu.commons.func.Action; import com.spbsu.commons.io.StreamTools; import com.spbsu.commons.io.codec.seq.Dictionary; import com.spbsu.commons.math.io.Vec2CharSequenceConverter; import com.spbsu.commons.math.vectors.Vec; import com.spbsu.commons.math.vectors.VecTools; import com.spbsu.commons.math.vectors.impl.vectors.ArrayVec; import com.spbsu.commons.random.FastRandom; import com.spbsu.commons.seq.*; import com.spbsu.commons.util.ArrayTools; import com.spbsu.direct.BroadMatch; import gnu.trove.list.TIntList; import gnu.trove.list.array.TDoubleArrayList; import gnu.trove.map.hash.TObjectDoubleHashMap; import gnu.trove.procedure.TIntDoubleProcedure; import java.io.IOException; import java.io.Writer; import static com.spbsu.commons.math.vectors.VecTools.l1; import static java.lang.Double.max; import static java.lang.Math.exp; import static java.lang.Math.log; /** * User: solar * Date: 12.11.15 * Time: 11:33 */ public class SimpleGenerativeModel { public static final String EMPTY_ID = "##EMPTY##"; private final WordGenProbabilityProvider[] providers; private final Dictionary<CharSeq> dict; private final FastRandom rng = new FastRandom(0); public static final int GIBBS_COUNT = 10; public SimpleGenerativeModel(Dictionary<CharSeq> dict, TIntList freqsLA) { this.dict = dict; this.providers = new WordGenProbabilityProvider[dict.size() + 1]; this.freqs = freqsLA; this.totalFreq = freqsLA.sum(); } public void loadStatistics(String fileName) throws IOException { for (int i = 0; i < providers.length; i++) { providers[i] = new WordGenProbabilityProvider(dict.size(), i); } final Vec2CharSequenceConverter converter = new Vec2CharSequenceConverter(); CharSeqTools.processLines(StreamTools.openTextFile(fileName), (Action<CharSequence>) sequence -> { final CharSequence[] split = CharSeqTools.split(sequence, '\t'); final WordGenProbabilityProvider provider; if (!split[0].equals(EMPTY_ID)) { final CharSequence[] parts = CharSeqTools.split(split[0].subSequence(1, split[0].length() - 1), ", "); final SeqBuilder<CharSeq> builder = new ArraySeqBuilder<>(CharSeq.class); for (final CharSequence part : parts) { builder.add(CharSeq.create(part.toString())); } final int index1 = dict.parse(builder.build()).intAt(0); if (index1 < 0) return; provider = providers[index1]; } else provider = providers[providers.length - 1]; final Vec vec = converter.convertFrom(split[1]); provider.beta = VecTools.copySparse(vec); // optimize storage space }); double totalBigramFreq = 0; for(int i = 0; i < providers.length; i++) { totalBigramFreq += l1(providers[i].beta); } for(int i = 0; i < providers.length; i++) { providers[i].probab = (l1(providers[i].beta) + 1) / (totalBigramFreq + providers.length); } for(int i = 0; i < providers.length; i++) { providers[i].init(providers, dict); } } private int index = 0; private final TDoubleArrayList window = new TDoubleArrayList(1000); private double windowSum = 0; public double totalFreq = 0; public final TIntList freqs; public void processSeq(IntSeq prevQSeq) { for (int i = 0; i < prevQSeq.length(); i++) { int symbol = prevQSeq.intAt(i); if (freqs.size() < symbol) freqs.fill(freqs.size(), symbol + 1, 0); freqs.set(symbol, freqs.get(symbol) + 1); totalFreq++; } } public void processGeneration(IntSeq prevQSeq, IntSeq currentQSeq, double alpha) { if (prevQSeq.length() * currentQSeq.length() > 10) // too many variants of bipartite graph return; final int variantsCount = 1 << (prevQSeq.length() * currentQSeq.length()); final int mask = (1 << currentQSeq.length()) - 1; int bestVariant; double bestLogProBab; { // expectation final Vec weights = new ArrayVec(variantsCount); for (int p = 0; p < variantsCount; p++) { double variantLogProBab = 0; { int variant = p; int generated = 0; for (int i = 0; i < prevQSeq.length(); i++, variant >>= currentQSeq.length()) { final int fragment = variant & mask; generated |= fragment; final int index = prevQSeq.intAt(i); if (index < 0) continue; variantLogProBab += providers[index].logP(fragment, currentQSeq); } for (int i = 0; i < currentQSeq.length(); i++, generated >>= 1) { if ((generated & 1) == 1) continue; variantLogProBab += log(freqs.get(currentQSeq.intAt(i)) + 1.) - log(totalFreq + freqs.size()); } } // Gibbs weights.set(p, variantLogProBab); // { // EM // if (variantLogProBab > bestLogProBab) { // bestLogProBab = variantLogProBab; // bestVariant = p; // } // } } { // Gibbs double sum = 0; double normalizer = weights.get(0); for (int i = 0; i < variantsCount; i++) { weights.set(i, exp(weights.get(i) - normalizer)); sum += weights.get(i); } for (int i = 0; i < GIBBS_COUNT; i++) { bestVariant = rng.nextSimple(weights, sum); bestLogProBab = log(weights.get(bestVariant)) + normalizer; gradientStep(prevQSeq, currentQSeq, alpha / GIBBS_COUNT, mask, bestVariant, bestLogProBab); } } { // EM // gradientStep(prevQSeq, currentQSeq, alpha, mask, bestVariant, bestLogProBab); } } index++; } private void gradientStep(IntSeq prevQSeq, IntSeq currentQSeq, double alpha, int mask, int bestVariant, double bestLogProBab) { // maximization gradient descent step int generated = 0; windowSum += bestLogProBab; window.add(bestLogProBab); final double remove; if (window.size() > 100000) { remove = window.removeAt(0); windowSum -= remove; } boolean debug = BroadMatch.debug && (index % 100000 == 0); if (debug) System.out.print(windowSum / window.size() + "\t" + "\n"); // debug = false; if (debug) System.out.println(prevQSeq + " -> " + currentQSeq + " " + bestLogProBab); double newProb = 0; for (int i = 0; i < prevQSeq.length(); i++, bestVariant >>= currentQSeq.length()) { final int fragment = bestVariant & mask; generated |= fragment; final int windex = prevQSeq.intAt(i); if (windex < 0) continue; providers[windex].update(fragment, currentQSeq, alpha, dict, debug); newProb += providers[windex].logP(fragment, currentQSeq); } if (debug) System.out.print(EMPTY_ID + " ->"); for (int i = 0; i < currentQSeq.length(); i++, generated >>= 1) { if ((generated & 1) == 1) continue; final int windex = currentQSeq.intAt(i); if (debug) System.out.print(dict.get(windex)); newProb += log(freqs.get(windex) + 1.) - log(totalFreq + freqs.size()); } if (debug) System.out.println("\nNew probability: " + newProb); } public void print(Writer out, boolean limit) { for (int i = 0; i < providers.length; i++) { final WordGenProbabilityProvider provider = providers[i]; provider.print(dict, out, limit); } } public void load(String inputFile) throws IOException { CharSeqTools.processLines(StreamTools.openTextFile(inputFile), new Action<CharSequence>() { int index = 0; final StringBuilder builder = new StringBuilder(); public void invoke(CharSequence line) { if (line.equals("}")) { WordGenProbabilityProvider provider = new WordGenProbabilityProvider(builder.toString(), dict); providers[provider.aindex] = provider; builder.delete(0, builder.length()); } else builder.append(line); } }); } public String findTheBestExpansion(ArraySeq<CharSeq> arg) { final StringBuilder builder = new StringBuilder(); final TObjectDoubleHashMap<Seq<CharSeq>> expansionScores = new TObjectDoubleHashMap<>(); final double[] normalize = new double[1]; dict.visitVariants(arg, freqs, totalFreq, (seq, probab) -> { if (probab < -100) return true; for (int i = 0; i < seq.length(); i++) { if (i > 0) builder.append(" "); final int symIndex = seq.intAt(i); visitExpVariants(symIndex, (a, b) -> { // System.out.println(dict.get(a).toString() + " " + b); final double symProbab = b * exp(probab); // double logProbab = log(symProbab); // if (logProbab < 1e-20) // return false; normalize[0] = max(exp(probab), normalize[0]); expansionScores.adjustOrPutValue(dict.get(a), symProbab, symProbab); return true; }, 1.); // builder.append(dict.get(symIndex)); } // builder.append("\t").append(probab).append("\n"); return true; }); //noinspection unchecked final Seq<CharSeq>[] keys = expansionScores.keys(new Seq[expansionScores.size()]); final double[] scores = expansionScores.values(); final int[] order = ArrayTools.sequence(0, keys.length); ArrayTools.parallelSort(scores, order); for (int i = order.length - 1; i >= 0; i--) { final double prob = scores[i] / normalize[0]; if (prob < 1e-7) break; builder.append(keys[order[i]].toString()).append(" -> ").append(prob).append("\n"); } return builder.toString(); } private void visitExpVariants(final int index, TIntDoubleProcedure todo, double genProb) { if (genProb < 1e-10 || index < 0) return; WordGenProbabilityProvider provider = providers[index]; final Seq<CharSeq> phrase = dict.get(index); // System.out.println("Expanding: " + phrase); if (provider != null) { provider.visitVariants((symIndex, symProb) -> { final double currentGenProb = genProb * symProb; final WordGenProbabilityProvider symProvider = providers[symIndex]; if (symProvider != null && symProvider.isMeaningful(index)) { visitExpVariants(symIndex, todo, currentGenProb); todo.execute(symIndex, currentGenProb); } return true; }); } } }