package com.rapidminer.operator.learner.test; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Random; import junit.framework.TestCase; import com.rapidminer.example.Attribute; import com.rapidminer.example.Attributes; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.set.SimpleExampleSet; import com.rapidminer.example.table.AttributeFactory; import com.rapidminer.example.table.DataRow; import com.rapidminer.example.table.DataRowReader; import com.rapidminer.example.table.DoubleArrayDataRow; import com.rapidminer.example.table.MemoryExampleTable; import com.rapidminer.operator.OperatorCapability; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.learner.Learner; import com.rapidminer.operator.learner.functions.kernel.RVMLearner; import com.rapidminer.test.TestContext; import com.rapidminer.tools.Ontology; import com.rapidminer.tools.OperatorService; /** Creates all learners using the {@link OperatorService} and constructs input example sets * according to their capabilities to check whether they operate without throwing exceptions / * throwing the correct exceptions. * * */ public class LearnerTest extends TestCase { private static final int NUM_EXAMPLES = 100; private OperatorDescription opDesc; public LearnerTest(OperatorDescription opDesc) { super("learnerTest"); this.opDesc = opDesc; } @Override protected void setUp() throws Exception { super.setUp(); TestContext.get().initRapidMiner(); } @Override public String getName() { return "Learner "+opDesc.getName() + " - "+opDesc.getKey()+" - "+opDesc.getOperatorClass(); //return "Test_learner_"+opDesc.getName(); // + " ("+opDesc.getKey()+", "+opDesc.getOperatorClass()+")"; } public void learnerTest() throws Exception { Learner learner = (Learner) OperatorService.createOperator(opDesc); MemoryExampleTable exTable = new MemoryExampleTable(); for (int i = 0; i < NUM_EXAMPLES; i++) { exTable.addDataRow(new DoubleArrayDataRow(new double[0])); } List<Attribute> regulars = new LinkedList<Attribute>(); if (learner.supportsCapability(OperatorCapability.BINOMINAL_ATTRIBUTES)) { regulars.add(addBinomialAttribute(exTable)); regulars.add(addBinomialAttribute(exTable)); } if (learner.supportsCapability(OperatorCapability.POLYNOMINAL_ATTRIBUTES)) { regulars.add(addNominalAttribute(exTable)); regulars.add(addNominalAttribute(exTable)); } if (learner.supportsCapability(OperatorCapability.NUMERICAL_ATTRIBUTES)) { regulars.add(addNumericalAttribute(exTable)); regulars.add(addNumericalAttribute(exTable)); } if (regulars.isEmpty()) { throw new Exception("No regular attribute type supported."); } List<Attribute> labels = new LinkedList<Attribute>(); if (learner.supportsCapability(OperatorCapability.POLYNOMINAL_LABEL)) { labels.add(addNominalAttribute(exTable)); } if (learner.supportsCapability(OperatorCapability.NUMERICAL_LABEL)) { labels.add(addNumericalAttribute(exTable)); } if (learner.supportsCapability(OperatorCapability.BINOMINAL_LABEL)) { labels.add(addBinomialAttribute(exTable)); } if (labels.isEmpty()) { throw new Exception("No label type supported."); } for (Attribute label : labels) { // Some learner need special parameters for certain label types if( !checkLearnerCapability(learner, label) ) continue; //AttributeSet attributes = new AttributeSet(regulars, Collections.singletonMap(Attributes.LABEL_NAME, label)); //exTable.createExampleSet(specialAttributes) ExampleSet exampleSet = new SimpleExampleSet(exTable, regulars, Collections.singletonMap(label, Attributes.LABEL_NAME)); //ExampleSet exampleSet = exTable.createExampleSet(attributes); learner.learn(exampleSet); } } private Attribute addBinomialAttribute(MemoryExampleTable exTable) { final Attribute att = AttributeFactory.createAttribute("binominal_"+(exTable.getNumberOfAttributes()+1), Ontology.BINOMINAL); att.getMapping().mapString("positive"); att.getMapping().mapString("negative"); exTable.addAttribute(att); DataRowReader dataRowReader = exTable.getDataRowReader(); Random random = new Random(); while (dataRowReader.hasNext()) { DataRow row = dataRowReader.next(); row.set(att, random.nextInt(2)); } return att; } private Attribute addNominalAttribute(MemoryExampleTable exTable) { final Attribute att = AttributeFactory.createAttribute("polynom_"+(exTable.getNumberOfAttributes()+1), Ontology.POLYNOMINAL); att.getMapping().mapString("one"); att.getMapping().mapString("two"); att.getMapping().mapString("three"); att.getMapping().mapString("four"); exTable.addAttribute(att); DataRowReader dataRowReader = exTable.getDataRowReader(); Random random = new Random(); while (dataRowReader.hasNext()) { DataRow row = dataRowReader.next(); row.set(att, random.nextInt(4)); } return att; } private Attribute addNumericalAttribute(MemoryExampleTable exTable) { final Attribute att = AttributeFactory.createAttribute("numeric_"+(exTable.getNumberOfAttributes()+1), Ontology.NUMERICAL); exTable.addAttribute(att); DataRowReader dataRowReader = exTable.getDataRowReader(); while (dataRowReader.hasNext()) { DataRow row = dataRowReader.next(); row.set(att, Math.random()*10d-5d); } return att; } private boolean checkLearnerCapability( Learner learner, Attribute label ) { if( learner instanceof RVMLearner ) { RVMLearner rvm = (RVMLearner) learner; if( label.isNominal() ) { rvm.setParameter(RVMLearner.PARAMETER_RVM_TYPE, "1"); } return true; } return true; } }