package edu.wiki.search; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.DataInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.StringReader; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.TermAttribute; import java.sql.PreparedStatement; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Map; import edu.wiki.api.concept.IConceptIterator; import edu.wiki.api.concept.IConceptVector; import edu.wiki.api.concept.scorer.CosineScorer; import edu.wiki.concept.ConceptVectorSimilarity; import edu.wiki.concept.TroveConceptVector; import edu.wiki.index.WikipediaAnalyzer; import edu.wiki.util.HeapSort; import gnu.trove.TIntFloatHashMap; import gnu.trove.TIntIntHashMap; /** * Performs search on the index located in database. * * @author Cagatay Calli <ccalli@gmail.com> */ public class ESASearcher { Connection connection; PreparedStatement pstmtQuery; PreparedStatement pstmtIdfQuery; PreparedStatement pstmtLinks; Statement stmtInlink; WikipediaAnalyzer analyzer; String strTermQuery = "SELECT t.vector FROM idx t WHERE t.term = ?"; String strIdfQuery = "SELECT t.idf FROM terms t WHERE t.term = ?"; String strMaxConcept = "SELECT MAX(id) FROM article"; String strInlinks = "SELECT i.target_id, i.inlink FROM inlinks i WHERE i.target_id IN "; String strLinks = "SELECT target_id FROM pagelinks WHERE source_id = ?"; int maxConceptId; int[] ids; double[] values; HashMap<String, Integer> freqMap = new HashMap<String, Integer>(30); HashMap<String, Double> tfidfMap = new HashMap<String, Double>(30); HashMap<String, Float> idfMap = new HashMap<String, Float>(30); ArrayList<String> termList = new ArrayList<String>(30); TIntIntHashMap inlinkMap; static float LINK_ALPHA = 0.5f; ConceptVectorSimilarity sim = new ConceptVectorSimilarity(new CosineScorer()); public void initDB() throws ClassNotFoundException, SQLException, IOException { // Load the JDBC driver String driverName = "com.mysql.jdbc.Driver"; // MySQL Connector Class.forName(driverName); // read DB config InputStream is = ESASearcher.class.getResourceAsStream("/config/db.conf"); BufferedReader br = new BufferedReader(new InputStreamReader(is)); String serverName = br.readLine(); String mydatabase = br.readLine(); String username = br.readLine(); String password = br.readLine(); br.close(); // Create a connection to the database String url = "jdbc:mysql://" + serverName + "/" + mydatabase; // a JDBC url connection = DriverManager.getConnection(url, username, password); pstmtQuery = connection.prepareStatement(strTermQuery); pstmtQuery.setFetchSize(1); pstmtIdfQuery = connection.prepareStatement(strIdfQuery); pstmtIdfQuery.setFetchSize(1); pstmtLinks = connection.prepareStatement(strLinks); pstmtLinks.setFetchSize(500); stmtInlink = connection.createStatement(); stmtInlink.setFetchSize(50); ResultSet res = connection.createStatement().executeQuery(strMaxConcept); res.next(); maxConceptId = res.getInt(1) + 1; } public void clean(){ freqMap.clear(); tfidfMap.clear(); idfMap.clear(); termList.clear(); inlinkMap.clear(); Arrays.fill(ids, 0); Arrays.fill(values, 0); } public ESASearcher() throws ClassNotFoundException, SQLException, IOException{ initDB(); analyzer = new WikipediaAnalyzer(); ids = new int[maxConceptId]; values = new double[maxConceptId]; inlinkMap = new TIntIntHashMap(300); } @Override protected void finalize() throws Throwable { connection.close(); super.finalize(); } /** * Retrieves full vector for regular features * @param query * @return Returns concept vector results exist, otherwise null * @throws IOException * @throws SQLException */ public IConceptVector getConceptVector(String query) throws IOException, SQLException{ String strTerm; int numTerms = 0; ResultSet rs; int doc; double score; int vint; double vdouble; double tf; double vsum; int plen; TokenStream ts = analyzer.tokenStream("contents",new StringReader(query)); ByteArrayInputStream bais; DataInputStream dis; this.clean(); for( int i=0; i<ids.length; i++ ) { ids[i] = i; } ts.reset(); while (ts.incrementToken()) { TermAttribute t = ts.getAttribute(TermAttribute.class); strTerm = t.term(); // record term IDF if(!idfMap.containsKey(strTerm)){ pstmtIdfQuery.setBytes(1, strTerm.getBytes("UTF-8")); pstmtIdfQuery.execute(); rs = pstmtIdfQuery.getResultSet(); if(rs.next()){ idfMap.put(strTerm, rs.getFloat(1)); } } // records term counts for TF if(freqMap.containsKey(strTerm)){ vint = freqMap.get(strTerm); freqMap.put(strTerm, vint+1); } else { freqMap.put(strTerm, 1); } termList.add(strTerm); numTerms++; } ts.end(); ts.close(); if(numTerms == 0){ return null; } // calculate TF-IDF vector (normalized) vsum = 0; for(String tk : idfMap.keySet()){ tf = 1.0 + Math.log(freqMap.get(tk)); vdouble = (idfMap.get(tk) * tf); tfidfMap.put(tk, vdouble); vsum += vdouble * vdouble; } vsum = Math.sqrt(vsum); // comment this out for canceling query normalization for(String tk : idfMap.keySet()){ vdouble = tfidfMap.get(tk); tfidfMap.put(tk, vdouble / vsum); } score = 0; for (String tk : termList) { pstmtQuery.setBytes(1, tk.getBytes("UTF-8")); pstmtQuery.execute(); rs = pstmtQuery.getResultSet(); if(rs.next()){ bais = new ByteArrayInputStream(rs.getBytes(1)); dis = new DataInputStream(bais); /** * 4 bytes: int - length of array * 4 byte (doc) - 8 byte (tfidf) pairs */ plen = dis.readInt(); // System.out.println("vector len: " + plen); for(int k = 0;k<plen;k++){ doc = dis.readInt(); score = dis.readFloat(); values[doc] += score * tfidfMap.get(tk); } bais.close(); dis.close(); } } // no result if(score == 0){ return null; } HeapSort.heapSort( values, ids ); IConceptVector newCv = new TroveConceptVector(ids.length); for( int i=ids.length-1; i>=0 && values[i] > 0; i-- ) { newCv.set( ids[i], values[i] / numTerms ); } return newCv; } /** * Returns trimmed form of concept vector * @param cv * @return */ public IConceptVector getNormalVector(IConceptVector cv, int LIMIT){ IConceptVector cv_normal = new TroveConceptVector( LIMIT); IConceptIterator it; if(cv == null) return null; it = cv.orderedIterator(); int count = 0; while(it.next()){ if(count >= LIMIT) break; cv_normal.set(it.getId(), it.getValue()); count++; } return cv_normal; } private TIntIntHashMap setInlinkCounts(Collection<Integer> ids) throws SQLException{ inlinkMap.clear(); String inPart = "("; for(int id: ids){ inPart += id + ","; } inPart = inPart.substring(0,inPart.length()-1) + ")"; // collect inlink counts ResultSet r = stmtInlink.executeQuery(strInlinks + inPart); while(r.next()){ inlinkMap.put(r.getInt(1), r.getInt(2)); } return inlinkMap; } private Collection<Integer> getLinks(int id) throws SQLException{ ArrayList<Integer> links = new ArrayList<Integer>(100); pstmtLinks.setInt(1, id); ResultSet r = pstmtLinks.executeQuery(); while(r.next()){ links.add(r.getInt(1)); } return links; } public IConceptVector getLinkVector(IConceptVector cv, int limit) throws SQLException { if(cv == null) return null; return getLinkVector(cv, true, LINK_ALPHA, limit); } /** * Computes secondary interpretation vector of regular features * @param cv * @param moreGeneral * @param ALPHA * @param LIMIT * @return * @throws SQLException */ public IConceptVector getLinkVector(IConceptVector cv, boolean moreGeneral, double ALPHA, int LIMIT) throws SQLException { IConceptIterator it; if(cv == null) return null; it = cv.orderedIterator(); int count = 0; ArrayList<Integer> pages = new ArrayList<Integer>(); TIntFloatHashMap valueMap2 = new TIntFloatHashMap(1000); TIntFloatHashMap valueMap3 = new TIntFloatHashMap(); ArrayList<Integer> npages = new ArrayList<Integer>(); HashMap<Integer, Float> secondMap = new HashMap<Integer, Float>(1000); this.clean(); // collect article objects while(it.next()){ pages.add(it.getId()); valueMap2.put(it.getId(),(float) it.getValue()); count++; } // prepare inlink counts setInlinkCounts(pages); for(int pid : pages){ Collection<Integer> raw_links = getLinks(pid); if(raw_links.isEmpty()){ continue; } ArrayList<Integer> links = new ArrayList<Integer>(raw_links.size()); final double inlink_factor_p = Math.log(inlinkMap.get(pid)); float origValue = valueMap2.get(pid); setInlinkCounts(raw_links); for(int lid : raw_links){ final double inlink_factor_link = Math.log(inlinkMap.get(lid)); // check concept generality.. if(inlink_factor_link - inlink_factor_p > 1){ links.add(lid); } } for(int lid : links){ if(!valueMap2.containsKey(lid)){ valueMap2.put(lid, 0.0f); npages.add(lid); } } float linkedValue = 0.0f; for(int lid : links){ if(valueMap3.containsKey(lid)){ linkedValue = valueMap3.get(lid); linkedValue += origValue; valueMap3.put(lid, linkedValue); } else { valueMap3.put(lid, origValue); } } } // for(int pid : pages){ // if(valueMap3.containsKey(pid)){ // secondMap.put(pid, (float) (valueMap2.get(pid) + ALPHA * valueMap3.get(pid))); // } // else { // secondMap.put(pid, (float) (valueMap2.get(pid) )); // } // } for(int pid : npages){ secondMap.put(pid, (float) (ALPHA * valueMap3.get(pid))); } //System.out.println("read links.."); ArrayList<Integer> keys = new ArrayList(secondMap.keySet()); //Sort keys by values. final Map langForComp = secondMap; Collections.sort(keys, new Comparator(){ public int compare(Object left, Object right){ Integer leftKey = (Integer)left; Integer rightKey = (Integer)right; Float leftValue = (Float)langForComp.get(leftKey); Float rightValue = (Float)langForComp.get(rightKey); return leftValue.compareTo(rightValue); } }); Collections.reverse(keys); IConceptVector cv_link = new TroveConceptVector(maxConceptId); int c = 0; for(int p : keys){ cv_link.set(p, secondMap.get(p)); c++; if(c >= LIMIT){ break; } } return cv_link; } public IConceptVector getCombinedVector(String query) throws IOException, SQLException{ IConceptVector cvBase = getConceptVector(query); IConceptVector cvNormal, cvLink; if(cvBase == null){ return null; } cvNormal = getNormalVector(cvBase,10); cvLink = getLinkVector(cvNormal,5); cvNormal.add(cvLink); return cvNormal; } /** * Calculate semantic relatedness between documents * @param doc1 * @param doc2 * @return returns relatedness if successful, -1 otherwise */ public double getRelatedness(String doc1, String doc2){ try { // IConceptVector c1 = getCombinedVector(doc1); // IConceptVector c2 = getCombinedVector(doc2); // IConceptVector c1 = getNormalVector(getConceptVector(doc1),10); // IConceptVector c2 = getNormalVector(getConceptVector(doc2),10); IConceptVector c1 = getConceptVector(doc1); IConceptVector c2 = getConceptVector(doc2); if(c1 == null || c2 == null){ // return 0; return -1; // undefined } final double rel = sim.calcSimilarity(c1, c2); // mark for dealloc c1 = null; c2 = null; return rel; } catch(Exception e){ e.printStackTrace(); return 0; } } }