/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.performance.test;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.MemoryExampleTable;
import com.rapidminer.example.test.ExampleTestTools;
import com.rapidminer.operator.performance.AbsoluteError;
import com.rapidminer.operator.performance.AbstractPerformanceEvaluator;
import com.rapidminer.operator.performance.PerformanceCriterion;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.performance.SquaredError;
import com.rapidminer.tools.att.AttributeSet;
/**
* Tests regression critetia.
*
* @author Simon Fischer, Ingo Mierswa
* @version $Id: MeasuredCriterionTest.java,v 1.16 2006/04/05 08:57:27
* ingomierswa Exp $
*/
public class MeasuredCriterionTest extends CriterionTestCase {
private ExampleSet exampleSet1, exampleSet2;
private ExampleSet createExampleSet(double[][] labelValues, double[] predictedValues) throws Exception {
Attribute label = ExampleTestTools.attributeReal();
label.setTableIndex(0);
List<Attribute> attributeList = new LinkedList<Attribute>();
attributeList.add(label);
MemoryExampleTable exampleTable = new MemoryExampleTable(attributeList, ExampleTestTools.createDataRowReader(labelValues));
AttributeSet attributeSet = new AttributeSet();
attributeSet.setSpecialAttribute("label", label);
ExampleSet exampleSet = exampleTable.createExampleSet(attributeSet);
Attribute predictedLabel = ExampleTestTools.createPredictedLabel(exampleSet);
Iterator<Example> r = exampleSet.iterator();
for (int i = 0; i < predictedValues.length; i++)
r.next().setValue(predictedLabel, predictedValues[i]);
return exampleSet;
}
public void setUp() throws Exception {
super.setUp();
exampleSet1 = createExampleSet(new double[][] { { 5.0 }, { 3.0 }, { -1.0 }, { -4.0 }, { 0.0 }, { 2.0 } }, new double[] { 6.0, 1.0, 0.0, -1.0, 3.0, -2.0 });
exampleSet2 = createExampleSet(new double[][] { { 3.0 }, { 6.0 }, { -1.0 } }, new double[] { 1.0, 8.0, -4.0 });
}
public void tearDown() throws Exception {
exampleSet1 = exampleSet2 = null;
super.tearDown();
}
/** Tests calculation, average, and clone. */
private void criterionTest(PerformanceCriterion c1, PerformanceCriterion c2, double expected1, double expected2, double expectedOverall) throws Exception {
PerformanceVector pv1 = new PerformanceVector();
pv1.addCriterion(c1);
AbstractPerformanceEvaluator.evaluate(null, exampleSet1, pv1, new LinkedList<PerformanceCriterion>(), false, true);
assertEquals(c1.getName() + " 1", expected1, c1.getAverage(), 0.00000001);
assertEquals(c1.getName() + " 1 clone", expected1, ((PerformanceCriterion) c1.clone()).getAverage(), 0.00000001);
PerformanceVector pv2 = new PerformanceVector();
pv2.addCriterion(c2);
AbstractPerformanceEvaluator.evaluate(null, exampleSet2, pv2, new LinkedList<PerformanceCriterion>(), false, true);
assertEquals(c2.getName() + " 2", expected2, c2.getAverage(), 0.00000001);
assertEquals(c2.getName() + " 2 clone", expected2, ((PerformanceCriterion) c2.clone()).getAverage(), 0.00000001);
c1.buildAverage(c2);
assertEquals(c1.getName() + " average", expectedOverall, c1.getMikroAverage(), 0.00000001);
assertEquals(c1.getName() + " makro average", (expected1 + expected2) / 2.0, c1.getMakroAverage(), 0.00000001);
}
public void testAbsoluteError() throws Exception {
criterionTest(new AbsoluteError(), new AbsoluteError(), 14.0 / 6.0, 7.0 / 3.0, 21.0 / 9.0);
}
public void testSquaredError() throws Exception {
criterionTest(new SquaredError(), new SquaredError(), 40.0 / 6.0, 17.0 / 3.0, 57.0 / 9.0);
}
// public void testScaledError() throws Exception {
// criterionTest(new ScaledError(), new ScaledError(),
// 14.0/6.0/9.0,
// 7.0/3.0/7.0,
// (14.0/9.0 + 7.0/7.0) / 9.0);
// }
}