/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * M5PExample.java * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand * */ package wekaexamples.classifiers; import weka.classifiers.Classifier; import weka.classifiers.trees.M5P; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; import weka.core.SerializationHelper; import weka.experiment.InstanceQuery; import java.util.Vector; /** * * * @author FracPete (fracpete at waikato dot ac dot nz) * @version $Revision$ */ public class M5PExample { public final static String FILENAME = "/some/where/m5pexample.save"; public final static String URL = "jdbc_url"; public final static String USER = "the_user"; public final static String PASSWORD = "the_password"; public void train() throws Exception { System.out.println("Training..."); // load training data from database InstanceQuery query = new InstanceQuery(); query.setDatabaseURL(URL); query.setUsername(USER); query.setPassword(PASSWORD); query.setQuery("select * from some_table"); Instances data = query.retrieveInstances(); data.setClassIndex(13); // train M5P M5P cl = new M5P(); // further options... cl.buildClassifier(data); // save model + header Vector v = new Vector(); v.add(cl); v.add(new Instances(data, 0)); SerializationHelper.write(FILENAME, v); System.out.println("Training finished!"); } public void predict() throws Exception { System.out.println("Predicting..."); // load data from database that needs predicting InstanceQuery query = new InstanceQuery(); query.setDatabaseURL(URL); query.setUsername(USER); query.setPassword(PASSWORD); query.setQuery("select * from some_table"); // retrieves the same table only for simplicty reasons. Instances data = query.retrieveInstances(); data.setClassIndex(14); // read model and header Vector v = (Vector) SerializationHelper.read(FILENAME); Classifier cl = (Classifier) v.get(0); Instances header = (Instances) v.get(1); // output predictions System.out.println("actual -> predicted"); for (int i = 0; i < data.numInstances(); i++) { Instance curr = data.instance(i); // create an instance for the classifier that fits the training data // Instances object returned here might differ slightly from the one // used during training the classifier, e.g., different order of // nominal values, different number of attributes. Instance inst = new DenseInstance(header.numAttributes()); inst.setDataset(header); for (int n = 0; n < header.numAttributes(); n++) { Attribute att = data.attribute(header.attribute(n).name()); // original attribute is also present in the current dataset if (att != null) { if (att.isNominal()) { // is this label also in the original data? // Note: // "numValues() > 0" is only used to avoid problems with nominal // attributes that have 0 labels, which can easily happen with // data loaded from a database if ((header.attribute(n).numValues() > 0) && (att.numValues() > 0)) { String label = curr.stringValue(att); int index = header.attribute(n).indexOfValue(label); if (index != -1) inst.setValue(n, index); } } else if (att.isNumeric()) { inst.setValue(n, curr.value(att)); } else { throw new IllegalStateException("Unhandled attribute type!"); } } } // predict class double pred = cl.classifyInstance(inst); System.out.println(inst.classValue() + " -> " + pred); } System.out.println("Predicting finished!"); } public static void main(String[] args) throws Exception { M5PExample m = new M5PExample(); m.train(); m.predict(); } }