/* This file is part of the Joshua Machine Translation System. * * Joshua is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 * of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free * Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, * MA 02111-1307 USA */ package joshua.corpus.lexprob; import java.io.FileInputStream; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectInputStream; import java.io.ObjectOutput; import java.util.logging.Level; import java.util.logging.Logger; import joshua.corpus.Corpus; import joshua.corpus.MatchedHierarchicalPhrases; import joshua.corpus.ParallelCorpus; import joshua.corpus.alignment.Alignments; import joshua.corpus.suffix_array.HierarchicalPhrase; import joshua.corpus.vocab.SymbolTable; import joshua.util.Counts; import joshua.util.Pair; /** * Represents lexical probability distributions in both directions. * <p> * This class calculates the probabilities from sorted word pair * counts. * * @author Lane Schwartz * @version $LastChangedDate: 2010-02-10 09:59:38 -0600 (Wed, 10 Feb 2010) $ */ public class LexProbs extends AbstractLexProbs { /** Logger for this class. */ private static final Logger logger = Logger.getLogger(LexProbs.class.getName()); /** Source language symbol table. */ protected final SymbolTable sourceVocab; /** Target language symbol table. */ protected final SymbolTable targetVocab; /** Aligned parallel corpus. */ protected final ParallelCorpus parallelCorpus; /** * The probability returned when no calculated lexical * translation probability is known. */ protected float floorProbability; /** Word co-occurrence counts from the parallel corpus. */ protected Counts<Integer,Integer> counts; /** * Constructs lexical translation probabilities from a * parallel corpus. * * @param parallelCorpus Aligned parallel corpus * @param floorProbability */ public LexProbs(ParallelCorpus parallelCorpus, float floorProbability) { logger.info("Calculating lexical translation probability table"); this.counts = initializeCooccurrenceCounts(parallelCorpus, floorProbability); this.sourceVocab = parallelCorpus.getSourceCorpus().getVocabulary(); this.targetVocab = parallelCorpus.getTargetCorpus().getVocabulary(); this.parallelCorpus = parallelCorpus; this.floorProbability = floorProbability; //Float.MIN_VALUE; logger.info("Calculating lexical translation probability table"); } /** * Constructs lexical translation probabilities from a * parallel corpus. * * @param parallelCorpus Aligned parallel corpus * @param ObjectIn */ public LexProbs(ParallelCorpus parallelCorpus, String lexCountsFileName) { // logger.info("Calculating lexical translation probability table"); // this.counts = initializeCooccurrenceCounts(parallelCorpus, floorProbability); this.sourceVocab = parallelCorpus.getSourceCorpus().getVocabulary(); this.targetVocab = parallelCorpus.getTargetCorpus().getVocabulary(); this.parallelCorpus = parallelCorpus; // File lexCounts = new File(lexCountsFileName); this.counts = new Counts<Integer, Integer>(); this.floorProbability = Float.MIN_VALUE; // File lexCounts = new File(lexCountsFileName); // if (!lexCounts.exists()) { // } else { try { ObjectInput in = new ObjectInputStream(new FileInputStream(lexCountsFileName)); // readExternal(in); logger.info("Reading lexical translation probability table"); readExternal(in); in.close(); } catch (Exception e) { logger.info("Calculating lexical translation probability table"); this.counts = initializeCooccurrenceCounts(parallelCorpus, floorProbability); // TODO Auto-generated catch block // e.printStackTrace(); } // } // logger.info("Calculating lexical translation probability table"); } /** * Gets co-occurrence counts from a parallel corpus. * * @param parallelCorpus Aligned parallel corpus * @param floorProbability * @return Word co-occurrence counts from the parallel * corpus. */ private static Counts<Integer,Integer> initializeCooccurrenceCounts(ParallelCorpus parallelCorpus, float floorProbability) { if (logger.isLoggable(Level.FINE)) { logger.fine("Counting word co-occurrence from parallel corpus. Using floor probability " + floorProbability); } Alignments alignments = parallelCorpus.getAlignments(); Corpus sourceCorpus = parallelCorpus.getSourceCorpus(); Corpus targetCorpus = parallelCorpus.getTargetCorpus(); int numSentences = parallelCorpus.getNumSentences(); Counts<Integer,Integer> counts = new Counts<Integer,Integer>(floorProbability); // Iterate over each sentence for (int sentenceID=0; sentenceID<numSentences; sentenceID++) { int sourceStart = sourceCorpus.getSentencePosition(sentenceID); int sourceEnd = sourceCorpus.getSentenceEndPosition(sentenceID); int targetStart = targetCorpus.getSentencePosition(sentenceID); int targetEnd = targetCorpus.getSentenceEndPosition(sentenceID); // Iterate over each word in the source sentence for (int sourceIndex=sourceStart; sourceIndex<sourceEnd; sourceIndex++) { // Get the token for the current source word int sourceWord = sourceCorpus.getWordID(sourceIndex); // Get the target indices aligned to this source word int[] targetPoints = alignments.getAlignedTargetIndices(sourceIndex); // If the source word is unaligned, // then we treat it as being aligned to a special NULL token; // we use Java's null to represent the NULL token if (targetPoints==null) { counts.incrementCount(sourceWord, null); } else { // If the source word is aligned, // then we must iterate over each aligned target point for (int targetPoint : targetPoints) { int targetWord = targetCorpus.getWordID(targetPoint); counts.incrementCount(sourceWord, targetWord); } } } // Iterate over each word in the target sentence for (int targetIndex=targetStart; targetIndex<targetEnd; targetIndex++) { // Get the token for the current source word int targetWord = targetCorpus.getWordID(targetIndex); // Get the source indices aligned to this target word int[] sourcePoints = alignments.getAlignedSourceIndices(targetIndex); // If the source word is unaligned, // then we treat it as being aligned to a special NULL token; // we use Java's null to represent the NULL token if (sourcePoints==null) { counts.incrementCount(null, targetWord); } } } return counts; } /* See Javadoc for LexicalProbabilities#sourceGivenTarget(Integer,Integer). */ public float sourceGivenTarget(Integer sourceWord, Integer targetWord) { return counts.getProbability(sourceWord, targetWord); } /* See Javadoc for LexicalProbabilities#targetGivenSource(Integer,Integer). */ public float targetGivenSource(Integer targetWord, Integer sourceWord) { return counts.getReverseProbability(targetWord, sourceWord); } /* See Javadoc for LexicalProbabilities#sourceGivenTarget(String,String). */ public float sourceGivenTarget(String sourceWord, String targetWord) { Integer targetID = (targetWord==null) ? null : targetVocab.getID(targetWord); Integer sourceID = (sourceWord==null) ? null : sourceVocab.getID(sourceWord); return sourceGivenTarget(sourceID, targetID); } /* See Javadoc for LexicalProbabilities#targetGivenSource(String,String). */ public float targetGivenSource(String targetWord, String sourceWord) { int targetID = (targetWord==null) ? null : targetVocab.getID(targetWord); int sourceID = (sourceWord==null) ? null : sourceVocab.getID(sourceWord); return targetGivenSource(targetID, sourceID); } /* See Javadoc for LexicalProbabilities#lexProbSourceGivenTarget(MatchedHierarchicalPhrases,int,HierarchicalPhrase). */ public float lexProbSourceGivenTarget(MatchedHierarchicalPhrases sourcePhrases, int sourcePhraseIndex, HierarchicalPhrase targetPhrase) { float sourceGivenTarget = 1.0f; Corpus sourceCorpus = parallelCorpus.getSourceCorpus(); Corpus targetCorpus = parallelCorpus.getTargetCorpus(); Alignments alignments = parallelCorpus.getAlignments(); // Iterate over each terminal sequence in the source phrase for (int seq=0; seq<sourcePhrases.getNumberOfTerminalSequences(); seq++) { // Iterate over each source index in the current terminal sequence for (int sourceWordIndex=sourcePhrases.getTerminalSequenceStartIndex(sourcePhraseIndex, seq), end=sourcePhrases.getTerminalSequenceEndIndex(sourcePhraseIndex, seq); sourceWordIndex<end; sourceWordIndex++) { int sourceWord = sourceCorpus.getWordID(sourceWordIndex); int[] targetIndices = alignments.getAlignedTargetIndices(sourceWordIndex); float sum = 0.0f; float average; if (targetIndices==null) { sum += this.sourceGivenTarget(sourceWord, null); average = sum; } else { for (int targetIndex : targetIndices) { int targetWord = targetCorpus.getWordID(targetIndex); sum += sourceGivenTarget(sourceWord, targetWord); } average = sum / targetIndices.length; } sourceGivenTarget *= average; } } return sourceGivenTarget; } /* See Javadoc for LexicalProbabilities#lexProbTargetGivenSource(MatchedHierarchicalPhrases,int,HierarchicalPhrase). */ public float lexProbTargetGivenSource(MatchedHierarchicalPhrases sourcePhrases, int sourcePhraseIndex, HierarchicalPhrase targetPhrase) { final boolean LOGGING_FINEST = logger.isLoggable(Level.FINEST); Corpus sourceCorpus = parallelCorpus.getSourceCorpus(); Corpus targetCorpus = parallelCorpus.getTargetCorpus(); Alignments alignments = parallelCorpus.getAlignments(); StringBuilder s; if (LOGGING_FINEST) { s = new StringBuilder(); s.append("lexProb( "); s.append(sourcePhrases.getPattern().toString()); s.append(" | "); s.append(targetPhrase.toString()); s.append(" ) = 1.0"); } else { s = null; } float targetGivenSource = 1.0f; // Iterate over each terminal sequence in the target phrase for (int seq=0; seq<targetPhrase.getNumberOfTerminalSequences(); seq++) { // Iterate over each source index in the current terminal sequence for (int targetWordIndex=targetPhrase.getTerminalSequenceStartIndex(seq), end=targetPhrase.getTerminalSequenceEndIndex(seq); targetWordIndex<end; targetWordIndex++) { int targetWord = targetCorpus.getWordID(targetWordIndex); int[] sourceIndices = alignments.getAlignedSourceIndices(targetWordIndex); float sum = 0.0f; float average; if (LOGGING_FINEST) s.append(" * ("); if (sourceIndices==null) { sum += targetGivenSource(targetWord, null); average = sum; if (LOGGING_FINEST) s.append(sum); } else { for (int sourceIndex : sourceIndices) { int sourceWord = sourceCorpus.getWordID(sourceIndex); float value = targetGivenSource(targetWord, sourceWord); sum += value; if (LOGGING_FINEST) { s.append('+'); s.append(value); } } average = sum / sourceIndices.length; } if (LOGGING_FINEST) s.append(')'); targetGivenSource *= average; } } if (LOGGING_FINEST) logger.finest(s.toString()); return targetGivenSource; } /* See Javadoc for LexicalProbabilities#getFloorProbability. */ public float getFloorProbability() { return floorProbability; } /** * Gets a string representation of the lexical probabilities. * <p> * The returned string will have one line per word pair. * The pairs are not guaranteed to be returned in any particular order. * * @return a string representation of the lexical probabilities */ @Override public String toString() { StringBuilder s = new StringBuilder(); for (Pair<Integer,Integer> pair : counts) { Integer sourceID = pair.first; Integer targetID = pair.second; if (sourceID==null) { s.append("NULL"); } else { s.append(sourceVocab.getWord(sourceID)); } s.append(' '); if (targetID==null) { s.append("NULL"); } else { s.append(targetVocab.getWord(targetID)); } s.append(' '); s.append(targetGivenSource(targetID,sourceID)); s.append(' '); s.append(sourceGivenTarget(sourceID,targetID)); s.append('\n'); } return s.toString(); } public void writeExternal(ObjectOutput out) throws IOException { counts.writeExternal(out); } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { /* Map<Integer, Map<Integer, Integer>> ctMap = (HashMap<Integer,Map<Integer,Integer>>) in.readObject(); counts.setCounts(ctMap); // Read bTotals Map<Integer, Integer> btMap = (HashMap<Integer,Integer>) in.readObject(); counts.setBTotals(btMap); // Read probabilities Map<Integer, Map<Integer, Float>> pbMap = (HashMap<Integer,Map<Integer,Float>>) in.readObject(); counts.setProbabilities(pbMap); // Read reverse probabilities Map<Integer, Map<Integer, Float>> rpMap = (HashMap<Integer,Map<Integer,Float>>) in.readObject(); counts.setProbabilities(rpMap); */ this.counts.readExternal(in); floorProbability = counts.getFloorProbability(); } public SymbolTable getSourceVocab() { return sourceVocab; } public SymbolTable getTargetVocab() { return targetVocab; } /** * Gets the word co-occurrence counts for this object. * * @return the word co-occurrence counts for this object. */ protected Counts<Integer,Integer> getCounts() { return this.counts; } }