/*******************************************************************************
* Copyright 2015-2016 - CNRS (Centre National de Recherche Scientifique)
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*******************************************************************************/
package eu.project.ttc.engines;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Joiner;
import com.google.common.base.MoreObjects;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.MinMaxPriorityQueue;
import com.google.common.collect.Sets;
import com.google.common.primitives.Ints;
import eu.project.ttc.metrics.ExplainedValue;
import eu.project.ttc.metrics.Explanation;
import eu.project.ttc.metrics.IExplanation;
import eu.project.ttc.metrics.SimilarityDistance;
import eu.project.ttc.metrics.TextExplanation;
import eu.project.ttc.models.ContextVector;
import eu.project.ttc.models.Term;
import eu.project.ttc.models.TermIndex;
import eu.project.ttc.models.index.CustomTermIndex;
import eu.project.ttc.models.index.TermIndexes;
import eu.project.ttc.models.index.TermMeasure;
import eu.project.ttc.models.index.TermValueProviders;
import eu.project.ttc.resources.BilingualDictionary;
import eu.project.ttc.utils.AlignerUtils;
import eu.project.ttc.utils.IteratorUtils;
import eu.project.ttc.utils.TermSuiteConstants;
/**
*
*
*
* @author Damien Cram
*
*/
public class BilingualAligner {
private static final Logger LOGGER = LoggerFactory.getLogger(BilingualAligner.class);
private static final String MSG_TERM_NOT_NULL = "Source term must not be null";
private static final String MSG_REQUIRES_SIZE_2_LEMMAS = "The term %s must have exactly two single-word terms (single-word terms: %s)";
private static final String MSG_SEVERAL_VECTORS_NOT_COMPUTED = "Several terms have no context vectors in target terminology (nb terms with vector: {}, nb terms without vector: {})";
private static final String ERR_VECTOR_NOT_SET = "Cannot align on term %s. Cause: context vector no set.";
/**
* The bonus factor applied to dictionary candidates when they are
* merged with distributional candidates
*/
public static final double DICO_CANDIDATE_BONUS_FACTOR = 30;
private BilingualDictionary dico;
private TermIndex sourceTermino;
private TermIndex targetTermino;
private SimilarityDistance distance;
public BilingualAligner(BilingualDictionary dico, TermIndex sourceTermino, TermIndex targetTermino, SimilarityDistance distance) {
super();
this.dico = dico;
this.targetTermino = targetTermino;
this.sourceTermino = sourceTermino;
this.distance = distance;
}
/**
* Overrides the default distance measure.
*
* @param distance
* an object implementing the similarity distance
*/
public void setDistance(SimilarityDistance distance) {
this.distance = distance;
}
/**
*
* Translates the source term with the help of the dictionary
* and computes the list of <code>contextSize</code> closest candidate
* terms in the target terminology.
*
* <code>sourceTerm</code>'s context vector must be computed and normalized,
* as well as all terms' context vectors in the target term index.
*
* @param sourceTerm
* the term to align with target term index
* @param nbCandidates
* the number of {@link TranslationCandidate} to return in the returned list
* @param minCandidateFrequency
* the minimum frequency of a target candidate
* @return
* A sorted list of {@link TranslationCandidate} sorted by distance desc. Each
* {@link TranslationCandidate} is a container for a target term index's term
* and its translation score.
*
*/
public List<TranslationCandidate> alignDicoThenDistributional(Term sourceTerm, int nbCandidates, int minCandidateFrequency) {
checkNotNull(sourceTerm);
Preconditions.checkArgument(sourceTerm.isContextVectorComputed(), ERR_VECTOR_NOT_SET, sourceTerm.getGroupingKey());
List<TranslationCandidate> dicoCandidates = Lists.newArrayList();
/*
* 1- find direct translation of the term in the dictionary
*/
dicoCandidates.addAll(sortTruncateNormalize(targetTermino, nbCandidates, alignDico(sourceTerm, Integer.MAX_VALUE)));
applySpecificityBonus(targetTermino, dicoCandidates);
/*
* 2- align against all terms in the corpus
*/
List<TranslationCandidate> alignedCandidateQueue = alignDistributional(sourceTerm, nbCandidates,
minCandidateFrequency);
/*
* 3- Merge candidates
*/
List<TranslationCandidate> mergedCandidates = dicoCandidates;
mergedCandidates.addAll(alignedCandidateQueue);
Collections.sort(mergedCandidates);
/*
* 4- Sort, truncate, and normalize
*/
List<TranslationCandidate> sortedTruncateedNormalized = sortTruncateNormalize(targetTermino, nbCandidates, mergedCandidates);
return sortedTruncateedNormalized;
}
public List<TranslationCandidate> alignDistributional(Term sourceTerm, int nbCandidates,
int minCandidateFrequency) {
Queue<TranslationCandidate> alignedCandidateQueue = MinMaxPriorityQueue.maximumSize(nbCandidates).create();
ContextVector sourceVector = sourceTerm.getContextVector();
ContextVector translatedSourceVector = AlignerUtils.translateVector(
sourceVector,
dico,
AlignerUtils.TRANSLATION_STRATEGY_MOST_SPECIFIC,
targetTermino);
ExplainedValue v;
int nbVectorsNotComputed = 0;
int nbVectorsComputed = 0;
for(Term targetTerm:IteratorUtils.toIterable(targetTermino.singleWordTermIterator())) {
if(targetTerm.getFrequency() < minCandidateFrequency)
continue;
if(targetTerm.isContextVectorComputed()) {
nbVectorsComputed++;
v = distance.getExplainedValue(translatedSourceVector, targetTerm.getContextVector());
alignedCandidateQueue.add(new TranslationCandidate(
targetTerm,
AlignmentMethod.DISTRIBUTIONAL,
v.getValue(),
v.getExplanation()));
}
}
if(nbVectorsNotComputed > 0) {
LOGGER.warn(MSG_SEVERAL_VECTORS_NOT_COMPUTED, nbVectorsComputed, nbVectorsNotComputed);
}
// sort alignedCandidates
List<TranslationCandidate> alignedCandidates = Lists.newArrayListWithCapacity(alignedCandidateQueue.size());
alignedCandidates.addAll(alignedCandidateQueue);
normalizeCandidateScores(alignedCandidates);
return Lists.newArrayList(alignedCandidateQueue);
}
private static final String ERR_MSG_BAD_SOURCE_LEMMA_SET_SIZE = "Unexpected size for a source lemma set: %s. Expected size: 2";
/**
*
*
* @param sourceTerm
* @param nbCandidates
* @param minCandidateFrequency
* @return
*/
public List<TranslationCandidate> align(Term sourceTerm, int nbCandidates, int minCandidateFrequency) {
if(sourceTerm.getGroupingKey().equals("npn: stockage de énergie"))
System.out.println(sourceTerm);
Preconditions.checkNotNull(sourceTerm);
List<TranslationCandidate> mergedCandidates = Lists.newArrayList();
List<List<Term>> sourceLemmaSets = AlignerUtils.getSingleLemmaTerms(sourceTermino, sourceTerm);
for(List<Term> sourceLemmaSet:sourceLemmaSets) {
Preconditions.checkState(sourceLemmaSet.size() == 1 || sourceLemmaSet.size() == 2,
ERR_MSG_BAD_SOURCE_LEMMA_SET_SIZE, sourceLemmaSet);
if(sourceLemmaSet.size() == 1) {
mergedCandidates.addAll(alignDicoThenDistributional(sourceLemmaSet.get(0), 3*nbCandidates, minCandidateFrequency));
} else if(sourceLemmaSet.size() == 2) {
List<TranslationCandidate> compositional = Lists.newArrayList();
try {
compositional.addAll(alignCompositionalSize2(sourceLemmaSet.get(0), sourceLemmaSet.get(1), nbCandidates, minCandidateFrequency));
} catch(RequiresSize2Exception e) {
// Do nothing
}
mergedCandidates.addAll(compositional);
if(mergedCandidates.isEmpty()) {
List<TranslationCandidate> semiDist = Lists.newArrayList();
try {
semiDist = alignSemiDistributionalSize2Syntagmatic(sourceLemmaSet.get(0), sourceLemmaSet.get(1), nbCandidates, minCandidateFrequency);
} catch(RequiresSize2Exception e) {
// Do nothing
}
mergedCandidates.addAll(semiDist);
}
}
}
removeDuplicatesOnTerm(mergedCandidates);
return sortTruncateNormalize(targetTermino, nbCandidates, mergedCandidates);
}
private List<TranslationCandidate> sortTruncateNormalize(TermIndex termIndex, int nbCandidates, Collection<TranslationCandidate> candidatesCandidates) {
List<TranslationCandidate> list = Lists.newArrayList(candidatesCandidates);
Collections.sort(list);
// set rank
for(int i = 0; i < list.size(); i++)
list.get(i).setRank(i+1);
List<TranslationCandidate> finalCandidates = list.subList(0, Ints.min(nbCandidates, candidatesCandidates.size()));
normalizeCandidateScores(finalCandidates);
return finalCandidates;
}
/*
* Filter candidates by specificity
*/
private void applySpecificityBonus(TermIndex termIndex, List<TranslationCandidate> list) {
Iterator<TranslationCandidate> it = list.iterator();
TranslationCandidate c;
while (it.hasNext()) {
c = (TranslationCandidate) it.next();
double wr = termIndex.getWRMeasure().getValue(c.getTerm());
c.setScore(c.getScore()*getSpecificityBonusFactor(wr));
}
}
private double getSpecificityBonusFactor(double wr) {
if(wr <= 1)
return 0.5;
else if(wr <= 2)
return 1;
else if(wr <= 10)
return 1.5;
else if(wr <= 100)
return 2;
else
return 5;
}
public List<TranslationCandidate> alignDico(Term sourceTerm, int nbCandidates) {
List<TranslationCandidate> dicoCandidates = Lists.newArrayList();
Collection<String> translations = dico.getTranslations(sourceTerm.getLemma());
ContextVector translatedSourceVector = AlignerUtils.translateVector(
sourceTerm.getContextVector(),
dico,
AlignerUtils.TRANSLATION_STRATEGY_MOST_SPECIFIC,
targetTermino);
for(String candidateLemma:translations) {
List<Term> terms = targetTermino.getCustomIndex(TermIndexes.LEMMA_LOWER_CASE).getTerms(candidateLemma);
for (Term candidateTerm : terms) {
if (candidateTerm.isContextVectorComputed())
dicoCandidates.add(new TranslationCandidate(candidateTerm, AlignmentMethod.DICTIONARY,
distance.getValue(translatedSourceVector, candidateTerm.getContextVector()),
Explanation.emptyExplanation()
));
}
}
return dicoCandidates;
}
public boolean canAlignCompositional(Term sourceTerm) {
return AlignerUtils.getSingleLemmaTerms(sourceTermino, sourceTerm)
.stream()
.anyMatch(slTerms -> slTerms.size() == 2);
}
public List<TranslationCandidate> alignCompositional(Term sourceTerm, int nbCandidates, int minCandidateFrequency) {
Preconditions.checkArgument(canAlignCompositional(sourceTerm), "Cannot align <%s> with compositional method", sourceTerm);
List<List<Term>> singleLemmaTermSets = AlignerUtils.getSingleLemmaTerms(sourceTermino, sourceTerm);
List<TranslationCandidate> candidates = Lists.newArrayList();
for(List<Term> singleLemmaTerms:singleLemmaTermSets) {
if(singleLemmaTerms.size() == 2) {
candidates.addAll(alignCompositionalSize2(
singleLemmaTerms.get(0),
singleLemmaTerms.get(1), nbCandidates, minCandidateFrequency));
}
}
return sortTruncateNormalize(targetTermino, nbCandidates, candidates);
}
public boolean canAlignSemiDistributional(Term sourceTerm) {
return AlignerUtils.getSingleLemmaTerms(sourceTermino, sourceTerm)
.stream()
.anyMatch(slTerms -> slTerms.size() == 2);
}
public List<TranslationCandidate> alignSemiDistributional(Term sourceTerm, int nbCandidates, int minCandidateFrequency) {
Preconditions.checkArgument(canAlignCompositional(sourceTerm), "Cannot align <%s> with compositional method", sourceTerm);
List<List<Term>> singleLemmaTermSets = AlignerUtils.getSingleLemmaTerms(sourceTermino, sourceTerm);
List<TranslationCandidate> candidates = Lists.newArrayList();
for(List<Term> singleLemmaTerms:singleLemmaTermSets) {
if(singleLemmaTerms.size() == 2) {
candidates.addAll(alignSemiDistributionalSize2Syntagmatic(
singleLemmaTerms.get(0),
singleLemmaTerms.get(1), nbCandidates, minCandidateFrequency));
}
}
return sortTruncateNormalize(targetTermino, nbCandidates, candidates);
}
public List<TranslationCandidate> alignCompositionalSize2(Term lemmaTerm1, Term lemmaTerm2, int nbCandidates, int minCandidateFrequency) {
List<TranslationCandidate> candidates = Lists.newArrayList();
List<TranslationCandidate> dicoCandidates1 = alignDico(lemmaTerm1, Integer.MAX_VALUE);
List<TranslationCandidate> dicoCandidates2 = alignDico(lemmaTerm2, Integer.MAX_VALUE);
candidates.addAll(combineCandidates(dicoCandidates1, dicoCandidates2, AlignmentMethod.COMPOSITIONAL));
return sortTruncateNormalize(targetTermino, nbCandidates, candidates);
}
public static class RequiresSize2Exception extends RuntimeException {
private static final long serialVersionUID = 1L;
private Term term;
private List<Term> swtTerms;
public RequiresSize2Exception(Term term, List<Term> swtTerms) {
super();
this.term = term;
this.swtTerms = swtTerms;
}
@Override
public String getMessage() {
return String.format(MSG_REQUIRES_SIZE_2_LEMMAS,
term,
Joiner.on(TermSuiteConstants.COMMA).join(swtTerms)
);
}
}
/**
* Join to lists of swt candidates and use the specificities (wrLog)
* of the combine terms as the candidate scores.
*
* FIXME Bad way of scoring candidates. They should be scored by similarity of context vectors with the source context vector
*
* @param candidates1
* @param candidates2
* @return
*/
private Collection<TranslationCandidate> combineCandidates(Collection<TranslationCandidate> candidates1,
Collection<TranslationCandidate> candidates2, AlignmentMethod method) {
Collection<TranslationCandidate> combination = Sets.newHashSet();
TermMeasure wrLog = targetTermino.getWRLogMeasure();
wrLog.compute();
for(TranslationCandidate candidate1:candidates1) {
for(TranslationCandidate candidate2:candidates2) {
/*
* 1- create candidate combine terms
*/
CustomTermIndex index = targetTermino.getCustomIndex(TermIndexes.WORD_COUPLE_LEMMA_LEMMA);
List<Term> candidateCombinedTerms = index.getTerms(candidate1.getTerm().getLemma() + "+" + candidate2.getTerm().getLemma());
candidateCombinedTerms.addAll(index.getTerms(candidate2.getTerm().getLemma() + "+" + candidate1.getTerm().getLemma()));
if(candidateCombinedTerms.isEmpty())
continue;
/*
* 2- Avoids retrieving too long terms by keeping the ones that have
* the lowest number of lemma+lemma keys.
*/
final Map<Term, Collection<String>> termLemmaLemmaKeys = Maps.newHashMap();
for(Term t:candidateCombinedTerms)
termLemmaLemmaKeys.put(t, TermValueProviders.WORD_LEMMA_LEMMA_PROVIDER.getClasses(targetTermino, t));
Collections.sort(candidateCombinedTerms, new Comparator<Term>() {
@Override
public int compare(Term o1, Term o2) {
return Integer.compare(termLemmaLemmaKeys.get(o1).size(), termLemmaLemmaKeys.get(o2).size());
}
});
List<Term> filteredTerms = Lists.newArrayList();
int minimumNbClasses = termLemmaLemmaKeys.get(candidateCombinedTerms.get(0)).size();
for(Term t:candidateCombinedTerms) {
if(termLemmaLemmaKeys.get(t).size() == minimumNbClasses)
filteredTerms.add(t);
else
break;
}
/*
* 3- Create candidates from filtered terms
*/
for(Term t:filteredTerms) {
combination.add(new TranslationCandidate(
t,
method,
wrLog.getValue(t),
new TextExplanation(String.format("Spécificité: %.1f", wrLog.getValue(t)))));
}
}
}
return combination;
}
private void checkNotNull(Term sourceTerm) {
Preconditions.checkNotNull(sourceTerm, MSG_TERM_NOT_NULL);
}
public List<TranslationCandidate> alignSemiDistributionalSize2Syntagmatic(Term lemmaTerm1, Term lemmaTerm2, int nbCandidates, int minCandidateFrequency) {
List<TranslationCandidate> candidates = Lists.newArrayList();
Collection<? extends TranslationCandidate> t1 = semiDistributional(lemmaTerm1, lemmaTerm2);
candidates.addAll(t1);
Collection<? extends TranslationCandidate> t2 = semiDistributional(lemmaTerm2, lemmaTerm1);
candidates.addAll(t2);
removeDuplicatesOnTerm(candidates);
return sortTruncateNormalize(targetTermino, nbCandidates, candidates);
}
private void removeDuplicatesOnTerm(List<TranslationCandidate> candidates) {
Set<Term> set = Sets.newHashSet();
Iterator<TranslationCandidate> it = candidates.iterator();
while(it.hasNext())
if(!set.add(it.next().getTerm()))
it.remove();
}
private Collection<? extends TranslationCandidate> semiDistributional(Term dicoTerm, Term vectorTerm) {
List<TranslationCandidate> candidates = Lists.newArrayList();
List<TranslationCandidate> dicoCandidates = alignDico(dicoTerm, Integer.MAX_VALUE);
if(dicoCandidates.isEmpty())
// Optimisation: no need to align since there is no possible combination
return candidates;
else {
List<TranslationCandidate> vectorCandidates = alignDicoThenDistributional(vectorTerm, Integer.MAX_VALUE, 1);
return combineCandidates(dicoCandidates, vectorCandidates, AlignmentMethod.SEMI_DISTRIBUTIONAL);
}
}
private void normalizeCandidateScores(List<TranslationCandidate> candidates) {
double sum = 0;
for(TranslationCandidate cand:candidates)
sum+= cand.getScore();
if(sum > 0d)
for(TranslationCandidate cand:candidates)
cand.setScore(cand.getScore()/sum);
}
public static enum AlignmentMethod {
DICTIONARY("dico", "dictionary"),
DISTRIBUTIONAL("dist", "distributional"),
COMPOSITIONAL("comp", "compositional"),
SEMI_DISTRIBUTIONAL("s-dist", "semi-distributional");
private String shortName;
private String longName;
private AlignmentMethod(String shortName, String longName) {
this.shortName = shortName;
this.longName = longName;
}
public String getShortName() {
return shortName;
}
public String getLongName() {
return longName;
}
}
public static class TranslationCandidate implements Comparable<TranslationCandidate> {
private IExplanation explanation;
private AlignmentMethod method;
private Term term;
private int rank=-1;
private double score;
// private TranslationCandidate(Term term, AlignmentMethod method, double score) {
// this(term, method, score, Explanation.emptyExplanation());
// }
public void setScore(double score) {
this.score = score;
}
public void setRank(int rank) {
this.rank = rank;
}
public int getRank() {
return rank;
}
private TranslationCandidate(Term term, AlignmentMethod method, double score, IExplanation explanation) {
super();
this.term = term;
this.score = score;
this.method = method;
this.explanation = explanation;
}
@Override
public int compareTo(TranslationCandidate o) {
return ComparisonChain.start()
.compare(o.score, score)
.compare(term, o.term)
.result();
}
public AlignmentMethod getMethod() {
return method;
}
public double getScore() {
return score;
}
public Term getTerm() {
return term;
}
@Override
public boolean equals(Object obj) {
if( obj instanceof TranslationCandidate)
return Objects.equal(((TranslationCandidate)obj).score, this.score)
&& Objects.equal(((TranslationCandidate)obj).term, this.term);
else
return false;
}
public IExplanation getExplanation() {
return explanation;
}
@Override
public int hashCode() {
return Objects.hashCode(term, score);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.addValue(this.term.getGroupingKey())
.addValue(this.method.toString())
.add("s",String.format("%.2f", this.score))
.toString();
}
}
public BilingualDictionary getDico() {
return this.dico;
}
}