package com.personalityextractor.entity.resolver;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import com.personalityextractor.data.source.Wikiminer;
import com.personalityextractor.entity.Entity;
import com.personalityextractor.entity.WikipediaEntity;
import com.personalityextractor.entity.extractor.EntityExtractFactory;
import com.personalityextractor.entity.extractor.IEntityExtractor;
import com.personalityextractor.entity.extractor.EntityExtractFactory.Extracter;
import com.personalityextractor.evaluation.PerfMetrics;
import com.personalityextractor.evaluation.PerfMetrics.Metric;
import com.personalityextractor.store.LuceneStore;
import com.personalityextractor.store.WikiminerDB;
import cs224n.util.CounterMap;
public class ViterbiResolver extends BaseEntityResolver {
private static LuceneStore db;
static {
try {
db = LuceneStore.getInstance();
db.loadIndices();
} catch (Exception e) {
e.printStackTrace();
}
}
public ViterbiResolver() {
}
/*
* get wikimimer compare() scores between all entities
*/
private CounterMap<String, String> populateCompareScores(
List<String> twEntities,
HashMap<String, ArrayList<WikipediaEntity>> tweetEntityTowikiEntities) {
CounterMap<String, String> probabilites = new CounterMap<String, String>();
for (int i = 0; i < tweetEntityTowikiEntities.get(twEntities.get(1))
.size(); i++) {
probabilites.setCount("-1",
tweetEntityTowikiEntities.get(twEntities.get(1)).get(i)
.getWikiminerID(), 0.0000001);
probabilites.setCount(
tweetEntityTowikiEntities.get(twEntities.get(1)).get(i)
.getWikiminerID(), "-1", 0.0000001);
}
for (int i = 1; i < twEntities.size(); i++) {
String twEntity = twEntities.get(i);
ArrayList<WikipediaEntity> wikiEntities = tweetEntityTowikiEntities
.get(twEntity);
for (int j = 0; j < wikiEntities.size(); j++) {
// iterate over ALL wikiEntities and get compare score for
// wikiEntities[i]
for (int k = i + 1; k < twEntities.size(); k++) {
ArrayList<WikipediaEntity> wEntities = tweetEntityTowikiEntities
.get(twEntities.get(k));
for (WikipediaEntity wEntity : wEntities) {
if (wEntity.getText().equalsIgnoreCase("void_node")
|| wEntity.getText().equalsIgnoreCase(
"start_node")
|| wEntity.getText().equalsIgnoreCase(
"end_node")) {
probabilites.setCount(wikiEntities.get(j)
.getWikiminerID(),
wEntity.getWikiminerID(), 0.0000001);
probabilites.setCount(wEntity.getWikiminerID(),
wikiEntities.get(j).getWikiminerID(),
0.0000001);
continue;
}
probabilites.setCount(wikiEntities.get(j)
.getWikiminerID(), wEntity.getWikiminerID(), db
.compare(wikiEntities.get(j).getWikiminerID(),
wEntity.getWikiminerID()));
probabilites.setCount(wEntity.getWikiminerID(),
wikiEntities.get(j).getWikiminerID(), db
.compare(wikiEntities.get(j)
.getWikiminerID(), wEntity
.getWikiminerID()));
}
}
}
}
return probabilites;
}
private HashMap<String, String> buildwikiIDToTweetEntityMap(
HashMap<String, ArrayList<WikipediaEntity>> tweetEntityTowikiEntities) {
// assuming that wikipedia ids are unique
HashMap<String, String> wikiIDToTweetEntity = new HashMap<String, String>();
Object[] objArray = tweetEntityTowikiEntities.keySet().toArray();
List<String> twEntities = Arrays.asList(Arrays.copyOf(objArray,
objArray.length, String[].class));
for (int i = 0; i < twEntities.size(); i++) {
ArrayList<WikipediaEntity> wikiEntities = tweetEntityTowikiEntities
.get(twEntities.get(i));
for (WikipediaEntity we : wikiEntities) {
wikiIDToTweetEntity.put(we.getWikiminerID(), twEntities.get(i));
}
}
return wikiIDToTweetEntity;
}
private HashMap<String, ArrayList<WikipediaEntity>> getWikiSenses(
List<String> entities) {
HashMap<String, ArrayList<WikipediaEntity>> tweetEntityTowikiEntities = new HashMap<String, ArrayList<WikipediaEntity>>();
// start node
ArrayList<WikipediaEntity> start = new ArrayList<WikipediaEntity>();
start.add(new WikipediaEntity("start_node", "-1", -1, "0.0000001"));
tweetEntityTowikiEntities.put("start_node", start);
// end node
ArrayList<WikipediaEntity> end = new ArrayList<WikipediaEntity>();
end.add(new WikipediaEntity("end_node", "-2", -1, "0.0000001"));
tweetEntityTowikiEntities.put("end_node", end);
for (String entity : entities) {
// List<WikipediaEntity> wikiEntities = new
// ArrayList<WikipediaEntity>();
// String xml = Wikiminer.getXML(entity, false);
// if (xml == null)
// continue;
// ArrayList<String[]> weentities =
// Wikiminer.getWikipediaSenses(xml, true);
// if (weentities.size() == 0)
// continue;
ArrayList<WikipediaEntity> ids = new ArrayList<WikipediaEntity>();
// for (String[] arr : weentities) {
// WikipediaEntity we = new WikipediaEntity(arr[0], arr[1], arr[2]);
// ids.add(we);
// wikiEntities.add(we);
// }
ids.addAll(db.search(entity));
// adding a void entity
WikipediaEntity we = new WikipediaEntity("void_node", "0",
-1, "0.0000001");
ids.add(we);
tweetEntityTowikiEntities.put(entity, ids);
}
return tweetEntityTowikiEntities;
}
private static List<String> swap(List<String> l, int p) {
int x = p;
int y = p + 1;
if (p == l.size() - 1) {
y = 0;
}
List<String> sList = new ArrayList<String>();
for (int i = 0; i < l.size(); i++) {
if (i == x) {
sList.add(l.get(y));
} else if (i == y) {
sList.add(l.get(x));
} else {
sList.add(l.get(i));
}
}
return sList;
}
public List<WikipediaEntity> resolve(List<String> entities) {
Date d1 = new Date();
List<WikipediaEntity> entityList = new ArrayList<WikipediaEntity>();
double bestProbability = (-1) * Integer.MAX_VALUE;
String bestPath = "";
HashMap<String, String> idToWikiEntityText = new HashMap<String, String>();
double minScore = Math.log(0.0000001);
// find potential wiki entities for each entity
HashMap<String, ArrayList<WikipediaEntity>> tweetEntityTowikiEntities = getWikiSenses(entities);
// remove entities which have no wikipedia senses
Object[] objArray = tweetEntityTowikiEntities.keySet().toArray();
List<String> twEntities = Arrays.asList(Arrays.copyOf(objArray,
objArray.length, String[].class));
for (int i = 0; i < entities.size(); i++) {
if (!twEntities.contains(entities.get(i))) {
entities.remove(i);
i--;
}
}
twEntities = entities;
// incase there is only entity
if (twEntities.size() == 1) {
entityList.add(tweetEntityTowikiEntities.get(twEntities.get(0))
.get(0));
return entityList;
}
// try all permutations of entities
// for (int x = 0; x < entities.size(); x++) {
// for (int z = 0; z < entities.size() - 1; z++) {
for (int x = 0; x < 1; x++) {
for (int z = 0; z < 1; z++) {
twEntities = swap(twEntities, z);
// add start and end nodes
twEntities.add(0, "start_node");
twEntities.add(twEntities.size(), "end_node");
// pre-calculate all compare scores between wikipedia entities.
// CounterMap<String, String> probabilites =
// populateCompareScores(
// twEntities, tweetEntityTowikiEntities);
// declare the dp matrix and initialize it for the first state
HashMap<String, String[]> prev_BestPaths = new HashMap<String, String[]>();
ArrayList<WikipediaEntity> first_entities = tweetEntityTowikiEntities
.get(twEntities.get(0));
for (WikipediaEntity we : first_entities) {
idToWikiEntityText.put(we.getWikiminerID(), we.getText());
prev_BestPaths.put(
we.getWikiminerID(),
new String[] { we.getCommonness(),
we.getWikiminerID(), we.getCommonness() });
}
for (int i = 1; i < twEntities.size(); i++) {
HashMap<String, String[]> next_BestPaths = new HashMap<String, String[]>();
ArrayList<WikipediaEntity> next_WikiSenses = tweetEntityTowikiEntities
.get(twEntities.get(i));
for (int j = 0; j < next_WikiSenses.size(); j++) {
idToWikiEntityText.put(next_WikiSenses.get(j)
.getWikiminerID(), next_WikiSenses.get(j)
.getText());
double total = 0;
String maxpath = "";
double maxprob = (-1) * Integer.MAX_VALUE;
double prob = 1;
String v_path = "";
double v_prob = 1;
ArrayList<WikipediaEntity> previous_WikiSenses = tweetEntityTowikiEntities
.get(twEntities.get(i - 1));
for (int k = 0; k < previous_WikiSenses.size(); k++) {
String[] objs = prev_BestPaths
.get(previous_WikiSenses.get(k)
.getWikiminerID());
prob = Double.parseDouble(objs[0]);
v_path = (String) objs[1];
v_prob = Double.parseDouble(objs[2]);
double count = db.compare(previous_WikiSenses
.get(k).getWikiminerID(), next_WikiSenses
.get(j).getWikiminerID());
double compareScore;
if (count == 0.0) {
compareScore = minScore;
} else {
compareScore = Math.log(count);
}
prob += (compareScore + (Double
.valueOf(previous_WikiSenses.get(k)
.getCommonness())));
v_prob += (compareScore + (Double
.valueOf(previous_WikiSenses.get(k)
.getCommonness())));
total += Math.exp(prob);
if (v_prob > maxprob) {
maxprob = v_prob;
maxpath = v_path
+ ","
+ next_WikiSenses.get(j)
.getWikiminerID();
}
}
next_BestPaths.put(next_WikiSenses.get(j)
.getWikiminerID(),
new String[] { String.valueOf(Math.log(total)),
maxpath, String.valueOf(maxprob) });
}
prev_BestPaths = next_BestPaths;
}
double total = 0;
String maxpath = "";
double maxprob = (-1) * Integer.MAX_VALUE;
double prob = 1;
String v_path = "";
double v_prob = 1;
for (String s : prev_BestPaths.keySet()) {
String[] info = prev_BestPaths.get(s);
prob = Double.parseDouble(info[0]);
v_path = info[1];
v_prob = Double.parseDouble(info[2]);
total += Math.exp(prob);
if (v_prob > maxprob) {
maxpath = v_path;
maxprob = v_prob;
}
}
if (maxprob > bestProbability) {
bestPath = maxpath;
bestProbability = maxprob;
// bestSequence = new ArrayList<String>(twEntities);
}
// System.out.println("Entities : " + twEntities);
// System.out.println("MaxPath: " + maxpath + "\tMaxProb: "
// + maxprob + "\n");
twEntities.remove(0);
twEntities.remove(twEntities.size() - 1);
}
}
// System.out.println("BestPath: " + bestPath + "\tBestProb: "
// + bestProbability + "\n");
String[] ids = bestPath.split(",");
for (int l = 1; l < ids.length - 1; l++) {
entityList.add(new WikipediaEntity(idToWikiEntityText.get(ids[l]),
ids[l], 1));
}
Date d2 = new Date();
PerfMetrics.getInstance().addToMetrics(Metric.RESOLUTION, (d2.getTime()-d1.getTime()));
return entityList;
}
public static void main(String args[]) {
ViterbiResolver vr = new ViterbiResolver();
try {
BufferedReader br = new BufferedReader(new FileReader(args[0]));
String line = "";
IEntityExtractor extractor = EntityExtractFactory
.produceExtractor(Extracter.NOUNPHRASE);
// List<String> ents = extractor.extract("Elantra with a Santa Fe");
// System.out.println(ents);
// List<WikipediaEntity> wes = vr.resolve(ents);
while ((line = br.readLine()) != null) {
System.out.println("tweet: " + line);
List<String> entities = extractor.extract(line);
System.out.println("entities: " + entities);
List<WikipediaEntity> wes = vr.resolve(entities);
for (WikipediaEntity we : wes) {
System.out.println(we.getText());
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}