package hex.genmodel.easy; import hex.ModelCategory; import hex.genmodel.GenModel; import hex.genmodel.algos.deepwater.DeepwaterMojoModel; import hex.genmodel.algos.word2vec.Word2VecMojoModel; import hex.genmodel.algos.word2vec.WordEmbeddingModel; import hex.genmodel.easy.exception.PredictException; import hex.genmodel.easy.exception.PredictNumberFormatException; import hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException; import hex.genmodel.easy.exception.PredictUnknownTypeException; import hex.genmodel.easy.prediction.*; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; /** * An easy-to-use prediction wrapper for generated models. Instantiate as follows. The following two are equivalent. * * EasyPredictModelWrapper model = new EasyPredictModelWrapper(rawModel); * * EasyPredictModelWrapper model = new EasyPredictModelWrapper( * new EasyPredictModelWrapper.Config() * .setModel(rawModel) * .setConvertUnknownCategoricalLevelsToNa(false)); * * Note that for any given model, you must use the exact one correct predict method below based on the * model category. * * By default, unknown categorical levels result in a thrown PredictUnknownCategoricalLevelException. * The API was designed with this default to make the simplest possible setup inform the user if there are concerns * with the data quality. * An alternate behavior is to automatically convert unknown categorical levels to N/A. To do this, use * setConvertUnknownCategoricalLevelsToNa(true) instead. * * If you choose to convert unknown categorical levels to N/A, you can see how many times this is happening * with the following methods: * * getTotalUnknownCategoricalLevelsSeen() * getUnknownCategoricalLevelsSeenPerColumn() * * <p></p> * See the top-of-tree master version of this file <a href="https://github.com/h2oai/h2o-3/blob/master/h2o-genmodel/src/main/java/hex/genmodel/easy/EasyPredictModelWrapper.java" target="_blank">here on github</a>. */ public class EasyPredictModelWrapper implements java.io.Serializable { // These private members are read-only after the constructor. private final GenModel m; private final HashMap<String, Integer> modelColumnNameToIndexMap; private final HashMap<Integer, HashMap<String, Integer>> domainMap; private final boolean convertUnknownCategoricalLevelsToNa; private final boolean convertInvalidNumbersToNa; private final ConcurrentHashMap<String,AtomicLong> unknownCategoricalLevelsSeenPerColumn; /** * Configuration builder for instantiating a Wrapper. */ public static class Config { private GenModel model; private boolean convertUnknownCategoricalLevelsToNa = false; private boolean convertInvalidNumbersToNa = false; /** * Specify model object to wrap. * * @param value model * @return this config object */ public Config setModel(GenModel value) { model = value; return this; } /** * @return model object being wrapped */ public GenModel getModel() { return model; } /** * Specify how to handle unknown categorical levels. * * @param value false: throw exception; true: convert to N/A * @return this config object */ public Config setConvertUnknownCategoricalLevelsToNa(boolean value) { convertUnknownCategoricalLevelsToNa = value; return this; } /** * @return Setting for unknown categorical levels handling */ public boolean getConvertUnknownCategoricalLevelsToNa() { return convertUnknownCategoricalLevelsToNa; } /** * Specify the default action when a string value cannot be converted to * a number. * * @param value if true, then an N/A value will be produced, if false an * exception will be thrown. */ public Config setConvertInvalidNumbersToNa(boolean value) { convertInvalidNumbersToNa = value; return this; } public boolean getConvertInvalidNumbersToNa() { return convertInvalidNumbersToNa; } } /** * Create a wrapper for a generated model. * * @param config The wrapper configuration */ public EasyPredictModelWrapper(Config config) { m = config.getModel(); // Create map of column names to index number. modelColumnNameToIndexMap = new HashMap<>(); String[] modelColumnNames = m.getNames(); for (int i = 0; i < modelColumnNames.length; i++) { modelColumnNameToIndexMap.put(modelColumnNames[i], i); } // How to handle unknown categorical levels. unknownCategoricalLevelsSeenPerColumn = new ConcurrentHashMap<>(); convertUnknownCategoricalLevelsToNa = config.getConvertUnknownCategoricalLevelsToNa(); convertInvalidNumbersToNa = config.getConvertInvalidNumbersToNa(); setupConvertUnknownCategoricalLevelsToNa(); // Create map of input variable domain information. // This contains the categorical string to numeric mapping. domainMap = new HashMap<>(); for (int i = 0; i < m.getNumCols(); i++) { String[] domainValues = m.getDomainValues(i); if (domainValues != null) { HashMap<String, Integer> m = new HashMap<>(); for (int j = 0; j < domainValues.length; j++) { m.put(domainValues[j], j); } domainMap.put(i, m); } } } /** * Create a wrapper for a generated model. * * @param model The generated model */ public EasyPredictModelWrapper(GenModel model) { this(new Config() .setModel(model)); } /** * Get the total number unknown categorical levels seen. * * A single prediction may contribute more than one to the count. * The count is only updated when setConvertUnknownCategoricalLevelsToNa is set to true. * * @return A long value. */ public long getTotalUnknownCategoricalLevelsSeen() { ConcurrentHashMap<String, AtomicLong> map = getUnknownCategoricalLevelsSeenPerColumn(); long total = 0; for (AtomicLong l : map.values()) { total += l.get(); } return total; } /** * Get unknown categorical level counts. * * A single prediction may contribute to more than one count. * Counts are only updated when setConvertUnknownCategoricalLevelsToNa is set to true. * * @return A hash map with a per-column count of unknown categorical levels seen when making predictions. */ public ConcurrentHashMap<String, AtomicLong> getUnknownCategoricalLevelsSeenPerColumn() { return unknownCategoricalLevelsSeenPerColumn; } /** * Make a prediction on a new data point. * * The type of prediction returned depends on the model type. * The caller needs to know what type of prediction to expect. * * This call is convenient for generically automating model deployment. * For specific applications (where the kind of model is known and doesn't change), it is recommended to call * specific prediction calls like predictBinomial() directly. * * @param data A new data point. * @return The prediction. * @throws PredictException */ public AbstractPrediction predict(RowData data, ModelCategory mc) throws PredictException { switch (mc) { case AutoEncoder: return predictAutoEncoder(data); case Binomial: return predictBinomial(data); case Multinomial: return predictMultinomial(data); case Clustering: return predictClustering(data); case Regression: return predictRegression(data); case DimReduction: return predictDimReduction(data); case WordEmbedding: return predictWord2Vec(data); case Unknown: throw new PredictException("Unknown model category"); default: throw new PredictException("Unhandled model category (" + m.getModelCategory() + ") in switch statement"); } } public AbstractPrediction predict(RowData data) throws PredictException { return predict(data, m.getModelCategory()); } /** * Make a prediction on a new data point using an AutoEncoder model. * * @param data A new data point. * @return The prediction. * @throws PredictException */ public AutoEncoderModelPrediction predictAutoEncoder(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.AutoEncoder, data); throw new RuntimeException("Unimplemented " + preds.length); } /** * Make a prediction on a new data point using a Dimension Reduction model (PCA, GLRM) * @param data A new data point. * @return The prediction. * @throws PredictException */ public DimReductionModelPrediction predictDimReduction(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.DimReduction, data); DimReductionModelPrediction p = new DimReductionModelPrediction(); p.dimensions = preds; return p; } /** * Lookup word embeddings for a given word (or set of words). * @param data RawData structure, every key with a String value will be translated to an embedding * @return The prediction * @throws PredictException if model is not a WordEmbedding model */ public Word2VecPrediction predictWord2Vec(RowData data) throws PredictException { validateModelCategory(ModelCategory.WordEmbedding); if (! (m instanceof WordEmbeddingModel)) throw new PredictException("Model is not of the expected type, class = " + m.getClass().getSimpleName()); final WordEmbeddingModel weModel = (WordEmbeddingModel) m; final int vecSize = weModel.getVecSize(); HashMap<String, float[]> embeddings = new HashMap<>(data.size()); for (String wordKey : data.keySet()) { Object value = data.get(wordKey); if (value instanceof String) { String word = (String) value; embeddings.put(wordKey, weModel.transform0(word, new float[vecSize])); } } Word2VecPrediction p = new Word2VecPrediction(); p.wordEmbeddings = embeddings; return p; } /** * Make a prediction on a new data point using a Binomial model. * * @param data A new data point. * @return The prediction. * @throws PredictException */ public BinomialModelPrediction predictBinomial(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.Binomial, data); BinomialModelPrediction p = new BinomialModelPrediction(); double d = preds[0]; p.labelIndex = (int) d; String[] domainValues = m.getDomainValues(m.getResponseIdx()); p.label = domainValues[p.labelIndex]; p.classProbabilities = new double[m.getNumResponseClasses()]; System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length); if (m.calibrateClassProbabilities(preds)) { p.calibratedClassProbabilities = new double[m.getNumResponseClasses()]; System.arraycopy(preds, 1, p.calibratedClassProbabilities, 0, p.calibratedClassProbabilities.length); } return p; } /** * Make a prediction on a new data point using a Multinomial model. * * @param data A new data point. * @return The prediction. * @throws PredictException */ public MultinomialModelPrediction predictMultinomial(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.Multinomial, data); MultinomialModelPrediction p = new MultinomialModelPrediction(); p.classProbabilities = new double[m.getNumResponseClasses()]; p.labelIndex = (int) preds[0]; String[] domainValues = m.getDomainValues(m.getResponseIdx()); p.label = domainValues[p.labelIndex]; System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length); return p; } /** * Sort in descending order. */ private SortedClassProbability[] sortByDescendingClassProbability(String[] domainValues, double[] classProbabilities) { assert (classProbabilities.length == domainValues.length); SortedClassProbability[] arr = new SortedClassProbability[domainValues.length]; for (int i = 0; i < domainValues.length; i++) { arr[i] = new SortedClassProbability(); arr[i].name = domainValues[i]; arr[i].probability = classProbabilities[i]; } Arrays.sort(arr, Collections.reverseOrder()); return arr; } /** * A helper function to return an array of binomial class probabilities for a prediction in sorted order. * The returned array has the most probable class in position 0. * * @param p The prediction. * @return An array with sorted class probabilities. */ public SortedClassProbability[] sortByDescendingClassProbability(BinomialModelPrediction p) { String[] domainValues = m.getDomainValues(m.getResponseIdx()); double[] classProbabilities = p.classProbabilities; return sortByDescendingClassProbability(domainValues, classProbabilities); } /** * A helper function to return an array of multinomial class probabilities for a prediction in sorted order. * The returned array has the most probable class in position 0. * * @param p The prediction. * @return An array with sorted class probabilities. */ public SortedClassProbability[] sortByDescendingClassProbability(MultinomialModelPrediction p) { String[] domainValues = m.getDomainValues(m.getResponseIdx()); double[] classProbabilities = p.classProbabilities; return sortByDescendingClassProbability(domainValues, classProbabilities); } /** * Make a prediction on a new data point using a Clustering model. * * @param data A new data point. * @return The prediction. * @throws PredictException */ public ClusteringModelPrediction predictClustering(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.Clustering, data); ClusteringModelPrediction p = new ClusteringModelPrediction(); p.cluster = (int) preds[0]; return p; } /** * Make a prediction on a new data point using a Regression model. * * @param data A new data point. * @return The prediction. * @throws PredictException */ public RegressionModelPrediction predictRegression(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.Regression, data); RegressionModelPrediction p = new RegressionModelPrediction(); p.value = preds[0]; return p; } /** * Make a prediction on a new data point using a k-LIME model. * * @param data A new data point. * @return The prediction. * @throws PredictException */ public KLimeModelPrediction predictKLime(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.Regression, data); KLimeModelPrediction p = new KLimeModelPrediction(); p.value = preds[0]; p.cluster = (int) preds[1]; p.reasonCodes = new double[preds.length - 2]; System.arraycopy(preds, 2, p.reasonCodes, 0, p.reasonCodes.length); return p; } //---------------------------------------------------------------------- // Transparent methods passed through to GenModel. //---------------------------------------------------------------------- /** * Get the category (type) of model. * @return The category. */ public ModelCategory getModelCategory() { return m.getModelCategory(); } /** * Get the array of levels for the response column. * "Domain" just means list of level names for a categorical (aka factor, enum) column. * If the response column is numerical and not categorical, this will return null. * * @return The array. */ public String[] getResponseDomainValues() { return m.getDomainValues(m.getResponseIdx()); } /** * Some autoencoder thing, I'm not sure what this does. * @return CSV header for autoencoder. */ public String getHeader() { return m.getHeader(); } //---------------------------------------------------------------------- // Private methods below this line. //---------------------------------------------------------------------- private void setupConvertUnknownCategoricalLevelsToNa() { if (convertUnknownCategoricalLevelsToNa) { for (int i = 0; i < m.getNumCols(); i++) { String[] domainValues = m.getDomainValues(i); if (domainValues != null) { String columnName = m.getNames()[i]; unknownCategoricalLevelsSeenPerColumn.put(columnName, new AtomicLong()); } } } else { unknownCategoricalLevelsSeenPerColumn.clear(); } } private void validateModelCategory(ModelCategory c) throws PredictException { if (!m.getModelCategories().contains(c)) throw new PredictException(c + " prediction type is not supported for this model."); } // This should have been called predict(), because that's what it does private double[] preamble(ModelCategory c, RowData data) throws PredictException { validateModelCategory(c); return predict(data, new double[m.getPredsSize(c)]); } private void setToNaN(double[] arr) { for (int i = 0; i < arr.length; i++) { arr[i] = Double.NaN; } } private double[] fillRawData(RowData data, double[] rawData) throws PredictException { // TODO: refactor boolean isImage = m instanceof DeepwaterMojoModel && ((DeepwaterMojoModel) m)._problem_type.equals("image"); boolean isText = m instanceof DeepwaterMojoModel && ((DeepwaterMojoModel) m)._problem_type.equals("text"); for (String dataColumnName : data.keySet()) { Integer index = modelColumnNameToIndexMap.get(dataColumnName); // Skip column names that are not known. // Skip the "response" column which should not be included in `rawData` if (index == null || index >= rawData.length) { continue; } BufferedImage img = null; String[] domainValues = m.getDomainValues(index); if (domainValues == null) { // Column is either numeric or a string (for images or text) double value = Double.NaN; Object o = data.get(dataColumnName); if (o instanceof String) { String s = ((String) o).trim(); // Url to an image given if (isImage) { boolean isURL = s.matches("^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]"); try { img = isURL? ImageIO.read(new URL(s)) : ImageIO.read(new File(s)); } catch (IOException e) { throw new PredictException("Couldn't read image from " + s); } } else if (isText) { // TODO: use model-specific vectorization of text throw new PredictException("MOJO scoring for text classification is not yet implemented."); } else { // numeric try { value = Double.parseDouble(s); } catch(NumberFormatException nfe) { if (!convertInvalidNumbersToNa) throw new PredictNumberFormatException("Unable to parse value: " + s + ", from column: "+ dataColumnName + ", as Double; " + nfe.getMessage()); } } } else if (o instanceof Double) { value = (Double) o; } else if (o instanceof byte[] && isImage) { // Read the image from raw bytes InputStream is = new ByteArrayInputStream((byte[]) o); try { img = ImageIO.read(is); } catch (IOException e) { throw new PredictException("Couldn't interpret raw bytes as an image."); } } else { throw new PredictUnknownTypeException( "Unexpected object type " + o.getClass().getName() + " for numeric column " + dataColumnName); } if (isImage && img != null) { DeepwaterMojoModel dwm = (DeepwaterMojoModel) m; int W = dwm._width; int H = dwm._height; int C = dwm._channels; float[] _destData = new float[W * H * C]; try { GenModel.img2pixels(img, W, H, C, _destData, 0, dwm._meanImageData); } catch (IOException e) { e.printStackTrace(); throw new PredictException("Couldn't vectorize image."); } rawData = new double[_destData.length]; for (int i = 0; i < rawData.length; ++i) rawData[i] = _destData[i]; return rawData; } rawData[index] = value; } else { // Column has categorical value. Object o = data.get(dataColumnName); double value; if (o instanceof String) { String levelName = (String) o; HashMap<String, Integer> columnDomainMap = domainMap.get(index); Integer levelIndex = columnDomainMap.get(levelName); if (levelIndex == null) { levelIndex = columnDomainMap.get(dataColumnName + "." + levelName); } if (levelIndex == null) { if (convertUnknownCategoricalLevelsToNa) { value = Double.NaN; unknownCategoricalLevelsSeenPerColumn.get(dataColumnName).incrementAndGet(); } else { throw new PredictUnknownCategoricalLevelException("Unknown categorical level (" + dataColumnName + "," + levelName + ")", dataColumnName, levelName); } } else { value = levelIndex; } } else if (o instanceof Double && Double.isNaN((double)o)) { value = (double)o; //Missing factor is the only Double value allowed } else { throw new PredictUnknownTypeException( "Unexpected object type " + o.getClass().getName() + " for categorical column " + dataColumnName); } rawData[index] = value; } } return rawData; } private double[] predict(RowData data, double[] preds) throws PredictException { double[] rawData = new double[m.nfeatures()]; setToNaN(rawData); rawData = fillRawData(data, rawData); preds = m.score0(rawData, preds); return preds; } }