package edu.wiki.modify; import edu.wiki.util.HeapSort; import gnu.trove.TIntDoubleHashMap; import gnu.trove.TIntFloatHashMap; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStreamWriter; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.text.DecimalFormat; import java.util.HashMap; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermEnum; import org.apache.lucene.index.TermFreqVector; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; /** * Reads TF and IDF from the index and * writes cosine-normalized TF.IDF values to database. * * Normalization is performed as in Gabrilovich et al. (2009) * * Usage: IndexModifier <Lucene index location> * * @author Cagatay Calli <ccalli@gmail.com> * */ public class IndexModifier { static Connection connection = null; static Statement stmtLink; static PreparedStatement pstmtVector; // static String strLoadData = "LOAD DATA LOCAL INFILE 'vector.txt' INTO TABLE idx FIELDS ENCLOSED BY \"'\""; static String strVectorQuery = "INSERT INTO idx (term,vector) VALUES (?,?)"; static String strTermLoadData = "LOAD DATA LOCAL INFILE 'term.txt' INTO TABLE terms FIELDS ENCLOSED BY \"'\""; static String strAllInlinks = "SELECT target_id,inlink FROM inlinks"; static String strLimitQuery = "SELECT COUNT(id) FROM article;"; private static IndexReader reader = null; static int limitID; private static TIntDoubleHashMap inlinkMap; static int WINDOW_SIZE = 100; static float WINDOW_THRES= 0.005f; static DecimalFormat df = new DecimalFormat("#.########"); public static void initDB() throws ClassNotFoundException, SQLException, IOException { // Load the JDBC driver String driverName = "com.mysql.jdbc.Driver"; // MySQL Connector Class.forName(driverName); // read DB config InputStream is = IndexModifier.class.getResourceAsStream("/config/db.conf"); BufferedReader br = new BufferedReader(new InputStreamReader(is)); String serverName = br.readLine(); String mydatabase = br.readLine(); String username = br.readLine(); String password = br.readLine(); br.close(); // Create a connection to the database String url = "jdbc:mysql://" + serverName + "/" + mydatabase + "?useUnicode=yes&characterEncoding=UTF-8"; // a JDBC url connection = DriverManager.getConnection(url, username, password); stmtLink = connection.createStatement(); stmtLink.setFetchSize(200); stmtLink.execute("DROP TABLE IF EXISTS idx"); stmtLink.execute("CREATE TABLE idx (" + "term VARBINARY(255)," + "vector MEDIUMBLOB " + ") DEFAULT CHARSET=binary"); stmtLink.execute("DROP TABLE IF EXISTS terms"); stmtLink.execute("CREATE TABLE terms (" + "term VARBINARY(255)," + "idf FLOAT " + ") DEFAULT CHARSET=binary"); stmtLink = connection.createStatement(); ResultSet res = stmtLink.executeQuery(strLimitQuery); res.next(); limitID = res.getInt(1); // read inlink counts inlinkMap = new TIntDoubleHashMap(limitID); int targetID, numInlinks; res = stmtLink.executeQuery(strAllInlinks); while(res.next()){ targetID = res.getInt(1); numInlinks = res.getInt(2); inlinkMap.put(targetID, Math.log(1+Math.log(1+numInlinks))); } pstmtVector = connection.prepareStatement(strVectorQuery); } /** * @param args * @throws IOException * @throws SQLException * @throws ClassNotFoundException * @throws NoSuchAlgorithmException */ public static void main(String[] args) throws IOException, ClassNotFoundException, SQLException { try { Directory fsdir = FSDirectory.open(new File(args[0])); reader = IndexReader.open(fsdir,true); } catch (Exception ex) { System.out.println("Cannot create index..." + ex.getMessage()); System.exit(-1); } initDB(); long sTime, eTime; sTime = System.currentTimeMillis(); int maxid = reader.maxDoc(); TermFreqVector tv; String[] terms; String term = ""; Term t; int tfreq = 0; float idf; float tf; float tfidf; double inlinkBoost; double sum; int wikiID; int hashInt; int numDocs = reader.numDocs(); TermEnum tnum = reader.terms(); HashMap<String, Float> idfMap = new HashMap<String, Float>(500000); HashMap<String, Float> tfidfMap = new HashMap<String, Float>(5000); HashMap<String, Integer> termHash = new HashMap<String, Integer>(500000); FileOutputStream fos = new FileOutputStream("vector.txt"); OutputStreamWriter osw = new OutputStreamWriter(fos,"UTF-8"); tnum = reader.terms(); hashInt = 0; while(tnum.next()){ t = tnum.term(); term = t.text(); tfreq = tnum.docFreq(); // get DF for the term // skip rare terms if(tfreq < 3){ continue; } // idf = (float)(Math.log(numDocs/(double)(tfreq+1)) + 1.0); idf = (float)(Math.log(numDocs/(double)(tfreq))); // idf = (float)(Math.log(numDocs/(double)(tfreq)) / Math.log(2)); idfMap.put(term, idf); termHash.put(term, hashInt++); } for(int i=0;i<maxid;i++){ if(!reader.isDeleted(i)){ //System.out.println(i); wikiID = Integer.valueOf(reader.document(i).getField("id").stringValue()); inlinkBoost = inlinkMap.get(wikiID); tv = reader.getTermFreqVector(i, "contents"); try { terms = tv.getTerms(); int[] fq = tv.getTermFrequencies(); sum = 0.0; tfidfMap.clear(); // for all terms of a document for(int k=0;k<terms.length;k++){ term = terms[k]; if(!idfMap.containsKey(term)) continue; tf = (float) (1.0 + Math.log(fq[k])); // tf = (float) (1.0 + Math.log(fq[k]) / Math.log(2)); idf = idfMap.get(term); tfidf = (float) (tf * idf); tfidfMap.put(term, tfidf); sum += tfidf * tfidf; } sum = Math.sqrt(sum); // for all terms of a document for(int k=0;k<terms.length;k++){ term = terms[k]; if(!idfMap.containsKey(term)) continue; tfidf = (float) (tfidfMap.get(term) / sum * inlinkBoost); // System.out.println(i + ": " + term + " " + fq[k] + " " + tfidf); // ++++ record to DB (term,doc,tfidf) +++++ osw.write(termHash.get(term) + "\t" + term + "\t" + wikiID + "\t" + df.format(tfidf) + "\n"); } } catch(Exception e){ e.printStackTrace(); System.out.println("ERR: " + wikiID + " " + tv); continue; } } } osw.close(); fos.close(); // sort tfidf entries according to terms String[] cmd = {"/bin/sh", "-c", "sort -S 1200M -n -t\\\t -k1 < vector.txt > vsorted.txt"}; Process p1 = Runtime.getRuntime().exec(cmd); try { int exitV = p1.waitFor(); if(exitV != 0){ System.exit(1); } } catch (InterruptedException e) { e.printStackTrace(); System.exit(1); } // delete unsorted doc-score file p1 = Runtime.getRuntime().exec("rm vector.txt"); try { int exitV = p1.waitFor(); if(exitV != 0){ System.exit(1); } } catch (InterruptedException e) { e.printStackTrace(); System.exit(1); } FileInputStream fis = new FileInputStream("vsorted.txt"); InputStreamReader isr = new InputStreamReader(fis,"UTF-8"); BufferedReader bir = new BufferedReader(isr); String line; String prevTerm = null; int doc; float score; TIntFloatHashMap hmap = new TIntFloatHashMap(100); // for pruning int mark, windowMark; float first = 0, last = 0, highest = 0; float [] window = new float[WINDOW_SIZE]; while((line = bir.readLine()) != null){ final String [] parts = line.split("\t"); term = parts[1]; // prune and write the vector if(prevTerm != null && !prevTerm.equals(term)){ int [] arrDocs = hmap.keys(); float [] arrScores = hmap.getValues(); HeapSort.heapSort(arrScores, arrDocs); // prune the vector mark = 0; windowMark = 0; highest = first = last = 0; ByteArrayOutputStream baos = new ByteArrayOutputStream(50000); DataOutputStream tdos = new DataOutputStream(baos); for(int j=arrDocs.length-1;j>=0;j--){ score = arrScores[j]; // sliding window window[windowMark] = score; if(mark == 0){ highest = score; first = score; } if(mark < WINDOW_SIZE){ tdos.writeInt(arrDocs[j]); tdos.writeFloat(score); } else if( highest*WINDOW_THRES < (first - last) ){ tdos.writeInt(arrDocs[j]); tdos.writeFloat(score); if(windowMark < WINDOW_SIZE-1){ first = window[windowMark+1]; } else { first = window[0]; } } else { // truncate break; } last = score; mark++; windowMark++; windowMark = windowMark % WINDOW_SIZE; } ByteArrayOutputStream dbvector = new ByteArrayOutputStream(); DataOutputStream dbdis = new DataOutputStream(dbvector); dbdis.writeInt(mark); dbdis.flush(); dbvector.write(baos.toByteArray()); dbvector.flush(); dbdis.close(); // write to DB pstmtVector.setString(1, prevTerm); pstmtVector.setBlob(2, new ByteArrayInputStream(dbvector.toByteArray())); pstmtVector.execute(); tdos.close(); baos.close(); hmap.clear(); } doc = Integer.valueOf(parts[2]); score = Float.valueOf(parts[3]); hmap.put(doc, score); prevTerm = term; } bir.close(); isr.close(); fis.close(); // record term IDFs FileOutputStream tos = new FileOutputStream("term.txt"); OutputStreamWriter tsw = new OutputStreamWriter(tos,"UTF-8"); for(String tk : idfMap.keySet()){ tsw.write("'" + tk.replace("\\","\\\\").replace("'","\\'") + "'\t"+idfMap.get(tk)+"\n"); } osw.close(); tsw.close(); stmtLink.execute(strTermLoadData); stmtLink.execute("CREATE INDEX idx_term ON terms (term(32))"); eTime = System.currentTimeMillis(); System.out.println("Total TIME (sec): "+ (eTime-sTime)/1000.0); reader.close(); connection.close(); } }