package esl.cuenet.ranking.rankers; import esl.cuenet.ranking.*; import esl.cuenet.ranking.network.NeoEntityBase; import esl.cuenet.ranking.network.OntProperties; import org.apache.log4j.Logger; import java.util.*; public class BasicRanker implements Ranker { private final EventEntityNetwork network; private final EntityBase entityBase; private static final double _THRESHOLD = 0.25; private HashMap<Long, Double> entityScoreMap = new HashMap<Long, Double>(100); private HashMap<Long, Double> eventScoreMap = new HashMap<Long, Double>(); private Logger logger = Logger.getLogger(BasicRanker.class); private HashMap<Long, Double> scoreUpdates = new HashMap<Long, Double>(); private Queue<Long> updatedQueue = new LinkedList<Long>(); private NodeEvaluator evaluator = new NodeEvaluator(); private SpatioTemporalIndex stIndex = null; public BasicRanker(EventEntityNetwork network, EntityBase entityBase) { this.network = network; this.entityBase = entityBase; this.stIndex = network.stIndex(SpatioTemporalIndex.OCCURS_DURING_IX); for (long id: this.entityBase) entityScoreMap.put(id, 0.0); Iterator<URINode> eventsIter = network.getEventsIterator(); while (eventsIter.hasNext()) eventScoreMap.put(eventsIter.next().getId(), 0.0); logger.info("Initialized BasicRanker with (" + entityScoreMap.size() + ", " + eventScoreMap.size() + ")."); } @Override public void assign(long nodeId, double score) { entityScoreMap.put(nodeId, score); updatedQueue.add(nodeId); } private int iters = 0; @Override public boolean canTerminate() { iters++; return (iters > 2); } @Override public void compute(PropagationFunction[] functions) { scoreUpdates.clear(); logger.info("UpdatedQueue contains " + updatedQueue.size() + " elements."); Queue<Long> tmp = new LinkedList<Long>(updatedQueue); for (long uId: tmp) { URINode updateNode = network.getNodeById(uId); if (evaluator.isEntity(updateNode)) propagateAlongEntity(updateNode, functions); if (evaluator.isEvent(updateNode)) propagateAloneEvent(updateNode, functions); } for (long id: scoreUpdates.keySet()) { if (eventScoreMap.containsKey(id)) eventScoreMap.put(id, eventScoreMap.get(id) + scoreUpdates.get(id)); if (entityScoreMap.containsKey(id)) entityScoreMap.put(id, entityScoreMap.get(id) + scoreUpdates.get(id)); //else logger.info("Ouch"); } int c = 0; updatedQueue.clear(); for (Map.Entry<Long, Double> entry: scoreUpdates.entrySet()) { if (entry.getValue() > 0) { updatedQueue.add(entry.getKey()); c++; } } logger.info("Touched: " + c + " nodes."); } private void propagateAloneEvent(URINode updateNode, PropagationFunction[] functions) { propagateToConnectedNodes(updateNode, functions); propagateToNeighboringEvents(updateNode, functions); } private void propagateToNeighboringEvents(URINode eventNode, PropagationFunction[] functions) { } private void propagateToConnectedNodes(URINode eventNode, PropagationFunction[] functions) { double score = eventScoreMap.get(eventNode.getId()); for (TypedEdge edge: eventNode.getAllRelationships()) { URINode otherNode = (edge.getStartNode() == eventNode) ? edge.getEndNode() : edge.getStartNode(); otherNode = unnest(otherNode); if (otherNode == null) continue; double val = propagate(eventNode, edge, otherNode, functions, score); updateScore(otherNode, val); } } private URINode unnest(URINode node) { for (TypedEdge edge: node.getAllRelationships()) { if ( !edge.hasProperty(OntProperties.TYPE) ) continue; if ( !OntProperties.IS_SAME_AS.equals(edge.getProperty(OntProperties.TYPE))) continue; return (edge.getStartNode() == node) ? edge.getEndNode() : edge.getStartNode(); } return null; } private void propagateAlongEntity(URINode entityNode, PropagationFunction[] functions) { double score = entityScoreMap.get(entityNode.getId()); for (TypedEdge edge: entityNode.getAllRelationships()) { if (edge.hasProperty(OntProperties.TYPE) && edge.getProperty(OntProperties.TYPE).equals(OntProperties.IS_SAME_AS)) { URINode entityAliasNode = (edge.getStartNode() == entityNode) ? edge.getEndNode() : edge.getStartNode(); for (TypedEdge aliasEdge: entityAliasNode.getAllRelationships()) { URINode otherNode = (aliasEdge.getStartNode() == entityAliasNode) ? aliasEdge.getEndNode() : aliasEdge.getStartNode(); double val = propagate(entityNode, aliasEdge, otherNode, functions, score); updateScore(otherNode, val); } } } } private void updateScore(URINode node, double update) { if (scoreUpdates.containsKey(node.getId())) update += scoreUpdates.get(node.getId()); scoreUpdates.put(node.getId(), update); } private double propagate(URINode startNode, TypedEdge edge, URINode endNode, PropagationFunction[] functions, double startNodeScore) { double ret = 0; for (PropagationFunction function: functions) { if (function.matchStartNode(startNode) && function.matchEdge(edge) && function.matchEndNode(endNode)) { ret += function.propagate(startNode, edge, endNode, startNodeScore); } } return ret; } @Override public Iterator<Map.Entry<URINode, Double>> results() { List<Map.Entry<Long, Double>> scores = new ArrayList<Map.Entry<Long, Double>>(entityScoreMap.entrySet()); Collections.sort(scores, new Comparator<Map.Entry<Long, Double>>() { @Override public int compare(Map.Entry<Long, Double> o1, Map.Entry<Long, Double> o2) { double diff = o1.getValue() - o2.getValue(); if (diff > 0) return -1; else if (diff < 0) return 1; return 0; } }); for (int i=0; i<10; i++) { logger.info(i + ". ID = " + scores.get(i).getKey() + "; SCORE = " + scores.get(i).getValue()); NeoEntityBase.printEntity(network.getNodeById(scores.get(i).getKey()), logger); } return null; } }