/* This file is part of the Joshua Machine Translation System. * * Joshua is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 * of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free * Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, * MA 02111-1307 USA */ package joshua.corpus.lexprob; import java.io.File; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.io.PrintStream; import java.util.ArrayList; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import joshua.corpus.Corpus; import joshua.corpus.CorpusArray; import joshua.corpus.MatchedHierarchicalPhrases; import joshua.corpus.alignment.Alignments; import joshua.corpus.suffix_array.BasicPhrase; import joshua.corpus.suffix_array.HierarchicalPhrase; import joshua.corpus.suffix_array.SuffixArray; import joshua.corpus.suffix_array.SuffixArrayFactory; import joshua.corpus.suffix_array.Suffixes; import joshua.corpus.vocab.SymbolTable; import joshua.util.Cache; import joshua.util.Pair; /** * Represents lexical probability distributions in both directions. * <p> * This class calculates the probabilities by sampling directly * from a parallel corpus. * * @author Lane Schwartz * @version $LastChangedDate:2008-11-13 13:13:31 -0600 (Thu, 13 Nov 2008) $ * @deprecated */ public class SampledLexProbs extends AbstractLexProbs { /** Logger for this class. */ private static final Logger logger = Logger.getLogger(SampledLexProbs.class.getName()); private final Cache<Integer,Map<Integer,Float>> sourceGivenTarget; private final Cache<Integer,Map<Integer,Float>> targetGivenSource; private final Suffixes sourceSuffixArray; private final Suffixes targetSuffixArray; /** Corpus array representing the target language corpus. */ final Corpus targetCorpus; /** * Represents alignments between words in the source corpus * and the target corpus. */ private final Alignments alignments; private final SymbolTable sourceVocab; private final SymbolTable targetVocab; private final float floorProbability; /** * When calculating probabilities, if a probability is less * than this value, do not explicitly store it. */ private final float thresholdProbability; private final int sampleSize; public SampledLexProbs(int sampleSize, Suffixes sourceSuffixArray, Suffixes targetSuffixArray, Alignments alignments, int cacheCapacity, boolean precalculate) { this.sampleSize = sampleSize; this.sourceSuffixArray = sourceSuffixArray; this.targetSuffixArray = targetSuffixArray; this.targetCorpus = targetSuffixArray.getCorpus(); this.alignments = alignments; this.sourceVocab = sourceSuffixArray.getVocabulary(); this.targetVocab = targetSuffixArray.getVocabulary(); this.thresholdProbability = 1.0f/(sampleSize*100); //TODO come up with a good value for this this.floorProbability = 1.0f/(sampleSize*100); this.sourceGivenTarget = new Cache<Integer,Map<Integer,Float>>(cacheCapacity); this.targetGivenSource = new Cache<Integer,Map<Integer,Float>>(cacheCapacity); if (precalculate) { for (int sourceWord : sourceVocab.getAllIDs()) { calculateTargetGivenSource(sourceWord); } for (int targetWord : targetVocab.getAllIDs()) { calculateSourceGivenTarget(targetWord); } } } public String toString() { StringBuilder s = new StringBuilder(); s.append("SampledLexProbs size information:"); s.append('\n'); s.append(sourceGivenTarget.size() + " target sides in sourceGivenTarget"); s.append('\n'); int count = 0; for (Map<Integer, Float> entry : sourceGivenTarget.values()) { count += entry.size(); } s.append(count + " source-target pairs in sourceGivenTarget"); s.append('\n'); s.append(targetGivenSource.size() + " source sides in targetGivenSource"); s.append('\n'); count = 0; for (Map<Integer, Float> entry : targetGivenSource.values()) { count += entry.size(); } s.append(count + " target-source pairs in targetGivenSource"); s.append('\n'); return s.toString(); } /** * For unit testing. * * @param sourceCorpusString * @param targetCorpusString * @param alignmentString * @return * @throws IOException */ public static SampledLexProbs getSampledLexProbs(String sourceCorpusString, String targetCorpusString, String alignmentString) throws IOException { String sourceFileName; { File sourceFile = File.createTempFile("source", new Date().toString()); PrintStream sourcePrintStream = new PrintStream(sourceFile); sourcePrintStream.println(sourceCorpusString); sourcePrintStream.close(); sourceFileName = sourceFile.getAbsolutePath(); } String targetFileName; { File targetFile = File.createTempFile("target", new Date().toString()); PrintStream targetPrintStream = new PrintStream(targetFile); targetPrintStream.println(targetCorpusString); targetPrintStream.close(); targetFileName = targetFile.getAbsolutePath(); } String alignmentFileName; { File alignmentFile = File.createTempFile("alignment", new Date().toString()); PrintStream alignmentPrintStream = new PrintStream(alignmentFile); alignmentPrintStream.println(alignmentString); alignmentPrintStream.close(); alignmentFileName = alignmentFile.getAbsolutePath(); } CorpusArray sourceCorpusArray = SuffixArrayFactory.createCorpusArray(sourceFileName); SuffixArray sourceSuffixArray = SuffixArrayFactory.createSuffixArray(sourceCorpusArray, SuffixArray.DEFAULT_CACHE_CAPACITY); CorpusArray targetCorpusArray = SuffixArrayFactory.createCorpusArray(targetFileName); SuffixArray targetSuffixArray = SuffixArrayFactory.createSuffixArray(targetCorpusArray, SuffixArray.DEFAULT_CACHE_CAPACITY); Alignments alignmentArray = SuffixArrayFactory.createAlignments(alignmentFileName, sourceSuffixArray, targetSuffixArray); return new SampledLexProbs(Integer.MAX_VALUE, sourceSuffixArray, targetSuffixArray, alignmentArray, Cache.DEFAULT_CAPACITY, false); } /** * Calculates the lexical probability of a source word given * a target word. * <p> * If this information has not previously been stored, this * method calculates it. * * @param sourceWord * @param targetWord * @return */ public float sourceGivenTarget(Integer sourceWord, Integer targetWord) { if (logger.isLoggable(Level.FINE)) logger.fine("Need to get source given target lexprob p(" + sourceVocab.getWord(sourceWord) + " | " + targetVocab.getWord(targetWord) + "); sourceWord ID == " + sourceWord + "; targetWord ID == " + targetWord); if (!sourceGivenTarget.containsKey(targetWord)) { calculateSourceGivenTarget(targetWord); } Map<Integer,Float> map = sourceGivenTarget.get(targetWord); if (map.containsKey(sourceWord)) { return sourceGivenTarget.get(targetWord).get(sourceWord); } else { if (logger.isLoggable(Level.FINE)) logger.fine("No source given target lexprob found for p(" + sourceVocab.getWord(sourceWord) + " | " + targetVocab.getWord(targetWord) + "); returning FLOOR_PROBABILITY " + floorProbability); return floorProbability; } } /** * * @param targetWord * @param sourceWord * @return */ public float targetGivenSource(Integer targetWord, Integer sourceWord) { if (logger.isLoggable(Level.FINE)) logger.fine("Need to get target given source lexprob p(" + targetVocab.getWord(targetWord) + " | " + sourceVocab.getWord(sourceWord) + "); sourceWord ID == " + sourceWord + "; targetWord ID == " + targetWord); if (!targetGivenSource.containsKey(sourceWord)) { calculateTargetGivenSource(sourceWord); } Map<Integer,Float> map = targetGivenSource.get(sourceWord); if (map.containsKey(targetWord)) { return map.get(targetWord); } else { if (logger.isLoggable(Level.FINE)) logger.fine("No target given source lexprob found for p(" + targetVocab.getWord(targetWord) + " | " + sourceVocab.getWord(sourceWord) + "); returning FLOOR_PROBABILITY " + floorProbability + "; sourceWord ID == " + sourceWord + "; targetWord ID == " + targetWord); return floorProbability; } } /** * * @param sourceWord * @param targetWord * @return */ public float sourceGivenTarget(String sourceWord, String targetWord) { int targetID = targetVocab.getID(targetWord); int sourceID = sourceVocab.getID(sourceWord); return sourceGivenTarget(sourceID, targetID); } /** * * @param targetWord * @param sourceWord * @return */ public float targetGivenSource(String targetWord, String sourceWord) { if (logger.isLoggable(Level.FINER)) logger.finer("Need to get target given source lexprob p(" + targetWord + " | " + sourceWord + "); sourceID==" +sourceVocab.getID(sourceWord) + "; targetID=="+targetVocab.getID(targetWord)); int targetID = targetVocab.getID(targetWord); int sourceID = sourceVocab.getID(sourceWord); return targetGivenSource(targetID, sourceID); } /** * Calculates the lexical translation probabilities (in * both directions) for a specific instance of a source * phrase in the corpus. * <p> * This method does NOT currently handle NULL aligned points * according to Koehn et al (2003). This may change in * future releases. * <p> * The problem arises when we need to calculate the * word-to-word lexical weights using the sourceGivenTarget * and targetGivenSource methods (actual calculations occur * in calculateSourceGivenTarget and calculateTargetGivenSource). * <p> * Let's say we want to calculate P(s14 | t75). (s14 is a * source word, t75 is a target word) We call sourceGivenTarget * and see that we haven't calculated the map for P(? | t75), * so we call calculateSourceGivenTarget(t75). * <p> * The calculateSourceGivenTarget method looks up all * instances of t75 in the target suffix array. It then * samples some of those instances and looks up the aligned * source word(s) for each sampled target word. Based on * that, probabilities are calculated and stored. * <p> * Now, what happens if instead of t75, we have NULL? * <p> * The calculateSourceGivenTarget cannot look up all instances * of NULL in the target suffix array. This is a problem. * <p> * We have access to all the information we need to calculate * null lexical translation probabilities. But, this would * probably be best done as a pre-process. * <p> * One possible solution would be to have a pre-process * that steps through each line in the alignment array to * find null alignment points and calculates null probabilities * at that point. * * @param sourcePhrases * @param sourcePhraseIndex * @param targetPhrase * @return the lexical probability and reverse lexical * probability */ public Pair<Float,Float> calculateLexProbs(MatchedHierarchicalPhrases sourcePhrases, int sourcePhraseIndex, HierarchicalPhrase targetPhrase) { // HierarchicalPhrase sourcePhrase = sourcePhrases.get(sourcePhraseIndex, sourceSuffixArray.getCorpus()); float sourceGivenTarget = 1.0f; Map<Integer,List<Integer>> reverseAlignmentPoints = new HashMap<Integer,List<Integer>>(); // Iterate over each terminal sequence in the source phrase // for (int seq=0; seq<sourcePhrases.size(); seq++) { // int a = sourcePhrase.terminalSequenceStartIndices.length; // int b = sourcePhrases.size(); // if (a!=b) { // int x=0; x++; // // int z = sourcePhrases.getNumberOfTerminalSequences(); // Corpus corpus = sourceSuffixArray.getCorpus(); // sourcePhrases.get(sourcePhraseIndex, corpus); // } for (int seq=0; seq<sourcePhrases.getNumberOfTerminalSequences(); seq++) { // for (int seq=0; seq<sourcePhrase.terminalSequenceStartIndices.length; seq++) { // Iterate over each source index in the current terminal sequence for (int sourceWordIndex=sourcePhrases.getTerminalSequenceStartIndex(sourcePhraseIndex, seq), // end=sourcePhrase.terminalSequenceEndIndices[seq]; end=sourcePhrases.getTerminalSequenceEndIndex(sourcePhraseIndex, seq); sourceWordIndex<end; sourceWordIndex++) { // for (int sourceWordIndex=sourcePhrase.terminalSequenceStartIndices[seq]; // sourceWordIndex<sourcePhrase.terminalSequenceEndIndices[seq]; // sourceWordIndex++) { float sum = 0.0f; // int sourceWord = sourceSuffixArray.corpus.corpus[sourceWordIndex]; int sourceWord = sourceSuffixArray.getCorpus().getWordID(sourceWordIndex); int[] targetIndices = alignments.getAlignedTargetIndices(sourceWordIndex); if (targetIndices==null) { //XXX We are not handling NULL aligned points according to Koehn et al (2003) //float sourceGivenNullAlignment = sourceGivenTarget(sourceWord, null); //sourceGivenTarget *= sourceGivenNullAlignment; //throw new RuntimeException("No alignments for source word at index " + sourceWordIndex); } else { // Iterate over each target index aligned to the current source word for (int targetIndex : targetIndices) { // int targetWord = targetCorpus.corpus[targetIndex]; int targetWord = targetCorpus.getWordID(targetIndex); sum += sourceGivenTarget(sourceWord, targetWord); // Keeping track of the reverse alignment points // (we need to do this convoluted step because we don't actually have a HierarchicalPhrase for the target side) if (!reverseAlignmentPoints.containsKey(targetIndex)) { reverseAlignmentPoints.put(targetIndex, new ArrayList<Integer>()); } reverseAlignmentPoints.get(targetIndex).add(sourceWord); } float average = sum / targetIndices.length; sourceGivenTarget *= average; } } } float targetGivenSource = 1.0f; // Actually calculate the reverse lexical translation probabilities for (Map.Entry<Integer, List<Integer>> entry : reverseAlignmentPoints.entrySet()) { // int targetWord = targetCorpus.corpus[entry.getKey()]; int targetWord = targetCorpus.getWordID(entry.getKey()); float sum = 0.0f; List<Integer> alignedSourceWords = entry.getValue(); for (int sourceWord : alignedSourceWords) { sum += targetGivenSource(targetWord, sourceWord); } float average = sum / ((float) alignedSourceWords.size()); targetGivenSource *= average; } return new Pair<Float,Float>(sourceGivenTarget,targetGivenSource); } /** * Calculates the lexical probabilities for a target word. * * @param targetWord */ private void calculateSourceGivenTarget(Integer targetWord) { Map<Integer,Integer> counts = new HashMap<Integer,Integer>(); int[] targetSuffixArrayBounds = targetSuffixArray.findPhrase(new BasicPhrase(targetVocab, targetWord)); int step = (targetSuffixArrayBounds[1]-targetSuffixArrayBounds[0]<sampleSize) ? 1 : (targetSuffixArrayBounds[1]-targetSuffixArrayBounds[0]) / sampleSize; float total = 0; for (int targetSuffixArrayIndex=targetSuffixArrayBounds[0],samples=0; targetSuffixArrayIndex<=targetSuffixArrayBounds[1] && samples<sampleSize; targetSuffixArrayIndex+=step, samples++) { int targetCorpusIndex = targetSuffixArray.getCorpusIndex(targetSuffixArrayIndex); int[] alignedSourceIndices = alignments.getAlignedSourceIndices(targetCorpusIndex); if (alignedSourceIndices==null) { if (!counts.containsKey(null)) { counts.put(null,1); } else { counts.put(null, counts.get(null) + 1); } total++; } else { for (int sourceIndex : alignedSourceIndices) { int sourceWord = sourceSuffixArray.getCorpus().getWordID(sourceIndex); if (!counts.containsKey(sourceWord)) { counts.put(sourceWord,1); } else { counts.put(sourceWord, counts.get(sourceWord) + 1); } total++; } } } Map<Integer,Float> sourceProbs = new HashMap<Integer,Float>(); for (Map.Entry<Integer,Integer> entry : counts.entrySet()) { // entry.getKey() corresponds to the source word // entry.getValue() corresponds to the number of times we have seen this source/target word pair // total is the number of times we saw this target with any source word float prob = entry.getValue()/total; if (prob > thresholdProbability) { sourceProbs.put(entry.getKey(), prob); } else { // Don't explicitly store a probability for this source-target pair // Instead, when querying for this pair return the floor value. } } sourceGivenTarget.put(targetWord, sourceProbs); } private void calculateTargetGivenSource(int sourceWord) { if (logger.isLoggable(Level.FINE)) logger.fine("Calculating lexprob distribution P( TARGET | " + sourceVocab.getWord(sourceWord) + "); sourceWord ID == " + sourceWord); Map<Integer,Integer> counts = new HashMap<Integer,Integer>(); int[] sourceSuffixArrayBounds = sourceSuffixArray.findPhrase(new BasicPhrase(sourceVocab, sourceWord)); int step = (sourceSuffixArrayBounds[1]-sourceSuffixArrayBounds[0]<sampleSize) ? 1 : (sourceSuffixArrayBounds[1]-sourceSuffixArrayBounds[0]) / sampleSize; float total = 0; for (int sourceSuffixArrayIndex=sourceSuffixArrayBounds[0],samples=0; sourceSuffixArrayIndex<=sourceSuffixArrayBounds[1] && samples<sampleSize; sourceSuffixArrayIndex+=step, samples++) { int sourceCorpusIndex = sourceSuffixArray.getCorpusIndex(sourceSuffixArrayIndex); int[] alignedTargetIndices = alignments.getAlignedTargetIndices(sourceCorpusIndex); if (alignedTargetIndices==null) { if (!counts.containsKey(null)) { if (logger.isLoggable(Level.FINEST)) logger.finest("Setting count(null | " + sourceVocab.getWord(sourceWord) + ") = 1"); counts.put(null,1); } else { counts.put(null, counts.get(null) + 1); } total++; } else { for (int targetIndex : alignedTargetIndices) { int targetWord = targetSuffixArray.getCorpus().getWordID(targetIndex); if (!counts.containsKey(targetWord)) { if (logger.isLoggable(Level.FINEST)) logger.finest("Setting count(" +targetVocab.getWord(targetWord) + " | " + sourceVocab.getWord(sourceWord) + ") = 1" + "; sourceWord ID == " + sourceWord + "; targetWord ID == " + targetWord); counts.put(targetWord,1); } else { int incrementedCount = counts.get(targetWord) + 1; if (logger.isLoggable(Level.FINEST)) logger.finest("Setting count(" +targetVocab.getWord(targetWord) + " | " + sourceVocab.getWord(sourceWord) + ") = " + incrementedCount + "; sourceWord ID == " + sourceWord + "; targetWord ID == " + targetWord); counts.put(targetWord,incrementedCount); } total++; } } } Map<Integer,Float> targetProbs = new HashMap<Integer,Float>(); for (Map.Entry<Integer,Integer> entry : counts.entrySet()) { // entry.getKey() corresponds to the target word // entry.getValue() corresponds to the number of times we have seen this target/source word pair // total is the number of times we saw this source with any target word Integer targetWord = entry.getKey(); float prob = ((float) entry.getValue())/total; if (prob > thresholdProbability) { if (logger.isLoggable(Level.FINEST)) logger.finest("Setting p(" +targetVocab.getWord(entry.getKey()) + " | " + sourceVocab.getWord(sourceWord) + ") = " + prob + "; sourceWord ID == " + sourceWord + "; targetWord ID == " + targetWord); targetProbs.put(targetWord, prob); } else { // Don't explicitly store a probability for this source-target pair // Instead, when querying for this pair return the floor value. } } if (logger.isLoggable(Level.FINER)) logger.finer("Storing " + targetProbs.size() + " probabilities for lexprob distribution P( TARGET | " + sourceVocab.getWord(sourceWord) + ")"); targetGivenSource.put(sourceWord, targetProbs); } public float lexProbSourceGivenTarget( MatchedHierarchicalPhrases sourcePhrases, int sourcePhraseIndex, HierarchicalPhrase targetPhrase) { // TODO Auto-generated method stub throw new RuntimeException(); } public float lexProbTargetGivenSource( MatchedHierarchicalPhrases sourcePhrases, int sourcePhraseIndex, HierarchicalPhrase targetPhrase) { // TODO Auto-generated method stub throw new RuntimeException(); } public float getFloorProbability() { return floorProbability; } public SymbolTable getSourceVocab() { return sourceVocab; } public SymbolTable getTargetVocab() { return targetVocab; } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { // TODO Auto-generated method stub } public void writeExternal(ObjectOutput out) throws IOException { // TODO Auto-generated method stub } }