package edu.isistan.daclassifier; import java.io.BufferedInputStream; import java.io.BufferedWriter; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileReader; import java.io.FileWriter; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.StringBufferInputStream; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.Set; import edu.isistan.daclassifier.output.DomainActionNode; import mulan.classifier.InvalidDataException; import mulan.classifier.MultiLabelLearner; import mulan.classifier.MultiLabelOutput; import mulan.classifier.meta.HMC; import mulan.classifier.transformation.LabelPowerset; import mulan.data.LabelNode; import mulan.data.LabelsBuilder; import mulan.data.LabelsMetaData; import mulan.data.MultiLabelInstances; import weka.classifiers.functions.SMO; import weka.classifiers.functions.supportVector.Kernel; import weka.classifiers.functions.supportVector.RBFKernel; import weka.core.Instance; import weka.core.Instances; import weka.filters.Filter; import weka.filters.unsupervised.attribute.StringToWordVector; @SuppressWarnings("deprecation") public class MachineClassifier { public Instances instances; public Instances finstances; public MultiLabelInstances multilabelinstances; public LabelsMetaData labelsmetadata; public MultiLabelLearner learner; private boolean debugEnabled = false; public MachineClassifier() { setDebugEnabled(false); } public void setDebugEnabled(boolean debugEnabled) { this.debugEnabled = debugEnabled; } public boolean isDebugEnabled() { return debugEnabled; } public void loadFullInstancesFromCSV(String[] filenames, String xmlfilename) { try { instances = ArffGenerator.readFromCSV(filenames); Instances clone = new Instances(instances); finstances = Filter.useFilter(clone, MachineLearner.getWordFilter(clone)); multilabelinstances = new MultiLabelInstances(new StringBufferInputStream(finstances.toString()), new BufferedInputStream(new FileInputStream(xmlfilename))); labelsmetadata = LabelsBuilder.createLabels(xmlfilename); } catch (Exception e) { e.printStackTrace(); } } public void loadSubsetInstancesFromCSV(String[] filenames, String xmlfilename, int percentage) { try { Instances dataset = ArffGenerator.readFromCSV(filenames); dataset.randomize(new Random(1)); int capacity = dataset.size(); double newPercentage = ((double) percentage) / 100d; int newCapacity = (int) (newPercentage * capacity); instances = new Instances(dataset, newCapacity); for(int index = 0; index < newCapacity; index++) instances.add(dataset.get(index)); Instances clone = new Instances(instances); finstances = Filter.useFilter(clone, MachineLearner.getWordFilter(clone)); multilabelinstances = new MultiLabelInstances(new StringBufferInputStream(finstances.toString()), new BufferedInputStream(new FileInputStream(xmlfilename))); labelsmetadata = LabelsBuilder.createLabels(xmlfilename); } catch (Exception e) { e.printStackTrace(); } } public void loadInstances(String filename, String xmlfilename) { try { instances = new Instances(new FileReader(filename)); finstances = Filter.useFilter(instances, MachineLearner.getWordFilter(instances)); multilabelinstances = new MultiLabelInstances(new StringBufferInputStream(finstances.toString()), new BufferedInputStream(new FileInputStream(xmlfilename))); labelsmetadata = LabelsBuilder.createLabels(xmlfilename); } catch (Exception e) { e.printStackTrace(); } } public void loadModel(String modelfilepath) throws Exception { File modelFile = new File(modelfilepath); if(!modelFile.exists()) throw new Exception("Model does not exist"); ObjectInputStream modelReader = new ObjectInputStream(new FileInputStream(modelFile)); learner = (MultiLabelLearner) modelReader.readObject(); modelReader.close(); } public void saveArff(String sourcefilepath, String filterfilepath) throws InvalidDataException, Exception { // Saving unfiltered source file File sourceFile = new File(sourcefilepath); if(!sourceFile.exists()) sourceFile.createNewFile(); BufferedWriter sourceWriter = new BufferedWriter(new FileWriter(sourceFile)); sourceWriter.write(instances.toString()); sourceWriter.close(); // Saving filtered source file File filterFile = new File(filterfilepath); if(!filterFile.exists()) filterFile.createNewFile(); BufferedWriter filterWriter = new BufferedWriter(new FileWriter(filterFile)); filterWriter.write(finstances.toString()); filterWriter.close(); } public void saveModel(String modelfilepath) throws InvalidDataException, Exception { // Saving model file File modelFile = new File(modelfilepath); if(!modelFile.exists()) modelFile.createNewFile(); ObjectOutputStream modelWriter = new ObjectOutputStream(new FileOutputStream(modelFile)); modelWriter.writeObject(learner); modelWriter.close(); } public void trainModel() { try { double cValue = 1; double gammaValue = -5; Kernel kernelValue = new RBFKernel(); double c = Math.pow(2, cValue); double gamma = Math.pow(2, gammaValue); // SMO smo = new SMO(); smo.setKernel(kernelValue); smo.setC(c); ((RBFKernel) kernelValue).setGamma(gamma); learner = new HMC(new LabelPowerset(smo)); learner.build(multilabelinstances); } catch (Exception e) { e.printStackTrace(); } } private MultiLabelOutput classifyPredicateInternal(String p, String p_desc, String a0, String a0_desc, String a1, String a1_desc, String a2, String a2_desc) throws Exception { Instances testInstances = ArffGenerator.generateTrainingSet(); Instance testInstance = ArffGenerator.generateTestInstance(p, p_desc, a0, a0_desc, a1, a1_desc, a2, a2_desc); testInstances.add(testInstance); // StringToWordVector filter = MachineLearner.getWordFilter(instances); @SuppressWarnings("unused") Instances fsource = Filter.useFilter(instances, filter); Instances fTestInstances = Filter.useFilter(testInstances, filter); for(int i = 0; i < fTestInstances.size(); i++) fTestInstances.get(i).setDataset(fTestInstances); Instance fTestInstance = fTestInstances.get(0); // MultiLabelOutput output = learner.makePrediction(fTestInstance); // if(debugEnabled) { String prettyPrint = prettyPrint(testInstance, fTestInstance, output); System.out.println(prettyPrint); } // return output; } public List<DomainActionNode> classifyPredicate(String p, String p_desc, String a0, String a0_desc, String a1, String a1_desc, String a2, String a2_desc) throws Exception { List<DomainActionNode> rootDomainActions = new ArrayList<DomainActionNode>(); List<DomainActionNode> domainActions = new ArrayList<DomainActionNode>(); int[] labelIndices = multilabelinstances.getLabelIndices(); // MultiLabelOutput output = classifyPredicateInternal(p, p_desc, a0, a0_desc, a1, a1_desc, a2, a2_desc); boolean[] bipartitions = output.getBipartition(); double[] confidences = output.getConfidences(); int[] rankings = output.getRanking(); int size = bipartitions.length; // for(int index = 0; index < size; index++) { boolean bipartition = bipartitions[index]; if(bipartition) { String label = multilabelinstances.getDataSet().attribute(labelIndices[index]).name(); double confidence = confidences[index]; int ranking = rankings[index]; DomainActionNode domainAction = new DomainActionNode(label, confidence, ranking); if(isRoot(domainAction)) rootDomainActions.add(domainAction); else addToParent(domainActions, domainAction); domainActions.add(domainAction); } } // return rootDomainActions; } private void addToParent(List<DomainActionNode> domainActions, DomainActionNode domainAction) { String parentLabel = findParent(domainAction); DomainActionNode parentDomainAction = null; for(DomainActionNode dAction : domainActions) { if(dAction.getLabel().equalsIgnoreCase(parentLabel)) parentDomainAction = dAction; } parentDomainAction.getChildrens().add(domainAction); domainAction.setParent(parentDomainAction); } private String findParent(DomainActionNode domainAction) { String label = domainAction.getLabel(); Set<LabelNode> nodes = labelsmetadata.getRootLabels(); return findParent(label, nodes); } private String findParent(String label, Set<LabelNode> nodes) { if(nodes == null) return null; else { String parent = null; Iterator<LabelNode> iterator = nodes.iterator(); while(iterator.hasNext() && parent == null) { LabelNode node = iterator.next(); parent = findParent(label, node.getChildren()); if(parent == null && node.getName().equalsIgnoreCase(label)) parent = node.getParent().getName(); } return parent; } } private boolean isRoot(DomainActionNode domainAction) { Set<LabelNode> rootLabels = labelsmetadata.getRootLabels(); for(LabelNode rootLabel : rootLabels) if(rootLabel.getName().equalsIgnoreCase(domainAction.getLabel())) return true; return false; } private String prettyPrint(Instance instance, Instance finstance, MultiLabelOutput output) { StringBuffer stringBuffer = new StringBuffer(); int[] labelIndices = multilabelinstances.getLabelIndices(); boolean[] bipartitions = output.getBipartition(); double[] confidences = output.getConfidences(); int[] rankings = output.getRanking(); // stringBuffer.append("P: " + instance.stringValue(ArffGenerator.attributes.get(ArffGenerator.sP))); stringBuffer.append("\n"); stringBuffer.append("P_DESC: " + instance.stringValue(ArffGenerator.attributes.get(ArffGenerator.sP_DESC))); stringBuffer.append("\n"); stringBuffer.append("A0: " + instance.stringValue(ArffGenerator.attributes.get(ArffGenerator.sA0))); stringBuffer.append("\n"); stringBuffer.append("A0_DESC: " + instance.stringValue(ArffGenerator.attributes.get(ArffGenerator.sA0_DESC))); stringBuffer.append("\n"); stringBuffer.append("A1: " + instance.stringValue(ArffGenerator.attributes.get(ArffGenerator.sA1))); stringBuffer.append("\n"); stringBuffer.append("A1_DESC: " + instance.stringValue(ArffGenerator.attributes.get(ArffGenerator.sA1_DESC))); stringBuffer.append("\n"); stringBuffer.append("A2: " + instance.stringValue(ArffGenerator.attributes.get(ArffGenerator.sA2))); stringBuffer.append("\n"); stringBuffer.append("A2_DESC: " + instance.stringValue(ArffGenerator.attributes.get(ArffGenerator.sA2_DESC))); stringBuffer.append("\n"); // for(int index = 0; index < bipartitions.length; index++) { boolean bipartition = bipartitions[index]; if(bipartition) { double confidence = confidences[index]; int ranking = rankings[index]; String label = multilabelinstances.getDataSet().attribute(labelIndices[index]).name(); stringBuffer.append(index + ". [" + label + ", " + ranking + "] (" + confidence + ")"); stringBuffer.append("\n"); } } stringBuffer.append(output); stringBuffer.append("\n"); return stringBuffer.toString(); } public void loadSubsetInstances() { String filename = Utils.getSubsetArffSourceFilename(); String xmlfilename = Utils.getLabelsFilename(); loadInstances(filename, xmlfilename); } public void loadFullInstances() { String filename = Utils.getFullArffSourceFilename(); String xmlfilename = Utils.getLabelsFilename(); loadInstances(filename, xmlfilename); } public void loadSubsetModel() throws Exception { String modelfilepath = Utils.getSubsetModelFilename(); loadModel(modelfilepath); } public void loadFullModel() throws Exception { String modelfilepath = Utils.getFullModelFilename(); loadModel(modelfilepath); } public void tryClassifier() { // Model Trial try { List<DomainActionNode> domainActions = classifyPredicate( "displayed", "to present, exhibit", "", "", "The list of health units", "entity displayed", "on the employee's local display", "location"); System.out.println(domainActions); } catch (Exception e) { e.printStackTrace(); } } public void saveSubsetArffFiles() { String sourcefilepath = Utils.getSubsetArffSourceFilename(); String filterfilepath = Utils.getSubsetArffFilteredFilename(); int percentage = 40; try { loadSubsetInstancesFromCSV(Utils.getCSVFilenames(), Utils.getLabelsFilename(), percentage); saveArff(sourcefilepath, filterfilepath); } catch (Exception e) { e.printStackTrace(); } } public void saveFullArffFiles() { String sourcefilepath = Utils.getFullArffSourceFilename(); String filterfilepath = Utils.getFullArffFilteredFilename(); try { loadFullInstancesFromCSV(Utils.getCSVFilenames(), Utils.getLabelsFilename()); saveArff(sourcefilepath, filterfilepath); } catch (Exception e) { e.printStackTrace(); } } public void saveSubsetClassifier() { String modelfilepath = Utils.getSubsetModelFilename(); try { loadInstances(Utils.getSubsetArffSourceFilename(), Utils.getLabelsFilename()); saveModel(modelfilepath); } catch (Exception e) { e.printStackTrace(); } } public void saveFullClassifier() { String modelfilepath = Utils.getFullModelFilename(); try { loadInstances(Utils.getFullArffSourceFilename(), Utils.getLabelsFilename()); saveModel(modelfilepath); } catch (Exception e) { e.printStackTrace(); } } public static void main(String[] args) throws Exception { MachineClassifier classifier = new MachineClassifier(); classifier.setDebugEnabled(true); // Full // classifier.saveFullArffFiles(); classifier.loadFullInstances(); classifier.loadFullModel(); // classifier.trainModel(); // classifier.saveFullClassifier(); classifier.tryClassifier(); // Subset // classifier.saveSubsetArffFiles(); classifier.loadSubsetInstances(); classifier.loadSubsetModel(); // classifier.trainModel(); // classifier.saveSubsetClassifier(); classifier.tryClassifier(); } }