package esl.cuenet.generative.structs;
import com.google.common.collect.*;
import com.mongodb.BasicDBObject;
import com.mongodb.util.JSON;
import esl.cuenet.algorithms.firstk.personal.accessor.Candidates;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.LineIterator;
import org.apache.log4j.Logger;
import java.io.File;
import java.io.IOException;
import java.util.*;
public class Propagate {
private Logger logger = Logger.getLogger(getClass());
private final ContextNetwork network;
private final double d = 0.85;
private Map<ContextNetwork.Instance, Double> eventScoreTable;
private final Table<Integer, Integer, Integer> semanticDistances;
private Table<ContextNetwork.Instance, ContextNetwork.Instance, Integer> sparseSubeventCountTable;
private Table<ContextNetwork.Instance, ContextNetwork.Instance, Integer> sparseObjectCountTable;
private Multimap<ContextNetwork.Instance, ContextNetwork.Instance> eventsWithinSpatialRange;
private Multimap<ContextNetwork.Instance, ContextNetwork.Instance> eventsWithinTemporalRange;
private long timespan;
private final SpaceTimeValueGenerators stGenerators;
private final int maxSemanticDistance;
private int maxCommonSubevents = 0;
private Table<ContextNetwork.Instance, ContextNetwork.Entity, Double> scoreTable = HashBasedTable.create();
private HashSet<ContextNetwork.Entity> allEntites;
public Propagate(ContextNetwork network, String semanticDistanceFile, SpaceTimeValueGenerators generators) {
this.network = network;
this.stGenerators = generators;
this.semanticDistances = HashBasedTable.create();
try {
LineIterator iter = FileUtils.lineIterator(new File(semanticDistanceFile));
while (iter.hasNext()) {
String line = iter.next();
String[] parts = line.split(" -> ");
int typeId = Integer.parseInt(parts[0]);
Map distMap = ((BasicDBObject) JSON.parse(parts[1])).toMap();
for (Object entry: distMap.entrySet()) {
Map.Entry e = (Map.Entry) entry;
semanticDistances.put(typeId, Integer.parseInt(e.getKey().toString()), (Integer) e.getValue());
}
}
} catch (IOException e) {
e.printStackTrace();
}
int maxDistance = -1;
for (Table.Cell<Integer, Integer, Integer> cell: semanticDistances.cellSet()) {
if (cell.getValue() > maxDistance) maxDistance = cell.getValue();
}
maxSemanticDistance = maxDistance;
long mintime = Long.MAX_VALUE, maxtime = Long.MIN_VALUE;
for (ContextNetwork.IndexedSubeventTree tree: network.eventTrees) {
if (tree.root.intervalStart < mintime) mintime = tree.root.intervalStart;
if (tree.root.intervalStart > maxtime) maxtime = tree.root.intervalStart;
}
timespan = maxtime - mintime;
logger.info("Timespan = " + timespan);
logger.info("MaxSemanticDistance = " + maxDistance);
}
public void prepare(HashSet<String> entities) {
int nc = network.nodeCount();
eventScoreTable = Maps.newHashMapWithExpectedSize(nc);
sparseSubeventCountTable = HashBasedTable.create(nc, nc);
sparseObjectCountTable = HashBasedTable.create(nc, nc);
eventsWithinSpatialRange = HashMultimap.create(nc, 100);
eventsWithinTemporalRange = HashMultimap.create(nc, 100);
allEntites = new HashSet<ContextNetwork.Entity>();
logger.info("Initializing scores");
double score;
for (ContextNetwork.IndexedSubeventTree tree: network.eventTrees) {
for (ContextNetwork.Instance instance: tree.instanceMap.values()) {
score = 0;
for (ContextNetwork.Entity person: instance.participants) {
allEntites.add(person);
if (entities.contains(person.id)) {
score++;
}
}
eventScoreTable.put(instance, score);
}
}
logger.info("Creating nearby index, and counting common subevents");
for (ContextNetwork.IndexedSubeventTree tree: network.eventTrees) {
findEventsWithinSTRange(tree);
}
for (ContextNetwork.IndexedSubeventTree tree: network.eventTrees) {
for (ContextNetwork.Instance instance: tree.instanceMap.values()) {
for (ContextNetwork.Entity entity: allEntites) {
scoreTable.put(instance, entity, 0.0);
}
}
}
for (ContextNetwork.IndexedSubeventTree tree: network.eventTrees) {
for (ContextNetwork.Instance instance: tree.instanceMap.values()) {
for (ContextNetwork.Entity entity: instance.participants) {
scoreTable.put(instance, entity, 1.0);
}
}
}
logger.info(scoreTable.size());
}
private int gzCount() {
int a=0;
for (double v: eventScoreTable.values()) if (v > 0) a++;
logger.info("Events with score > 0 = " + a);
return a;
}
private void findEventsWithinSTRange(ContextNetwork.IndexedSubeventTree tree) {
for (ContextNetwork.IndexedSubeventTree otherTree: network.eventTrees) {
if (otherTree == tree) continue;
if (Math.abs(tree.root.intervalStart - otherTree.root.intervalStart) < (50 * timespan/100)) {
for (ContextNetwork.Instance instance: tree.instanceMap.values()) {
for (ContextNetwork.Instance otherInstance: otherTree.instanceMap.values()) {
eventsWithinTemporalRange.put(instance, otherInstance);
findCommonSubs(tree, instance, otherTree, otherInstance);
}
}
}
}
String thisloc = tree.root.location;
for (ContextNetwork.IndexedSubeventTree otherTree: network.eventTrees) {
if (otherTree == tree) continue;
String otherloc = otherTree.root.location;
if (stGenerators.distance(thisloc, otherloc) < 50) {
for (ContextNetwork.Instance instance: tree.instanceMap.values()) {
for (ContextNetwork.Instance otherInstance: otherTree.instanceMap.values()) {
eventsWithinSpatialRange.put(instance, otherInstance);
findCommonSubs(tree, instance, otherTree, otherInstance);
}
}
}
}
}
private void findCommonSubs(ContextNetwork.IndexedSubeventTree tree, ContextNetwork.Instance instance,
ContextNetwork.IndexedSubeventTree otree, ContextNetwork.Instance otherInstance) {
if (sparseSubeventCountTable.contains(instance, otherInstance)) return;
Set<Integer> subtypesInstance = getSubeventTypes(tree, instance);
Set<Integer> subtypesOtherInstance = getSubeventTypes(otree, otherInstance);
int count = Sets.intersection(subtypesInstance, subtypesOtherInstance).size();
if (count > maxCommonSubevents) maxCommonSubevents = count;
sparseSubeventCountTable.put(instance, otherInstance, count);
sparseSubeventCountTable.put(otherInstance, instance, count);
Set<String> objectsInstance = getObjects(tree, instance);
Set<String> objectsOtherInstance = getObjects(tree, instance);
count = Sets.intersection(objectsInstance, objectsOtherInstance).size();
sparseObjectCountTable.put(instance, otherInstance, count);
sparseObjectCountTable.put(otherInstance, instance, count);
}
private int getCommonSubeventCount(ContextNetwork.Instance i, ContextNetwork.Instance j) {
if ( !sparseSubeventCountTable.contains(i, j) ) return 0;
return sparseSubeventCountTable.get(i, j);
}
private int getCommonObjectCount(ContextNetwork.Instance i, ContextNetwork.Instance j) {
if ( !sparseObjectCountTable.contains(i, j) ) return 0;
return sparseObjectCountTable.get(i, j);
}
private Set<String> getObjects(ContextNetwork.IndexedSubeventTree tree, ContextNetwork.Instance instance) {
Set<String> objects = Sets.newHashSet();
Stack<ContextNetwork.Instance> visited = new Stack<ContextNetwork.Instance>();
visited.add(instance);
while ( !visited.empty() ) {
ContextNetwork.Instance i = visited.pop();
for (ContextNetwork.InstanceId sub: i.immediateSubevents) {
ContextNetwork.Instance iSub = tree.instanceMap.get(sub);
visited.add(iSub);
for (ContextNetwork.Entity p: iSub.participants) objects.add(p.id);
}
}
return objects;
}
private Set<Integer> getSubeventTypes(ContextNetwork.IndexedSubeventTree tree, ContextNetwork.Instance instance) {
Set<Integer> subTypes = Sets.newHashSet();
Stack<ContextNetwork.Instance> visited = new Stack<ContextNetwork.Instance>();
visited.add(instance);
while ( !visited.empty() ) {
ContextNetwork.Instance i = visited.pop();
for (ContextNetwork.InstanceId sub: i.immediateSubevents) {
visited.add(tree.instanceMap.get(sub));
subTypes.add(sub.eventId);
}
}
return subTypes;
}
public double propagateOnceTable() {
Table<ContextNetwork.Instance, ContextNetwork.Entity, Double> timeScore = temporalPropagationTable();
//Table<ContextNetwork.Instance, ContextNetwork.Entity, Double> typeScore = typePropagationTable();
Table<ContextNetwork.Instance, ContextNetwork.Entity, Double> newScore = HashBasedTable.create();
double score; Double temp;
for (Table.Cell<ContextNetwork.Instance, ContextNetwork.Entity, Double> cell: scoreTable.cellSet()) {
temp = timeScore.get(cell.getRowKey(), cell.getColumnKey());
score = (temp == null? 0 : temp);
newScore.put(cell.getRowKey(), cell.getColumnKey(), d * score);
}
double delta = computeDeltaTable(scoreTable, newScore);
scoreTable = newScore;
return delta;
}
private double computeDeltaTable(Table<ContextNetwork.Instance, ContextNetwork.Entity, Double> scoreTable,
Table<ContextNetwork.Instance, ContextNetwork.Entity, Double> newScore) {
double l1norm = -1;
for (ContextNetwork.Instance instance: scoreTable.rowKeySet()) {
Map<ContextNetwork.Entity, Double> original = scoreTable.row(instance);
Map<ContextNetwork.Entity, Double> updates = newScore.row(instance);
double l1 = 0;
for (ContextNetwork.Entity object: original.keySet()) {
l1 += Math.abs(original.get(object) - updates.get(object));
}
l1norm = Math.max(l1norm, l1);
//logger.info("Delta " + instance + " " + l1);
}
return l1norm;
}
public double propagateOnce() {
Map<ContextNetwork.Instance, Double> timeScore = temporalPropagation();
//Map<ContextNetwork.Instance, Double> spatialScore = spatialPropagation();
Map<ContextNetwork.Instance, Double> objectScore = objectPropagation();
Map<ContextNetwork.Instance, Double> typeScore = typePropagation();
Map<ContextNetwork.Instance, Double> structuralScore = structuralPropagation();
Map<ContextNetwork.Instance, Double> newScore = Maps.newHashMapWithExpectedSize(eventScoreTable.size());
for (ContextNetwork.Instance item: eventScoreTable.keySet()) {
double score = 0;
score += timeScore.get(item);
//score += spatialScore.get(item);
//score += objectScore.get(item);
//score += typeScore.get(item);
//score += structuralScore.get(item);
newScore.put(item, d * score);
}
double delta = computeDelta(eventScoreTable, newScore);
eventScoreTable = newScore;
return delta;
}
private Table<ContextNetwork.Instance, ContextNetwork.Entity, Double> temporalPropagationTable() {
Table<ContextNetwork.Instance, ContextNetwork.Entity, Double> newScore = HashBasedTable.create();
long range = 1 * timespan/100;
long diff; double fraction, ns;
HashMap<ContextNetwork.Instance, Integer> neighborCount = Maps.newHashMap();
for (ContextNetwork.Instance instance: eventScoreTable.keySet()) {
int n = 0;
for (ContextNetwork.Instance neighbor: eventScoreTable.keySet()) {
if (instance == neighbor) continue;
diff = Math.abs(instance.intervalStart - neighbor.intervalStart);
if (diff > range) continue;
n++;
}
neighborCount.put(instance, n);
}
for (ContextNetwork.Instance instance: eventScoreTable.keySet()) {
int numOfNeighbors = neighborCount.get(instance);
if (numOfNeighbors == 0) continue;
for (ContextNetwork.Instance neighbor: eventScoreTable.keySet()) {
if (instance == neighbor) continue;
diff = Math.abs(instance.intervalStart - neighbor.intervalStart);
if (diff > range) continue;
fraction = 1 - (double)diff/range;
Map<ContextNetwork.Entity, Double> scoresAtInstance = scoreTable.row(instance);
for(Map.Entry<ContextNetwork.Entity, Double> entry: scoresAtInstance.entrySet()) {
if (Double.compare(entry.getValue(), 0) == 0) continue;
//ns = scoreTable.get(neighbor, entry.getKey());
//ns = ns + fraction * entry.getValue();
//if (ns > 1) ns = 1;
ns = fraction * entry.getValue() / numOfNeighbors;
newScore.put(neighbor, entry.getKey(), ns);
}
}
}
return newScore;
}
private Table<ContextNetwork.Instance, ContextNetwork.Entity, Double> typePropagationTable() {
Table<ContextNetwork.Instance, ContextNetwork.Entity, Double> newScore = HashBasedTable.create();
return newScore;
}
private Map<ContextNetwork.Instance, Double> temporalPropagation() {
Map<ContextNetwork.Instance, Double> scores = Maps.newHashMapWithExpectedSize(eventScoreTable.size());
double sum;
double intervalDiff;
for (ContextNetwork.Instance instance: eventScoreTable.keySet()) {
sum = 0;
for (ContextNetwork.Instance neighbor: eventsWithinTemporalRange.get(instance)) {
double neighScore = eventScoreTable.get(neighbor);
if (Double.compare(neighScore, 0) == 0) continue;
intervalDiff = 1- (Math.abs((double)instance.intervalStart - neighbor.intervalStart) / timespan);
neighScore = neighScore * intervalDiff; //linear drop
neighScore /= eventsWithinTemporalRange.get(neighbor).size(); //normalize using temporal fanout from each neighbor
sum += neighScore;
}
scores.put(instance, sum);
}
return scores;
}
private Map<ContextNetwork.Instance, Double> spatialPropagation() {
Map<ContextNetwork.Instance, Double> scores = Maps.newHashMapWithExpectedSize(eventScoreTable.size());
double sum;
double spatialDiff;
for (ContextNetwork.Instance instance: eventScoreTable.keySet()) {
sum = 0;
for (ContextNetwork.Instance neighbor: eventsWithinSpatialRange.get(instance)) {
double neighScore = eventScoreTable.get(neighbor);
if (neighScore == 0) continue;
spatialDiff = 1 - (stGenerators.distance(instance.location, neighbor.location))/stGenerators.getMaxDist();
neighScore = neighScore * spatialDiff;
neighScore /= eventsWithinSpatialRange.get(neighbor).size();
sum += neighScore;
}
//logger.info(instance + " " + sum);
scores.put(instance, sum);
}
return scores;
}
private Map<ContextNetwork.Instance, Double> objectPropagation() {
Map<ContextNetwork.Instance, Double> scores = Maps.newHashMapWithExpectedSize(eventScoreTable.size());
double sum;
double scoreDiff;
for (ContextNetwork.Instance instance: eventScoreTable.keySet()) {
sum = 0;
for (ContextNetwork.Instance neighbor: sparseObjectCountTable.row(instance).keySet()) {
scoreDiff = sparseObjectCountTable.get(instance, neighbor);
scoreDiff = scoreDiff / sparseObjectCountTable.row(neighbor).size();
sum += scoreDiff;
}
//logger.info(instance + " " + sum);
scores.put(instance, sum);
}
return scores;
}
private Map<ContextNetwork.Instance, Double> typePropagation() {
Map<ContextNetwork.Instance, Double> scores = Maps.newHashMapWithExpectedSize(eventScoreTable.size());
double sum;
double scoreDiff;
for (ContextNetwork.Instance instance: eventScoreTable.keySet()) {
sum = 0;
for (ContextNetwork.Instance neighbor: sparseObjectCountTable.row(instance).keySet()) {
scoreDiff = 0;
if (maxCommonSubevents > 0)
scoreDiff = (double) getCommonSubeventCount(instance, neighbor) / maxCommonSubevents;
scoreDiff += (double) semanticDistances.get(instance.id.eventId, neighbor.id.eventId) / maxSemanticDistance;
sum = scoreDiff/2;
}
//if (sum > 0) logger.info(instance + " " + sum);
scores.put(instance, sum);
}
return scores;
}
private Map<ContextNetwork.Instance, Double> structuralPropagation() {
Map<ContextNetwork.Instance, Double> scores = Maps.newHashMapWithExpectedSize(eventScoreTable.size());
double sum = 0;
for (ContextNetwork.Instance instance: eventScoreTable.keySet()) {
scores.put(instance, sum);
}
return scores;
}
private double computeDelta(Map<ContextNetwork.Instance, Double> eventScoreTable, Map<ContextNetwork.Instance, Double> newScore) {
double totalDiff = 0;
for (ContextNetwork.Instance instance: eventScoreTable.keySet()) {
totalDiff += Math.abs(eventScoreTable.get(instance) - newScore.get(instance));
}
return totalDiff;
}
public void show() {
logger.info(network.count() + "; " + network.nodeCount());
logger.info(semanticDistances.toString());
dispScores();
}
public void dispScores() {
if (eventScoreTable == null) return;
logger.info(" --- SCORES ---");
for (Map.Entry<ContextNetwork.Instance, Double> entry: eventScoreTable.entrySet()) {
if (entry.getValue() > 0) logger.info(entry.getKey() + " " + entry.getValue());
}
}
public List<Map.Entry<String,Double>> orderObjects() {
Set<Map.Entry<ContextNetwork.Instance,Double>> entrySet = eventScoreTable.entrySet();
PriorityQueue<Map.Entry<ContextNetwork.Instance,Double>> ballot =
new PriorityQueue<Map.Entry<ContextNetwork.Instance,Double>>(entrySet.size(), new Comparator<Map.Entry<ContextNetwork.Instance,Double>>() {
@Override
public int compare(Map.Entry<ContextNetwork.Instance,Double> o1, Map.Entry<ContextNetwork.Instance,Double> o2) {
if (o2.getValue() > o1.getValue()) return 1;
if (o2.getValue() < o1.getValue()) return -1;
return 0;
}
});
ballot.addAll(entrySet);
List<Map.Entry<ContextNetwork.Instance,Double>> sortedList = Lists.newArrayList();
while ( !ballot.isEmpty() ) {
Map.Entry<ContextNetwork.Instance,Double> entry = ballot.remove();
if (entry.getValue() > 0) sortedList.add(entry);
}
HashMap<String, Double> objectScoreTable = new HashMap<String, Double>();
for (ContextNetwork.IndexedSubeventTree tree: network.eventTrees)
for (ContextNetwork.Instance instance: tree.instanceMap.values())
for (ContextNetwork.Entity e: instance.participants)
objectScoreTable.put(e.id, 0D);
for (ContextNetwork.IndexedSubeventTree tree: network.eventTrees) {
for (ContextNetwork.Instance instance: tree.instanceMap.values()) {
double score = eventScoreTable.get(instance);
int size = instance.participants.size();
if (size == 0) continue;
for (ContextNetwork.Entity e: instance.participants) {
double s = objectScoreTable.get(e.id);
s += score/size;
objectScoreTable.put(e.id, s);
}
}
}
PriorityQueue<Map.Entry<String, Double>> objBallot =
new PriorityQueue<Map.Entry<String, Double>>(entrySet.size(), new Comparator<Map.Entry<String, Double>>() {
@Override
public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
if (o2.getValue() > o1.getValue()) return 1;
if (o2.getValue() < o1.getValue()) return -1;
return 0;
}
});
objBallot.addAll(objectScoreTable.entrySet());
List<Map.Entry<String, Double>> list = Lists.newArrayList();
while ( !objBallot.isEmpty() ) {
Map.Entry<String, Double> entry = objBallot.remove();
if (entry.getValue() > 0) list.add(entry);
}
return list;
}
public void printScores(int event_id, int instance_id) {
Map<ContextNetwork.Entity, Double> map = scoreTable.row(new ContextNetwork.Instance(event_id, instance_id));
PriorityQueue<Map.Entry<ContextNetwork.Entity, Double>> objBallot = new PriorityQueue<Map.Entry<ContextNetwork.Entity, Double>>(map.entrySet().size(),
new Comparator<Map.Entry<ContextNetwork.Entity, Double>>() {
@Override
public int compare(Map.Entry<ContextNetwork.Entity, Double> o1, Map.Entry<ContextNetwork.Entity, Double> o2) {
if (o2.getValue() > o1.getValue()) return 1;
if (o2.getValue() < o1.getValue()) return -1;
return 0;
}
});
objBallot.addAll(map.entrySet());
while ( !objBallot.isEmpty() ) {
Map.Entry<ContextNetwork.Entity, Double> entry = objBallot.remove();
if (entry.getValue() > 0) logger.info(entry.getKey() + "\t" + entry.getValue());
}
}
public int[] findObjectPositions(int event_id, int instance_id, List<String> objects) {
int[] objectsPositions = new int[objects.size()];
Map<ContextNetwork.Entity, Double> map = scoreTable.row(new ContextNetwork.Instance(event_id, instance_id));
PriorityQueue<Map.Entry<ContextNetwork.Entity, Double>> objBallot = new PriorityQueue<Map.Entry<ContextNetwork.Entity, Double>>(map.entrySet().size(),
new Comparator<Map.Entry<ContextNetwork.Entity, Double>>() {
@Override
public int compare(Map.Entry<ContextNetwork.Entity, Double> o1, Map.Entry<ContextNetwork.Entity, Double> o2) {
if (o2.getValue() > o1.getValue()) return 1;
if (o2.getValue() < o1.getValue()) return -1;
return 0;
}
});
objBallot.addAll(map.entrySet());
double[] precisions = new double[objBallot.size()];
double[] recalls = new double[objBallot.size()];
double precision, recall;
int ix = 0, jx = 0;
int PR = 0;
while ( !objBallot.isEmpty() ) {
Map.Entry<ContextNetwork.Entity, Double> entry = objBallot.remove();
//if (entry.getValue() > 0) logger.info(entry.getKey() + "\t" + entry.getValue());
int searchIx = objects.indexOf(entry.getKey().id);
if (searchIx >= 0)
objectsPositions[jx++] = ix;
ix++;
precision = (double) jx / ix;
recall = (double) jx / objects.size();
precisions[PR] = precision;
recalls[PR] = recall;
PR++;
}
logger.info("Precision" + instance_id + " <- c" + Arrays.toString(precisions));
logger.info("Recalls" + instance_id + " <- c" + Arrays.toString(recalls));
return objectsPositions;
}
}