/** * Copyright (C) 2001-2017 by RapidMiner and the contributors * * Complete list of developers available at our web site: * * http://rapidminer.com * * This program is free software: you can redistribute it and/or modify it under the terms of the * GNU Affero General Public License as published by the Free Software Foundation, either version 3 * of the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License along with this program. * If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.tools.math.function.aggregation; import com.rapidminer.operator.ports.metadata.AttributeMetaData; import com.rapidminer.tools.Ontology; import java.lang.reflect.InvocationTargetException; /** * Superclass for aggregation functions providing some generic functions. * * In comparison to the more specialized functions available in the * {@link com.rapidminer.operator.preprocessing.transformation.aggregation.AggregationFunction} * these functions have a more broader use, but are limited to numerical values. * * @author Tobias Malbrecht * */ public abstract class AbstractAggregationFunction implements AggregationFunction { @SuppressWarnings("unchecked") public static final Class<? extends AggregationFunction>[] KNOWN_AGGREGATION_FUNCTIONS = new Class[] { AverageFunction.class, VarianceFunction.class, StandardDeviationFunction.class, CountFunction.class, MinFunction.class, MaxFunction.class, SumFunction.class, ModeFunction.class, MedianFunction.class, ProductFunction.class }; public static final String[] KNOWN_AGGREGATION_FUNCTION_NAMES = { "average", "variance", "standard_deviation", "count", "minimum", "maximum", "sum", "mode", "median", "product" }; public enum AggregationFunctionType { average, variance, standard_deviation, count, minimum, maximum, sum, mode, median, product, } public static final int AVERAGE = 0; public static final int VARIANCE = 1; public static final int STANDARD_DEVIATION = 2; public static final int COUNT = 3; public static final int MINIMUM = 4; public static final int MAXIMUM = 5; public static final int SUM = 6; public static final int MODE = 7; public static final int MEDIAN = 8; public static final int PRODUCT = 9; public static final boolean DEFAULT_IGNORE_MISSINGS = true; protected boolean ignoreMissings = DEFAULT_IGNORE_MISSINGS; protected boolean foundMissing = false; @SuppressWarnings("unchecked") public static AggregationFunction createAggregationFunction(String functionName, boolean ignoreMissings) throws InstantiationException, IllegalAccessException, ClassNotFoundException, NoSuchMethodException, InvocationTargetException { int typeIndex = -1; for (int i = 0; i < KNOWN_AGGREGATION_FUNCTION_NAMES.length; i++) { if (KNOWN_AGGREGATION_FUNCTION_NAMES[i].equals(functionName)) { typeIndex = i; break; } } Class<? extends AggregationFunction> clazz = null; if (typeIndex < 0) { clazz = (Class<? extends AggregationFunction>) Class.forName(functionName); } else { clazz = KNOWN_AGGREGATION_FUNCTIONS[typeIndex]; } return clazz.getConstructor(Boolean.class).newInstance(ignoreMissings); } public static AggregationFunction createAggregationFunction(String functionName) throws InstantiationException, IllegalAccessException, ClassNotFoundException, NoSuchMethodException, InvocationTargetException { return createAggregationFunction(functionName, true); } public static AggregationFunction createAggregationFunction(int typeIndex, boolean ignoreMissings) throws InstantiationException, IllegalAccessException, NoSuchMethodException, InvocationTargetException { if (typeIndex >= 0 && typeIndex < KNOWN_AGGREGATION_FUNCTION_NAMES.length) { Class<? extends AggregationFunction> clazz = KNOWN_AGGREGATION_FUNCTIONS[typeIndex]; return clazz.getConstructor(new Class[] { Boolean.class }).newInstance(ignoreMissings); } else { throw new InstantiationException(); } } public static AggregationFunction createAggregationFunction(int typeIndex) throws InstantiationException, IllegalAccessException, NoSuchMethodException, InvocationTargetException { return createAggregationFunction(typeIndex, true); } public AbstractAggregationFunction() { this(true); } public AbstractAggregationFunction(Boolean ignoreMissings) { this.ignoreMissings = ignoreMissings; this.foundMissing = false; reset(); } /** * Reset the counters. */ protected abstract void reset(); /** * Resets the counters and computes the aggregation function solely based on the given values. */ @Override public synchronized double calculate(double[] values) { reset(); for (int i = 0; i < values.length; i++) { update(values[i]); } return getValue(); } /** * Resets the counters and computes the aggregation function solely based on the given values * and the given weights. */ @Override public synchronized double calculate(double[] values, double[] weights) { reset(); if (values.length != weights.length) { return Double.NaN; } for (int i = 0; i < values.length; i++) { update(values[i], weights[i]); } return getValue(); } /** * Standard behavior is to return true for all numerical attributes */ @Override public boolean supportsAttribute(AttributeMetaData amd) { return amd.isNumerical(); } /** * Standard behavior is to return true for all numerical value types. */ @Override public boolean supportsValueType(int valueType) { return Ontology.ATTRIBUTE_VALUE_TYPE.isA(valueType, Ontology.NUMERICAL); } /** * Standard behaviour is to return inputType, i.e. same output type as input type. */ @Override public int getValueTypeOfResult(int inputType) { return inputType; } }