package org.wikibrain.matrix; import gnu.trove.list.array.TIntArrayList; import gnu.trove.map.hash.TIntLongHashMap; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import java.io.*; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Iterator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class DenseMatrixWriter { public static final byte ROW_PADDING = Byte.MIN_VALUE; private static final Logger LOG = LoggerFactory.getLogger(DenseMatrixWriter.class); private File path; private TIntLongHashMap rowOffsets = new TIntLongHashMap(); private TIntArrayList rowIndexes = new TIntArrayList(); private File bodyPath; private BufferedOutputStream body; private long bodyOffset = 0; private ValueConf vconf; private int colIds[]; public DenseMatrixWriter(File path, ValueConf conf) throws IOException { this.path = path; this.vconf = conf; info("writing matrix to " + path); // write tmp matrix file this.bodyPath = File.createTempFile("matrix", null); this.bodyPath.deleteOnExit(); this.body = new BufferedOutputStream(new FileOutputStream(bodyPath)); info("writing body to tmp file at " + bodyPath); } public ValueConf getValueConf() { return vconf; } public synchronized void writeRow(DenseMatrixRow row) throws IOException { if (!row.getValueConf().almostEquals(vconf)) { throw new IllegalArgumentException("Value conf for row does not match the writer's value conf"); } if (colIds == null) { colIds = row.getColIds(); } if (!Arrays.equals(colIds, row.getColIds())) { throw new IllegalArgumentException("Column id mismatch for row " + row.getRowIndex()); } row.getBuffer().rewind(); byte[] bytes = new byte[row.getBuffer().remaining()]; row.getBuffer().get(bytes, 0, bytes.length); rowOffsets.put(row.getRowIndex(), bodyOffset); rowIndexes.add(row.getRowIndex()); body.write(bytes); bodyOffset += bytes.length; // pad rows to 8 byte offsets to speed things up. while (bodyOffset % 8 != 0) { bodyOffset++; body.write(ROW_PADDING); } } public void finish() throws IOException { body.close(); info("wrote " + bodyOffset + " bytes in body of matrix"); // write offset file info("generating header"); int sizeHeader = 16 + rowOffsets.size() * 12 + 4 + colIds.length * 4; body = new BufferedOutputStream(new FileOutputStream(path)); body.write(intToBytes(DenseMatrix.FILE_HEADER)); body.write(floatToBytes(vconf.minScore)); body.write(floatToBytes(vconf.maxScore)); body.write(intToBytes(rowOffsets.size())); body.write(intToBytes(colIds.length)); // Next write row indexes in sorted order (4 bytes per row) int sortedIndexes[] = rowIndexes.toArray(); Arrays.sort(sortedIndexes); for (int rowIndex : sortedIndexes) { body.write(intToBytes(rowIndex)); } // Next write offsets for sorted indexes. (8 bytes per row) for (int rowIndex : sortedIndexes) { long rowOffset = rowOffsets.get(rowIndex); body.write(longToBytes(rowOffset + sizeHeader)); } // Finally, write column ids for (int c : colIds) { body.write(intToBytes(c)); } InputStream r = new FileInputStream(bodyPath); // append other file IOUtils.copyLarge(r, body); r.close(); body.flush(); body.close(); info("wrote " + FileUtils.sizeOf(path) + " bytes to " + path); } private void info(String message) { LOG.info("dense matrix writer " + path + ": " + message); } public static void write(File file, Iterator<DenseMatrixRow> rows) throws IOException { write(file, rows, new ValueConf()); } public static void write(File file, Iterator<DenseMatrixRow> rows, ValueConf vconf) throws IOException { DenseMatrixWriter w = new DenseMatrixWriter(file, vconf); while (rows.hasNext()) { w.writeRow(rows.next()); } w.finish(); } private static byte[] intToBytes(int i) { return ByteBuffer.allocate(4).putInt(i).array(); } private static byte[] longToBytes(long i) { return ByteBuffer.allocate(8).putLong(i).array(); } private static byte[] floatToBytes(float f) { return ByteBuffer.allocate(4).putFloat(f).array(); } }