package org.wikibrain.spatial.cookbook.tflevaluate; import au.com.bytecode.opencsv.CSVWriter; import com.vividsolutions.jts.geom.Geometry; import com.vividsolutions.jts.geom.Point; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import org.wikibrain.core.WikiBrainException; import org.wikibrain.core.cmd.Env; import org.wikibrain.core.cmd.EnvBuilder; import org.wikibrain.core.dao.DaoException; import org.wikibrain.core.dao.LocalPageDao; import org.wikibrain.core.dao.UniversalPageDao; import org.wikibrain.core.lang.Language; import org.wikibrain.core.lang.LanguageSet; import org.wikibrain.core.model.Title; import org.wikibrain.core.model.UniversalPage; import org.wikibrain.spatial.dao.SpatialDataDao; import org.wikibrain.spatial.dao.SpatialNeighborDao; import org.wikibrain.sr.SRMetric; import org.wikibrain.sr.SRResult; import org.wikibrain.utils.ParallelForEach; import org.wikibrain.utils.Procedure; import java.io.FileWriter; import java.io.IOException; import java.util.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Created by toby on 5/17/14. */ public class KNNEvaluator { private static int WIKIDATA_CONCEPTS = 1; private static final Logger LOG = LoggerFactory.getLogger(KNNEvaluator.class); private Random random = new Random(); private final SpatialDataDao sdDao; private final LocalPageDao lpDao; private final UniversalPageDao upDao; private final SpatialNeighborDao snDao; private final List<Language> langs; private final Map<Language, SRMetric> metrics; private final DistanceMetrics distanceMetrics; private final List<UniversalPage> concepts = new ArrayList<UniversalPage>(); private final Map<Integer, Point> locations = new HashMap<Integer, Point>(); private final Env env; private CSVWriter output; private String layerName = "wikidata"; public KNNEvaluator(Env env, LanguageSet languages) throws ConfigurationException { this.env = env; //this.langs = new ArrayList<Language>(env.getLanguages().getLanguages()); langs = new ArrayList<Language>(); for(Language lang : languages.getLanguages()) langs.add(lang); // Get data access objects Configurator c = env.getConfigurator(); this.sdDao = c.get(SpatialDataDao.class); this.lpDao = c.get(LocalPageDao.class); this.upDao = c.get(UniversalPageDao.class); this.snDao = c.get(SpatialNeighborDao.class); this.distanceMetrics = new DistanceMetrics(env, c, snDao); // build SR metrics this.metrics = new HashMap<Language, SRMetric>(); for(Language lang : langs){ SRMetric m = c.get(SRMetric.class, "ensemble", "language", lang.getLangCode()); metrics.put(lang, m); } } public static <T> List<T> getRandomSubList(List<T> input, int subsetSize) { if(subsetSize > input.size()) subsetSize = input.size(); Random r = new Random(); int inputSize = input.size(); for (int i = 0; i < subsetSize; i++) { int indexToSwap = i + r.nextInt(inputSize - i); T temp = input.get(i); input.set(i, input.get(indexToSwap)); input.set(indexToSwap, temp); } return input.subList(0, subsetSize); } public static <T>T getRandomElement(List<T> input){ return getRandomSubList(input, 1).get(0); } private void writeHeader() throws IOException { String[] headerEntries = new String[5 + langs.size()]; headerEntries[0] = "ITEM_NAME_1"; headerEntries[1] = "ITEM_ID_1"; headerEntries[2] = "ITEM_NAME_2"; headerEntries[3] = "ITEM_ID_2"; headerEntries[4] = "KNN_DISTANCE"; int counter = 0; for (Language lang : langs) { headerEntries[5 + counter] = lang.getLangCode() + "_SR"; counter ++; } output.writeNext(headerEntries); output.flush(); } private void writeRow(UniversalPage c1, UniversalPage c2, Integer KNNDistance, List<SRResult> results) throws WikiBrainException, IOException { Title t1 = c1.getBestEnglishTitle(lpDao, true); Title t2 = c2.getBestEnglishTitle(lpDao, true); String[] rowEntries = new String[5 + langs.size()]; rowEntries[0] = t1.getCanonicalTitle(); rowEntries[1] = String.valueOf(c1.getUnivId()); rowEntries[2] = t2.getCanonicalTitle(); rowEntries[3] = String.valueOf(c2.getUnivId()); rowEntries[4] = String.valueOf(KNNDistance); int counter = 0; for (SRResult result : results) { rowEntries[5 + counter] = String.valueOf(result.getScore()); counter ++; } output.writeNext(rowEntries); output.flush(); //if(CSVRowCounter % 1000 == 0 // LOG.info("Finished writing to CSV Row " + CSVRowCounter); //} } /** * * @param originId Origins to start * @param k K as in "K-nearest neighbors" * @param limitPerLevel Number of samples to pick from each "K-nearest neighbors" to evaluate * @param limitBranch not used * @param maxDist Max distance (depth of search) * @param outputPath The path for the output CSV file * @throws DaoException * @throws IOException */ public void evaluate(Iterable<Integer> originId, final Integer k, final Integer limitPerLevel, Integer limitBranch, final Integer maxDist, String outputPath) throws DaoException, IOException{ //TODO: parallel this process...originId.size() should definitely be larger than the number of available threads this.output = new CSVWriter(new FileWriter(outputPath), ','); writeHeader(); ParallelForEach.iterate(originId.iterator(), new Procedure<Integer>() { @Override public void call(Integer arg) throws Exception { evaluateForOne(arg, sdDao.getGeometry(arg, layerName, "earth"), k, limitPerLevel, maxDist); } }); } //TODO: return only a limited number of pairs for each recursion public void evaluateForOne(Integer originId, Geometry originGeom, Integer k, Integer limitPerLevel, Integer maxDist) throws DaoException{ Integer CSVRowCounter = 0; Set<Integer> excludeIds = new HashSet<Integer>(); Map<Integer, Integer> evalResult = new HashMap<Integer, Integer>(); List<Integer> nodeToDiscover = new LinkedList<Integer>(); Map<Integer, Geometry> geometryMap = new HashMap<Integer, Geometry>(); evalResult.put(originId, 0); geometryMap.put(originId, originGeom); excludeIds.add(originId); nodeToDiscover.add(originId); int counter = 0; while(counter < maxDist){ counter ++; Integer nodeToExpand = getRandomElement(nodeToDiscover); Map<Integer, Geometry> thisLevel; //No need to lock as they are all Read-Read // synchronized (this){ thisLevel = snDao.getKNNeighbors(geometryMap.get(nodeToExpand), k, layerName, "earth", excludeIds); //} if(thisLevel == null || thisLevel.size() == 0) break; excludeIds.addAll(thisLevel.keySet()); List<Integer> nodesToPutInTheCSV = getRandomSubList(new ArrayList(thisLevel.keySet()), limitPerLevel); for(Integer i : nodesToPutInTheCSV){ evalResult.put(i, counter); } Integer nodeToAdd = getRandomElement(new ArrayList<Integer>(thisLevel.keySet())); nodeToDiscover.add(nodeToAdd); geometryMap.put(nodeToAdd, thisLevel.get(nodeToAdd)); } for(Integer x : evalResult.keySet()){ for(Integer y : evalResult.keySet()){ try { List<SRResult> results = new ArrayList<SRResult>(); //synchronized (this){ for (Language lang : langs) { SRMetric sr = metrics.get(lang); results.add(sr.similarity(upDao.getById(x).getLocalId(lang), upDao.getById(y).getLocalId(lang), false)); } writeRow(upDao.getById(x), upDao.getById(y), Math.abs(evalResult.get(x) - evalResult.get(y)), results); //} CSVRowCounter++; if(CSVRowCounter % 5000 == 0) LOG.info("Thread " + Thread.currentThread().getId() + " Now printing " + CSVRowCounter + " From " + x + " To " + y + " at level " + Math.abs(evalResult.get(x) - evalResult.get(y))); } catch (Exception e){ //do nothing } } } } /* public Map<Integer, Integer> evaluateRecursive(Integer originId, Geometry originGeom, Integer k, Integer limitPerLevel, Integer limitBranch, Integer maxDist) throws DaoException{ if (maxDist == 0){ return new HashMap<Integer, Integer>(); } if (maxDist < currentLevel){ currentLevel = maxDist; LOG.info("reached level " + currentLevel); } excludeIds.add(originId); Map<Integer, Geometry> thisLevel = snDao.getKNNeighbors(originGeom, k, layerName, "earth", excludeIds); excludeIds.addAll(thisLevel.keySet()); Map<Integer, Integer> thisLevelRes = new HashMap<Integer, Integer>(); if(limitBranch > thisLevel.size()) limitBranch = thisLevel.size(); List<Integer> candidateList = getRandomSubList(new LinkedList<Integer>(thisLevel.keySet()), limitBranch); for(Integer i : candidateList){ thisLevelRes.put(i, 1); Map<Integer, Integer> childLevelRes = evaluateRecursive(i, thisLevel.get(i), k, limitPerLevel, limitBranch, maxDist - 1); for(Integer q: childLevelRes.keySet()){ thisLevelRes.put(q, childLevelRes.get(q) + 1); } } UniversalPage originPage = upDao.getById(originId, WIKIDATA_CONCEPTS); for(Integer i: thisLevelRes.keySet()){ try { List<SRResult> results = new ArrayList<SRResult>(); for (Language lang : langs) { MonolingualSRMetric sr = metrics.get(lang); results.add(sr.similarity(originPage.getLocalId(lang), upDao.getById(i, WIKIDATA_CONCEPTS).getLocalId(lang), false)); } LOG.info("Now printing " + CSVRowCounter + " From " + originId + " To " + i + " at level " + maxDist); writeRow(originPage, upDao.getById(i, WIKIDATA_CONCEPTS), thisLevelRes.get(i), results); } catch (Exception e){ //do nothing } } if(limitPerLevel > thisLevelRes.size()) limitPerLevel = thisLevelRes.size(); List<Integer> returnList = getRandomSubList(new LinkedList<Integer>(thisLevelRes.keySet()), limitPerLevel); Map<Integer, Integer> returnMap = new HashMap<Integer, Integer>(); for(Integer i : returnList){ returnMap.put(i, thisLevelRes.get(i)); } return returnMap; } */ public static void main(String[] args) throws Exception { Env env = EnvBuilder.envFromArgs(args); Configurator conf = env.getConfigurator(); KNNEvaluator evaluator = new KNNEvaluator(env, new LanguageSet("simple")); SpatialDataDao sdDao = conf.get(SpatialDataDao.class); Set<Integer> originSet = new HashSet<Integer>(); originSet.add(36091);originSet.add(956);originSet.add(64);originSet.add(258);originSet.add(60);originSet.add(65);originSet.add(90);originSet.add(84);originSet.add(1490); evaluator.evaluate(originSet, 100, 5, 1, 30, "test-topo.csv"); } }