package edu.stanford.nlp.loglinear.model; import edu.stanford.nlp.loglinear.model.proto.ConcatVectorTableProto; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Arrays; import java.util.Iterator; import java.util.function.Supplier; /** * Created on 8/9/15. * @author keenon * <p> * This is basically a type specific wrapper over NDArray */ public class ConcatVectorTable extends NDArray<Supplier<ConcatVector>> { /** * Constructor takes a list of neighbor variables to use for this factor. This must not change after construction, * and the number of states of those variables must also not change. * * @param dimensions list of neighbor variables assignment range sizes */ public ConcatVectorTable(int[] dimensions) { super(dimensions); } /** * Convenience function to write this factor directly to a stream, encoded as proto. Reversible with readFromStream. * * @param stream the stream to write to. does not flush automatically * @throws IOException passed through from the stream */ public void writeToStream(OutputStream stream) throws IOException { getProtoBuilder().build().writeTo(stream); } /** * Convenience function to read a factor (assumed serialized with proto) directly from a stream. * * @param stream the stream to be read from * @return a new in-memory feature factor * @throws IOException passed through from the stream */ public static ConcatVectorTable readFromStream(InputStream stream) throws IOException { return readFromProto(ConcatVectorTableProto.ConcatVectorTable.parseFrom(stream)); } /** * Returns the proto builder object for this feature factor. Recursively constructs protos for all the concat * vectors in factorTable. * * @return proto Builder object */ public ConcatVectorTableProto.ConcatVectorTable.Builder getProtoBuilder() { ConcatVectorTableProto.ConcatVectorTable.Builder b = ConcatVectorTableProto.ConcatVectorTable.newBuilder(); for (int n : getDimensions()) { b.addDimensionSize(n); } for (int[] assignment : this) { b.addFactorTable(getAssignmentValue(assignment).get().getProtoBuilder()); } return b; } /** * Creates a new in-memory feature factor from a proto serialization, * * @param proto the proto object to be turned into an in-memory feature factor * @return an in-memory feature factor, complete with in-memory concat vectors */ public static ConcatVectorTable readFromProto(ConcatVectorTableProto.ConcatVectorTable proto) { int[] neighborSizes = new int[proto.getDimensionSizeCount()]; for (int i = 0; i < neighborSizes.length; i++) { neighborSizes[i] = proto.getDimensionSize(i); } ConcatVectorTable factor = new ConcatVectorTable(neighborSizes); int i = 0; for (int[] assignment : factor) { final ConcatVector vector = ConcatVector.readFromProto(proto.getFactorTable(i)); factor.setAssignmentValue(assignment, () -> vector); i++; } return factor; } /** * Deep comparison for equality of value, plus tolerance, for every concatvector in the table, plus dimensional * arrangement. This is mostly useful for testing. * * @param other the vector table to compare against * @param tolerance the tolerance to use in value comparisons * @return whether the two tables are equivalent by value */ public boolean valueEquals(ConcatVectorTable other, double tolerance) { if (!Arrays.equals(other.getDimensions(), getDimensions())) return false; for (int[] assignment : this) { if (!getAssignmentValue(assignment).get().valueEquals(other.getAssignmentValue(assignment).get(), tolerance)) { return false; } } return true; } NDArray<Supplier<ConcatVector>> originalThunks = null; /** * This is an optimization that will fault all the ConcatVectors into memory, and future .get() on the Supplier objs * will result in a very fast return by reference. Basically this works by wrapping the output of the old thunks * inside new, thinner closures that carry around the answer in memory. This is a no-op if vectors were already * cached. */ public void cacheVectors() { if (originalThunks != null) return; originalThunks = new NDArray<>(getDimensions()); // OPTIMIZATION: // Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd, // we use the fast version that just mutates one array. Since this is read once for us here, this is ideal. Iterator<int[]> fastPassByReferenceIterator = fastPassByReferenceIterator(); int[] assignment = fastPassByReferenceIterator.next(); while (true) { Supplier<ConcatVector> originalThunk = getAssignmentValue(assignment); originalThunks.setAssignmentValue(assignment, originalThunk); // Construct a new, thinner closure around the cached value ConcatVector result = originalThunk.get(); setAssignmentValue(assignment, () -> result); // Set the assignment arrays correctly if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next(); else break; } } /** * This will release references to the cached ConcatVectors created by cacheVectors(), so that they can be cleaned * up by the GC. If no cache was constructed, this is a no-op. */ public void releaseCache() { if (originalThunks != null) { // OPTIMIZATION: // Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd, // we use the fast version that just mutates one array. Since this is read once for us here, this is ideal. Iterator<int[]> fastPassByReferenceIterator = fastPassByReferenceIterator(); int[] assignment = fastPassByReferenceIterator.next(); while (true) { setAssignmentValue(assignment, originalThunks.getAssignmentValue(assignment)); // Set the assignment arrays correctly if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next(); else break; } // Release our replicated set of original thunks originalThunks = null; } } /** * Clones the table, but keeps the values by reference. * * @return a new NDArray, a perfect replica of this one */ public ConcatVectorTable cloneTable() { ConcatVectorTable copy = new ConcatVectorTable(getDimensions().clone()); // OPTIMIZATION: // Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which need to be GC'd, // we use the fast version that just mutates one array. Since this is read once for us here, this is ideal. Iterator<int[]> fastPassByReferenceIterator = fastPassByReferenceIterator(); int[] assignment = fastPassByReferenceIterator.next(); while (true) { copy.setAssignmentValue(assignment, getAssignmentValue(assignment)); // Set the assignment arrays correctly if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next(); else break; } return copy; } }