/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * MessageClassifier.java * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand * */ package wekaexamples.book; import weka.classifiers.Classifier; import weka.classifiers.trees.J48; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; import weka.core.Instances; import weka.core.SerializationHelper; import weka.core.Utils; import weka.filters.Filter; import weka.filters.unsupervised.attribute.StringToWordVector; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.Serializable; import java.util.ArrayList; /** * Java program for classifying short text messages into two classes 'miss' * and 'hit'. * <p/> * See also wiki article <a href="http://weka.wiki.sourceforge.net/MessageClassifier">MessageClassifier</a>. * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision$ */ public class MessageClassifier implements Serializable { /** for serialization. */ private static final long serialVersionUID = -123455813150452885L; /** The training data gathered so far. */ private Instances m_Data = null; /** The filter used to generate the word counts. */ private StringToWordVector m_Filter = new StringToWordVector(); /** The actual classifier. */ private Classifier m_Classifier = new J48(); /** Whether the model is up to date. */ private boolean m_UpToDate; /** * Constructs empty training dataset. */ public MessageClassifier() { String nameOfDataset = "MessageClassificationProblem"; // Create vector of attributes. ArrayList<Attribute> attributes = new ArrayList<Attribute>(2); // Add attribute for holding messages. attributes.add(new Attribute("Message", (ArrayList<String>) null)); // Add class attribute. ArrayList<String> classValues = new ArrayList<String>(2); classValues.add("miss"); classValues.add("hit"); attributes.add(new Attribute("Class", classValues)); // Create dataset with initial capacity of 100, and set index of class. m_Data = new Instances(nameOfDataset, attributes, 100); m_Data.setClassIndex(m_Data.numAttributes() - 1); } /** * Updates model using the given training message. * * @param message the message content * @param classValue the class label */ public void updateData(String message, String classValue) { // Make message into instance. Instance instance = makeInstance(message, m_Data); // Set class value for instance. instance.setClassValue(classValue); // Add instance to training data. m_Data.add(instance); m_UpToDate = false; } /** * Classifies a given message. * * @param message the message content * @throws Exception if classification fails */ public void classifyMessage(String message) throws Exception { // Check whether classifier has been built. if (m_Data.numInstances() == 0) throw new Exception("No classifier available."); // Check whether classifier and filter are up to date. if (!m_UpToDate) { // Initialize filter and tell it about the input format. m_Filter.setInputFormat(m_Data); // Generate word counts from the training data. Instances filteredData = Filter.useFilter(m_Data, m_Filter); // Rebuild classifier. m_Classifier.buildClassifier(filteredData); m_UpToDate = true; } // Make separate little test set so that message // does not get added to string attribute in m_Data. Instances testset = m_Data.stringFreeStructure(); // Make message into test instance. Instance instance = makeInstance(message, testset); // Filter instance. m_Filter.input(instance); Instance filteredInstance = m_Filter.output(); // Get index of predicted class value. double predicted = m_Classifier.classifyInstance(filteredInstance); // Output class value. System.err.println("Message classified as : " + m_Data.classAttribute().value((int) predicted)); } /** * Method that converts a text message into an instance. * * @param text the message content to convert * @param data the header information * @return the generated Instance */ private Instance makeInstance(String text, Instances data) { // Create instance of length two. Instance instance = new DenseInstance(2); // Set value for message attribute Attribute messageAtt = data.attribute("Message"); instance.setValue(messageAtt, messageAtt.addStringValue(text)); // Give instance access to attribute information from the dataset. instance.setDataset(data); return instance; } /** * Main method. The following parameters are recognized: * <ul> * <li> * <code>-m messagefile</code><br/> * Points to the file containing the message to classify or use for * updating the model. * </li> * <li> * <code>-c classlabel</code><br/> * The class label of the message if model is to be updated. Omit for * classification of a message. * </li> * <li> * <code>-t modelfile</code><br/> * The file containing the model. If it doesn't exist, it will be * created automatically. * </li> * </ul> * * @param args the commandline options */ public static void main(String[] args) { try { // Read message file into string. String messageName = Utils.getOption('m', args); if (messageName.length() == 0) throw new Exception("Must provide name of message file ('-m <file>')."); FileReader m = new FileReader(messageName); StringBuffer message = new StringBuffer(); int l; while ((l = m.read()) != -1) message.append((char) l); m.close(); // Check if class value is given. String classValue = Utils.getOption('c', args); // If model file exists, read it, otherwise create new one. String modelName = Utils.getOption('t', args); if (modelName.length() == 0) throw new Exception("Must provide name of model file ('-t <file>')."); MessageClassifier messageCl; try { messageCl = (MessageClassifier) SerializationHelper.read(modelName); } catch (FileNotFoundException e) { messageCl = new MessageClassifier(); } // Check if there are any options left Utils.checkForRemainingOptions(args); // Process message. if (classValue.length() != 0) messageCl.updateData(message.toString(), classValue); else messageCl.classifyMessage(message.toString()); // Save message classifier object only if it was updated. if (classValue.length() != 0) SerializationHelper.write(modelName, messageCl); } catch (Exception e) { e.printStackTrace(); } } }