package doser.webclassify.algorithm;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.collections15.Factory;
import org.apache.commons.collections15.functors.MapTransformer;
import org.codehaus.jettison.json.JSONArray;
import org.codehaus.jettison.json.JSONException;
import org.codehaus.jettison.json.JSONObject;
import doser.entitydisambiguation.algorithms.collective.Edge;
import doser.entitydisambiguation.algorithms.collective.Vertex;
import doser.entitydisambiguation.dpo.DisambiguatedEntity;
import doser.language.Languages;
import doser.webclassify.dpo.Paragraph;
import doser.word2vec.Word2VecJsonFormat;
import edu.uci.ics.jung.algorithms.scoring.PageRank;
import edu.uci.ics.jung.graph.DirectedGraph;
import edu.uci.ics.jung.graph.DirectedSparseMultigraph;
public class EntitySignificanceAlgorithmPR_W2V implements
EntityRelevanceAlgorithm {
private Map<String, Float> word2vecsimilarities;
private Map<Edge, Number> edgeWeights;
private Factory<Integer> edgeFactory;
@Override
public String process(Map<DisambiguatedEntity, Integer> map, Paragraph p, Languages lang) {
Set<String> entitySet = new HashSet<String>();
List<String> entities = new LinkedList<String>();
for (Map.Entry<DisambiguatedEntity, Integer> entry : map.entrySet()) {
entities.add(entry.getKey().getEntityUri());
entitySet.add(entry.getKey().getEntityUri());
}
if (entities.size() == 0) {
return "";
} else {
computeWord2VecSimilarities(entitySet);
DirectedGraph<Vertex, Edge> graph = buildGraph(entities);
PageRank<Vertex, Edge> pr = new PageRank<Vertex, Edge>(graph,
MapTransformer.getInstance(edgeWeights), 0.1);
pr.setMaxIterations(100);
pr.evaluate();
Collection<Vertex> vertexCol = graph.getVertices();
String topEntity = null;
double max = 0;
for (Vertex v : vertexCol) {
Double score = pr.getVertexScore(v);
if (score > max) {
topEntity = v.getUris().get(0);
max = score;
}
}
return topEntity;
}
}
private float getWord2VecSimilarity(String source, String target) {
source = source.replaceAll("http://dbpedia.org/resource/", "");
target = target.replaceAll("http://dbpedia.org/resource/", "");
int c = source.compareToIgnoreCase(target);
String res = "";
if (c < 0) {
res = source + "|" + target;
} else if (c == 0) {
res = source + "|" + target;
} else {
res = target + "|" + source;
}
float result = 0;
if (this.word2vecsimilarities.containsKey(res)) {
result = this.word2vecsimilarities.get(res) + 1.0f;
}
return result;
}
private void computeWord2VecSimilarities(Set<String> entities) {
this.word2vecsimilarities = new HashMap<String, Float>();
Set<String> combinations = new HashSet<String>();
for (String s1 : entities) {
for (String s2 : entities) {
combinations.add(s1.replaceAll("http://dbpedia.org/resource/",
"")
+ "|"
+ s2.replaceAll("http://dbpedia.org/resource/", ""));
}
}
Word2VecJsonFormat format = new Word2VecJsonFormat();
format.setData(combinations);
JSONArray res = Word2VecJsonFormat.performquery(format, "w2vsim");
for (int i = 0; i < res.length(); i++) {
try {
JSONObject obj = res.getJSONObject(i);
String ents = obj.getString("ents");
float sim = (float) obj.getDouble("sim");
this.word2vecsimilarities.put(ents, sim);
} catch (JSONException e) {
e.printStackTrace();
}
}
}
private DirectedGraph<Vertex, Edge> buildGraph(List<String> entities) {
this.edgeWeights = new HashMap<Edge, Number>();
this.edgeFactory = new Factory<Integer>() {
int i = 0;
public Integer create() {
return i++;
}
};
DirectedGraph<Vertex, Edge> graph = new DirectedSparseMultigraph<Vertex, Edge>();
for (String e : entities) {
Vertex v = new Vertex();
v.addUri(e);
graph.addVertex(v);
}
Collection<Vertex> vertexes = graph.getVertices();
for (Vertex v1 : vertexes) {
for (Vertex v2 : vertexes) {
float similarity = this.getWord2VecSimilarity(
v1.getUris().get(0), v2.getUris().get(0));
Edge edge = new Edge(this.edgeFactory.create(), v2, similarity);
v1.addOutGoingEdge(edge);
graph.addEdge(edge, v1, v2);
}
}
vertexes = graph.getVertices();
for (Vertex v : vertexes) {
Set<Edge> edges = v.getOutgoingEdges();
for (Edge e : edges) {
edgeWeights.put(e, e.getProbability());
}
}
return graph;
}
}