/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * XML.java * Copyright (C) 2009-2012 University of Waikato, Hamilton, New Zealand */ package weka.classifiers.evaluation.output.prediction; import weka.classifiers.Classifier; import weka.core.Attribute; import weka.core.Instance; import weka.core.Utils; import weka.core.Version; import weka.core.xml.XMLDocument; /** <!-- globalinfo-start --> * Outputs the predictions in XML.<br/> * <br/> * The following DTD is used:<br/> * <br/> * <!DOCTYPE predictions<br/> * [<br/> * <!ELEMENT predictions (prediction*)><br/> * <!ATTLIST predictions version CDATA "3.5.8"><br/> * <!ATTLIST predictions name CDATA #REQUIRED><br/> * <br/> * <!ELEMENT prediction ((actual_label,predicted_label,error,(prediction|distribution),attributes?)|(actual_value,predicted_value,error,attributes?))><br/> * <!ATTLIST prediction index CDATA #REQUIRED><br/> * <br/> * <!ELEMENT actual_label ANY><br/> * <!ATTLIST actual_label index CDATA #REQUIRED><br/> * <!ELEMENT predicted_label ANY><br/> * <!ATTLIST predicted_label index CDATA #REQUIRED><br/> * <!ELEMENT error ANY><br/> * <!ELEMENT prediction ANY><br/> * <!ELEMENT distribution (class_label+)><br/> * <!ELEMENT class_label ANY><br/> * <!ATTLIST class_label index CDATA #REQUIRED><br/> * <!ATTLIST class_label predicted (yes|no) "no"><br/> * <!ELEMENT actual_value ANY><br/> * <!ELEMENT predicted_value ANY><br/> * <!ELEMENT attributes (attribute+)><br/> * <!ELEMENT attribute ANY><br/> * <!ATTLIST attribute index CDATA #REQUIRED><br/> * <!ATTLIST attribute name CDATA #REQUIRED><br/> * <!ATTLIST attribute type (numeric|date|nominal|string|relational) #REQUIRED><br/> * ]<br/> * > * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -p <range> * The range of attributes to print in addition to the classification. * (default: none)</pre> * * <pre> -distribution * Whether to turn on the output of the class distribution. * Only for nominal class attributes. * (default: off)</pre> * * <pre> -decimals <num> * The number of digits after the decimal point. * (default: 3)</pre> * * <pre> -file <path> * The file to store the output in, instead of outputting it on stdout. * Gets ignored if the supplied path is a directory. * (default: .)</pre> * * <pre> -suppress * In case the data gets stored in a file, then this flag can be used * to suppress the regular output. * (default: not suppressed)</pre> * <!-- options-end --> * * @author fracpete (fracpete at waikato dot ac dot nz) * @version $Revision: 8937 $ */ public class XML extends AbstractOutput { /** for serialization. */ private static final long serialVersionUID = -3165514277316824801L; /** the DocType definition. */ public final static String DTD_DOCTYPE = XMLDocument.DTD_DOCTYPE; /** the Element definition. */ public final static String DTD_ELEMENT = XMLDocument.DTD_ELEMENT; /** the AttList definition. */ public final static String DTD_ATTLIST = XMLDocument.DTD_ATTLIST; /** the optional marker. */ public final static String DTD_OPTIONAL = XMLDocument.DTD_OPTIONAL; /** the at least one marker. */ public final static String DTD_AT_LEAST_ONE = XMLDocument.DTD_AT_LEAST_ONE; /** the zero or more marker. */ public final static String DTD_ZERO_OR_MORE = XMLDocument.DTD_ZERO_OR_MORE; /** the option separator. */ public final static String DTD_SEPARATOR = XMLDocument.DTD_SEPARATOR; /** the CDATA placeholder. */ public final static String DTD_CDATA = XMLDocument.DTD_CDATA; /** the ANY placeholder. */ public final static String DTD_ANY = XMLDocument.DTD_ANY; /** the #PCDATA placeholder. */ public final static String DTD_PCDATA = XMLDocument.DTD_PCDATA; /** the #IMPLIED placeholder. */ public final static String DTD_IMPLIED = XMLDocument.DTD_IMPLIED; /** the #REQUIRED placeholder. */ public final static String DTD_REQUIRED = XMLDocument.DTD_REQUIRED; /** the "version" attribute. */ public final static String ATT_VERSION = XMLDocument.ATT_VERSION; /** the "name" attribute. */ public final static String ATT_NAME = XMLDocument.ATT_NAME; /** the "type" attribute. */ public final static String ATT_TYPE = "type"; /** the value "yes". */ public final static String VAL_YES = XMLDocument.VAL_YES; /** the value "no". */ public final static String VAL_NO = XMLDocument.VAL_NO; /** the predictions tag. */ public final static String TAG_PREDICTIONS = "predictions"; /** the prediction tag. */ public final static String TAG_PREDICTION = "prediction"; /** the actual_nominal tag. */ public final static String TAG_ACTUAL_LABEL = "actual_label"; /** the predicted_nominal tag. */ public final static String TAG_PREDICTED_LABEL = "predicted_label"; /** the error tag. */ public final static String TAG_ERROR = "error"; /** the distribution tag. */ public final static String TAG_DISTRIBUTION = "distribution"; /** the class_label tag. */ public final static String TAG_CLASS_LABEL = "class_label"; /** the actual_numeric tag. */ public final static String TAG_ACTUAL_VALUE = "actual_value"; /** the predicted_numeric tag. */ public final static String TAG_PREDICTED_VALUE = "predicted_value"; /** the attributes tag. */ public final static String TAG_ATTRIBUTES = "attributes"; /** the attribute tag. */ public final static String TAG_ATTRIBUTE = "attribute"; /** the index attribute. */ public final static String ATT_INDEX = "index"; /** the predicted attribute. */ public final static String ATT_PREDICTED = "predicted"; /** the DTD. */ public final static String DTD = "<!" + DTD_DOCTYPE + " " + TAG_PREDICTIONS + "\n" + "[\n" + " <!" + DTD_ELEMENT + " " + TAG_PREDICTIONS + " (" + TAG_PREDICTION + DTD_ZERO_OR_MORE + ")" + ">\n" + " <!" + DTD_ATTLIST + " " + TAG_PREDICTIONS + " " + ATT_VERSION + " " + DTD_CDATA + " \"" + Version.VERSION + "\"" + ">\n" + " <!" + DTD_ATTLIST + " " + TAG_PREDICTIONS + " " + ATT_NAME + " " + DTD_CDATA + " " + DTD_REQUIRED + ">\n" + "\n" + " <!" + DTD_ELEMENT + " " + TAG_PREDICTION + " " + "(" + "(" + TAG_ACTUAL_LABEL + "," + TAG_PREDICTED_LABEL + "," + TAG_ERROR + "," + "(" + TAG_PREDICTION + DTD_SEPARATOR + TAG_DISTRIBUTION + ")" + "," + TAG_ATTRIBUTES + DTD_OPTIONAL + ")" + DTD_SEPARATOR + "(" + TAG_ACTUAL_VALUE + "," + TAG_PREDICTED_VALUE + "," + TAG_ERROR + "," + TAG_ATTRIBUTES + DTD_OPTIONAL + ")" + ")" + ">\n" + " <!" + DTD_ATTLIST + " " + TAG_PREDICTION + " " + ATT_INDEX + " " + DTD_CDATA + " " + DTD_REQUIRED + ">\n" + "\n" + " <!" + DTD_ELEMENT + " " + TAG_ACTUAL_LABEL + " " + DTD_ANY + ">\n" + " <!" + DTD_ATTLIST + " " + TAG_ACTUAL_LABEL + " " + ATT_INDEX + " " + DTD_CDATA + " " + DTD_REQUIRED + ">\n" + " <!" + DTD_ELEMENT + " " + TAG_PREDICTED_LABEL + " " + DTD_ANY + ">\n" + " <!" + DTD_ATTLIST + " " + TAG_PREDICTED_LABEL + " " + ATT_INDEX + " " + DTD_CDATA + " " + DTD_REQUIRED + ">\n" + " <!" + DTD_ELEMENT + " " + TAG_ERROR + " " + DTD_ANY + ">\n" + " <!" + DTD_ELEMENT + " " + TAG_PREDICTION + " " + DTD_ANY + ">\n" + " <!" + DTD_ELEMENT + " " + TAG_DISTRIBUTION + " (" + TAG_CLASS_LABEL + DTD_AT_LEAST_ONE + ")" + ">\n" + " <!" + DTD_ELEMENT + " " + TAG_CLASS_LABEL + " " + DTD_ANY + ">\n" + " <!" + DTD_ATTLIST + " " + TAG_CLASS_LABEL + " " + ATT_INDEX + " " + DTD_CDATA + " " + DTD_REQUIRED + ">\n" + " <!" + DTD_ATTLIST + " " + TAG_CLASS_LABEL + " " + ATT_PREDICTED + " (" + VAL_YES + DTD_SEPARATOR + VAL_NO + ") " + "\"" + VAL_NO + "\"" + ">\n" + " <!" + DTD_ELEMENT + " " + TAG_ACTUAL_VALUE + " " + DTD_ANY + ">\n" + " <!" + DTD_ELEMENT + " " + TAG_PREDICTED_VALUE + " " + DTD_ANY + ">\n" + " <!" + DTD_ELEMENT + " " + TAG_ATTRIBUTES + " (" + TAG_ATTRIBUTE + DTD_AT_LEAST_ONE + ")" + ">\n" + " <!" + DTD_ELEMENT + " " + TAG_ATTRIBUTE + " " + DTD_ANY + ">\n" + " <!" + DTD_ATTLIST + " " + TAG_ATTRIBUTE + " " + ATT_INDEX + " " + DTD_CDATA + " " + DTD_REQUIRED + ">\n" + " <!" + DTD_ATTLIST + " " + TAG_ATTRIBUTE + " " + ATT_NAME + " " + DTD_CDATA + " " + DTD_REQUIRED + ">\n" + " <!" + DTD_ATTLIST + " " + TAG_ATTRIBUTE + " " + ATT_TYPE + " " + "(" + Attribute.typeToString(Attribute.NUMERIC) + DTD_SEPARATOR + Attribute.typeToString(Attribute.DATE) + DTD_SEPARATOR + Attribute.typeToString(Attribute.NOMINAL) + DTD_SEPARATOR + Attribute.typeToString(Attribute.STRING) + DTD_SEPARATOR + Attribute.typeToString(Attribute.RELATIONAL) + ")" + " " + DTD_REQUIRED + ">\n" + "]\n" + ">"; /** * Returns a string describing the output generator. * * @return a description suitable for * displaying in the GUI */ public String globalInfo() { return "Outputs the predictions in XML.\n\n" + "The following DTD is used:\n\n" + DTD; } /** * Returns a short display text, to be used in comboboxes. * * @return a short display text */ public String getDisplay() { return "XML"; } /** * Replaces certain characters with their XML entities. * * @param s the string to process * @return the processed string */ protected String sanitize(String s) { String result; result = s; result = result.replaceAll("&", "&"); result = result.replaceAll("<", "<"); result = result.replaceAll(">", ">"); result = result.replaceAll("\"", """); return result; } /** * Performs the actual printing of the header. */ protected void doPrintHeader() { append("<?xml version=\"1.0\" encoding=\"utf-8\"?>\n"); append("\n"); append(DTD + "\n\n"); append("<" + TAG_PREDICTIONS + " " + ATT_VERSION + "=\"" + Version.VERSION + "\"" + " " + ATT_NAME + "=\"" + sanitize(m_Header.relationName()) + "\">\n"); } /** * Builds a string listing the attribute values in a specified range of indices, * separated by commas and enclosed in brackets. * * @param instance the instance to print the values from * @return a string listing values of the attributes in the range */ protected String attributeValuesString(Instance instance) { StringBuffer text = new StringBuffer(); if (m_Attributes != null) { text.append(" <" + TAG_ATTRIBUTES + ">\n"); m_Attributes.setUpper(instance.numAttributes() - 1); for (int i=0; i<instance.numAttributes(); i++) { if (m_Attributes.isInRange(i) && i != instance.classIndex()) { text.append(" <" + TAG_ATTRIBUTE + " " + ATT_INDEX + "=\"" + (i+1) + "\"" + " " + ATT_NAME + "=\"" + sanitize(instance.attribute(i).name()) + "\"" + " " + ATT_TYPE + "=\"" + Attribute.typeToString(instance.attribute(i).type()) + "\"" + ">"); text.append(sanitize(instance.toString(i))); text.append("</" + TAG_ATTRIBUTE + ">\n"); } } text.append(" </" + TAG_ATTRIBUTES + ">\n"); } return text.toString(); } /** * Store the prediction made by the classifier as a string. * * @param dist the distribution to use * @param inst the instance to generate text from * @param index the index in the dataset * @throws Exception if something goes wrong */ protected void doPrintClassification(double[] dist, Instance inst, int index) throws Exception { int prec = m_NumDecimals; Instance withMissing = (Instance)inst.copy(); withMissing.setDataset(inst.dataset()); double predValue = 0; if (Utils.sum(dist) == 0) { predValue = Utils.missingValue(); } else { if (inst.classAttribute().isNominal()) { predValue = Utils.maxIndex(dist); } else { predValue = dist[0]; } } // opening tag append(" <" + TAG_PREDICTION + " " + ATT_INDEX + "=\"" + (index+1) + "\">\n"); if (inst.dataset().classAttribute().isNumeric()) { // actual append(" <" + TAG_ACTUAL_VALUE + ">"); if (inst.classIsMissing()) append("?"); else append(Utils.doubleToString(inst.classValue(), prec)); append("</" + TAG_ACTUAL_VALUE + ">\n"); // predicted append(" <" + TAG_PREDICTED_VALUE + ">"); if (inst.classIsMissing()) append("?"); else append(Utils.doubleToString(predValue, prec)); append("</" + TAG_PREDICTED_VALUE + ">\n"); // error append(" <" + TAG_ERROR + ">"); if (Utils.isMissingValue(predValue) || inst.classIsMissing()) append("?"); else append(Utils.doubleToString(predValue - inst.classValue(), prec)); append("</" + TAG_ERROR + ">\n"); } else { // actual append(" <" + TAG_ACTUAL_LABEL + " " + ATT_INDEX + "=\"" + ((int) inst.classValue()+1) + "\"" + ">"); append(sanitize(inst.toString(inst.classIndex()))); append("</" + TAG_ACTUAL_LABEL + ">\n"); // predicted append(" <" + TAG_PREDICTED_LABEL + " " + ATT_INDEX + "=\"" + ((int) predValue+1) + "\"" + ">"); if (Utils.isMissingValue(predValue)) append("?"); else append(sanitize(inst.dataset().classAttribute().value((int)predValue))); append("</" + TAG_PREDICTED_LABEL + ">\n"); // error? append(" <" + TAG_ERROR + ">"); if (!Utils.isMissingValue(predValue) && !inst.classIsMissing() && ((int) predValue+1 != (int) inst.classValue()+1)) append(VAL_YES); else append(VAL_NO); append("</" + TAG_ERROR + ">\n"); // prediction/distribution if (m_OutputDistribution) { append(" <" + TAG_DISTRIBUTION + ">\n"); for (int n = 0; n < dist.length; n++) { append(" <" + TAG_CLASS_LABEL + " " + ATT_INDEX + "=\"" + (n+1) + "\""); if (!Utils.isMissingValue(predValue) && (n == (int) predValue)) append(" " + ATT_PREDICTED + "=\"" + VAL_YES + "\""); append(">"); append(Utils.doubleToString(dist[n], prec)); append("</" + TAG_CLASS_LABEL + ">\n"); } append(" </" + TAG_DISTRIBUTION + ">\n"); } else { append(" <" + TAG_PREDICTION + ">"); if (Utils.isMissingValue(predValue)) append("?"); else append(Utils.doubleToString(dist[(int)predValue], prec)); append("</" + TAG_PREDICTION + ">\n"); } } // attributes if (m_Attributes != null) append(attributeValuesString(withMissing)); // closing tag append(" </" + TAG_PREDICTION + ">\n"); } /** * Store the prediction made by the classifier as a string. * * @param classifier the classifier to use * @param inst the instance to generate text from * @param index the index in the dataset * @throws Exception if something goes wrong */ protected void doPrintClassification(Classifier classifier, Instance inst, int index) throws Exception { double[] d = classifier.distributionForInstance(inst); doPrintClassification(d, inst, index); } /** * Does nothing. */ protected void doPrintFooter() { append("</" + TAG_PREDICTIONS + ">\n"); } }