package doser.entitydisambiguation.algorithms.collective.general; import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.commons.collections15.Factory; import org.apache.commons.collections15.functors.MapTransformer; import org.apache.commons.math.stat.descriptive.SummaryStatistics; import doser.entitydisambiguation.algorithms.SurfaceForm; import doser.entitydisambiguation.algorithms.collective.AbstractWord2VecPageRank; import doser.entitydisambiguation.algorithms.collective.Edge; import doser.entitydisambiguation.algorithms.collective.Vertex; import doser.entitydisambiguation.knowledgebases.AbstractEntityCentricKBGeneral; import edu.uci.ics.jung.algorithms.scoring.PageRankWithPriors; import edu.uci.ics.jung.graph.DirectedSparseMultigraph; class FinalEntityDisambiguation extends AbstractWord2VecPageRank { private static final int PREPROCESSINGCONTEXTSIZE = 500; public FinalEntityDisambiguation(AbstractEntityCentricKBGeneral eckb, List<SurfaceForm> rep) { super(eckb, rep); } @Override public void setup() { this.graph = new DirectedSparseMultigraph<Vertex, Edge>(); this.edgeWeights = new HashMap<Edge, Number>(); this.edgeFactory = new Factory<Integer>() { int i = 0; public Integer create() { return i++; } }; this.disambiguatedSurfaceForms = new BitSet(repList.size()); for (int i = 0; i < repList.size(); i++) { if (repList.get(i).getCandidates().size() <= 1) { this.disambiguatedSurfaceForms.set(i); } } buildMainGraph(); } @Override protected void buildMainGraph() { List<String> disambiguatedEntities = new LinkedList<String>(); // Add Vertexes for (SurfaceForm rep : repList) { List<String> arrList = rep.getCandidates(); for (String s : arrList) { int occs = this.eckb.getFeatureDefinition().getOccurrences( rep.getSurfaceForm(), s); List<String> l = new LinkedList<String>(); l.add(s); if (rep.getCandidates().size() == 1) { disambiguatedEntities.add(rep.getCandidates().get(0)); addVertex(l, rep.getSurfaceForm(), rep.getQueryNr(), true, 20000, rep.getContext()); } else { addVertex(l, rep.getSurfaceForm(), rep.getQueryNr(), true, occs, rep.getContext()); } } } // Add Document AsVertex addVertex(disambiguatedEntities, "", -1, true, 50000, ""); // Add Edges List<Vertex> vertexList = new ArrayList<Vertex>(graph.getVertices()); // Create Word2Vec Queries Set<String> w2vFormatStrings = new HashSet<String>(); for (Vertex v1 : vertexList) { for (Vertex v2 : vertexList) { if (!v1.equals(v2) && !areCandidatesofSameSF(v1, v2)) { List<String> l1 = v1.getUris(); List<String> l2 = v2.getUris(); if (l1.size() == 1 && l2.size() == 1) { String format = this.eckb.generateWord2VecFormatString( l1.get(0), l2.get(0)); w2vFormatStrings.add(format); } } } } Map<String, Float> similarityMap = this.eckb .getWord2VecSimilarities(w2vFormatStrings); for (Vertex v1 : vertexList) { for (Vertex v2 : vertexList) { if (!v1.equals(v2) && !areCandidatesofSameSF(v1, v2)) { List<String> l1 = v1.getUris(); List<String> l2 = v2.getUris(); if (l1.size() == 1 && l2.size() == 1) { double weight = similarityMap.get(this.eckb .generateWord2VecFormatString(l1.get(0), l2.get(0))); // Add Doc2Vec Local Compatibility // First experiment: Harmonic mean // double localComp = (0.8*this.d2v.getDoc2VecSimilarity( // v2.getText(), v2.getContext(), l2.get(0))); addEdge(v1, v2, edgeFactory.create(), weight); } } } } // Set Edge Probabilities Collection<Vertex> vertexes = graph.getVertices(); for (Vertex v : vertexes) { Set<Edge> edges = v.getOutgoingEdges(); for (Edge e : edges) { edgeWeights.put(e, e.getProbability()); } } } @Override protected PageRankWithPriors<Vertex, Edge> performPageRank() { PageRankWithPriors<Vertex, Edge> pr = new PageRankWithPriors<Vertex, Edge>( graph, MapTransformer.getInstance(edgeWeights), getRootPrior(graph.getVertices()), 0.13); pr.setMaxIterations(250); pr.evaluate(); return pr; } @Override public boolean analyzeResults(PageRankWithPriors<Vertex, Edge> pr) { boolean disambiguationStop = true; Collection<Vertex> vertexCol = graph.getVertices(); for (int i = 0; i < repList.size(); i++) { if (!disambiguatedSurfaceForms.get(i)) { int qryNr = repList.get(i).getQueryNr(); double maxScore = 0; SummaryStatistics stats = new SummaryStatistics(); String tempSolution = ""; List<Candidate> scores = new ArrayList<Candidate>(); for (Vertex v : vertexCol) { if (v.getEntityQuery() == qryNr && v.isCandidate()) { scores.add(new Candidate(pr.getVertexScore(v))); double score = Math.abs(pr.getVertexScore(v)); stats.addValue(score); System.out.println("Score for: "+v.getUris().get(0)+" : "+score); if (score > maxScore) { tempSolution = v.getUris().get(0); maxScore = score; } } } SurfaceForm rep = repList.get(i); Collections.sort(scores, Collections.reverseOrder()); double secondMax = scores.get(1).score; if (!Double.isInfinite(maxScore)) { double avg = stats.getMean(); double threshold = computeThreshold(avg, maxScore); // if (secondMax < threshold) { updateGraph(rep.getCandidates(), tempSolution, rep.getQueryNr()); rep.setDisambiguatedEntity(tempSolution); System.out.println("Ich disambiguiere: "+tempSolution); disambiguatedSurfaceForms.set(i); disambiguationStop = false; break; } // } } } return disambiguationStop; } /** * Threshold Computation // IMPORTANT DISAMBIGUATION PARAMETER * * @param avg * @param highest * @return */ private double computeThreshold(double avg, double highest) { double diff = highest - avg; double min = diff * 0.25; return highest - min; } class Candidate implements Comparable<Candidate> { private double score; Candidate(double score) { super(); this.score = score; } @Override public int compareTo(Candidate o) { if (score < o.score) { return -1; } else if (score > o.score) { return 1; } else { return 0; } } } }