package edu.stanford.nlp.loglinear.model; import edu.stanford.nlp.loglinear.model.proto.GraphicalModelProto; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.*; import java.util.function.Function; /** * Created on 8/7/15. * @author keenon * <p> * A basic graphical model representation: Factors and Variables. This should be a fairly familiar interface to anybody * who's taken a basic PGM course (eg https://www.coursera.org/course/pgm). The key points: * - Stitching together feature factors * - Attaching metadata to everything, so that different sections of the program can communicate in lots of unplanned * ways. For now, the planned meta-data is a lot of routing and status information to do with LENSE. * <p> * This is really just the data structure, and inference lives elsewhere and must use public interfaces to access these * models. We just provide basic utility functions here, and barely do that, because we pass through directly to maps * wherever appropriate. */ public class GraphicalModel { public Map<String, String> modelMetaData = new HashMap<>(); public List<Map<String, String>> variableMetaData = new ArrayList<>(); public Set<Factor> factors = new HashSet<>(); /** * A single factor in this graphical model. ConcatVectorTable can be reused multiple times if the same graph (or different * ones) and this is the glue object that tells a model where the factor lives, and what it is connected to. */ public static class Factor { public ConcatVectorTable featuresTable; public int[] neigborIndices; public Map<String, String> metaData = new HashMap<>(); /** * DO NOT USE. FOR SERIALIZATION ONLY. */ private Factor() { } public Factor(ConcatVectorTable featuresTable, int[] neighborIndices) { this.featuresTable = featuresTable; this.neigborIndices = neighborIndices; } /** * @return the factor meta-data, by reference */ public Map<String, String> getMetaDataByReference() { return metaData; } /** * Does a deep comparison, using equality with tolerance checks against the vector table of values. * * @param other the factor to compare to * @param tolerance the tolerance to accept in differences * @return whether the two factors are within tolerance of one another */ public boolean valueEquals(Factor other, double tolerance) { return Arrays.equals(neigborIndices, other.neigborIndices) && metaData.equals(other.metaData) && featuresTable.valueEquals(other.featuresTable, tolerance); } public GraphicalModelProto.Factor.Builder getProtoBuilder() { GraphicalModelProto.Factor.Builder builder = GraphicalModelProto.Factor.newBuilder(); for (int neighbor : neigborIndices) { builder.addNeighbor(neighbor); } builder.setFeaturesTable(featuresTable.getProtoBuilder()); builder.setMetaData(GraphicalModel.getProtoMetaDataBuilder(metaData)); return builder; } public static Factor readFromProto(GraphicalModelProto.Factor proto) { Factor factor = new Factor(); factor.featuresTable = ConcatVectorTable.readFromProto(proto.getFeaturesTable()); factor.metaData = GraphicalModel.readMetaDataFromProto(proto.getMetaData()); factor.neigborIndices = new int[proto.getNeighborCount()]; for (int i = 0; i < factor.neigborIndices.length; i++) { factor.neigborIndices[i] = proto.getNeighbor(i); } return factor; } /** * Duplicates this factor. * * @return a copy of the factor */ public Factor cloneFactor() { Factor clone = new Factor(); clone.neigborIndices = neigborIndices.clone(); clone.featuresTable = featuresTable.cloneTable(); clone.metaData.putAll(metaData); return clone; } } /** * @return a reference to the model meta-data */ public Map<String, String> getModelMetaDataByReference() { return modelMetaData; } /** * Gets the metadata for a variable. Creates blank metadata if does not exists, then returns that. Pass by reference. * * @param variableIndex the variable number, 0 indexed, to retrieve * @return the metadata map corresponding to that variable number */ public synchronized Map<String, String> getVariableMetaDataByReference(int variableIndex) { while (variableIndex >= variableMetaData.size()) { variableMetaData.add(new HashMap<>()); } return variableMetaData.get(variableIndex); } /** * This is the preferred way to add factors to a graphical model. Specify the neighbors, their dimensions, and a * function that maps from variable assignments to ConcatVector's of features, and this function will handle the * data flow of constructing and populating a factor matching those specifications. * <p> * IMPORTANT: assignmentFeaturizer must be REPEATABLE and NOT HAVE SIDE EFFECTS * This is because it is actually stored as a lazy closure until the full featurized vector is needed, and then it * is created, used, and discarded. It CAN BE CALLED MULTIPLE TIMES, and must always return the same value in order * for behavior of downstream systems to be defined. * * @param neighborIndices the names of the variables, as indices * @param neighborDimensions the sizes of the neighbor variables, corresponding to the order in neighborIndices * @param assignmentFeaturizer a function that maps from an assignment to the variables, represented as an array of * assignments in the same order as presented in neighborIndices, to a ConcatVector of * features for that assignment. * @return a reference to the created factor. This can be safely ignored, as the factor is already saved in the model */ public Factor addFactor(int[] neighborIndices, int[] neighborDimensions, Function<int[], ConcatVector> assignmentFeaturizer) { ConcatVectorTable features = new ConcatVectorTable(neighborDimensions); for (int[] assignment : features) { features.setAssignmentValue(assignment, () -> assignmentFeaturizer.apply(assignment)); } return addFactor(features, neighborIndices); } /** * Creates an instantiated factor in this graph, with neighborIndices representing the neighbor variables by integer * index. * * @param featureTable the feature table to use to drive the value of the factor * @param neighborIndices the indices of the neighboring variables, in order * @return a reference to the created factor. This can be safely ignored, as the factor is already saved in the model */ public Factor addFactor(ConcatVectorTable featureTable, int[] neighborIndices) { assert (featureTable.getDimensions().length == neighborIndices.length); Factor factor = new Factor(featureTable, neighborIndices); factors.add(factor); return factor; } /** * @return an array of integers, indicating variable sizes given by each of the factors in the model */ public int[] getVariableSizes() { if (factors.size() == 0) { return new int[0]; } int maxVar = 0; for (Factor f : factors) { for (int n : f.neigborIndices) { if (n > maxVar) maxVar = n; } } int[] sizes = new int[maxVar + 1]; for (int i = 0; i < sizes.length; i++) { sizes[i] = -1; } for (Factor f : factors) { for (int i = 0; i < f.neigborIndices.length; i++) { sizes[f.neigborIndices[i]] = f.featuresTable.getDimensions()[i]; } } return sizes; } /** * Writes the protobuf version of this graphical model to a stream. reversible with readFromStream(). * * @param stream the output stream to write to * @throws IOException passed through from the stream */ public void writeToStream(OutputStream stream) throws IOException { getProtoBuilder().build().writeDelimitedTo(stream); } /** * Static function to deserialize a graphical model from an input stream. * * @param stream the stream to read from, assuming protobuf encoding * @return a new graphical model * @throws IOException passed through from the stream */ public static GraphicalModel readFromStream(InputStream stream) throws IOException { return readFromProto(GraphicalModelProto.GraphicalModel.parseDelimitedFrom(stream)); } /** * @return the proto builder corresponding to this GraphicalModel */ public GraphicalModelProto.GraphicalModel.Builder getProtoBuilder() { GraphicalModelProto.GraphicalModel.Builder builder = GraphicalModelProto.GraphicalModel.newBuilder(); builder.setMetaData(getProtoMetaDataBuilder(modelMetaData)); for (Map<String, String> metaData : variableMetaData) { builder.addVariableMetaData(getProtoMetaDataBuilder(metaData)); } for (Factor factor : factors) { builder.addFactor(factor.getProtoBuilder()); } return builder; } /** * Recreates an in-memory GraphicalModel from a proto serialization, recursively creating all the ConcatVectorTable's * and ConcatVector's in memory as well. * * @param proto the proto to read * @return an in-memory GraphicalModel */ public static GraphicalModel readFromProto(GraphicalModelProto.GraphicalModel proto) { if (proto == null) return null; GraphicalModel model = new GraphicalModel(); model.modelMetaData = readMetaDataFromProto(proto.getMetaData()); model.variableMetaData = new ArrayList<>(); for (int i = 0; i < proto.getVariableMetaDataCount(); i++) { model.variableMetaData.add(readMetaDataFromProto(proto.getVariableMetaData(i))); } for (int i = 0; i < proto.getFactorCount(); i++) { model.factors.add(Factor.readFromProto(proto.getFactor(i))); } return model; } /** * Check that two models are deeply value-equivalent, down to the concat vectors inside the factor tables, within * some tolerance. Mostly useful for testing. * * @param other the graphical model to compare against. * @param tolerance the tolerance to accept when comparing concat vectors for value equality. * @return whether the two models are tolerance equivalent */ public boolean valueEquals(GraphicalModel other, double tolerance) { if (!modelMetaData.equals(other.modelMetaData)) { return false; } if (!variableMetaData.equals(other.variableMetaData)) { return false; } // compare factor sets for equality Set<Factor> remaining = new HashSet<>(); remaining.addAll(factors); for (Factor otherFactor : other.factors) { Factor match = null; for (Factor factor : remaining) { if (factor.valueEquals(otherFactor, tolerance)) { match = factor; break; } } if (match == null) return false; else remaining.remove(match); } return remaining.size() <= 0; } /** * Displays a list of factors, by neighbor. * * @return a formatted list of factors, by neighbor */ @Override public String toString() { String s = "{"; for (Factor f : factors) { s += "\n\t" + Arrays.toString(f.neigborIndices) + "@" + f; } s += "\n}"; return s; } /** * The point here is to allow us to save a copy of the model with a current set of factors and metadata mappings, * which can come in super handy with gameplaying applications. The cloned model doesn't instantiate the feature * thunks inside factors, those are just taken over individually. * * @return a clone */ public GraphicalModel cloneModel() { GraphicalModel clone = new GraphicalModel(); clone.modelMetaData.putAll(modelMetaData); for (int i = 0; i < variableMetaData.size(); i++) { if (variableMetaData.get(i) != null) { clone.getVariableMetaDataByReference(i).putAll(variableMetaData.get(i)); } } for (Factor f : factors) { clone.factors.add(f.cloneFactor()); } return clone; } //////////////////////////////////////////////////////////////////////////// // PRIVATE IMPLEMENTATION //////////////////////////////////////////////////////////////////////////// private static GraphicalModelProto.MetaData.Builder getProtoMetaDataBuilder(Map<String, String> metaData) { GraphicalModelProto.MetaData.Builder builder = GraphicalModelProto.MetaData.newBuilder(); for (String key : metaData.keySet()) { builder.addKey(key); builder.addValue(metaData.get(key)); } return builder; } private static Map<String, String> readMetaDataFromProto(GraphicalModelProto.MetaData proto) { Map<String, String> metaData = new HashMap<>(); for (int i = 0; i < proto.getKeyCount(); i++) { metaData.put(proto.getKey(i), proto.getValue(i)); } return metaData; } }