package org.wikibrain.matrix; import org.apache.commons.lang3.ArrayUtils; import java.nio.ByteBuffer; import java.nio.IntBuffer; import java.nio.ShortBuffer; import java.util.LinkedHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * A single dense matrix row backed by a byte buffer. The row contains: * - a row id (int), * - a set of n columns, each with a value (float packed into two bytes) * * Since the matrix is dense, the row assumes that a single copy of column ids is * stored somewhere in the container matrix. * * The row can either be created from the component data, or from a byte buffer. * This means that the object can wrap data from an mmap'd file in the correct format. */ public final class DenseMatrixRow extends BaseMatrixRow implements MatrixRow { private static final Logger LOG = LoggerFactory.getLogger(DenseMatrixRow.class); public static final Float MIN_SCORE = -1.1f; public static final Float MAX_SCORE = 1.1f; public static final Float SCORE_RANGE = (MAX_SCORE - MIN_SCORE); public static final int PACKED_RANGE = (Short.MAX_VALUE - Short.MIN_VALUE); public static final int HEADER = 0xfefefefa; private final float c1; private final float c2; /** * The main "source" buffer. */ private ByteBuffer buffer; /** * A view buffer that points to the header. */ private IntBuffer headerBuffer; /** * A view buffer that points to the ids. */ private int[] colIds; /** * A view buffer that points to the values. */ private ShortBuffer valBuffer; private ValueConf vconf; public DenseMatrixRow(ValueConf vconf, int rowIndex, LinkedHashMap<Integer, Float> row) { this(vconf, rowIndex, ArrayUtils.toPrimitive(row.keySet().toArray(new Integer[] {})), ArrayUtils.toPrimitive(row.values().toArray(new Float[]{})) ); } public DenseMatrixRow(ValueConf vconf, int rowIndex, int colIds[], float colVals[]) { if (!SparseMatrixUtils.isIncreasing(colIds)) { throw new IllegalArgumentException("Columns must be sorted by id"); } this.vconf = vconf; this.c1 = vconf.c1; this.c2 = vconf.c2; this.colIds = colIds; short packed[] = new short[colVals.length]; for (int i = 0; i < colVals.length; i++) { packed[i] = vconf.pack(colVals[i]); } createBuffer(rowIndex, colIds, packed); } public void createBuffer(int rowIndex, int colIds[], short colVals[]) { assert(colIds.length == colVals.length); this.colIds = colIds; buffer = ByteBuffer.allocate( 4 + // header 4 + // row index 2 * colVals.length // col values ); createViewBuffers(colVals.length); headerBuffer.put(0, HEADER); headerBuffer.put(1, rowIndex); valBuffer.put(colVals, 0, colVals.length); } private void createViewBuffers(int numColumns) { buffer.position(0); headerBuffer = buffer.asIntBuffer(); buffer.position(2 * 4); valBuffer = buffer.asShortBuffer(); } /** * Wrap an existing byte buffer that contains a row. * @param colIds * @param buffer */ public DenseMatrixRow(ValueConf vconf, int colIds[], ByteBuffer buffer) { if (!SparseMatrixUtils.isIncreasing(colIds)) { throw new IllegalArgumentException("Columns must be sorted by id"); } this.vconf = vconf; this.colIds = colIds; this.buffer = buffer; this.c1 = vconf.c1; this.c2 = vconf.c2; if (this.buffer.getInt(0) != HEADER) { throw new IllegalArgumentException("Invalid header in byte buffer"); } createViewBuffers(buffer.getInt(8)); } public final double dot(float [] vector) { if (vector.length != colIds.length) throw new IllegalArgumentException(); double sum = 0.0; for (int i = 0; i < vector.length; i++) { sum += vector[i] * (c1 * valBuffer.get(i) + c2); } return sum; } public final double dot(DenseMatrixRow X) { double sum = 0.0; for (int i = 0; i < X.colIds.length; i++) { sum += (c1 * X.valBuffer.get(i) + c2) * (c1 * valBuffer.get(i) + c2); } return sum; } @Override public final int getColIndex(int i) { return colIds[i]; } @Override public final float getColValue(int i) { return vconf.unpack(valBuffer.get(i)); } public final short getPackedColValue(int i) { return valBuffer.get(i); } @Override public final int getRowIndex() { return headerBuffer.get(1); } @Override public final int getNumCols() { return colIds.length; } public ByteBuffer getBuffer() { return buffer; } public ValueConf getValueConf() { return vconf; } protected int[] getColIds() { return colIds; } public float[] getValues() { float vals[] = new float[colIds.length]; for (int i = 0; i < vals.length; i++) { vals[i] = getColValue(i); } return vals; } }