/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * GridSearch.java * Copyright (C) 2006 University of Waikato, Hamilton, New Zealand */ package weka.classifiers.meta; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.RandomizableSingleClassifierEnhancer; import weka.classifiers.functions.LinearRegression; import weka.core.*; import weka.core.Capabilities.Capability; import weka.filters.Filter; import weka.filters.supervised.attribute.PLSFilter; import weka.filters.unsupervised.attribute.MathExpression; import weka.filters.unsupervised.attribute.NumericCleaner; import weka.filters.unsupervised.instance.Resample; import java.beans.PropertyDescriptor; import java.io.File; import java.io.Serializable; import java.util.*; /** <!-- globalinfo-start --> * Performs a grid search of parameter pairs for the a classifier (Y-axis, default is LinearRegression with the "Ridge" parameter) and the PLSFilter (X-axis, "# of Components") and chooses the best pair found for the actual predicting.<br/> * <br/> * The initial grid is worked on with 2-fold CV to determine the values of the parameter pairs for the selected type of evaluation (e.g., accuracy). The best point in the grid is then taken and a 10-fold CV is performed with the adjacent parameter pairs. If a better pair is found, then this will act as new center and another 10-fold CV will be performed (kind of hill-climbing). This process is repeated until no better pair is found or the best pair is on the border of the grid.<br/> * In case the best pair is on the border, one can let GridSearch automatically extend the grid and continue the search. Check out the properties 'gridIsExtendable' (option '-extend-grid') and 'maxGridExtensions' (option '-max-grid-extensions <num>').<br/> * <br/> * GridSearch can handle doubles, integers (values are just cast to int) and booleans (0 is false, otherwise true). float, char and long are supported as well.<br/> * <br/> * The best filter/classifier setup can be accessed after the buildClassifier call via the getBestFilter/getBestClassifier methods.<br/> * Note on the implementation: after the data has been passed through the filter, a default NumericCleaner filter is applied to the data in order to avoid numbers that are getting too small and might produce NaNs in other schemes. * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -E <CC|RMSE|RRSE|MAE|RAE|COMB|ACC|KAP> * Determines the parameter used for evaluation: * CC = Correlation coefficient * RMSE = Root mean squared error * RRSE = Root relative squared error * MAE = Mean absolute error * RAE = Root absolute error * COMB = Combined = (1-abs(CC)) + RRSE + RAE * ACC = Accuracy * KAP = Kappa * (default: CC)</pre> * * <pre> -y-property <option> * The Y option to test (without leading dash). * (default: classifier.ridge)</pre> * * <pre> -y-min <num> * The minimum for Y. * (default: -10)</pre> * * <pre> -y-max <num> * The maximum for Y. * (default: +5)</pre> * * <pre> -y-step <num> * The step size for Y. * (default: 1)</pre> * * <pre> -y-base <num> * The base for Y. * (default: 10)</pre> * * <pre> -y-expression <expr> * The expression for Y. * Available parameters: * BASE * FROM * TO * STEP * I - the current iteration value * (from 'FROM' to 'TO' with stepsize 'STEP') * (default: 'pow(BASE,I)')</pre> * * <pre> -filter <filter specification> * The filter to use (on X axis). Full classname of filter to include, * followed by scheme options. * (default: weka.filters.supervised.attribute.PLSFilter)</pre> * * <pre> -x-property <option> * The X option to test (without leading dash). * (default: filter.numComponents)</pre> * * <pre> -x-min <num> * The minimum for X. * (default: +5)</pre> * * <pre> -x-max <num> * The maximum for X. * (default: +20)</pre> * * <pre> -x-step <num> * The step size for X. * (default: 1)</pre> * * <pre> -x-base <num> * The base for X. * (default: 10)</pre> * * <pre> -x-expression <expr> * The expression for the X value. * Available parameters: * BASE * MIN * MAX * STEP * I - the current iteration value * (from 'FROM' to 'TO' with stepsize 'STEP') * (default: 'pow(BASE,I)')</pre> * * <pre> -extend-grid * Whether the grid can be extended. * (default: no)</pre> * * <pre> -max-grid-extensions <num> * The maximum number of grid extensions (-1 is unlimited). * (default: 3)</pre> * * <pre> -sample-size <num> * The size (in percent) of the sample to search the inital grid with. * (default: 100)</pre> * * <pre> -traversal <ROW-WISE|COLUMN-WISE> * The type of traversal for the grid. * (default: COLUMN-WISE)</pre> * * <pre> -log-file <filename> * The log file to log the messages to. * (default: none)</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.functions.LinearRegression)</pre> * * <pre> * Options specific to classifier weka.classifiers.functions.LinearRegression: * </pre> * * <pre> -D * Produce debugging output. * (default no debugging output)</pre> * * <pre> -S <number of selection method> * Set the attribute selection method to use. 1 = None, 2 = Greedy. * (default 0 = M5' method)</pre> * * <pre> -C * Do not try to eliminate colinear attributes. * </pre> * * <pre> -R <double> * Set ridge parameter (default 1.0e-8). * </pre> * * <pre> * Options specific to filter weka.filters.supervised.attribute.PLSFilter ('-filter'): * </pre> * * <pre> -D * Turns on output of debugging information.</pre> * * <pre> -C <num> * The number of components to compute. * (default: 20)</pre> * * <pre> -U * Updates the class attribute as well. * (default: off)</pre> * * <pre> -M * Turns replacing of missing values on. * (default: off)</pre> * * <pre> -A <SIMPLS|PLS1> * The algorithm to use. * (default: PLS1)</pre> * * <pre> -P <none|center|standardize> * The type of preprocessing that is applied to the data. * (default: center)</pre> * <!-- options-end --> * * Examples: * <ul> * <li> * <b>Optimizing SMO with RBFKernel (C and gamma)</b> * <ul> * <li>Set the evaluation to <i>Accuracy</i>.</li> * <li>Set the filter to <code>weka.filters.AllFilter</code> since we * don't need any special data processing and we don't optimize the * filter in this case (data gets always passed through filter!).</li> * <li>Set <code>weka.classifiers.functions.SMO</code> as classifier * with <code>weka.classifiers.functions.supportVector.RBFKernel</code> * as kernel. * </li> * <li>Set the XProperty to "classifier.c", XMin to "1", XMax to "16", * XStep to "1" and the XExpression to "I". This will test the "C" * parameter of SMO for the values from 1 to 16.</li> * <li>Set the YProperty to "classifier.kernel.gamma", YMin to "-5", * YMax to "2", YStep to "1" YBase to "10" and YExpression to * "pow(BASE,I)". This will test the gamma of the RBFKernel with the * values 10^-5, 10^-4,..,10^2.</li> * </ul> * </li> * <li> * <b>Optimizing PLSFilter with LinearRegression (# of components and ridge) - default setup</b> * <ul> * <li>Set the evaluation to <i>Correlation coefficient</i>.</li> * <li>Set the filter to <code>weka.filters.supervised.attribute.PLSFilter</code>.</li> * <li>Set <code>weka.classifiers.functions.LinearRegression</code> as * classifier and use no attribute selection and no elimination of * colinear attributes.</li> * <li>Set the XProperty to "filter.numComponents", XMin to "5", XMax * to "20" (this depends heavily on your dataset, should be no more * than the number of attributes!), XStep to "1" and XExpression to * "I". This will test the number of components the PLSFilter will * produce from 5 to 20.</li> * <li>Set the YProperty to "classifier.ridge", XMin to "-10", XMax to * "5", YStep to "1" and YExpression to "pow(BASE,I)". This will * try ridge parameters from 10^-10 to 10^5.</li> * </ul> * </li> * </ul> * * General notes: * <ul> * <li>Turn the <i>debug</i> flag on in order to see some progress output in the * console</li> * <li>If you want to view the fitness landscape that GridSearch explores, * select a <i>log file</i>. This log will then contain Gnuplot data and * script block for viewing the landscape. Just copy paste those blocks * into files named accordingly and run Gnuplot with them.</li> * </ul> * * @author Bernhard Pfahringer (bernhard at cs dot waikato dot ac dot nz) * @author Geoff Holmes (geoff at cs dot waikato dot ac dot nz) * @author fracpete (fracpete at waikato dot ac dot nz) * @version $Revision: 5928 $ * @see PLSFilter * @see LinearRegression * @see NumericCleaner */ public class GridSearch extends RandomizableSingleClassifierEnhancer implements AdditionalMeasureProducer, Summarizable { /** * a serializable version of Point2D.Double * * @see java.awt.geom.Point2D.Double */ protected class PointDouble extends java.awt.geom.Point2D.Double implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = 7151661776161898119L; /** * the default constructor * * @param x the x value of the point * @param y the y value of the point */ public PointDouble(double x, double y) { super(x, y); } /** * Determines whether or not two points are equal. * * @param obj an object to be compared with this PointDouble * @return true if the object to be compared has the same values; * false otherwise. */ public boolean equals(Object obj) { PointDouble pd; pd = (PointDouble) obj; return (Utils.eq(this.getX(), pd.getX()) && Utils.eq(this.getY(), pd.getY())); } /** * returns a string representation of the Point * * @return the point as string */ public String toString() { return super.toString().replaceAll(".*\\[", "["); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * a serializable version of Point * * @see java.awt.Point */ protected class PointInt extends java.awt.Point implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = -5900415163698021618L; /** * the default constructor * * @param x the x value of the point * @param y the y value of the point */ public PointInt(int x, int y) { super(x, y); } /** * returns a string representation of the Point * * @return the point as string */ public String toString() { return super.toString().replaceAll(".*\\[", "["); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * for generating the parameter pairs in a grid */ protected class Grid implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = 7290732613611243139L; /** the minimum on the X axis */ protected double m_MinX; /** the maximum on the X axis */ protected double m_MaxX; /** the step size for the X axis */ protected double m_StepX; /** the label for the X axis */ protected String m_LabelX; /** the minimum on the Y axis */ protected double m_MinY; /** the maximum on the Y axis */ protected double m_MaxY; /** the step size for the Y axis */ protected double m_StepY; /** the label for the Y axis */ protected String m_LabelY; /** the number of points on the X axis */ protected int m_Width; /** the number of points on the Y axis */ protected int m_Height; /** * initializes the grid * * @param minX the minimum on the X axis * @param maxX the maximum on the X axis * @param stepX the step size for the X axis * @param minY the minimum on the Y axis * @param maxY the maximum on the Y axis * @param stepY the step size for the Y axis */ public Grid(double minX, double maxX, double stepX, double minY, double maxY, double stepY) { this(minX, maxX, stepX, "", minY, maxY, stepY, ""); } /** * initializes the grid * * @param minX the minimum on the X axis * @param maxX the maximum on the X axis * @param stepX the step size for the X axis * @param labelX the label for the X axis * @param minY the minimum on the Y axis * @param maxY the maximum on the Y axis * @param stepY the step size for the Y axis * @param labelY the label for the Y axis */ public Grid(double minX, double maxX, double stepX, String labelX, double minY, double maxY, double stepY, String labelY) { super(); m_MinX = minX; m_MaxX = maxX; m_StepX = stepX; m_LabelX = labelX; m_MinY = minY; m_MaxY = maxY; m_StepY = stepY; m_LabelY = labelY; m_Height = (int) StrictMath.round((m_MaxY - m_MinY) / m_StepY) + 1; m_Width = (int) StrictMath.round((m_MaxX - m_MinX) / m_StepX) + 1; // is min < max? if (m_MinX >= m_MaxX) throw new IllegalArgumentException("XMin must be smaller than XMax!"); if (m_MinY >= m_MaxY) throw new IllegalArgumentException("YMin must be smaller than YMax!"); // steps positive? if (m_StepX <= 0) throw new IllegalArgumentException("XStep must be a positive number!"); if (m_StepY <= 0) throw new IllegalArgumentException("YStep must be a positive number!"); // check borders if (!Utils.eq(m_MinX + (m_Width-1)*m_StepX, m_MaxX)) throw new IllegalArgumentException( "X axis doesn't match! Provided max: " + m_MaxX + ", calculated max via min and step size: " + (m_MinX + (m_Width-1)*m_StepX)); if (!Utils.eq(m_MinY + (m_Height-1)*m_StepY, m_MaxY)) throw new IllegalArgumentException( "Y axis doesn't match! Provided max: " + m_MaxY + ", calculated max via min and step size: " + (m_MinY + (m_Height-1)*m_StepY)); } /** * Tests itself against the provided grid object * * @param o the grid object to compare against * @return if the two grids have the same setup */ public boolean equals(Object o) { boolean result; Grid g; g = (Grid) o; result = (width() == g.width()) && (height() == g.height()) && (getMinX() == g.getMinX()) && (getMinY() == g.getMinY()) && (getStepX() == g.getStepX()) && (getStepY() == g.getStepY()) && getLabelX().equals(g.getLabelX()) && getLabelY().equals(g.getLabelY()); return result; } /** * returns the left border * * @return the left border */ public double getMinX() { return m_MinX; } /** * returns the right border * * @return the right border */ public double getMaxX() { return m_MaxX; } /** * returns the step size on the X axis * * @return the step size */ public double getStepX() { return m_StepX; } /** * returns the label for the X axis * * @return the label */ public String getLabelX() { return m_LabelX; } /** * returns the bottom border * * @return the bottom border */ public double getMinY() { return m_MinY; } /** * returns the top border * * @return the top border */ public double getMaxY() { return m_MaxY; } /** * returns the step size on the Y axis * * @return the step size */ public double getStepY() { return m_StepY; } /** * returns the label for the Y axis * * @return the label */ public String getLabelY() { return m_LabelY; } /** * returns the number of points in the grid on the Y axis (incl. borders) * * @return the number of points in the grid on the Y axis */ public int height() { return m_Height; } /** * returns the number of points in the grid on the X axis (incl. borders) * * @return the number of points in the grid on the X axis */ public int width() { return m_Width; } /** * returns the values at the given point in the grid * * @param x the x-th point on the X axis * @param y the y-th point on the Y axis * @return the value pair at the given position */ public PointDouble getValues(int x, int y) { if (x >= width()) throw new IllegalArgumentException("Index out of scope on X axis (" + x + " >= " + width() + ")!"); if (y >= height()) throw new IllegalArgumentException("Index out of scope on Y axis (" + y + " >= " + height() + ")!"); return new PointDouble(m_MinX + m_StepX*x, m_MinY + m_StepY*y); } /** * returns the closest index pair for the given value pair in the grid. * * @param values the values to get the indices for * @return the closest indices in the grid */ public PointInt getLocation(PointDouble values) { PointInt result; int x; int y; double distance; double currDistance; int i; // determine x x = 0; distance = m_StepX; for (i = 0; i < width(); i++) { currDistance = StrictMath.abs(values.getX() - getValues(i, 0).getX()); if (Utils.sm(currDistance, distance)) { distance = currDistance; x = i; } } // determine y y = 0; distance = m_StepY; for (i = 0; i < height(); i++) { currDistance = StrictMath.abs(values.getY() - getValues(0, i).getY()); if (Utils.sm(currDistance, distance)) { distance = currDistance; y = i; } } result = new PointInt(x, y); return result; } /** * checks whether the given values are on the border of the grid * * @param values the values to check * @return true if the the values are on the border */ public boolean isOnBorder(PointDouble values) { return isOnBorder(getLocation(values)); } /** * checks whether the given location is on the border of the grid * * @param location the location to check * @return true if the the location is on the border */ public boolean isOnBorder(PointInt location) { if (location.getX() == 0) return true; else if (location.getX() == width() - 1) return true; if (location.getY() == 0) return true; else if (location.getY() == height() - 1) return true; else return false; } /** * returns a subgrid with the same step sizes, but different borders * * @param top the top index * @param left the left index * @param bottom the bottom index * @param right the right index * @return the Sub-Grid */ public Grid subgrid(int top, int left, int bottom, int right) { return new Grid( getValues(left, top).getX(), getValues(right, top).getX(), getStepX(), getLabelX(), getValues(left, bottom).getY(), getValues(left, top).getY(), getStepY(), getLabelY()); } /** * returns an extended grid that encompasses the given point (won't be on * the border of the grid). * * @param values the point that the grid should contain * @return the extended grid */ public Grid extend(PointDouble values) { double minX; double maxX; double minY; double maxY; double distance; Grid result; // left if (Utils.smOrEq(values.getX(), getMinX())) { distance = getMinX() - values.getX(); // exactly on grid point? if (Utils.eq(distance, 0)) minX = getMinX() - getStepX() * (StrictMath.round(distance / getStepX()) + 1); else minX = getMinX() - getStepX() * (StrictMath.round(distance / getStepX())); } else { minX = getMinX(); } // right if (Utils.grOrEq(values.getX(), getMaxX())) { distance = values.getX() - getMaxX(); // exactly on grid point? if (Utils.eq(distance, 0)) maxX = getMaxX() + getStepX() * (StrictMath.round(distance / getStepX()) + 1); else maxX = getMaxX() + getStepX() * (StrictMath.round(distance / getStepX())); } else { maxX = getMaxX(); } // bottom if (Utils.smOrEq(values.getY(), getMinY())) { distance = getMinY() - values.getY(); // exactly on grid point? if (Utils.eq(distance, 0)) minY = getMinY() - getStepY() * (StrictMath.round(distance / getStepY()) + 1); else minY = getMinY() - getStepY() * (StrictMath.round(distance / getStepY())); } else { minY = getMinY(); } // top if (Utils.grOrEq(values.getY(), getMaxY())) { distance = values.getY() - getMaxY(); // exactly on grid point? if (Utils.eq(distance, 0)) maxY = getMaxY() + getStepY() * (StrictMath.round(distance / getStepY()) + 1); else maxY = getMaxY() + getStepY() * (StrictMath.round(distance / getStepY())); } else { maxY = getMaxY(); } result = new Grid(minX, maxX, getStepX(), getLabelX(), minY, maxY, getStepY(), getLabelY()); // did the grid really extend? if (equals(result)) throw new IllegalStateException("Grid extension failed!"); return result; } /** * returns an Enumeration over all pairs in the given row * * @param y the row to retrieve * @return an Enumeration over all pairs * @see #getValues(int, int) */ public Enumeration<PointDouble> row(int y) { Vector result; int i; result = new Vector(); for (i = 0; i < width(); i++) result.add(getValues(i, y)); return result.elements(); } /** * returns an Enumeration over all pairs in the given column * * @param x the column to retrieve * @return an Enumeration over all pairs * @see #getValues(int, int) */ public Enumeration<PointDouble> column(int x) { Vector result; int i; result = new Vector(); for (i = 0; i < height(); i++) result.add(getValues(x, i)); return result.elements(); } /** * returns a string representation of the grid * * @return a string representation */ public String toString() { String result; result = "X: " + m_MinX + " - " + m_MaxX + ", Step " + m_StepX; if (m_LabelX.length() != 0) result += " (" + m_LabelX + ")"; result += "\n"; result += "Y: " + m_MinY + " - " + m_MaxY + ", Step " + m_StepY; if (m_LabelY.length() != 0) result += " (" + m_LabelY + ")"; result += "\n"; result += "Dimensions (Rows x Columns): " + height() + " x " + width(); return result; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * A helper class for storing the performance of a values-pair. * Can be sorted with the PerformanceComparator class. * * @see PerformanceComparator */ protected class Performance implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = -4374706475277588755L; /** the value pair the classifier was built with */ protected PointDouble m_Values; /** the Correlation coefficient */ protected double m_CC; /** the Root mean squared error */ protected double m_RMSE; /** the Root relative squared error */ protected double m_RRSE; /** the Mean absolute error */ protected double m_MAE; /** the Relative absolute error */ protected double m_RAE; /** the Accuracy */ protected double m_ACC; /** the kappa value */ protected double m_Kappa; /** * initializes the performance container * * @param values the values-pair * @param evaluation the evaluation to extract the performance * measures from * @throws Exception if retrieving of measures fails */ public Performance(PointDouble values, Evaluation evaluation) throws Exception { super(); m_Values = values; m_RMSE = evaluation.rootMeanSquaredError(); m_RRSE = evaluation.rootRelativeSquaredError(); m_MAE = evaluation.meanAbsoluteError(); m_RAE = evaluation.relativeAbsoluteError(); try { m_CC = evaluation.correlationCoefficient(); } catch (Exception e) { m_CC = Double.NaN; } try { m_ACC = evaluation.pctCorrect(); } catch (Exception e) { m_ACC = Double.NaN; } try { m_Kappa = evaluation.kappa(); } catch (Exception e) { m_Kappa = Double.NaN; } } /** * returns the performance measure * * @param evaluation the type of measure to return * @return the performance measure */ public double getPerformance(int evaluation) { double result; result = Double.NaN; switch (evaluation) { case EVALUATION_CC: result = m_CC; break; case EVALUATION_RMSE: result = m_RMSE; break; case EVALUATION_RRSE: result = m_RRSE; break; case EVALUATION_MAE: result = m_MAE; break; case EVALUATION_RAE: result = m_RAE; break; case EVALUATION_COMBINED: result = (1 - StrictMath.abs(m_CC)) + m_RRSE + m_RAE; break; case EVALUATION_ACC: result = m_ACC; break; case EVALUATION_KAPPA: result = m_Kappa; break; default: throw new IllegalArgumentException("Evaluation type '" + evaluation + "' not supported!"); } return result; } /** * returns the values-pair for this performance * * @return the values-pair */ public PointDouble getValues() { return m_Values; } /** * returns a string representation of this performance object * * @param evaluation the type of performance to return * @return a string representation */ public String toString(int evaluation) { String result; result = "Performance (" + getValues() + "): " + getPerformance(evaluation) + " (" + new SelectedTag(evaluation, TAGS_EVALUATION) + ")"; return result; } /** * returns a Gnuplot string of this performance object * * @param evaluation the type of performance to return * @return the gnuplot string (x, y, z) */ public String toGnuplot(int evaluation) { String result; result = getValues().getX() + "\t" + getValues().getY() + "\t" + getPerformance(evaluation); return result; } /** * returns a string representation of this performance object * * @return a string representation */ public String toString() { String result; int i; result = "Performance (" + getValues() + "): "; for (i = 0; i < TAGS_EVALUATION.length; i++) { if (i > 0) result += ", "; result += getPerformance(TAGS_EVALUATION[i].getID()) + " (" + new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION) + ")"; } return result; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * A concrete Comparator for the Performance class. * * @see Performance */ protected class PerformanceComparator implements Comparator<Performance>, Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = 6507592831825393847L; /** the performance measure to use for comparison * @see GridSearch#TAGS_EVALUATION */ protected int m_Evaluation; /** * initializes the comparator with the given performance measure * * @param evaluation the performance measure to use * @see GridSearch#TAGS_EVALUATION */ public PerformanceComparator(int evaluation) { super(); m_Evaluation = evaluation; } /** * returns the performance measure that's used to compare the objects * * @return the performance measure * @see GridSearch#TAGS_EVALUATION */ public int getEvaluation() { return m_Evaluation; } /** * Compares its two arguments for order. Returns a negative integer, * zero, or a positive integer as the first argument is less than, * equal to, or greater than the second. * * @param o1 the first performance * @param o2 the second performance * @return the order */ public int compare(Performance o1, Performance o2) { int result; double p1; double p2; p1 = o1.getPerformance(getEvaluation()); p2 = o2.getPerformance(getEvaluation()); if (Utils.sm(p1, p2)) result = -1; else if (Utils.gr(p1, p2)) result = 1; else result = 0; // only correlation coefficient/accuracy/kappa obey to this order, for the // errors (and the combination of all three), the smaller the number the // better -> hence invert them if ( (getEvaluation() != EVALUATION_CC) && (getEvaluation() != EVALUATION_ACC) && (getEvaluation() != EVALUATION_KAPPA) ) result = -result; return result; } /** * Indicates whether some other object is "equal to" this Comparator. * * @param obj the object to compare with * @return true if the same evaluation type is used */ public boolean equals(Object obj) { if (!(obj instanceof PerformanceComparator)) throw new IllegalArgumentException("Must be PerformanceComparator!"); return (m_Evaluation == ((PerformanceComparator) obj).m_Evaluation); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * Generates a 2-dim array for the performances from a grid for a certain * type. x-min/y-min is in the bottom-left corner, i.e., getTable()[0][0] * returns the performance for the x-min/y-max pair. * <pre> * x-min x-max * |-------------| * - y-max * | * | * - y-min * </pre> */ protected class PerformanceTable implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = 5486491313460338379L; /** the corresponding grid */ protected Grid m_Grid; /** the performances */ protected Vector<Performance> m_Performances; /** the type of performance the table was generated for */ protected int m_Type; /** the table with the values */ protected double[][] m_Table; /** the minimum performance */ protected double m_Min; /** the maximum performance */ protected double m_Max; /** * initializes the table * * @param grid the underlying grid * @param performances the performances * @param type the type of performance */ public PerformanceTable(Grid grid, Vector<Performance> performances, int type) { super(); m_Grid = grid; m_Type = type; m_Performances = performances; generate(); } /** * generates the table */ protected void generate() { Performance perf; int i; PointInt location; m_Table = new double[getGrid().height()][getGrid().width()]; m_Min = 0; m_Max = 0; for (i = 0; i < getPerformances().size(); i++) { perf = (Performance) getPerformances().get(i); location = getGrid().getLocation(perf.getValues()); m_Table[getGrid().height() - (int) location.getY() - 1][(int) location.getX()] = perf.getPerformance(getType()); // determine min/max if (i == 0) { m_Min = perf.getPerformance(m_Type); m_Max = m_Min; } else { if (perf.getPerformance(m_Type) < m_Min) m_Min = perf.getPerformance(m_Type); if (perf.getPerformance(m_Type) > m_Max) m_Max = perf.getPerformance(m_Type); } } } /** * returns the corresponding grid * * @return the underlying grid */ public Grid getGrid() { return m_Grid; } /** * returns the underlying performances * * @return the underlying performances */ public Vector<Performance> getPerformances() { return m_Performances; } /** * returns the type of performance * * @return the type of performance */ public int getType() { return m_Type; } /** * returns the generated table * * @return the performance table * @see #m_Table * @see #generate() */ public double[][] getTable() { return m_Table; } /** * the minimum performance * * @return the performance */ public double getMin() { return m_Min; } /** * the maximum performance * * @return the performance */ public double getMax() { return m_Max; } /** * returns the table as string * * @return the table as string */ public String toString() { String result; int i; int n; result = "Table (" + new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag().getReadable() + ") - " + "X: " + getGrid().getLabelX() + ", Y: " + getGrid().getLabelY() + ":\n"; for (i = 0; i < getTable().length; i++) { if (i > 0) result += "\n"; for (n = 0; n < getTable()[i].length; n++) { if (n > 0) result += ","; result += getTable()[i][n]; } } return result; } /** * returns a string containing a gnuplot script+data file * * @return the data in gnuplot format */ public String toGnuplot() { StringBuffer result; Tag type; int i; result = new StringBuffer(); type = new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag(); result.append("Gnuplot (" + type.getReadable() + "):\n"); result.append("# begin 'gridsearch.data'\n"); result.append("# " + type.getReadable() + "\n"); for (i = 0; i < getPerformances().size(); i++) result.append(getPerformances().get(i).toGnuplot(type.getID()) + "\n"); result.append("# end 'gridsearch.data'\n\n"); result.append("# begin 'gridsearch.plot'\n"); result.append("# " + type.getReadable() + "\n"); result.append("set data style lines\n"); result.append("set contour base\n"); result.append("set surface\n"); result.append("set title '" + m_Data.relationName() + "'\n"); result.append("set xrange [" + getGrid().getMinX() + ":" + getGrid().getMaxX() + "]\n"); result.append("set xlabel 'x (" + getFilter().getClass().getName() + ": " + getXProperty() + ")'\n"); result.append("set yrange [" + getGrid().getMinY() + ":" + getGrid().getMaxY() + "]\n"); result.append("set ylabel 'y - (" + getClassifier().getClass().getName() + ": " + getYProperty() + ")'\n"); result.append("set zrange [" + (getMin() - (getMax() - getMin())*0.1) + ":" + (getMax() + (getMax() - getMin())*0.1) + "]\n"); result.append("set zlabel 'z - " + type.getReadable() + "'\n"); result.append("set dgrid3d " + getGrid().height() + "," + getGrid().width() + ",1\n"); result.append("show contour\n"); result.append("splot 'gridsearch.data'\n"); result.append("pause -1\n"); result.append("# end 'gridsearch.plot'"); return result.toString(); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** * Represents a simple cache for performance objects. */ protected class PerformanceCache implements Serializable, RevisionHandler { /** for serialization */ private static final long serialVersionUID = 5838863230451530252L; /** the cache for points in the grid that got calculated */ protected Hashtable m_Cache = new Hashtable(); /** * returns the ID string for a cache item * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return the ID string */ protected String getID(int cv, PointDouble values) { return cv + "\t" + values.getX() + "\t" + values.getY(); } /** * checks whether the point was already calculated ones * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return true if the value is already cached */ public boolean isCached(int cv, PointDouble values) { return (get(cv, values) != null); } /** * returns a cached performance object, null if not yet in the cache * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return the cached performance item, null if not in cache */ public Performance get(int cv, PointDouble values) { return (Performance) m_Cache.get(getID(cv, values)); } /** * adds the performance to the cache * * @param cv the number of folds in the cross-validation * @param p the performance object to store */ public void add(int cv, Performance p) { m_Cache.put(getID(cv, p.getValues()), p); } /** * returns a string representation of the cache * * @return the string representation of the cache */ public String toString() { return m_Cache.toString(); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } } /** for serialization */ private static final long serialVersionUID = -3034773968581595348L; /** evaluation via: Correlation coefficient */ public static final int EVALUATION_CC = 0; /** evaluation via: Root mean squared error */ public static final int EVALUATION_RMSE = 1; /** evaluation via: Root relative squared error */ public static final int EVALUATION_RRSE = 2; /** evaluation via: Mean absolute error */ public static final int EVALUATION_MAE = 3; /** evaluation via: Relative absolute error */ public static final int EVALUATION_RAE = 4; /** evaluation via: Combined = (1-CC) + RRSE + RAE */ public static final int EVALUATION_COMBINED = 5; /** evaluation via: Accuracy */ public static final int EVALUATION_ACC = 6; /** evaluation via: kappa statistic */ public static final int EVALUATION_KAPPA = 7; /** evaluation */ public static final Tag[] TAGS_EVALUATION = { new Tag(EVALUATION_CC, "CC", "Correlation coefficient"), new Tag(EVALUATION_RMSE, "RMSE", "Root mean squared error"), new Tag(EVALUATION_RRSE, "RRSE", "Root relative squared error"), new Tag(EVALUATION_MAE, "MAE", "Mean absolute error"), new Tag(EVALUATION_RAE, "RAE", "Root absolute error"), new Tag(EVALUATION_COMBINED, "COMB", "Combined = (1-abs(CC)) + RRSE + RAE"), new Tag(EVALUATION_ACC, "ACC", "Accuracy"), new Tag(EVALUATION_KAPPA, "KAP", "Kappa") }; /** row-wise grid traversal */ public static final int TRAVERSAL_BY_ROW = 0; /** column-wise grid traversal */ public static final int TRAVERSAL_BY_COLUMN = 1; /** traversal */ public static final Tag[] TAGS_TRAVERSAL = { new Tag(TRAVERSAL_BY_ROW, "row-wise", "row-wise"), new Tag(TRAVERSAL_BY_COLUMN, "column-wise", "column-wise") }; /** the prefix to indicate that the option is for the classifier */ public final static String PREFIX_CLASSIFIER = "classifier."; /** the prefix to indicate that the option is for the filter */ public final static String PREFIX_FILTER = "filter."; /** the Filter */ protected Filter m_Filter; /** the Filter with the best setup */ protected Filter m_BestFilter; /** the Classifier with the best setup */ protected Classifier m_BestClassifier; /** the best values */ protected PointDouble m_Values = null; /** the type of evaluation */ protected int m_Evaluation = EVALUATION_CC; /** the Y option to work on (without leading dash, preceding 'classifier.' * means to set the option for the classifier 'filter.' for the filter) */ protected String m_Y_Property = PREFIX_CLASSIFIER + "ridge"; /** the minimum of Y */ protected double m_Y_Min = -10; /** the maximum of Y */ protected double m_Y_Max = +5; /** the step size of Y */ protected double m_Y_Step = 1; /** the base for Y */ protected double m_Y_Base = 10; /** * The expression for the Y property. Available parameters for the * expression: * <ul> * <li>BASE</li> * <li>FROM (= min)</li> * <li>TO (= max)</li> * <li>STEP</li> * <li>I - the current value (from 'from' to 'to' with stepsize 'step')</li> * </ul> * * @see MathematicalExpression * @see MathExpression */ protected String m_Y_Expression = "pow(BASE,I)"; /** the X option to work on (without leading dash, preceding 'classifier.' * means to set the option for the classifier 'filter.' for the filter) */ protected String m_X_Property = PREFIX_FILTER + "numComponents"; /** the minimum of X */ protected double m_X_Min = +5; /** the maximum of X */ protected double m_X_Max = +20; /** the step size of */ protected double m_X_Step = 1; /** the base for */ protected double m_X_Base = 10; /** * The expression for the X property. Available parameters for the * expression: * <ul> * <li>BASE</li> * <li>FROM (= min)</li> * <li>TO (= max)</li> * <li>STEP</li> * <li>I - the current value (from 'from' to 'to' with stepsize 'step')</li> * </ul> * * @see MathematicalExpression * @see MathExpression */ protected String m_X_Expression = "I"; /** whether the grid can be extended */ protected boolean m_GridIsExtendable = false; /** maximum number of grid extensions (-1 means unlimited) */ protected int m_MaxGridExtensions = 3; /** the number of extensions performed */ protected int m_GridExtensionsPerformed = 0; /** the sample size to search the initial grid with */ protected double m_SampleSize = 100; /** the traversal */ protected int m_Traversal = TRAVERSAL_BY_COLUMN; /** the log file to use */ protected File m_LogFile = new File(System.getProperty("user.dir")); /** the value-pairs grid */ protected Grid m_Grid; /** the training data */ protected Instances m_Data; /** the cache for points in the grid that got calculated */ protected PerformanceCache m_Cache; /** whether all performances in the grid are the same */ protected boolean m_UniformPerformance = false; /** * the default constructor */ public GridSearch() { super(); // classifier m_Classifier = new LinearRegression(); ((LinearRegression) m_Classifier).setAttributeSelectionMethod(new SelectedTag(LinearRegression.SELECTION_NONE, LinearRegression.TAGS_SELECTION)); ((LinearRegression) m_Classifier).setEliminateColinearAttributes(false); // filter m_Filter = new PLSFilter(); PLSFilter filter = new PLSFilter(); filter.setPreprocessing(new SelectedTag(PLSFilter.PREPROCESSING_STANDARDIZE, PLSFilter.TAGS_PREPROCESSING)); filter.setReplaceMissing(true); try { m_BestClassifier = AbstractClassifier.makeCopy(m_Classifier); } catch (Exception e) { e.printStackTrace(); } try { m_BestFilter = Filter.makeCopy(filter); } catch (Exception e) { e.printStackTrace(); } } /** * Returns a string describing classifier * * @return a description suitable for displaying in the * explorer/experimenter gui */ public String globalInfo() { return "Performs a grid search of parameter pairs for the a classifier " + "(Y-axis, default is LinearRegression with the \"Ridge\" parameter) " + "and the PLSFilter (X-axis, \"# of Components\") and chooses the best " + "pair found for the actual predicting.\n\n" + "The initial grid is worked on with 2-fold CV to determine the values " + "of the parameter pairs for the selected type of evaluation (e.g., " + "accuracy). The best point in the grid is then taken and a 10-fold CV " + "is performed with the adjacent parameter pairs. If a better pair is " + "found, then this will act as new center and another 10-fold CV will " + "be performed (kind of hill-climbing). This process is repeated until " + "no better pair is found or the best pair is on the border of the grid.\n" + "In case the best pair is on the border, one can let GridSearch " + "automatically extend the grid and continue the search. Check out the " + "properties 'gridIsExtendable' (option '-extend-grid') and " + "'maxGridExtensions' (option '-max-grid-extensions <num>').\n\n" + "GridSearch can handle doubles, integers (values are just cast to int) " + "and booleans (0 is false, otherwise true). float, char and long are " + "supported as well.\n\n" + "The best filter/classifier setup can be accessed after the buildClassifier " + "call via the getBestFilter/getBestClassifier methods.\n" + "Note on the implementation: after the data has been passed through " + "the filter, a default NumericCleaner filter is applied to the data in " + "order to avoid numbers that are getting too small and might produce " + "NaNs in other schemes."; } /** * String describing default classifier. * * @return the classname of the default classifier */ protected String defaultClassifierString() { return LinearRegression.class.getName(); } /** * Gets an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions(){ Vector result; Enumeration en; String desc; SelectedTag tag; int i; result = new Vector(); desc = ""; for (i = 0; i < TAGS_EVALUATION.length; i++) { tag = new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION); desc += "\t" + tag.getSelectedTag().getIDStr() + " = " + tag.getSelectedTag().getReadable() + "\n"; } result.addElement(new Option( "\tDetermines the parameter used for evaluation:\n" + desc + "\t(default: " + new SelectedTag(EVALUATION_CC, TAGS_EVALUATION) + ")", "E", 1, "-E " + Tag.toOptionList(TAGS_EVALUATION))); result.addElement(new Option( "\tThe Y option to test (without leading dash).\n" + "\t(default: " + PREFIX_CLASSIFIER + "ridge)", "y-property", 1, "-y-property <option>")); result.addElement(new Option( "\tThe minimum for Y.\n" + "\t(default: -10)", "y-min", 1, "-y-min <num>")); result.addElement(new Option( "\tThe maximum for Y.\n" + "\t(default: +5)", "y-max", 1, "-y-max <num>")); result.addElement(new Option( "\tThe step size for Y.\n" + "\t(default: 1)", "y-step", 1, "-y-step <num>")); result.addElement(new Option( "\tThe base for Y.\n" + "\t(default: 10)", "y-base", 1, "-y-base <num>")); result.addElement(new Option( "\tThe expression for Y.\n" + "\tAvailable parameters:\n" + "\t\tBASE\n" + "\t\tFROM\n" + "\t\tTO\n" + "\t\tSTEP\n" + "\t\tI - the current iteration value\n" + "\t\t(from 'FROM' to 'TO' with stepsize 'STEP')\n" + "\t(default: 'pow(BASE,I)')", "y-expression", 1, "-y-expression <expr>")); result.addElement(new Option( "\tThe filter to use (on X axis). Full classname of filter to include, \n" + "\tfollowed by scheme options.\n" + "\t(default: weka.filters.supervised.attribute.PLSFilter)", "filter", 1, "-filter <filter specification>")); result.addElement(new Option( "\tThe X option to test (without leading dash).\n" + "\t(default: " + PREFIX_FILTER + "numComponents)", "x-property", 1, "-x-property <option>")); result.addElement(new Option( "\tThe minimum for X.\n" + "\t(default: +5)", "x-min", 1, "-x-min <num>")); result.addElement(new Option( "\tThe maximum for X.\n" + "\t(default: +20)", "x-max", 1, "-x-max <num>")); result.addElement(new Option( "\tThe step size for X.\n" + "\t(default: 1)", "x-step", 1, "-x-step <num>")); result.addElement(new Option( "\tThe base for X.\n" + "\t(default: 10)", "x-base", 1, "-x-base <num>")); result.addElement(new Option( "\tThe expression for the X value.\n" + "\tAvailable parameters:\n" + "\t\tBASE\n" + "\t\tMIN\n" + "\t\tMAX\n" + "\t\tSTEP\n" + "\t\tI - the current iteration value\n" + "\t\t(from 'FROM' to 'TO' with stepsize 'STEP')\n" + "\t(default: 'pow(BASE,I)')", "x-expression", 1, "-x-expression <expr>")); result.addElement(new Option( "\tWhether the grid can be extended.\n" + "\t(default: no)", "extend-grid", 0, "-extend-grid")); result.addElement(new Option( "\tThe maximum number of grid extensions (-1 is unlimited).\n" + "\t(default: 3)", "max-grid-extensions", 1, "-max-grid-extensions <num>")); result.addElement(new Option( "\tThe size (in percent) of the sample to search the inital grid with.\n" + "\t(default: 100)", "sample-size", 1, "-sample-size <num>")); result.addElement(new Option( "\tThe type of traversal for the grid.\n" + "\t(default: " + new SelectedTag(TRAVERSAL_BY_COLUMN, TAGS_TRAVERSAL) + ")", "traversal", 1, "-traversal " + Tag.toOptionList(TAGS_TRAVERSAL))); result.addElement(new Option( "\tThe log file to log the messages to.\n" + "\t(default: none)", "log-file", 1, "-log-file <filename>")); en = super.listOptions(); while (en.hasMoreElements()) result.addElement(en.nextElement()); if (getFilter() instanceof OptionHandler) { result.addElement(new Option( "", "", 0, "\nOptions specific to filter " + getFilter().getClass().getName() + " ('-filter'):")); en = ((OptionHandler) getFilter()).listOptions(); while (en.hasMoreElements()) result.addElement(en.nextElement()); } return result.elements(); } /** * returns the options of the current setup * * @return the current options */ public String[] getOptions(){ int i; Vector result; String[] options; result = new Vector(); result.add("-E"); result.add("" + getEvaluation()); result.add("-y-property"); result.add("" + getYProperty()); result.add("-y-min"); result.add("" + getYMin()); result.add("-y-max"); result.add("" + getYMax()); result.add("-y-step"); result.add("" + getYStep()); result.add("-y-base"); result.add("" + getYBase()); result.add("-y-expression"); result.add("" + getYExpression()); result.add("-filter"); if (getFilter() instanceof OptionHandler) result.add( getFilter().getClass().getName() + " " + Utils.joinOptions(((OptionHandler) getFilter()).getOptions())); else result.add( getFilter().getClass().getName()); result.add("-x-property"); result.add("" + getXProperty()); result.add("-x-min"); result.add("" + getXMin()); result.add("-x-max"); result.add("" + getXMax()); result.add("-x-step"); result.add("" + getXStep()); result.add("-x-base"); result.add("" + getXBase()); result.add("-x-expression"); result.add("" + getXExpression()); if (getGridIsExtendable()) { result.add("-extend-grid"); result.add("-max-grid-extensions"); result.add("" + getMaxGridExtensions()); } result.add("-sample-size"); result.add("" + getSampleSizePercent()); result.add("-traversal"); result.add("" + getTraversal()); result.add("-log-file"); result.add("" + getLogFile()); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); return (String[]) result.toArray(new String[result.size()]); } /** * Parses the options for this object. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -E <CC|RMSE|RRSE|MAE|RAE|COMB|ACC|KAP> * Determines the parameter used for evaluation: * CC = Correlation coefficient * RMSE = Root mean squared error * RRSE = Root relative squared error * MAE = Mean absolute error * RAE = Root absolute error * COMB = Combined = (1-abs(CC)) + RRSE + RAE * ACC = Accuracy * KAP = Kappa * (default: CC)</pre> * * <pre> -y-property <option> * The Y option to test (without leading dash). * (default: classifier.ridge)</pre> * * <pre> -y-min <num> * The minimum for Y. * (default: -10)</pre> * * <pre> -y-max <num> * The maximum for Y. * (default: +5)</pre> * * <pre> -y-step <num> * The step size for Y. * (default: 1)</pre> * * <pre> -y-base <num> * The base for Y. * (default: 10)</pre> * * <pre> -y-expression <expr> * The expression for Y. * Available parameters: * BASE * FROM * TO * STEP * I - the current iteration value * (from 'FROM' to 'TO' with stepsize 'STEP') * (default: 'pow(BASE,I)')</pre> * * <pre> -filter <filter specification> * The filter to use (on X axis). Full classname of filter to include, * followed by scheme options. * (default: weka.filters.supervised.attribute.PLSFilter)</pre> * * <pre> -x-property <option> * The X option to test (without leading dash). * (default: filter.numComponents)</pre> * * <pre> -x-min <num> * The minimum for X. * (default: +5)</pre> * * <pre> -x-max <num> * The maximum for X. * (default: +20)</pre> * * <pre> -x-step <num> * The step size for X. * (default: 1)</pre> * * <pre> -x-base <num> * The base for X. * (default: 10)</pre> * * <pre> -x-expression <expr> * The expression for the X value. * Available parameters: * BASE * MIN * MAX * STEP * I - the current iteration value * (from 'FROM' to 'TO' with stepsize 'STEP') * (default: 'pow(BASE,I)')</pre> * * <pre> -extend-grid * Whether the grid can be extended. * (default: no)</pre> * * <pre> -max-grid-extensions <num> * The maximum number of grid extensions (-1 is unlimited). * (default: 3)</pre> * * <pre> -sample-size <num> * The size (in percent) of the sample to search the inital grid with. * (default: 100)</pre> * * <pre> -traversal <ROW-WISE|COLUMN-WISE> * The type of traversal for the grid. * (default: COLUMN-WISE)</pre> * * <pre> -log-file <filename> * The log file to log the messages to. * (default: none)</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.functions.LinearRegression)</pre> * * <pre> * Options specific to classifier weka.classifiers.functions.LinearRegression: * </pre> * * <pre> -D * Produce debugging output. * (default no debugging output)</pre> * * <pre> -S <number of selection method> * Set the attribute selection method to use. 1 = None, 2 = Greedy. * (default 0 = M5' method)</pre> * * <pre> -C * Do not try to eliminate colinear attributes. * </pre> * * <pre> -R <double> * Set ridge parameter (default 1.0e-8). * </pre> * * <pre> * Options specific to filter weka.filters.supervised.attribute.PLSFilter ('-filter'): * </pre> * * <pre> -D * Turns on output of debugging information.</pre> * * <pre> -C <num> * The number of components to compute. * (default: 20)</pre> * * <pre> -U * Updates the class attribute as well. * (default: off)</pre> * * <pre> -M * Turns replacing of missing values on. * (default: off)</pre> * * <pre> -A <SIMPLS|PLS1> * The algorithm to use. * (default: PLS1)</pre> * * <pre> -P <none|center|standardize> * The type of preprocessing that is applied to the data. * (default: center)</pre> * <!-- options-end --> * * @param options the options to use * @throws Exception if setting of options fails */ public void setOptions(String[] options) throws Exception { String tmpStr; String[] tmpOptions; tmpStr = Utils.getOption('E', options); if (tmpStr.length() != 0) setEvaluation(new SelectedTag(tmpStr, TAGS_EVALUATION)); else setEvaluation(new SelectedTag(EVALUATION_CC, TAGS_EVALUATION)); tmpStr = Utils.getOption("y-property", options); if (tmpStr.length() != 0) setYProperty(tmpStr); else setYProperty(PREFIX_CLASSIFIER + "ridge"); tmpStr = Utils.getOption("y-min", options); if (tmpStr.length() != 0) setYMin(Double.parseDouble(tmpStr)); else setYMin(-10); tmpStr = Utils.getOption("y-max", options); if (tmpStr.length() != 0) setYMax(Double.parseDouble(tmpStr)); else setYMax(10); tmpStr = Utils.getOption("y-step", options); if (tmpStr.length() != 0) setYStep(Double.parseDouble(tmpStr)); else setYStep(1); tmpStr = Utils.getOption("y-base", options); if (tmpStr.length() != 0) setYBase(Double.parseDouble(tmpStr)); else setYBase(10); tmpStr = Utils.getOption("y-expression", options); if (tmpStr.length() != 0) setYExpression(tmpStr); else setYExpression("pow(BASE,I)"); tmpStr = Utils.getOption("filter", options); tmpOptions = Utils.splitOptions(tmpStr); if (tmpOptions.length != 0) { tmpStr = tmpOptions[0]; tmpOptions[0] = ""; setFilter((Filter) Utils.forName(Filter.class, tmpStr, tmpOptions)); } tmpStr = Utils.getOption("x-property", options); if (tmpStr.length() != 0) setXProperty(tmpStr); else setXProperty(PREFIX_FILTER + "filters[0].kernel.gamma"); tmpStr = Utils.getOption("x-min", options); if (tmpStr.length() != 0) setXMin(Double.parseDouble(tmpStr)); else setXMin(-10); tmpStr = Utils.getOption("x-max", options); if (tmpStr.length() != 0) setXMax(Double.parseDouble(tmpStr)); else setXMax(10); tmpStr = Utils.getOption("x-step", options); if (tmpStr.length() != 0) setXStep(Double.parseDouble(tmpStr)); else setXStep(1); tmpStr = Utils.getOption("x-base", options); if (tmpStr.length() != 0) setXBase(Double.parseDouble(tmpStr)); else setXBase(10); tmpStr = Utils.getOption("x-expression", options); if (tmpStr.length() != 0) setXExpression(tmpStr); else setXExpression("pow(BASE,I)"); setGridIsExtendable(Utils.getFlag("extend-grid", options)); if (getGridIsExtendable()) { tmpStr = Utils.getOption("max-grid-extensions", options); if (tmpStr.length() != 0) setMaxGridExtensions(Integer.parseInt(tmpStr)); else setMaxGridExtensions(3); } tmpStr = Utils.getOption("sample-size", options); if (tmpStr.length() != 0) setSampleSizePercent(Double.parseDouble(tmpStr)); else setSampleSizePercent(100); tmpStr = Utils.getOption("traversal", options); if (tmpStr.length() != 0) setTraversal(new SelectedTag(tmpStr, TAGS_TRAVERSAL)); else setTraversal(new SelectedTag(TRAVERSAL_BY_ROW, TAGS_TRAVERSAL)); tmpStr = Utils.getOption("log-file", options); if (tmpStr.length() != 0) setLogFile(new File(tmpStr)); else setLogFile(new File(System.getProperty("user.dir"))); super.setOptions(options); } /** * Set the base learner. * * @param newClassifier the classifier to use. */ public void setClassifier(Classifier newClassifier) { boolean numeric; boolean nominal; Capabilities cap = newClassifier.getCapabilities(); numeric = cap.handles(Capability.NUMERIC_CLASS) || cap.hasDependency(Capability.NUMERIC_CLASS); nominal = cap.handles(Capability.NOMINAL_CLASS) || cap.hasDependency(Capability.NOMINAL_CLASS) || cap.handles(Capability.BINARY_CLASS) || cap.hasDependency(Capability.BINARY_CLASS) || cap.handles(Capability.UNARY_CLASS) || cap.hasDependency(Capability.UNARY_CLASS); if ((m_Evaluation == EVALUATION_CC) && !numeric) throw new IllegalArgumentException( "Classifier needs to handle numeric class for chosen type of evaluation!"); if (((m_Evaluation == EVALUATION_ACC) || (m_Evaluation == EVALUATION_KAPPA)) && !nominal) throw new IllegalArgumentException( "Classifier needs to handle nominal class for chosen type of evaluation!"); super.setClassifier(newClassifier); try { m_BestClassifier = AbstractClassifier.makeCopy(m_Classifier); } catch (Exception e) { e.printStackTrace(); } } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String filterTipText() { return "The filter to be used (only used for setup)."; } /** * Set the kernel filter (only used for setup). * * @param value the kernel filter. */ public void setFilter(Filter value) { m_Filter = value; try { m_BestFilter = Filter.makeCopy(m_Filter); } catch (Exception e) { e.printStackTrace(); } } /** * Get the kernel filter. * * @return the kernel filter */ public Filter getFilter() { return m_Filter; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String evaluationTipText() { return "Sets the criterion for evaluating the classifier performance and " + "choosing the best one."; } /** * Sets the criterion to use for evaluating the classifier performance. * * @param value .the evaluation criterion */ public void setEvaluation(SelectedTag value) { if (value.getTags() == TAGS_EVALUATION) { m_Evaluation = value.getSelectedTag().getID(); } } /** * Gets the criterion used for evaluating the classifier performance. * * @return the current evaluation criterion. */ public SelectedTag getEvaluation() { return new SelectedTag(m_Evaluation, TAGS_EVALUATION); } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String YPropertyTipText() { return "The Y property to test (normally the classifier)."; } /** * Get the Y property (normally the classifier). * * @return Value of the property. */ public String getYProperty() { return m_Y_Property; } /** * Set the Y property (normally the classifier). * * @param value the Y property. */ public void setYProperty(String value) { m_Y_Property = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String YMinTipText() { return "The minimum of Y (normally the classifier)."; } /** * Get the value of the minimum of Y. * * @return Value of the minimum of Y. */ public double getYMin() { return m_Y_Min; } /** * Set the value of the minimum of Y. * * @param value Value to use as minimum of Y. */ public void setYMin(double value) { m_Y_Min = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String YMaxTipText() { return "The maximum of Y."; } /** * Get the value of the Maximum of Y. * * @return Value of the Maximum of Y. */ public double getYMax() { return m_Y_Max; } /** * Set the value of the Maximum of Y. * * @param value Value to use as Maximum of Y. */ public void setYMax(double value) { m_Y_Max = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String YStepTipText() { return "The step size of Y."; } /** * Get the value of the step size for Y. * * @return Value of the step size for Y. */ public double getYStep() { return m_Y_Step; } /** * Set the value of the step size for Y. * * @param value Value to use as the step size for Y. */ public void setYStep(double value) { m_Y_Step = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String YBaseTipText() { return "The base of Y."; } /** * Get the value of the base for Y. * * @return Value of the base for Y. */ public double getYBase() { return m_Y_Base; } /** * Set the value of the base for Y. * * @param value Value to use as the base for Y. */ public void setYBase(double value) { m_Y_Base = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String YExpressionTipText() { return "The expression for the Y value (parameters: BASE, FROM, TO, STEP, I)."; } /** * Get the expression for the Y value. * * @return Expression for the Y value. */ public String getYExpression() { return m_Y_Expression; } /** * Set the expression for the Y value. * * @param value Expression for the Y value. */ public void setYExpression(String value) { m_Y_Expression = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String XPropertyTipText() { return "The X property to test (normally the filter)."; } /** * Get the X property to test (normally the filter). * * @return Value of the X property. */ public String getXProperty() { return m_X_Property; } /** * Set the X property. * * @param value the X property. */ public void setXProperty(String value) { m_X_Property = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String XMinTipText() { return "The minimum of X."; } /** * Get the value of the minimum of X. * * @return Value of the minimum of X. */ public double getXMin() { return m_X_Min; } /** * Set the value of the minimum of X. * * @param value Value to use as minimum of X. */ public void setXMin(double value) { m_X_Min = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String XMaxTipText() { return "The maximum of X."; } /** * Get the value of the Maximum of X. * * @return Value of the Maximum of X. */ public double getXMax() { return m_X_Max; } /** * Set the value of the Maximum of X. * * @param value Value to use as Maximum of X. */ public void setXMax(double value) { m_X_Max = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String XStepTipText() { return "The step size of X."; } /** * Get the value of the step size for X. * * @return Value of the step size for X. */ public double getXStep() { return m_X_Step; } /** * Set the value of the step size for X. * * @param value Value to use as the step size for X. */ public void setXStep(double value) { m_X_Step = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String XBaseTipText() { return "The base of X."; } /** * Get the value of the base for X. * * @return Value of the base for X. */ public double getXBase() { return m_X_Base; } /** * Set the value of the base for X. * * @param value Value to use as the base for X. */ public void setXBase(double value) { m_X_Base = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String XExpressionTipText() { return "The expression for the X value (parameters: BASE, FROM, TO, STEP, I)."; } /** * Get the expression for the X value. * * @return Expression for the X value. */ public String getXExpression() { return m_X_Expression; } /** * Set the expression for the X value. * * @param value Expression for the X value. */ public void setXExpression(String value) { m_X_Expression = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String gridIsExtendableTipText() { return "Whether the grid can be extended."; } /** * Get whether the grid can be extended dynamically. * * @return true if the grid can be extended. */ public boolean getGridIsExtendable() { return m_GridIsExtendable; } /** * Set whether the grid can be extended dynamically. * * @param value whether the grid can be extended dynamically. */ public void setGridIsExtendable(boolean value) { m_GridIsExtendable = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String maxGridExtensionsTipText() { return "The maximum number of grid extensions, -1 for unlimited."; } /** * Gets the maximum number of grid extensions, -1 for unlimited. * * @return the max number of grid extensions */ public int getMaxGridExtensions() { return m_MaxGridExtensions; } /** * Sets the maximum number of grid extensions, -1 for unlimited. * * @param value the maximum of grid extensions. */ public void setMaxGridExtensions(int value) { m_MaxGridExtensions = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String sampleSizePercentTipText() { return "The sample size (in percent) to use in the initial grid search."; } /** * Gets the sample size for the initial grid search. * * @return the sample size. */ public double getSampleSizePercent() { return m_SampleSize; } /** * Sets the sample size for the initial grid search. * * @param value the sample size for the initial grid search. */ public void setSampleSizePercent(double value) { m_SampleSize = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String traversalTipText() { return "Sets type of traversal of the grid, either by rows or columns."; } /** * Sets the type of traversal for the grid. * * @param value the traversal type */ public void setTraversal(SelectedTag value) { if (value.getTags() == TAGS_TRAVERSAL) { m_Traversal = value.getSelectedTag().getID(); } } /** * Gets the type of traversal for the grid. * * @return the current traversal type. */ public SelectedTag getTraversal() { return new SelectedTag(m_Traversal, TAGS_TRAVERSAL); } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String logFileTipText() { return "The log file to log the messages to."; } /** * Gets current log file. * * @return the log file. */ public File getLogFile() { return m_LogFile; } /** * Sets the log file to use. * * @param value the log file. */ public void setLogFile(File value) { m_LogFile = value; } /** * returns the best filter setup * * @return the best filter setup */ public Filter getBestFilter() { return m_BestFilter; } /** * returns the best Classifier setup * * @return the best Classifier setup */ public Classifier getBestClassifier() { return m_BestClassifier; } /** * Returns an enumeration of the measure names. * * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector result; result = new Vector(); result.add("measureX"); result.add("measureY"); result.add("measureGridExtensionsPerformed"); return result.elements(); } /** * Returns the value of the named measure * * @param measureName the name of the measure to query for its value * @return the value of the named measure */ public double getMeasure(String measureName) { if (measureName.equalsIgnoreCase("measureX")) return evaluate(getValues().getX(), true); else if (measureName.equalsIgnoreCase("measureY")) return evaluate(getValues().getY(), false); else if (measureName.equalsIgnoreCase("measureGridExtensionsPerformed")) return getGridExtensionsPerformed(); else throw new IllegalArgumentException("Measure '" + measureName + "' not supported!"); } /** * returns the parameter pair that was found to work best * * @return the best parameter combination */ public PointDouble getValues() { return m_Values; } /** * returns the number of grid extensions that took place during the search * (only applicable if the grid was extendable). * * @return the number of grid extensions that were performed * @see #getGridIsExtendable() */ public int getGridExtensionsPerformed() { return m_GridExtensionsPerformed; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result; Capabilities classes; Iterator iter; Capability capab; if (getFilter() == null) result = super.getCapabilities(); else result = getFilter().getCapabilities(); // only nominal and numeric classes allowed classes = result.getClassCapabilities(); iter = classes.capabilities(); while (iter.hasNext()) { capab = (Capability) iter.next(); if ( (capab != Capability.BINARY_CLASS) && (capab != Capability.NOMINAL_CLASS) && (capab != Capability.NUMERIC_CLASS) && (capab != Capability.DATE_CLASS) ) result.disable(capab); } result.enable(Capability.MISSING_CLASS_VALUES); // set dependencies for (Capability cap: Capability.values()) result.enableDependency(cap); if (result.getMinimumNumberInstances() < 1) result.setMinimumNumberInstances(1); result.setOwner(this); return result; } /** * prints the specified message to stdout if debug is on and can also dump * the message to a log file * * @param message the message to print or store in a log file */ protected void log(String message) { log(message, false); } /** * prints the specified message to stdout if debug is on and can also dump * the message to a log file * * @param message the message to print or store in a log file * @param onlyLog if true the message will only be put into the log file * but not to stdout */ protected void log(String message, boolean onlyLog) { // print to stdout? if (getDebug() && (!onlyLog)) System.out.println(message); // log file? if (!getLogFile().isDirectory()) Debug.writeToFile(getLogFile().getAbsolutePath(), message, true); } /** * replaces the current option in the options array with a new value * * @param options the current options * @param option the option to set a new value for * @param value the value to set * @return the updated array * @throws Exception if something goes wrong */ protected String[] updateOption(String[] options, String option, String value) throws Exception { String[] result; Vector tmpOptions; int i; // remove old option Utils.getOption(option, options); // add option with new value at the beginning (to avoid clashes with "--") tmpOptions = new Vector(); tmpOptions.add("-" + option); tmpOptions.add("" + value); // move options into vector for (i = 0; i < options.length; i++) { if (options[i].length() != 0) tmpOptions.add(options[i]); } result = (String[]) tmpOptions.toArray(new String[tmpOptions.size()]); return result; } /** * evalutes the expression for the current iteration * * @param value the current iteration value (from 'min' to 'max' with * stepsize 'step') * @param isX true if X is to be evaluated otherwise Y * @return the generated value, NaN if the evaluation fails */ protected double evaluate(double value, boolean isX) { double result; HashMap symbols; String expr; double base; double min; double max; double step; if (isX) { expr = getXExpression(); base = getXBase(); min = getXMin(); max = getXMax(); step = getXStep(); } else { expr = getYExpression(); base = getYBase(); min = getYMin(); max = getYMax(); step = getYStep(); } try { symbols = new HashMap(); symbols.put("BASE", new Double(base)); symbols.put("FROM", new Double(min)); symbols.put("TO", new Double(max)); symbols.put("STEP", new Double(step)); symbols.put("I", new Double(value)); result = MathematicalExpression.evaluate(expr, symbols); } catch (Exception e) { result = Double.NaN; } return result; } /** * tries to set the value as double, integer (just casts it to int!) or * boolean (false if 0, otherwise true) in the object according to the * specified path. float, char and long are also supported. * * @param o the object to modify * @param path the property path * @param value the value to set * @return the modified object * @throws Exception if neither double nor int could be set */ protected Object setValue(Object o, String path, double value) throws Exception { PropertyDescriptor desc; Class c; desc = PropertyPath.getPropertyDescriptor(o, path); c = desc.getPropertyType(); // float if ((c == Float.class) || (c == Float.TYPE)) PropertyPath.setValue(o, path, new Float((float) value)); // double else if ((c == Double.class) || (c == Double.TYPE)) PropertyPath.setValue(o, path, new Double(value)); // char else if ((c == Character.class) || (c == Character.TYPE)) PropertyPath.setValue(o, path, new Integer((char) value)); // int else if ((c == Integer.class) || (c == Integer.TYPE)) PropertyPath.setValue(o, path, new Integer((int) value)); // long else if ((c == Long.class) || (c == Long.TYPE)) PropertyPath.setValue(o, path, new Long((long) value)); // boolean else if ((c == Boolean.class) || (c == Boolean.TYPE)) PropertyPath.setValue(o, path, (value == 0 ? new Boolean(false) : new Boolean(true))); else throw new Exception( "Could neither set double nor integer nor boolean value for '" + path + "'!"); return o; } /** * returns a fully configures object (a copy of the provided one) * * @param original the object to create a copy from and set the parameters * @param valueX the current iteration value for X * @param valueY the current iteration value for Y * @return the configured classifier * @throws Exception if setup fails */ protected Object setup(Object original, double valueX, double valueY) throws Exception { Object result; result = new SerializedObject(original).getObject(); if (original instanceof Classifier) { if (getXProperty().startsWith(PREFIX_CLASSIFIER)) setValue( result, getXProperty().substring(PREFIX_CLASSIFIER.length()), valueX); if (getYProperty().startsWith(PREFIX_CLASSIFIER)) setValue( result, getYProperty().substring(PREFIX_CLASSIFIER.length()), valueY); } else if (original instanceof Filter) { if (getXProperty().startsWith(PREFIX_FILTER)) setValue( result, getXProperty().substring(PREFIX_FILTER.length()), valueX); if (getYProperty().startsWith(PREFIX_FILTER)) setValue( result, getYProperty().substring(PREFIX_FILTER.length()), valueY); } else { throw new IllegalArgumentException("Object must be either classifier or filter!"); } return result; } /** * generates a table string for all the performances in the grid and returns * that. * * @param grid the current grid to align the performances to * @param performances the performances to align * @param type the type of performance * @return the table string */ protected String logPerformances(Grid grid, Vector<Performance> performances, Tag type) { StringBuffer result; PerformanceTable table; result = new StringBuffer(type.getReadable() + ":\n"); table = new PerformanceTable(grid, performances, type.getID()); result.append(table.toString() + "\n"); result.append("\n"); result.append(table.toGnuplot() + "\n"); result.append("\n"); return result.toString(); } /** * aligns all performances in the grid and prints those tables to the log * file. * * @param grid the current grid to align the performances to * @param performances the performances to align */ protected void logPerformances(Grid grid, Vector performances) { int i; for (i = 0; i < TAGS_EVALUATION.length; i++) log("\n" + logPerformances(grid, performances, TAGS_EVALUATION[i]), true); } /** * determines the best values-pair for the given grid, using CV with * specified number of folds. * * @param grid the grid to work on * @param inst the data to work with * @param cv the number of folds for the cross-validation * @return the best values pair * @throws Exception if setup or training fails */ protected PointDouble determineBestInGrid(Grid grid, Instances inst, int cv) throws Exception { int i; Enumeration<PointDouble> enm; Vector<Performance> performances; PointDouble values; Instances data; Evaluation eval; PointDouble result; Classifier classifier; Filter filter; int size; boolean cached; boolean allCached; Performance p1; Performance p2; double x; double y; performances = new Vector(); log("Determining best pair with " + cv + "-fold CV in Grid:\n" + grid + "\n"); if (m_Traversal == TRAVERSAL_BY_COLUMN) size = grid.width(); else size = grid.height(); allCached = true; for (i = 0; i < size; i++) { if (m_Traversal == TRAVERSAL_BY_COLUMN) enm = grid.column(i); else enm = grid.row(i); filter = null; data = null; while (enm.hasMoreElements()) { values = enm.nextElement(); // already calculated? cached = m_Cache.isCached(cv, values); if (cached) { performances.add(m_Cache.get(cv, values)); } else { allCached = false; x = evaluate(values.getX(), true); y = evaluate(values.getY(), false); // data pass through filter if (filter == null) { filter = (Filter) setup(getFilter(), x, y); filter.setInputFormat(inst); data = Filter.useFilter(inst, filter); // make sure that the numbers don't get too small - otherwise NaNs! Filter cleaner = new NumericCleaner(); cleaner.setInputFormat(data); data = Filter.useFilter(data, cleaner); } // setup classifier classifier = (Classifier) setup(getClassifier(), x, y); // evaluate eval = new Evaluation(data); eval.crossValidateModel(classifier, data, cv, new Random(getSeed())); performances.add(new Performance(values, eval)); // add to cache m_Cache.add(cv, new Performance(values, eval)); } log("" + performances.get(performances.size() - 1) + ": cached=" + cached); } } if (allCached) { log("All points were already cached - abnormal state!"); throw new IllegalStateException("All points were already cached - abnormal state!"); } // sort list Collections.sort(performances, new PerformanceComparator(m_Evaluation)); result = performances.get(performances.size() - 1).getValues(); // check whether all performances are the same m_UniformPerformance = true; p1 = performances.get(0); for (i = 1; i < performances.size(); i++) { p2 = performances.get(i); if (p2.getPerformance(m_Evaluation) != p1.getPerformance(m_Evaluation)) { m_UniformPerformance = false; break; } } if (m_UniformPerformance) log("All performances are the same!"); logPerformances(grid, performances); log("\nBest performance:\n" + performances.get(performances.size() - 1)); return result; } /** * returns the best values-pair in the grid * * @return the best values pair * @throws Exception if something goes wrong */ protected PointDouble findBest() throws Exception { PointInt center; Grid neighborGrid; boolean finished; PointDouble result; PointDouble resultOld; int iteration; Instances sample; Resample resample; log("Step 1:\n"); // generate sample? if (getSampleSizePercent() == 100) { sample = m_Data; } else { log("Generating sample (" + getSampleSizePercent() + "%)"); resample = new Resample(); resample.setRandomSeed(getSeed()); resample.setSampleSizePercent(getSampleSizePercent()); resample.setInputFormat(m_Data); sample = Filter.useFilter(m_Data, resample); } finished = false; iteration = 0; m_GridExtensionsPerformed = 0; m_UniformPerformance = false; // find first center log("\n=== Initial grid - Start ==="); result = determineBestInGrid(m_Grid, sample, 2); log("\nResult of Step 1: " + result + "\n"); log("=== Initial grid - End ===\n"); finished = m_UniformPerformance; if (!finished) { do { iteration++; resultOld = (PointDouble) result.clone(); center = m_Grid.getLocation(result); // on border? -> finished (if it cannot be extended) if (m_Grid.isOnBorder(center)) { log("Center is on border of grid."); // can we extend grid? if (getGridIsExtendable()) { // max number of extensions reached? if (m_GridExtensionsPerformed == getMaxGridExtensions()) { log("Maximum number of extensions reached!\n"); finished = true; } else { m_GridExtensionsPerformed++; m_Grid = m_Grid.extend(result); center = m_Grid.getLocation(result); log("Extending grid (" + m_GridExtensionsPerformed + "/" + getMaxGridExtensions() + "):\n" + m_Grid + "\n"); } } else { finished = true; } } // new grid with current best one at center and immediate neighbors // around it if (!finished) { neighborGrid = m_Grid.subgrid( (int) center.getY() + 1, (int) center.getX() - 1, (int) center.getY() - 1, (int) center.getX() + 1); result = determineBestInGrid(neighborGrid, sample, 10); log("\nResult of Step 2/Iteration " + (iteration) + ":\n" + result); finished = m_UniformPerformance; // no improvement? if (result.equals(resultOld)) { finished = true; log("\nNo better point found."); } } } while (!finished); } log("\nFinal result: " + result); return result; } /** * builds the classifier * * @param data the training instances * @throws Exception if something goes wrong */ public void buildClassifier(Instances data) throws Exception { String strX; String strY; double x; double y; // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class m_Data = new Instances(data); m_Data.deleteWithMissingClass(); m_Cache = new PerformanceCache(); if (getXProperty().startsWith(PREFIX_FILTER)) strX = m_Filter.getClass().getName(); else strX = m_Classifier.getClass().getName(); if (getYProperty().startsWith(PREFIX_CLASSIFIER)) strY = m_Classifier.getClass().getName(); else strY = m_Filter.getClass().getName(); m_Grid = new Grid(getXMin(), getXMax(), getXStep(), strX + ", property " + getXProperty() + ", expr. " + getXExpression() + ", base " + getXBase(), getYMin(), getYMax(), getYStep(), strY + ", property " + getYProperty() + ", expr. " + getYExpression() + ", base " + getYBase()); log("\n" + this.getClass().getName() + "\n" + this.getClass().getName().replaceAll(".", "=") + "\n" + "Options: " + Utils.joinOptions(getOptions()) + "\n"); // find best m_Values = findBest(); // setup best configurations x = evaluate(m_Values.getX(), true); y = evaluate(m_Values.getY(), false); m_BestFilter = (Filter) setup(getFilter(), x, y); m_BestClassifier = (Classifier) setup(getClassifier(), x, y); // process data m_Filter = (Filter) setup(getFilter(), x, y); m_Filter.setInputFormat(m_Data); Instances transformed = Filter.useFilter(m_Data, m_Filter); // train classifier m_Classifier = (Classifier) setup(getClassifier(), x, y); m_Classifier.buildClassifier(transformed); } /** * Classifies the given instance. * * @param instance the test instance * @return the classification * @throws Exception if classification can't be done successfully */ public double classifyInstance(Instance instance) throws Exception { // transform instance m_Filter.input(instance); m_Filter.batchFinished(); Instance transformed = m_Filter.output(); // classify instance return m_Classifier.classifyInstance(transformed); } /** * returns a string representation of the classifier * * @return a string representation of the classifier */ public String toString() { String result; result = ""; if (m_Values == null) { result = "No search performed yet."; } else { result = this.getClass().getName() + ":\n" + "Filter: " + getFilter().getClass().getName() + (getFilter() instanceof OptionHandler ? " " + Utils.joinOptions(((OptionHandler) getFilter()).getOptions()) : "") + "\n" + "Classifier: " + getClassifier().getClass().getName() + " " + Utils.joinOptions(((OptionHandler)getClassifier()).getOptions()) + "\n\n" + "X property: " + getXProperty() + "\n" + "Y property: " + getYProperty() + "\n\n" + "Evaluation: " + getEvaluation().getSelectedTag().getReadable() + "\n" + "Coordinates: " + getValues() + "\n"; if (getGridIsExtendable()) result += "Grid-Extensions: " + getGridExtensionsPerformed() + "\n"; result += "Values: " + evaluate(getValues().getX(), true) + " (X coordinate)" + ", " + evaluate(getValues().getY(), false) + " (Y coordinate)" + "\n\n" + m_Classifier.toString(); } return result; } /** * Returns a string that summarizes the object. * * @return the object summarized as a string */ public String toSummaryString() { String result; result = "Best filter: " + getBestFilter().getClass().getName() + (getBestFilter() instanceof OptionHandler ? " " + Utils.joinOptions(((OptionHandler) getBestFilter()).getOptions()) : "") + "\n" + "Best classifier: " + getBestClassifier().getClass().getName() + " " + Utils.joinOptions(((OptionHandler)getBestClassifier()).getOptions()); return result; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 5928 $"); } /** * Main method for running this classifier from commandline. * * @param args the options */ public static void main(String[] args) { runClassifier(new GridSearch(), args); } }