package doser.entitydisambiguation.algorithms.collective; import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.commons.collections15.Factory; import org.apache.commons.collections15.Transformer; import org.apache.commons.collections15.functors.MapTransformer; import doser.entitydisambiguation.algorithms.SurfaceForm; import doser.entitydisambiguation.knowledgebases.AbstractEntityCentricKBGeneral; import edu.uci.ics.jung.algorithms.scoring.PageRankWithPriors; import edu.uci.ics.jung.graph.DirectedGraph; import edu.uci.ics.jung.graph.DirectedSparseMultigraph; public abstract class AbstractWord2VecPageRank { protected AbstractEntityCentricKBGeneral eckb; protected Map<Edge, Number> edgeWeights; protected DirectedGraph<Vertex, Edge> graph; protected Factory<Integer> edgeFactory; protected BitSet disambiguatedSurfaceForms; protected List<SurfaceForm> allCandidates; protected List<SurfaceForm> repList; public AbstractWord2VecPageRank(AbstractEntityCentricKBGeneral featureDefinition, List<SurfaceForm> rep) { super(); this.eckb = featureDefinition; this.repList = rep; } public void solve() { while (true) { PageRankWithPriors<Vertex, Edge> pr = performPageRank(); if (analyzeResults(pr)) { break; } } } protected PageRankWithPriors<Vertex, Edge> performPageRank() { PageRankWithPriors<Vertex, Edge> pr = new PageRankWithPriors<Vertex, Edge>(graph, MapTransformer.getInstance(edgeWeights), getRootPrior(graph.getVertices()), 0.15); pr.setMaxIterations(200); pr.evaluate(); return pr; } public void setup() { this.graph = new DirectedSparseMultigraph<Vertex, Edge>(); this.edgeWeights = new HashMap<Edge, Number>(); this.edgeFactory = new Factory<Integer>() { int i = 0; public Integer create() { return i++; } }; List<SurfaceForm> list = new LinkedList<SurfaceForm>(); for (SurfaceForm r : this.repList) { list.add((SurfaceForm) r.clone()); } Collections.sort(list); this.repList = list; this.disambiguatedSurfaceForms = new BitSet(repList.size()); for (int i = 0; i < repList.size(); i++) { if (repList.get(i).getCandidates().size() <= 1) { this.disambiguatedSurfaceForms.set(i); } } buildMainGraph(); } protected void buildMainGraph() { List<String> disambiguatedEntities = new LinkedList<String>(); // Add Vertexes for (SurfaceForm rep : repList) { List<String> arrList = rep.getCandidates(); for (String s : arrList) { int occs = eckb.getFeatureDefinition().getOccurrences(rep.getSurfaceForm(), s); List<String> l = new LinkedList<String>(); l.add(s); if (rep.getCandidates().size() == 1) { disambiguatedEntities.add(rep.getCandidates().get(0)); addVertex(l, rep.getSurfaceForm(), rep.getQueryNr(), true, 20000, rep.getContext()); } else { addVertex(l, rep.getSurfaceForm(), rep.getQueryNr(), true, occs, rep.getContext()); } } } // Add Document AsVertex addVertex(disambiguatedEntities, "", -1, true, 50000, ""); // Add Edges List<Vertex> vertexList = new ArrayList<Vertex>(graph.getVertices()); // Create Word2Vec Queries Set<String> w2vFormatStrings = new HashSet<String>(); for (Vertex v1 : vertexList) { for (Vertex v2 : vertexList) { if (!v1.equals(v2) && !areCandidatesofSameSF(v1, v2)) { List<String> l1 = v1.getUris(); List<String> l2 = v2.getUris(); if (l1.size() == 1 && l2.size() == 1) { String format = this.eckb.generateWord2VecFormatString(l1.get(0), l2.get(0)); w2vFormatStrings.add(format); } else if (l1.size() > l2.size() && l1.size() > 0 && l2.size() > 0) { String format = this.eckb.generateWord2VecFormatString(l1, l2.get(0)); w2vFormatStrings.add(format); } else if (l2.size() > l1.size() && l1.size() > 0 && l2.size() > 0) { String format = this.eckb.generateWord2VecFormatString(l2, l1.get(0)); w2vFormatStrings.add(format); } } } } Map<String, Float> similarityMap = this.eckb.getWord2VecSimilarities(w2vFormatStrings); for (Vertex v1 : vertexList) { for (Vertex v2 : vertexList) { if (!v1.equals(v2) && !areCandidatesofSameSF(v1, v2)) { List<String> l1 = v1.getUris(); List<String> l2 = v2.getUris(); if (l1.size() == 1 && l2.size() == 1) { double weight = similarityMap.get(this.eckb.generateWord2VecFormatString(l1.get(0), l2.get(0))); if (weight < 0.00000001) { System.out.println(weight + " " + l1.get(0) + " " + l2.get(0)); } // Add Doc2Vec Local Compatibility // First experiment: Harmonic mean // double localComp = super.getDoc2VecSimilarity( // v2.getText(), v2.getContext(), l2.get(0)); // double hm = 2 * (localComp * weight) // / (localComp + weight); // System.out.println(l1.get(0) + " "+l2.get(0) // +" Connection: "+ weight+ " Localcomp: "+ localComp // + "HarmonicMean: "+ hm); addEdge(v1, v2, edgeFactory.create(), weight); } else if (l1.size() > l2.size() && l1.size() > 0 && l2.size() > 0) { double weight = similarityMap.get(this.eckb.generateWord2VecFormatString(l1, l2.get(0))); addEdge(v1, v2, edgeFactory.create(), weight); } else if (l2.size() > l1.size() && l1.size() > 0 && l2.size() > 0) { double weight = similarityMap.get(this.eckb.generateWord2VecFormatString(l2, l1.get(0))); addEdge(v1, v2, edgeFactory.create(), weight); } } } } // Set Edge Probabilities Collection<Vertex> vertexes = graph.getVertices(); for (Vertex v : vertexes) { Set<Edge> edges = v.getOutgoingEdges(); for (Edge e : edges) { // System.out.println("From: " + v.getUris().get(0) + // " To: "+e.getTarget().getUris().get(0)+ // " Probability: "+e.getProbability()); edgeWeights.put(e, e.getProbability()); } } } protected void addVertex(List<String> uri, String sf, int qryNr, boolean isCandidate, int occurrences, String context) { Vertex v = new Vertex(); for (String u : uri) { v.addUri(u); } v.setCandidate(isCandidate); v.setText(sf); v.setEntityQuery(qryNr); v.setOccurrences(occurrences); v.setContext(context); graph.addVertex(v); } protected void addEdge(Vertex out, Vertex in, int edgeNr, double transition) { Edge edge = new Edge(edgeNr, in, transition); out.addOutGoingEdge(edge); graph.addEdge(edge, out, in); } protected void removeVertex(Vertex rem) { Set<Edge> outs = rem.getOutgoingEdges(); for (Edge e : outs) { removeEdge(e); } rem.removeAllOutgoingEdges(); Collection<Vertex> n = graph.getNeighbors(rem); // BugFix if (n != null) { for (Vertex v : n) { removeEdge(v, rem); } } graph.removeVertex(rem); } protected void updateGraph(List<String> candidates, String disambiguatedEntity, int entityQry) { Collection<Vertex> vertexCol = graph.getVertices(); List<Vertex> relVertexes = new ArrayList<Vertex>(); for (Vertex v : vertexCol) { if (v.getEntityQuery() == entityQry) { relVertexes.add(v); } } for (String s : candidates) { if (!s.equalsIgnoreCase(disambiguatedEntity)) { for (Vertex v : relVertexes) { if (v.getUris().get(0).equalsIgnoreCase(s)) { removeVertex(v); } } } } } public List<SurfaceForm> getRepresentation() { return this.repList; } /** * Assigns a probability of 1/<code>roots.size()</code> to each of the * elements of <code>roots</code>. * * @param <V> * the vertex type * @param roots * the vertices to be assigned nonzero prior probabilities * @return */ protected Transformer<Vertex, Double> getRootPrior(Collection<Vertex> roots) { final Collection<Vertex> inner_roots = roots; double sum = 0; for (Vertex v : inner_roots) { sum += v.getOccurrences(); } final double overallOccs = sum; Transformer<Vertex, Double> distribution = new Transformer<Vertex, Double>() { public Double transform(Vertex input) { if (inner_roots.contains(input)) { double d = new Double(input.getOccurrences() / (double) overallOccs); return d; } else { return 0.0; } } }; return distribution; } protected List<SurfaceForm> getCollectiveSFRepresentations() { return this.repList; } // protected List<String> computeSensePriorRankedList(int qryNr, int bestOf) { // List<Candidate> canList = new LinkedList<Candidate>(); // Collection<Vertex> vertexCol = graph.getVertices(); // for (Vertex c : vertexCol) { // if (c.getEntityQuery() == qryNr && c.isCandidate()) { // canList.add(new Candidate(c.getUris().get(0), c.getOccurrences())); // } // } // Collections.sort(canList, Collections.reverseOrder()); // List<String> strList = new LinkedList<String>(); // for (Candidate c : canList.subList(0, bestOf)) { // strList.add(c.candidate); // } // return strList; // } private void removeEdge(Edge e) { graph.removeEdge(e); edgeWeights.remove(e); } private void removeEdge(Vertex out, Vertex in) { Edge e = out.removeOutgoingEdge(in, edgeWeights); if (e != null) { graph.removeEdge(e); edgeWeights.remove(e); } } protected boolean areCandidatesofSameSF(Vertex v1, Vertex v2) { int qryNr1 = v1.getEntityQuery(); int qryNr2 = v2.getEntityQuery(); if (qryNr1 == -1 || qryNr2 == -1 || v1.getEntityQuery() != v2.getEntityQuery()) { return false; } return true; } public abstract boolean analyzeResults(PageRankWithPriors<Vertex, Edge> pr); protected class Candidate implements Comparable<Candidate> { private String candidate; private double score; protected Candidate(String candidate, double score) { super(); this.candidate = candidate; this.score = score; } @Override public int compareTo(Candidate o) { if (this.score < o.score) { return -1; } else if (this.score > o.score) { return 1; } else { return 0; } } protected String getCandidate() { return candidate; } protected double getScore() { return score; } } }