package com.tlabs.speechalyzer.classifier; import java.io.BufferedInputStream; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileReader; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.StringReader; import java.util.Random; import org.apache.log4j.Logger; import com.felix.util.FileUtil; import com.felix.util.KeyValues; import com.felix.util.Util; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.bayes.NaiveBayes; import weka.classifiers.functions.SMO; import weka.classifiers.trees.J48; import weka.core.Instance; import weka.core.Instances; public class WEKAClassifier implements IClassifier { private String _trainFileName; private String _testFileName; private String _modelFileName; private Logger _logger; private KeyValues _config; private Categories _categories; private Classifier _classifier = null; private String _classifierType; public String getInfo() { return "WEKA classifier with model: " + _classifierType; } public String getModelFileName() { return _modelFileName; } public WEKAClassifier(KeyValues config) { _config = config; _logger = Logger .getLogger("com.tlabs.speechalyzer.classifier.WEKAClassifier"); _classifierType = _config.getString("classifier"); try { _trainFileName = _config.getAbsPath("trainFile"); _testFileName = _config.getAbsPath("testFile"); _modelFileName = FileUtil.addNamePart(_config .getAbsPath("modelFile"), "_" + _classifierType); _categories = new Categories(_config.getString("categories")); loadModel(); } catch (Exception e) { _logger.error(e.getMessage()); e.printStackTrace(); } } public void setClassifierType(String classifierType) { if (_classifierType.compareTo(classifierType) == 0) return; _classifierType = classifierType; _modelFileName = FileUtil.addNamePart(_config.getAbsPath("modelFile"), "_" + _classifierType); loadModel(); } public ClassificationResult classify() { try { String input = FileUtil.getFileText(_testFileName); Instances classifyInsts = new Instances(new StringReader(input)); Instance instance = classifyInsts.instance(0); int classIndex = classifyInsts.numAttributes() - 1; classifyInsts.setClassIndex(classIndex); instance.setClassMissing(); double[] res = _classifier.distributionForInstance(instance); ClassificationResult cr = new ClassificationResult(); String[] catArray = _categories.getCategoryArray(); for (int i = 0; i < catArray.length; i++) { cr.addResult(_categories.getCategoryArray()[i], res[i]); } double cls = _classifier.classifyInstance(instance); instance.setClassValue(cls); String resId = instance.stringValue(classIndex); System.out.println("erg: " + resId + ", " + cls); return cr; } catch (Exception e) { _logger .error("ERROR classifying. Perhaps model didn't fit test? : " + e.getMessage()); e.printStackTrace(); } return null; } public void trainModel() { try { _logger.info("training model ..."); BufferedReader trainReader = new BufferedReader(new FileReader( _trainFileName));// File with // text // examples Instances trainInsts = new Instances(trainReader); trainInsts.setClassIndex(trainInsts.numAttributes() - 1); if (_classifier == null) { if (_classifierType.compareTo("smo") == 0) _classifier = new SMO(); else if (_classifierType.compareTo("naiveBayes") == 0) _classifier = new NaiveBayes(); else if (_classifierType.compareTo("j48") == 0) _classifier = new J48(); else _logger.error("no/wrong classifier"); } _classifier.buildClassifier(trainInsts); _logger.info("training model finished"); ObjectOutputStream oos = null; try { oos = new ObjectOutputStream(new FileOutputStream( _modelFileName)); oos.writeObject(_classifier); } catch (Exception e) { e.printStackTrace(); } _logger.info("model saved to file: " + _modelFileName); } catch (Exception e) { _logger.error(e.getMessage()); e.printStackTrace(); } } public String evaluate() { _logger.info("evaluating..."); try { BufferedReader trainReader = new BufferedReader(new FileReader( _config.getFileHandler("trainFile")));// File with // examples Instances trainInsts = new Instances(trainReader); trainInsts.setClassIndex(trainInsts.numAttributes() - 1); Evaluation eval = new Evaluation(trainInsts); eval.crossValidateModel(_classifier, trainInsts, 10, new Random()); // return _classifier.toString()+"\n" + eval.toSummaryString(); return eval.toSummaryString()+eval.toMatrixString(); } catch (Exception e) { _logger.error(e.getMessage()); e.printStackTrace(); } return ""; } public void loadModel(String filePath) { _modelFileName = filePath; loadModel(); } /** * Load the model specified in configuration. */ public void loadModel() { _logger.info("loading model from file " + _modelFileName + "..."); ObjectInputStream ois = null; try { ois = new ObjectInputStream(new BufferedInputStream( new FileInputStream(_modelFileName))); _classifier = (Classifier) ois.readObject(); } catch (Exception e) { reportError("problem opening classifier model: " + e.getMessage()); } } public void reportError(String mesg) { System.err.println(mesg); _logger.error(mesg); } }