package org.wikipedia.miner.db; import gnu.trove.map.hash.THashMap; import gnu.trove.set.hash.TIntHashSet; import java.io.*; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.Iterator; import java.util.Map; import java.util.TreeMap; import org.apache.hadoop.record.CsvRecordInput; import org.apache.hadoop.record.CsvRecordOutput; import org.apache.log4j.Logger; import org.wikipedia.miner.db.struct.DbLabel; import org.wikipedia.miner.db.struct.DbSenseForLabel; import org.wikipedia.miner.util.ProgressTracker; import org.wikipedia.miner.util.WikipediaConfiguration; import org.wikipedia.miner.util.text.TextProcessor; import com.sleepycat.bind.tuple.StringBinding; import com.sleepycat.je.Database; import com.sleepycat.je.DatabaseEntry; /** * A {@link WDatabase} for associating Strings with statistics about the articles (senses) this string could refer to. */ public class LabelDatabase extends WDatabase<String, DbLabel> { //TODO: Labels cached as patricia trie // It would be extremely cool to cache labels in a Patricia Trie rather than a HashMap. // This would give us: // - iteration in lexicographical order // - prefix-based search (i.e. autocomplete queries) // - fast generation of labels within edit distance (i.e spelling correction) // - faster detection of labels within texts, because we wouldn't need to check every single n-gram. private TextProcessor textProcessor ; /** * Creates or connects to a database, whose name and type will be {@link WDatabase.DatabaseType#label}. * This will index label statistics according to their raw, unprocessed texts. * * @param env the WEnvironment surrounding this database */ public LabelDatabase(WEnvironment env) { super( env, DatabaseType.label, new StringBinding(), new RecordBinding<DbLabel>() { public DbLabel createRecordInstance() { return new DbLabel() ; } } ) ; textProcessor = null ; } /** * Creates or connects to a database, whose type will be {@link WDatabase.DatabaseType#label} and name will be * {@link WDatabase.DatabaseType#label} concatenated with {@link TextProcessor#getName()}. * This will index label statistics according to their texts, after processing with the given {@link TextProcessor} * * @param env the WEnvironment surrounding this database * @param tp a text processor to apply to texts before indexing */ public LabelDatabase(WEnvironment env, TextProcessor tp) { super( env, DatabaseType.label, "label" + tp.getName(), new StringBinding(), new RecordBinding<DbLabel>() { public DbLabel createRecordInstance() { return new DbLabel() ; } } ) ; textProcessor = tp ; } /** * Returns the text processor used to modify texts before they are used to index documents (may be null). * @return the text processor used to modify texts before they are used to index documents (may be null). */ public TextProcessor getTextProcessor() { return textProcessor ; } /** * If this database uses a text processor, you must prepare it (by calling {@link #prepare(File,int)} before use. * This returns true if that call has been made. * * @return true if the database has been prepared for use, otherwise false */ public boolean isPrepared() { return getDatabase(true) != null ; } /** * Retrieves the label statistics associated with the given text key. * * <p>Note:<b> you should NOT apply text processors to the key; that will be done internally within this method. * * @return true if the database has been prepared for use, otherwise false */ @Override public DbLabel retrieve(String key) { if (textProcessor == null) return super.retrieve(key) ; else return super.retrieve(textProcessor.processText(key)) ; } @Override public DbLabel filterCacheEntry(WEntry<String,DbLabel> e, WikipediaConfiguration conf) { TIntHashSet validIds = conf.getArticlesOfInterest() ; DbLabel label = e.getValue() ; if ((float)label.getLinkDocCount()/label.getTextDocCount() < conf.getMinLinkProbability()) return null ; ArrayList<DbSenseForLabel> newSenses = new ArrayList<DbSenseForLabel>() ; for(DbSenseForLabel sense:label.getSenses()) { if (validIds != null && !validIds.contains(sense.getId())) continue ; if (!sense.getFromRedirect() && !sense.getFromTitle()) { if ((float)sense.getLinkDocCount()/label.getLinkDocCount() < conf.getMinSenseProbability()) continue ; } newSenses.add(sense) ; } if (newSenses.size() == 0) return null ; label.setSenses(newSenses) ; return label ; } @Override public WEntry<String,DbLabel> deserialiseCsvRecord(CsvRecordInput record) throws IOException { String text = record.readString(null) ; DbLabel l = new DbLabel() ; l.deserialize(record) ; return new WEntry<String,DbLabel>(text, l) ; } /** * If this database uses a text processor, then you must prepare it before use. This involves copying all labels and statistics * from the original label database (the one with no text processor), re-indexing all entries, and merging statistics whose * processed texts collide with each other. * * This is done via an external sort, to avoid memory overflow. * * @param tempDir a directory for writing temporary files. Any files created will be deleted, but directories will not. * @param passes the number of passes to break the task into (more = slower, but less memory required) * @throws IOException if the temporary directory is not writable. */ public void prepare(File tempDir, int passes) throws IOException { if (textProcessor == null) return ; WDatabase<String,DbLabel> originalLabels = env.getDbLabel(null) ; long labelCount = originalLabels.getDatabaseSize() ; tempDir.mkdirs() ; ProgressTracker tracker = new ProgressTracker((2*passes)+1, LabelDatabase.class) ; for (int pass=0 ; pass<passes ; pass++) { tracker.startTask(labelCount, "Gathering and processing labels (pass " + (pass+1) + " of " + passes + ")") ; TreeMap<String, DbLabel> tmpProcessedLabels = new TreeMap<String, DbLabel>() ; WIterator<String,DbLabel> dbIter = originalLabels.getIterator() ; while (dbIter.hasNext()) { WEntry<String,DbLabel> e = dbIter.next(); String processedText = textProcessor.processText(e.getKey()) ; if (Math.abs(processedText.hashCode()) % passes == pass) { DbLabel storedLabel = tmpProcessedLabels.get(processedText) ; if (storedLabel == null) { tmpProcessedLabels.put(processedText, e.getValue()) ; } else { tmpProcessedLabels.put(processedText, mergeLabels(storedLabel, e.getValue())) ; } } tracker.update(); } dbIter.close(); //Dump gathered labels into temporary file tracker.startTask(tmpProcessedLabels.size(), "Dumping processed labels (pass " + (pass+1) + " of " + passes + ")") ; Iterator<Map.Entry<String,DbLabel>> mapIter = tmpProcessedLabels.entrySet().iterator() ; File tempFile = new File(tempDir.getPath() + File.separator + "tmpLabels" + pass + ".csv") ; tempFile.deleteOnExit() ; BufferedWriter writer = new BufferedWriter(new FileWriter(tempFile)) ; while (mapIter.hasNext()) { tracker.update(); Map.Entry<String, DbLabel> e = mapIter.next(); ByteArrayOutputStream outStream = new ByteArrayOutputStream() ; CsvRecordOutput cro = new CsvRecordOutput(outStream) ; cro.writeString(e.getKey(), null) ; e.getValue().serialize(cro) ; writer.write(outStream.toString("UTF-8")) ; } } Database db = getDatabase(false) ; long bytesToRead = 0 ; long bytesRead = 0 ; BufferedReader[] readers = new BufferedReader[passes] ; String[] currKeys = new String[passes] ; DbLabel[] currValues = new DbLabel[passes] ; String line ; File[] tempFiles = new File[passes] ; for (int pass=0 ; pass<passes ; pass++) { File tempFile = new File(tempDir.getPath() + File.separator + "tmpLabels" + pass + ".csv") ; tempFiles[pass] = tempFile ; bytesToRead = bytesToRead + tempFile.length() ; readers[pass] = new BufferedReader(new FileReader(tempFile)) ; if ((line=readers[pass].readLine()) != null) { bytesRead = bytesRead + line.length() + 1 ; //System.out.println(line) ; line = line + "\n" ; try { CsvRecordInput cri = new CsvRecordInput(new ByteArrayInputStream(line.getBytes("UTF8"))) ; currKeys[pass] = cri.readString(null) ; currValues[pass] = new DbLabel() ; currValues[pass].deserialize(cri) ; } catch (Exception e) { Logger.getLogger(LabelDatabase.class).error("Could not parse '" + line + "'") ; currKeys[pass] = null ; currValues[pass] = null ; } } else { currKeys[pass] = null ; currValues[pass] = null ; } } tracker.startTask(bytesToRead, "Storing processed labels") ; while (true) { String lowestKey = null ; int pass = -1 ; for (int i=0 ; i<passes ; i++) { if (currKeys[i] != null && (lowestKey == null || currKeys[i].compareTo(lowestKey) < 0)) { lowestKey = currKeys[i] ; pass = i ; } } if (pass < 0) { //all readers are finished break ; } else { //save lowestKey and associated value DatabaseEntry k = new DatabaseEntry() ; keyBinding.objectToEntry(lowestKey, k) ; DatabaseEntry v = new DatabaseEntry() ; valueBinding.objectToEntry(currValues[pass], v) ; db.put(null, k, v) ; //advance reader if ((line=readers[pass].readLine()) != null) { bytesRead = bytesRead + line.length() + 1 ; tracker.update(bytesRead) ; line = line + "\n" ; try { CsvRecordInput cri = new CsvRecordInput(new ByteArrayInputStream(line.getBytes("UTF8"))) ; currKeys[pass] = cri.readString(null) ; currValues[pass] = new DbLabel() ; currValues[pass].deserialize(cri) ; } catch (Exception e) { Logger.getLogger(LabelDatabase.class).error("Could not parse '" + line + "'") ; currKeys[pass] = null ; currValues[pass] = null ; } } else { currKeys[pass] = null ; currValues[pass] = null ; } } } for (BufferedReader r:readers) r.close(); for (File tempFile:tempFiles) tempFile.delete() ; env.cleanAndCheckpoint() ; getDatabase(true) ; } private DbLabel mergeLabels(DbLabel lblA, DbLabel lblB) { THashMap<Integer,DbSenseForLabel> senseHash = new THashMap<Integer,DbSenseForLabel>() ; if (lblA.getSenses() != null) { for (DbSenseForLabel s:lblA.getSenses()) senseHash.put(s.getId(), s) ; } if (lblB.getSenses() != null) { for (DbSenseForLabel s1:lblB.getSenses()) { DbSenseForLabel s2 = senseHash.get(s1.getId()) ; if (s2 == null) { senseHash.put(s1.getId(), s1) ; } else { DbSenseForLabel s3 = new DbSenseForLabel() ; s3.setId(s1.getId()) ; s3.setLinkDocCount(s1.getLinkDocCount() + s2.getLinkDocCount()) ; s3.setLinkOccCount(s1.getLinkOccCount() + s2.getLinkOccCount()) ; s3.setFromRedirect(s1.getFromRedirect() || s2.getFromRedirect()) ; s3.setFromTitle(s1.getFromTitle() || s2.getFromTitle()) ; senseHash.put(s3.getId(), s3) ; } } } ArrayList<DbSenseForLabel> mergedSenses = new ArrayList<DbSenseForLabel>() ; for (DbSenseForLabel s:senseHash.values()) mergedSenses.add(s) ; Collections.sort(mergedSenses, new Comparator<DbSenseForLabel>() { public int compare(DbSenseForLabel a, DbSenseForLabel b) { int cmp = new Long(b.getLinkOccCount()).compareTo(a.getLinkOccCount()) ; if (cmp != 0) return cmp ; cmp = new Long(b.getLinkDocCount()).compareTo(a.getLinkDocCount()) ; if (cmp != 0) return cmp ; return(new Integer(a.getId()).compareTo(b.getId())) ; } }) ; DbLabel mergedLabel = new DbLabel() ; mergedLabel.setLinkDocCount(lblA.getLinkDocCount() + lblB.getLinkDocCount()) ; mergedLabel.setLinkOccCount(lblA.getLinkOccCount() + lblB.getLinkOccCount()) ; mergedLabel.setTextDocCount(lblA.getTextDocCount() + lblB.getTextDocCount()) ; mergedLabel.setTextOccCount(lblA.getTextOccCount() + lblB.getTextOccCount()) ; mergedLabel.setSenses(mergedSenses) ; return mergedLabel ; } }