package doser.word2vec.semanticCategories;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.apache.log4j.Logger;
import org.jgrapht.Graph;
import org.jgrapht.UndirectedGraph;
import org.jgrapht.alg.DijkstraShortestPath;
import org.jgrapht.graph.AbstractBaseGraph;
import org.jgrapht.graph.ClassBasedEdgeFactory;
import org.jgrapht.graph.DefaultEdge;
import com.hp.hpl.jena.query.QueryException;
import com.hp.hpl.jena.query.QueryExecution;
import com.hp.hpl.jena.query.QueryExecutionFactory;
import com.hp.hpl.jena.query.QueryFactory;
import com.hp.hpl.jena.query.QuerySolution;
import com.hp.hpl.jena.query.ResultSet;
import com.hp.hpl.jena.rdf.model.Model;
import com.hp.hpl.jena.rdf.model.ModelFactory;
import com.hp.hpl.jena.rdf.model.Property;
import com.hp.hpl.jena.rdf.model.RDFNode;
import com.hp.hpl.jena.rdf.model.Resource;
import com.hp.hpl.jena.rdf.model.Statement;
import com.hp.hpl.jena.rdf.model.StmtIterator;
public class Sampling {
public static final String CATEGORYPURITY = "/home/zwicklbauer/word2vec/MSEDbpediaCategories_Min5.txt";
public static final int MAXIMUMSAMPLENR = 5000;
private Graph<String, DefaultEdge> graph;
private String[] catSet;
private HashSet<String> catHash;
private Model m;
private Random random;
public Sampling() {
super();
BufferedReader reader = null;
List<String> catList = new LinkedList<String>();
this.catHash = new HashSet<String>();
try {
reader = new BufferedReader(
new FileReader(new File(CATEGORYPURITY)));
String line = null;
while ((line = reader.readLine()) != null) {
String splitter[] = line.split("\t");
double score = Double.parseDouble(splitter[0]);
if (score < 0.033 && score > -2) {
catList.add(splitter[1]);
this.catHash.add(splitter[1]);
}
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
this.catSet = new String[catList.size()];
this.catSet = catList.toArray(this.catSet);
this.graph = createGraph();
this.m = ModelFactory.createDefaultModel();
this.m.read("/home/zwicklbauer/HDTGeneration/article_categories_en.nt");
this.random = new Random();
System.out.println(this.catHash.size());
}
public Map<Integer, HashSet<EntityPair>> generateCandidates(int maxDistance) {
ConcurrentMap<Integer, HashSet<EntityPair>> map = new ConcurrentHashMap<Integer, HashSet<EntityPair>>();
for (int i = 0; i < maxDistance; i++) {
int stepsize = i;
int counter = 0;
while (true) {
String randomCat = pickCategory();
String basicCat = randomCat;
String e1 = queryEntitiesFromCategory(randomCat);
if (e1 != null) {
boolean foundRelevantCategory = false;
while (!foundRelevantCategory) {
// Choose random category
randomCat = performRandomStep(randomCat);
if (catHash.contains(randomCat)) {
foundRelevantCategory = true;
}
}
if (randomCat != null) {
String e2 = queryEntitiesFromCategory(randomCat);
if (e2 != null) {
List path = DijkstraShortestPath.findPathBetween(
graph, basicCat, randomCat);
if (path.size() == stepsize
&& !e1.equalsIgnoreCase(e2)) {
if (map.containsKey(stepsize)) {
HashSet<EntityPair> set = map.get(stepsize);
set.add(new EntityPair(e1, e2, basicCat,
randomCat));
counter++;
} else {
HashSet<EntityPair> set = new HashSet<EntityPair>();
set.add(new EntityPair(e1, e2, basicCat,
randomCat));
map.put(stepsize, set);
counter++;
}
}
}
}
}
System.out.println(counter);
if (counter == MAXIMUMSAMPLENR) {
break;
}
}
}
return map;
}
private String performRandomStep(String current) {
if (current == null) {
return null;
}
String jumpstep = null;
Set<DefaultEdge> edges = graph.edgesOf(current);
int max = edges.size();
int ran = random.nextInt(max);
int counter = 0;
for (DefaultEdge e : edges) {
if (ran == counter) {
String source = graph.getEdgeSource(e);
String target = graph.getEdgeTarget(e);
if (source.equalsIgnoreCase(current)) {
jumpstep = target;
} else {
jumpstep = source;
}
break;
}
counter++;
}
return jumpstep;
}
private String pickCategory() {
int index = random.nextInt(catSet.length);
return catSet[index];
}
public UndirectedGraph<String, DefaultEdge> createGraph() {
Model model = ModelFactory.createDefaultModel();
model.read("/home/zwicklbauer/HDTGeneration/skos_categories_en.nt");
StmtIterator it = model.listStatements();
UndirectedGraph<String, DefaultEdge> graph = new MiGrafo();
Set<String> set = new HashSet<String>();
int counter = 0;
while (it.hasNext()) {
Statement s = it.next();
Resource r = s.getSubject();
Property p = s.getPredicate();
RDFNode n = s.getObject();
if (p.getURI().equalsIgnoreCase(
"http://www.w3.org/2004/02/skos/core#broader")
&& n.isResource()) {
set.add(r.getURI());
Resource target = n.asResource();
set.add(target.getURI());
if (!graph.containsVertex(r.getURI())) {
graph.addVertex(r.getURI());
}
if (!graph.containsVertex(target.getURI())) {
graph.addVertex(target.getURI());
}
graph.addEdge(r.getURI(), target.getURI());
if (counter % 10000 == 0) {
System.out.println(counter);
}
counter++;
}
}
return graph;
}
private String queryEntitiesFromCategory(final String catUri) {
String res = null;
final String query = "SELECT ?entities WHERE{ ?entities <http://purl.org/dc/terms/subject> <"
+ catUri + ">. }";
try {
final com.hp.hpl.jena.query.Query cquery = QueryFactory
.create(query);
final QueryExecution qexec = QueryExecutionFactory
.create(cquery, m);
final ResultSet results = qexec.execSelect();
List<String> entities = new LinkedList<String>();
while (results.hasNext()) {
final QuerySolution sol = results.nextSolution();
entities.add(sol.getResource("entities").getURI());
}
if (entities.size() != 0) {
int randomNr = this.random.nextInt(entities.size());
return entities.get(randomNr);
}
} catch (final QueryException e) {
Logger.getRootLogger().error(e.getStackTrace());
}
return res;
}
class MiGrafo extends AbstractBaseGraph<String, DefaultEdge> implements
UndirectedGraph<String, DefaultEdge> {
private static final long serialVersionUID = 1L;
MiGrafo() {
super(new ClassBasedEdgeFactory<String, DefaultEdge>(
DefaultEdge.class), true, true);
}
}
public static void main(String[] args) {
Sampling sampling = new Sampling();
Map<Integer, HashSet<EntityPair>> map = sampling.generateCandidates(4);
File file = new File("/home/zwicklbauer/samplingoutput.dat");
PrintWriter writer = null;
try {
writer = new PrintWriter(file);
for (Map.Entry<Integer, HashSet<EntityPair>> entry : map.entrySet()) {
Integer key = entry.getKey();
HashSet<EntityPair> value = entry.getValue();
for (EntityPair p : value) {
writer.println(String.valueOf(key) + "\t" + p.getEntity1()
+ "\t" + p.getEntity2() + "\t" + p.getCategory1()
+ "\t" + p.getCategory2());
}
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} finally {
if (writer != null) {
writer.close();
}
}
}
}