/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.test.asserter;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.junit.Assert;
import org.junit.ComparisonFailure;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.SparseDataRow;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.IOObjectCollection;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.associations.FrequentItemSet;
import com.rapidminer.operator.learner.associations.FrequentItemSets;
import com.rapidminer.operator.learner.functions.LinearRegressionModel;
import com.rapidminer.operator.meta.ParameterSet;
import com.rapidminer.operator.meta.ParameterValue;
import com.rapidminer.operator.nio.file.FileObject;
import com.rapidminer.operator.performance.PerformanceCriterion;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.visualization.dependencies.ANOVAMatrix;
import com.rapidminer.operator.visualization.dependencies.NumericalMatrix;
import com.rapidminer.test_utils.Asserter;
import com.rapidminer.test_utils.RapidAssert;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.math.AnovaCalculator.AnovaSignificanceTestResult;
import com.rapidminer.tools.math.Averagable;
import com.rapidminer.tools.math.AverageVector;
/**
* @author Marius Helf
*
*/
public class AsserterFactoryRapidMiner implements AsserterFactory {
@Override
public List<Asserter> createAsserters() {
List<Asserter> asserters = new LinkedList<Asserter>();
/* asserter for ParameterSet */
asserters.add(new Asserter() {
@Override
public Class<?> getAssertable() {
return ParameterSet.class;
}
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
ParameterSet expected = (ParameterSet) expectedObj;
ParameterSet actual = (ParameterSet) actualObj;
RapidAssert.assertEquals(message + " (performance vectors do not match)", expected.getPerformance(),
actual.getPerformance());
Iterator<ParameterValue> expectedIt = expected.getParameterValues();
Iterator<ParameterValue> actualIt = actual.getParameterValues();
while (expectedIt.hasNext()) {
Assert.assertTrue(message + "(expected parameter vector is longer than actual parameter vector)",
actualIt.hasNext());
ParameterValue expectedParValue = expectedIt.next();
ParameterValue actualParValue = actualIt.next();
RapidAssert.assertEquals(message + " (parameter values)", expectedParValue, actualParValue);
}
Assert.assertFalse(message + "(expected parameter vector is shorter than actual parameter vector)",
actualIt.hasNext());
}
});
/* asserter for PerformanceCriterion */
asserters.add(new Asserter() {
/**
* Tests for equality by testing all averages, standard deviation and variances, as well
* as the fitness, max fitness and example count.
*
* @param message
* message to display if an error occurs
* @param expected
* expected criterion
* @param actual
* actual criterion
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
PerformanceCriterion expected = (PerformanceCriterion) expectedObj;
PerformanceCriterion actual = (PerformanceCriterion) actualObj;
List<Asserter> averegableAsserter = RapidAssert.ASSERTER_REGISTRY.getAsserterForClass(Averagable.class);
if (averegableAsserter != null) {
for (Asserter asserter : averegableAsserter) {
asserter.assertEquals(message, expected, actual);
}
} else {
throw new ComparisonFailure("Comparison of " + Averagable.class.toString() + " is not supported. ",
expectedObj.toString(), actualObj.toString());
}
assertDouble(message + " (fitness is not equal)", expected.getFitness(), actual.getFitness());
assertDouble(message + " (max fitness is not equal)", expected.getMaxFitness(), actual.getMaxFitness());
assertDouble(message + " (example count is not equal)", expected.getExampleCount(), actual.getExampleCount());
}
@Override
public Class<?> getAssertable() {
return PerformanceCriterion.class;
}
});
asserters.add(new Asserter() {
/**
* Tests for equality by testing all averages, standard deviation and variances.
*
* @param message
* message to display if an error occurs
* @param expected
* expected averagable
* @param actual
* actual averagable
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
Averagable expected = (Averagable) expectedObj;
Averagable actual = (Averagable) actualObj;
assertDouble(message + " (average is not equal)", expected.getAverage(), actual.getAverage());
assertDouble(message + " (makro average is not equal)", expected.getMakroAverage(), actual.getMakroAverage());
assertDouble(message + " (mikro average is not equal)", expected.getMikroAverage(), actual.getMikroAverage());
assertDouble(message + " (average count is not equal)", expected.getAverageCount(), actual.getAverageCount());
assertDouble(message + " (makro standard deviation is not equal)", expected.getMakroStandardDeviation(),
actual.getMakroStandardDeviation());
assertDouble(message + " (mikro standard deviation is not equal)", expected.getMikroStandardDeviation(),
actual.getMikroStandardDeviation());
assertDouble(message + " (standard deviation is not equal)", expected.getStandardDeviation(),
actual.getStandardDeviation());
assertDouble(message + " (makro variance is not equal)", expected.getMakroVariance(),
actual.getMakroVariance());
assertDouble(message + " (mikro variance is not equal)", expected.getMikroVariance(),
actual.getMikroVariance());
assertDouble(message + " (variance is not equal)", expected.getVariance(), actual.getVariance());
}
@Override
public Class<?> getAssertable() {
return Averagable.class;
}
});
asserters.add(new Asserter() {
/**
* Tests the two average vectors for equality by testing the size and each averagable.
*
* @param message
* message to display if an error occurs
* @param expected
* expected vector
* @param actual
* actual vector
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
AverageVector expected = (AverageVector) expectedObj;
AverageVector actual = (AverageVector) actualObj;
message = message + "Average vectors are not equals";
int expSize = expected.getSize();
int actSize = actual.getSize();
Assert.assertEquals(message + " (size of the average vector is not equal)", expSize, actSize);
int size = expSize;
for (int i = 0; i < size; i++) {
RapidAssert.assertEquals(message, expected.getAveragable(i), actual.getAveragable(i));
}
}
@Override
public Class<?> getAssertable() {
return AverageVector.class;
}
});
// Asserter for ExampleSet
asserters.add(new Asserter() {
/**
* Tests two example sets by iterating over all examples.
*
* @param message
* message to display if an error occurs
* @param expected
* expected value
* @param actual
* actual value
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
ExampleSet expected = (ExampleSet) expectedObj;
ExampleSet actual = (ExampleSet) actualObj;
message = message + " - ExampleSets are not equal";
boolean compareAttributeDefaultValues = true;
if (expected.getExampleTable().size() > 0) {
compareAttributeDefaultValues = expected.getExampleTable().getDataRow(0) instanceof SparseDataRow;
}
// compare attributes
RapidAssert.assertEquals(message, expected.getAttributes(), actual.getAttributes(),
compareAttributeDefaultValues);
// compare number of examples
Assert.assertEquals(message + " (number of examples)", expected.size(), actual.size());
// compare example values
Iterator<Example> i1 = expected.iterator();
Iterator<Example> i2 = actual.iterator();
int row = 1;
while (i1.hasNext() && i2.hasNext()) {
RapidAssert.assertEquals(message + "(example number " + row + ", {0} value of {1})", i1.next(),
i2.next());
row++;
}
}
@Override
public Class<?> getAssertable() {
return ExampleSet.class;
}
});
asserters.add(new Asserter() {
/**
* Tests the collection of ioobjects
*
* @param expected
* @param actual
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
@SuppressWarnings("unchecked")
IOObjectCollection<IOObject> expected = (IOObjectCollection<IOObject>) expectedObj;
@SuppressWarnings("unchecked")
IOObjectCollection<IOObject> actual = (IOObjectCollection<IOObject>) actualObj;
message = message + "Collection \"" + actual.getSource() + "\" of IOObjects are not equal: ";
Assert.assertEquals(message + " (number of items)", expected.size(), actual.size());
RapidAssert.assertEquals(message, expected.getObjects(), actual.getObjects());
}
@Override
public Class<?> getAssertable() {
return IOObjectCollection.class;
}
});
asserters.add(new Asserter() {
/**
* Test two numerical matrices for equality. This contains tests about the number of
* columns and rows, as well as column&row names and if the matrices are marked as
* symmetrical and if every value within the matrix is equal.
*
* @param message
* message to display if an error occurs
* @param expected
* expected matrix
* @param actual
* actual matrix
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
NumericalMatrix expected = (NumericalMatrix) expectedObj;
NumericalMatrix actual = (NumericalMatrix) actualObj;
message = message + "Numerical matrices are not equal";
int expNrOfCols = expected.getNumberOfColumns();
int actNrOfCols = actual.getNumberOfColumns();
Assert.assertEquals(message + " (column number is not equal)", expNrOfCols, actNrOfCols);
int expNrOfRows = expected.getNumberOfRows();
int actNrOfRows = actual.getNumberOfRows();
Assert.assertEquals(message + " (row number is not equal)", expNrOfRows, actNrOfRows);
int cols = expNrOfCols;
int rows = expNrOfRows;
for (int col = 0; col < cols; col++) {
String expectedColName = expected.getColumnName(col);
String actualColName = actual.getColumnName(col);
Assert.assertEquals(message + " (column name at index " + col + " is not equal)", expectedColName,
actualColName);
}
for (int row = 0; row < rows; row++) {
String expectedRowName = expected.getRowName(row);
String actualRowName = actual.getRowName(row);
Assert.assertEquals(message + " (row name at index " + row + " is not equal)", expectedRowName,
actualRowName);
}
Assert.assertEquals(message + " (matrix symmetry is not equal)", expected.isSymmetrical(),
actual.isSymmetrical());
for (int row = 0; row < rows; row++) {
for (int col = 0; col < cols; col++) {
double expectedVal = expected.getValue(row, col);
double actualVal = actual.getValue(row, col);
assertDouble(message + " (value at row " + row + " and column " + col + " is not equal)",
expectedVal, actualVal);
}
}
}
@Override
public Class<?> getAssertable() {
return NumericalMatrix.class;
}
});
asserters.add(new Asserter() {
/**
* Tests the two performance vectors for equality by testing the size, the criteria
* names, the main criterion and each criterion.
*
* @param message
* message to display if an error occurs
* @param expected
* expected vector
* @param actual
* actual vector
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
PerformanceVector expected = (PerformanceVector) expectedObj;
PerformanceVector actual = (PerformanceVector) actualObj;
message = message + "Performance vectors are not equal";
int expSize = expected.getSize();
int actSize = actual.getSize();
Assert.assertEquals(message + " (size of the performance vector is not equal)", expSize, actSize);
int size = expSize;
RapidAssert.assertArrayEquals(message, expected.getCriteriaNames(), actual.getCriteriaNames());
RapidAssert.assertEquals(message, expected.getMainCriterion(), actual.getMainCriterion());
for (int i = 0; i < size; i++) {
RapidAssert.assertEquals(message, expected.getCriterion(i), actual.getCriterion(i));
}
}
@Override
public Class<?> getAssertable() {
return PerformanceVector.class;
}
});
asserters.add(new Asserter() {
/**
* Tests the two file objects for equality by testing the
*
*
* @param message
* message to display if an error occurs
* @param expected
* expected file object
* @param actual
* actual file object
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) throws RuntimeException {
FileObject fo1 = (FileObject) expectedObj;
FileObject fo2 = (FileObject) actualObj;
InputStream is1 = null;
InputStream is2 = null;
ByteArrayOutputStream bs1 = null;
ByteArrayOutputStream bs2 = null;
try {
is1 = fo1.openStream();
is2 = fo2.openStream();
bs1 = new ByteArrayOutputStream();
bs2 = new ByteArrayOutputStream();
Tools.copyStreamSynchronously(is1, bs1, true);
Tools.copyStreamSynchronously(is2, bs2, true);
byte[] fileData1 = bs1.toByteArray();
byte[] fileData2 = bs2.toByteArray();
RapidAssert.assertArrayEquals("file object data", fileData1, fileData2);
} catch (OperatorException e) {
throw new RuntimeException("Stream Error");
} catch (IOException e) {
throw new RuntimeException("Stream Error");
} finally {
if (is1 != null) {
try {
is1.close();
} catch (IOException e) {
// silent
}
}
if (is2 != null) {
try {
is2.close();
} catch (IOException e) {
// silent
}
}
if (bs1 != null) {
try {
bs1.close();
} catch (IOException e) {
// silent
}
}
if (bs2 != null) {
try {
bs2.close();
} catch (IOException e) {
// silent
}
}
}
}
@Override
public Class<?> getAssertable() {
return FileObject.class;
}
});
// Asserter for ExampleSet
asserters.add(new Asserter() {
@Override
public Class<?> getAssertable() {
return FrequentItemSets.class;
}
/**
* Tests two FrequentItemSets by iterating over all inner Sets.
*
* @param message
* message to display if an error occurs
* @param expected
* expected value
* @param actual
* actual value
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
FrequentItemSets expected = (FrequentItemSets) expectedObj;
FrequentItemSets actual = (FrequentItemSets) actualObj;
message = message + " - FrequentItemSet \"" + actual.getSource() + "\" does not match the expected Set";
// compare size
Assert.assertEquals(
message + " : size is not equal(expected <" + expected.size() + "> was <" + actual.size() + ">)",
expected.size(), actual.size());
// compare number of transactions
Assert.assertEquals(
message + " : number of transactions is not equal(expected <" + expected.getNumberOfTransactions()
+ "> was <" + actual.getNumberOfTransactions() + ">)", expected.getNumberOfTransactions(),
actual.getNumberOfTransactions());
// compare example values
expected.sortSets();
actual.sortSets();
Iterator<FrequentItemSet> i1 = expected.iterator();
Iterator<FrequentItemSet> i2 = actual.iterator();
while (i1.hasNext() && i2.hasNext()) {
Assert.assertTrue(message, i1.next().compareTo(i2.next()) == 0);
}
}
});
// Asserter for linear regression model
asserters.add(new Asserter() {
@Override
public Class<?> getAssertable() {
return LinearRegressionModel.class;
}
/**
* Tests two linearRegression models by comparing all values
*
* @param message
* message to display if an error occurs
* @param expected
* expected value
* @param actual
* actual value
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
LinearRegressionModel expected = (LinearRegressionModel) expectedObj;
LinearRegressionModel actual = (LinearRegressionModel) actualObj;
message = message + " - Linear Regression Model \"" + actual.getSource()
+ "\" does not match the expected Model";
// compare coefficients
Assert.assertArrayEquals(message + " : coefficients are not equal", expected.getCoefficients(),
actual.getCoefficients(), 1E-15);
// compare probabilities
Assert.assertArrayEquals(message + " : probabilities are not equal", expected.getProbabilities(),
actual.getProbabilities(), 1E-15);
// compare selected attribute names
Assert.assertArrayEquals(message + " : selected attributes are not equal",
expected.getSelectedAttributeNames(), actual.getSelectedAttributeNames());
// compare selected attributes
Assert.assertArrayEquals(message + " : selected attributes are not equal", expected.getSelectedAttributes(),
actual.getSelectedAttributes());
// compare standard errors
Assert.assertArrayEquals(message + " : standard errors are not equal", expected.getStandardErrors(),
actual.getStandardErrors(), 1E-15);
// compare standardized coefficients
Assert.assertArrayEquals(message + " : standardized coefficients are not equal",
expected.getStandardizedCoefficients(), actual.getStandardizedCoefficients(), 1E-15);
// compare tolerances
Assert.assertArrayEquals(message + " : tolerances are not equal", expected.getTolerances(),
actual.getTolerances(), 1E-15);
// compare t-stats
Assert.assertArrayEquals(message + " : t statistics are not equal", expected.getTStats(),
actual.getTStats(), 1E-15);
}
});
// Asserter for ANOVA-Matrixes
asserters.add(new Asserter() {
@Override
public Class<?> getAssertable() {
return ANOVAMatrix.class;
}
/**
* Tests two ANOVA-Matrixes by comparing all values
*
* @param message
* message to display if an error occurs
* @param expected
* expected value
* @param actual
* actual value
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
ANOVAMatrix expected = (ANOVAMatrix) expectedObj;
ANOVAMatrix actual = (ANOVAMatrix) actualObj;
message = message + " - ANOVA-Matrix \"" + actual.getSource() + "\" does not match the expected Matrix";
// compare all entries
double[][] expectedProbabilities = expected.getProbabilities();
double[][] actualProbabilities = actual.getProbabilities();
for (int i = 0; i < expectedProbabilities.length; i++) {
for (int j = 0; j < expectedProbabilities[i].length; j++) {
Assert.assertEquals(message + " : probabilities are not equal", expectedProbabilities[i][j],
actualProbabilities[i][j], 1E-15);
}
}
Assert.assertEquals(message + " : significance levels are not equal", expected.getSignificanceLevel(),
actual.getSignificanceLevel(), 1E-15);
}
});
// Asserter for Significance Test Results
asserters.add(new Asserter() {
@Override
public Class<?> getAssertable() {
return AnovaSignificanceTestResult.class;
}
/**
* Tests two ANOVA-Significance test results
*
* @param message
* message to display if an error occurs
* @param expected
* expected value
* @param actual
* actual value
*/
@Override
public void assertEquals(String message, Object expectedObj, Object actualObj) {
AnovaSignificanceTestResult expected = (AnovaSignificanceTestResult) expectedObj;
AnovaSignificanceTestResult actual = (AnovaSignificanceTestResult) actualObj;
message = message + " - ANOVA significance test result \"" + actual.getSource()
+ "\" does not match the expected result";
// compare alpha values
Assert.assertEquals(message + " : alpha values are not equal", expected.getAlpha(), actual.getAlpha(), 1E-15);
// compare DF1
Assert.assertEquals(message + " : first degrees of freedom are equal", expected.getDf1(), actual.getDf1(),
1E-15);
// compare DF2
Assert.assertEquals(message + " : second degrees of freedom are equal", expected.getDf2(), actual.getDf2(),
1E-15);
// compare F values
Assert.assertEquals(message + " : F-values are not equal", expected.getFValue(), actual.getFValue(), 1E-15);
// compare mean square residuals
Assert.assertEquals(message + " : mean square residual values are not equal",
expected.getMeanSquaresResiduals(), actual.getMeanSquaresResiduals(), 1E-15);
// compare mean square between
Assert.assertEquals(message + " : mean square between values are not equal",
expected.getMeanSquaresBetween(), actual.getMeanSquaresBetween(), 1E-15);
// compare probabilities
Assert.assertEquals(message + " : probabilities are not equal", expected.getProbability(),
actual.getProbability(), 1E-15);
// compare sum squares between
Assert.assertEquals(message + " : sum squares between values are not equal",
expected.getSumSquaresBetween(), actual.getSumSquaresBetween(), 1E-15);
// compare sum squares residuals
Assert.assertEquals(message + " : sum squares residuals values are not equal",
expected.getSumSquaresResiduals(), actual.getSumSquaresResiduals(), 1E-15);
}
});
return asserters;
}
private void assertDouble(String message, double expected, double result) {
org.junit.Assert.assertEquals(message, expected, result, 1e-08);
}
}