package doser.entitydisambiguation.algorithms.collective.general; import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import org.apache.commons.collections15.Factory; import org.apache.commons.collections15.functors.MapTransformer; import org.apache.commons.math.stat.descriptive.SummaryStatistics; import doser.entitydisambiguation.algorithms.SurfaceForm; import doser.entitydisambiguation.algorithms.collective.AbstractWord2VecPageRank; import doser.entitydisambiguation.algorithms.collective.Edge; import doser.entitydisambiguation.algorithms.collective.Vertex; import doser.entitydisambiguation.knowledgebases.AbstractEntityCentricKBGeneral; import edu.uci.ics.jung.algorithms.scoring.PageRankWithPriors; import edu.uci.ics.jung.graph.DirectedSparseMultigraph; class Word2VecDisambiguatorGeneral extends AbstractWord2VecPageRank { // private static final int MAXIMUMCANDIDATESPERSF = 8; private List<SurfaceForm> origList; private boolean disambiguate; private int maximumcandidatespersf; private int iterations; Word2VecDisambiguatorGeneral(AbstractEntityCentricKBGeneral eckb, List<SurfaceForm> rep, boolean disambiguate, int maximumcandidatespersf, int iterations) { super(eckb, rep); this.origList = new ArrayList<SurfaceForm>(); this.disambiguate = disambiguate; this.maximumcandidatespersf = maximumcandidatespersf; this.iterations = iterations; } @Override public void setup() { this.graph = new DirectedSparseMultigraph<Vertex, Edge>(); this.edgeWeights = new HashMap<Edge, Number>(); this.edgeFactory = new Factory<Integer>() { int i = 0; public Integer create() { return i++; } }; for (SurfaceForm sf : repList) { SurfaceForm clone = (SurfaceForm) sf.clone(); this.origList.add(clone); } this.disambiguatedSurfaceForms = new BitSet(repList.size()); for (int i = 0; i < repList.size(); i++) { if (repList.get(i).getCandidates().size() <= 1) { this.disambiguatedSurfaceForms.set(i); } } buildMainGraph(); } @Override protected PageRankWithPriors<Vertex, Edge> performPageRank() { PageRankWithPriors<Vertex, Edge> pr = new PageRankWithPriors<Vertex, Edge>( graph, MapTransformer.getInstance(edgeWeights), getRootPrior(graph.getVertices()), 0.09); pr.setMaxIterations(iterations); pr.evaluate(); return pr; } @Override public boolean analyzeResults(PageRankWithPriors<Vertex, Edge> pr) { boolean disambiguationStop = true; Collection<Vertex> vertexCol = graph.getVertices(); for (int i = 0; i < repList.size(); i++) { if (!disambiguatedSurfaceForms.get(i) && repList.get(i).isRelevant()) { int qryNr = repList.get(i).getQueryNr(); double maxScore = 0; SummaryStatistics stats = new SummaryStatistics(); String tempSolution = ""; List<Candidate> scores = new ArrayList<Candidate>(); for (Vertex v : vertexCol) { if (v.getEntityQuery() == qryNr && v.isCandidate()) { scores.add(new Candidate(v.getUris().get(0), pr .getVertexScore(v))); double score = Math.abs(pr.getVertexScore(v)); stats.addValue(score); if (score > maxScore) { tempSolution = v.getUris().get(0); maxScore = score; } } } SurfaceForm rep = repList.get(i); SurfaceForm clone = origList.get(i); Collections.sort(scores, Collections.reverseOrder()); double secondMax = scores.get(1).score; List<String> newCandidates = new ArrayList<String>(); for(int j = 0; j < maximumcandidatespersf; j++) { if(scores.size() > j) { newCandidates.add(scores.get(j).can); } else { break; } } if (!Double.isInfinite(maxScore)) { double avg = stats.getMean(); double threshold = computeThreshold(avg, maxScore); if (secondMax < threshold && disambiguate) { updateGraph(rep.getCandidates(), tempSolution, rep.getQueryNr()); rep.setDisambiguatedEntity(tempSolution); clone.setDisambiguatedEntity(tempSolution); disambiguatedSurfaceForms.set(i); disambiguationStop = false; break; } else { clone.setCandidates(newCandidates); } } } } return disambiguationStop; } /** * Threshold Computation // IMPORTANT DISAMBIGUATION PARAMETER * * @param avg * @param highest * @return */ private double computeThreshold(double avg, double highest) { double diff = highest - avg; double min = diff * 0.5; return highest - min; } @Override public List<SurfaceForm> getRepresentation() { return this.origList; } class Candidate implements Comparable<Candidate> { private double score; private String can; Candidate(String can, double score) { super(); this.score = score; this.can = can; } @Override public int compareTo(Candidate o) { if (score < o.score) { return -1; } else if (score > o.score) { return 1; } else { return 0; } } } }