package org.wikibrain.core.dao.matrix;
import com.google.code.externalsorting.ExternalSort;
import com.typesafe.config.Config;
import gnu.trove.function.TDoubleFunction;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.TIntDoubleMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.procedure.TIntProcedure;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringEscapeUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.core.dao.DaoFilter;
import org.wikibrain.core.dao.LocalLinkDao;
import org.wikibrain.core.lang.Language;
import org.wikibrain.core.lang.LanguageSet;
import org.wikibrain.core.lang.LocalId;
import org.wikibrain.core.model.LocalLink;
import org.wikibrain.matrix.*;
import org.wikibrain.utils.*;
import java.io.*;
import java.nio.charset.Charset;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* <p>This class wraps a local link dao delegate and builds a fast, sparse, matrix
* and its transpose to speed up graph lookups.</p>
*
* <p>
* Three API calls are partially supported:
* 1. The three-argument version of getLinks()
* 2. get() if a) a language and b) either a src or dest is specified.
* 3. count() for the same requirements as 2.
* 4. PageRank values (beware that PageRank estimates are lazily calculated
* the first time a pagerank value is requested.)
* </p>
*
* <p>
* All other calls are delegated to the passed-in delegate.
* Note that this dao also loads the links into the delegate.
* </p>
*
* @author Shilad Sen
*/
public class MatrixLocalLinkDao implements LocalLinkDao {
private static final Logger LOG = LoggerFactory.getLogger(MatrixLocalLinkDao.class);
private final File dir;
private LocalLinkDao delegate;
private SparseMatrix matrix = null;
private SparseMatrix transpose = null;
private Map<Language, TIntDoubleMap> pageRanks = null;
private Set<File> allWriterFiles = Collections.newSetFromMap(new ConcurrentHashMap<File, Boolean>());
private Set<BufferedWriter> allWriters = Collections.newSetFromMap(
new ConcurrentHashMap<BufferedWriter, Boolean>());
private ThreadLocal<BufferedWriter> writers = new ThreadLocal<BufferedWriter>();
public MatrixLocalLinkDao(LocalLinkDao delegate, File dir) throws DaoException {
this.delegate = delegate;
this.dir = dir;
dir.mkdirs();
try {
load();
} catch (IOException e) {
throw new DaoException(e);
}
if (matrix == null && transpose == null) {
boolean delegateHasData = false;
try {
delegateHasData = delegate.get(new DaoFilter().setLimit(1)).iterator().hasNext();
} catch (Exception e) {
LOG.warn("Error occurred while trying to fetch links from " + delegate +
". Assuming it is empty and continuing.");
}
if (delegateHasData) {
LOG.warn("MatrixLocalLinkDao empty, but delegate is not. Attempting to rebuild...");
rebuild();
}
}
}
/**
* Rebuild the dao from the delegate.
* @throws DaoException
*/
public void rebuild() throws DaoException {
LocalLinkDao tmp = delegate;
this.delegate = null;
IOUtils.closeQuietly(matrix);
IOUtils.closeQuietly(transpose);
matrix = null;
transpose = null;
FileUtils.deleteQuietly(getMatrixFile());
FileUtils.deleteQuietly(getTransposeFile());
beginLoad();
ParallelForEach.iterate(tmp.get(new DaoFilter()).iterator(), new Procedure<LocalLink>() {
@Override
public void call(LocalLink ll) throws Exception {
save(ll);
}
});
endLoad();
this.delegate = tmp;
}
private void load() throws IOException {
if (!getMatrixFile().isFile()) {
LOG.warn("Matrix" + getMatrixFile()+ " missing, disabling fast lookups.");
} else if (!getTransposeFile().isFile()) {
LOG.warn("Matrix" + getTransposeFile()+ " missing, disabling fast lookups.");
} else {
matrix = new SparseMatrix(getMatrixFile());
transpose = new SparseMatrix(getTransposeFile());
}
if (getPageRanksFile().isFile() && getPageRanksFile().lastModified() > getMatrixFile().lastModified()) {
pageRanks = (Map<Language, TIntDoubleMap>) WpIOUtils.readObjectFromFile(getPageRanksFile());
}
}
@Override
public LocalLink getLink(Language language, int sourceId, int destId) throws DaoException {
return delegate.getLink(language, sourceId, destId);
}
@Override
public void beginLoad() throws DaoException {
allWriters.clear();
if (delegate != null) delegate.beginLoad();
// Initialize object database with existing links
if (matrix != null) {
ParallelForEach.iterate(matrix.iterator(), new Procedure<SparseMatrixRow>() {
@Override
public void call(SparseMatrixRow row) throws Exception {
BufferedWriter writer = getSortingWriter();
for (int i = 0; i < row.getNumCols(); i++) {
writer.write(row.getRowIndex() + " " + row.getColIndex(i) + "\n");
}
}
});
}
}
/**
* Calculates the PageRank associated with a particular page.
* Currently only implemented by the MatrixLocalLinkDao.
* PageRank estimation is performed lazily, so the first time this method is called
* will be very expensive, and future invocations will be cached.
*
* @param language
* @param pageId
* @return An estimate of the pageRank. The sum of PageRank values for all pages will
* approximately sum to 1.0.
*/
@Override
public double getPageRank(Language language, int pageId) {
if (pageRanks == null) {
synchronized (this) {
if (pageRanks == null) {
pageRanks = computePageRanks();
try {
WpIOUtils.writeObjectToFile(getPageRanksFile(), pageRanks);
} catch (IOException e) {
throw new IllegalStateException("Unexpected exception:", e);
}
}
}
}
TIntDoubleMap langRanks = pageRanks.get(language);
if (langRanks != null && langRanks.containsKey(pageId)) {
return langRanks.get(pageId);
} else {
return 0.0;
}
}
/**
* Calculates the PageRank associated with a particular page.
* Currently only implemented by the MatrixLocalLinkDao.
* PageRank estimation is performed lazily, so the first time this method is called
* will be very expensive, and future invocations will be cached.
*
* @param localId
* @return An estimate of the pageRank. The sum of PageRank values for all pages will
* approximately sum to 1.0.
*/
@Override
public double getPageRank(LocalId localId) {
return getPageRank(localId.getLanguage(), localId.getId());
}
private BufferedWriter getSortingWriter() throws IOException {
if (writers.get() == null) {
File file = File.createTempFile("links-sorter", ".txt");
file.deleteOnExit();
file.delete();
writers.set(WpIOUtils.openWriter(file));
allWriters.add(writers.get());
allWriterFiles.add(file);
}
return writers.get();
}
private static final double DAMPING_FACTOR = 0.85;
private static class LangRanks {
TIntDoubleMap pageSums = new TIntDoubleHashMap();
TIntDoubleMap nextSums = new TIntDoubleHashMap();
}
private Map<Language, TIntDoubleMap> computePageRanks() {
Map<Language, LangRanks> ranks = new HashMap<Language, LangRanks>();
// Set initial weights
for (SparseMatrixRow row : matrix) {
LocalId src = LocalId.fromInt(row.getRowIndex());
LangRanks lr = ranks.get(src.getLanguage());
if (lr == null) {
lr = new LangRanks();
ranks.put(src.getLanguage(), lr);
}
lr.pageSums.put(src.getId(), 1.0);
}
// normalize (divide by num pages)
for (final LangRanks lr : ranks.values()) {
final int n = lr.pageSums.size();
lr.pageSums.transformValues(new TDoubleFunction() {
@Override
public double execute(double v) {
return 1.0 / n;
}
});
}
// perform iterations
for (int i = 0;i < 20; i++) {
for (SparseMatrixRow row : matrix) {
int ncols = row.getNumCols();
if (ncols == 0) continue;
LocalId src = LocalId.fromInt(row.getRowIndex());
LangRanks lr = ranks.get(src.getLanguage());
double w = lr.pageSums.get(src.getId()) / ncols;
for (int j = 0; j < ncols; j++) {
LocalId dest = LocalId.fromInt(row.getColIndex(j));
if (dest.getLanguage() == src.getLanguage()) {
lr.nextSums.adjustOrPutValue(dest.getId(), w, w);
}
}
}
// update values, measure change
final double[] delta = {0.0};
for (final LangRanks lr : ranks.values()) {
final int n = lr.nextSums.size();
lr.nextSums.forEachKey(new TIntProcedure() {
@Override
public boolean execute(int id) {
double ps = lr.pageSums.get(id);
double ns = (1.0 - DAMPING_FACTOR) / n + DAMPING_FACTOR * lr.nextSums.get(id);
delta[0] += Math.abs(ps - ns);
lr.nextSums.put(id, 0);
lr.pageSums.put(id, ns);
return true;
}
});
}
LOG.info("change in pageranks at iteration {} is {}.", i, delta);
}
Map<Language, TIntDoubleMap> result = new HashMap<Language, TIntDoubleMap>();
for (Language lang : ranks.keySet()) {
result.put(lang, ranks.get(lang).pageSums);
}
return result;
}
@Override
public void save(LocalLink item) throws DaoException {
if (delegate != null) delegate.save(item);
// skip red links
if (item.getDestId() < 0 || item.getSourceId() < 0) {
return;
}
LocalId src = new LocalId(item.getLanguage(), item.getSourceId());
LocalId dest = new LocalId(item.getLanguage(), item.getDestId());
if (!src.canPackInInt() || !dest.canPackInInt()) {
return;
}
try {
BufferedWriter writer = getSortingWriter();
writer.write(src.toInt() + " " + dest.toInt() + "\n");
} catch (IOException e) {
throw new DaoException(e);
}
}
public File getMatrixFile() {
return new File(dir, "links.matrix");
}
public File getPageRanksFile() {
return new File(dir, "pageRanks.bin");
}
public File getTransposeFile() {
return new File(dir, "links-transpose.matrix");
}
@Override
public void clear() throws DaoException {
delegate.clear();
FileUtils.deleteQuietly(getMatrixFile());
FileUtils.deleteQuietly(getTransposeFile());
}
private static final int MAX_SORT_THREADS = 4;
private File sortFiles() throws IOException {
for (BufferedWriter writer : allWriters) {
writer.close();
}
ParallelForEach.iterate(allWriterFiles.iterator(), MAX_SORT_THREADS, 10, new Procedure<File>() {
@Override
public void call(File file) throws Exception {
sort(file, MAX_SORT_THREADS);
}
}, 10);
File file = File.createTempFile("local-links-sorted.", ".txt");
file.deleteOnExit();
Comparator<String> comparator = new Comparator<String>() {
public int compare(String r1, String r2){
return r1.compareTo(r2);}};
LOG.info("merging all sorted files to " + file);
ExternalSort.mergeSortedFiles(new ArrayList<File>(allWriterFiles), file, comparator, Charset.forName("utf-8"));
return file;
}
private static final int SORT_FILES_MAX = 100;
private static final long SORT_MEMORY_MAX = (Runtime.getRuntime().maxMemory() / MAX_SORT_THREADS / 5);
private void sort(File file, long concurrentThreads) throws IOException {
long maxMemory = SORT_MEMORY_MAX / concurrentThreads;
int maxFiles = (int) Math.max(
SORT_FILES_MAX / concurrentThreads,
(int)(file.length() / (maxMemory / 2)));
LOG.info("sorting " + file + " using max of " + maxFiles);
Comparator<String> comparator = new Comparator<String>() {
public int compare(String r1, String r2){
return r1.compareTo(r2);}};
List<File> l = ExternalSort.sortInBatch(
WpIOUtils.openBufferedReader(file),
file.length(),
comparator,
maxFiles,
maxMemory,
Charset.forName("utf-8"),
null,
true,
0,
false);
LOG.info("merging " + file);
ExternalSort.mergeSortedFiles(l, file, comparator, Charset.forName("utf-8"));
LOG.info("finished sorting" + file);
}
@Override
public void endLoad() throws DaoException {
if (delegate != null) delegate.endLoad();
try {
// close the old matrix and transpose
LOG.info("closing existing matrix and transpose.");
if (matrix != null) IOUtils.closeQuietly(matrix);
if (transpose != null) IOUtils.closeQuietly(transpose);
LOG.info("sorting files");
File file = sortFiles();
LOG.info("writing adjacency matrix rows");
ValueConf vconf = new ValueConf(); // unused because there are no values.
SparseMatrixWriter writer = new SparseMatrixWriter(getMatrixFile(), vconf);
BufferedReader reader = WpIOUtils.openBufferedReader(file);
TIntList packedDest = new TIntArrayList();
int cellCount = 0;
int rowCount = 0;
LocalId lastSrc = null;
while (true) {
String line = reader.readLine();
if (line == null) {
break;
}
String tokens[] = line.trim().split(" ");
if (tokens.length == 2){
cellCount++;
LocalId src = LocalId.fromInt(Integer.valueOf(tokens[0]));
LocalId dest = LocalId.fromInt(Integer.valueOf(tokens[1]));
if (lastSrc != null && !src.equals(lastSrc)) {
if (++rowCount % 100000 == 0) {
LOG.info("writing adjacency matrix row " + rowCount
+ ", found " + cellCount + " links");
}
SparseMatrixRow row = new SparseMatrixRow(
vconf,
lastSrc.toInt(),
packedDest.toArray(),
new short[packedDest.size()]
);
writer.writeRow(row);
packedDest.clear();
}
packedDest.add(dest.toInt());
lastSrc = src;
} else {
LOG.info("Invalid line: '" + StringEscapeUtils.escapeJava(line) + "'");
}
}
if (packedDest.size() > 0) {
SparseMatrixRow row = new SparseMatrixRow(
vconf,
lastSrc.toInt(),
packedDest.toArray(),
new short[packedDest.size()]
);
writer.writeRow(row);
}
LOG.info("finalizing adjacency matrix");
writer.finish();
LOG.info("loading adjacency matrix");
matrix = new SparseMatrix(getMatrixFile());
LOG.info("writing transpose of adjacency matrix");
SparseMatrixTransposer transposer = new SparseMatrixTransposer(matrix, getTransposeFile());
transposer.transpose();
LOG.info("loading transpose of adjacency matrix");
transpose = new SparseMatrix(getTransposeFile());
} catch (IOException e) {
throw new DaoException(e);
}
}
@Override
public Iterable<LocalLink> getLinks(Language language, int localId, boolean outlinks, boolean isParseable, LocalLink.LocationType locationType) throws DaoException {
return delegate.getLinks(language, localId, outlinks, isParseable, locationType);
}
@Override
public Iterable<LocalLink> getLinks(Language language, int localId, boolean outlinks) throws DaoException {
LocalId id = new LocalId(language, localId);
if (!id.canPackInInt()) {
return delegate.getLinks(language, localId, outlinks);
}
List<LocalLink> links = new ArrayList<LocalLink>();
try {
SparseMatrixRow row = outlinks ? matrix.getRow(id.toInt()) : transpose.getRow(id.toInt());
if (row == null) {
return links;
}
for (int i = 0; i < row.getNumCols(); i++) {
LocalId lid = LocalId.fromInt(row.getColIndex(i));
int srcId = outlinks ? localId : lid.getId();
int destId = outlinks ? lid.getId() : localId;
LocalLink ll = new LocalLink(
lid.getLanguage(),
null,
srcId,
destId,
outlinks,
0,
true,
LocalLink.LocationType.NONE
);
links.add(ll);
}
return links;
} catch (IOException e) {
throw new DaoException(e);
}
}
@Override
public Iterable<LocalLink> get(DaoFilter daoFilter) throws DaoException {
// there must be languages
if (daoFilter.getLangIds() == null) {
return delegate.get(daoFilter);
}
// either source ids or dest ids must be set
if (daoFilter.getSourceIds() == null && daoFilter.getDestIds() == null) {
return delegate.get(daoFilter);
}
// both must not be set
if (daoFilter.getSourceIds() != null && daoFilter.getDestIds() != null) {
return delegate.get(daoFilter);
}
// we don't handle location types
if (daoFilter.getLocTypes() != null || daoFilter.isParseable() != null) {
return delegate.get(daoFilter);
}
// collect link set
List<LocalLink> links = new ArrayList<LocalLink>();
int limit = daoFilter.getLimitOrInfinity();
if (daoFilter.getSourceIds() != null) {
for (int langId : daoFilter.getLangIds()) {
for (int srcId : daoFilter.getSourceIds()) {
for (LocalLink ll : getLinks(Language.getById(langId), srcId, true)) {
links.add(ll);
if (links.size() >= limit) break;
}
}
}
} else if (daoFilter.getDestIds() != null) {
for (int langId : daoFilter.getLangIds()) {
for (int destId : daoFilter.getDestIds()) {
for (LocalLink ll : getLinks(Language.getById(langId), destId, false)) {
links.add(ll);
if (links.size() >= limit) break;
}
}
}
}
return links;
}
@Override
public int getCount(DaoFilter daoFilter) throws DaoException {
// there must be languages
if (daoFilter.getLangIds() == null) {
return delegate.getCount(daoFilter);
}
// either source ids or dest ids must be set
if (daoFilter.getSourceIds() == null && daoFilter.getDestIds() == null) {
return delegate.getCount(daoFilter);
}
// both must not be set
if (daoFilter.getSourceIds() != null && daoFilter.getDestIds() != null) {
return delegate.getCount(daoFilter);
}
// we don't handle location types
if (daoFilter.getLocTypes() != null || daoFilter.isParseable() != null) {
return delegate.getCount(daoFilter);
}
// collect link count
try {
int count = 0;
if (daoFilter.getSourceIds() != null) {
List<Integer> packed = getPackedIds(daoFilter);
if (packed == null) {
return delegate.getCount(daoFilter);
}
for (int key : packed) {
SparseMatrixRow row = matrix.getRow(key);
count += (row == null) ? 0 : row.getNumCols();
}
} else if (daoFilter.getDestIds() != null) {
List<Integer> packed = getPackedIds(daoFilter);
if (packed == null) {
return delegate.getCount(daoFilter);
}
for (int key : packed) {
SparseMatrixRow row = transpose.getRow(key);
count += (row == null) ? 0 : row.getNumCols();
}
} else {
throw new IllegalArgumentException();
}
return count;
} catch (IOException e) {
throw new DaoException(e);
}
}
@Override
public LanguageSet getLoadedLanguages() throws DaoException {
return delegate.getLoadedLanguages();
}
public SparseMatrix getMatrix() {
return matrix;
}
public SparseMatrix getTranspose() {
return transpose;
}
private List<Integer> getPackedIds(DaoFilter filter) {
if (filter.getSourceIds() != null && filter.getDestIds() != null) {
throw new IllegalArgumentException();
}
Collection<Integer> ids = (filter.getSourceIds() != null)
? filter.getSourceIds() : filter.getDestIds();
if (ids == null) {
throw new IllegalArgumentException();
}
List<Integer> packed = new ArrayList<Integer>();
for (int langId : filter.getLangIds()) {
for (int id : ids) {
LocalId lid = new LocalId(Language.getById(langId), id);
if (!lid.canPackInInt()) {
return null;
}
packed.add(lid.toInt());
}
}
return packed;
}
public static class Provider extends org.wikibrain.conf.Provider<LocalLinkDao> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class<LocalLinkDao> getType() {
return LocalLinkDao.class;
}
@Override
public String getPath() {
return "dao.localLink";
}
@Override
public MatrixLocalLinkDao get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("matrix")) {
return null;
}
try {
return new MatrixLocalLinkDao(
getConfigurator().get(
LocalLinkDao.class,
config.getString("delegate")),
new File(config.getString("path"))
);
} catch (DaoException e) {
throw new ConfigurationException(e);
}
}
}
}