/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.meta; import java.util.Collection; import java.util.List; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.set.SplittedExampleSet; import com.rapidminer.operator.IOObject; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorCapability; import com.rapidminer.operator.OperatorChain; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.OperatorVersion; import com.rapidminer.operator.ValueDouble; import com.rapidminer.operator.learner.CapabilityCheck; import com.rapidminer.operator.learner.CapabilityProvider; import com.rapidminer.operator.learner.PredictionModel; import com.rapidminer.operator.ports.InputPort; import com.rapidminer.operator.ports.OutputPort; import com.rapidminer.operator.ports.PortPairExtender; import com.rapidminer.operator.ports.metadata.CapabilityPrecondition; import com.rapidminer.operator.ports.metadata.ExampleSetMetaData; import com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule; import com.rapidminer.operator.ports.metadata.MDInteger; import com.rapidminer.operator.ports.metadata.PassThroughRule; import com.rapidminer.operator.ports.metadata.SetRelation; import com.rapidminer.operator.ports.metadata.SubprocessTransformRule; import com.rapidminer.parameter.ParameterType; import com.rapidminer.parameter.ParameterTypeBoolean; import com.rapidminer.parameter.ParameterTypeCategory; import com.rapidminer.parameter.ParameterTypeInt; import com.rapidminer.parameter.UndefinedParameterError; import com.rapidminer.parameter.conditions.BooleanParameterCondition; import com.rapidminer.tools.RandomGenerator; /** * Operator chain that splits an {@link ExampleSet} into a training and test sets similar to XValidation, but returns * the test set predictions instead of a performance vector. The inner two operators must be a learner returning a * {@link Model} and an operator or operator chain that can apply this model (usually a model applier) * * @author Stefan Rueping, Ingo Mierswa, Sebastian Land */ public class XVPrediction extends OperatorChain implements CapabilityProvider { /** The parameter name for "Number of subsets for the crossvalidation." */ public static final String PARAMETER_NUMBER_OF_VALIDATIONS = "number_of_validations"; /** * The parameter name for "Set the number of validations to the number of examples. If set to true, * number_of_validations is ignored." */ public static final String PARAMETER_LEAVE_ONE_OUT = "leave_one_out"; /** The parameter name for "Defines the sampling type of the cross validation." */ public static final String PARAMETER_SAMPLING_TYPE = "sampling_type"; private int number; private int iteration; private final InputPort exampleSetInput = getInputPorts().createPort("example set", ExampleSet.class); private final OutputPort trainingProcessExampleSource = getSubprocess(0).getInnerSources().createPort("training"); private final InputPort trainingProcessModelSink = getSubprocess(0).getInnerSinks().createPort("model"); // training -> testing private final PortPairExtender throughExtender = new PortPairExtender("through", getSubprocess(0).getInnerSinks(), getSubprocess(1).getInnerSources()); // testing private final OutputPort applyProcessModelSource = getSubprocess(1).getInnerSources().createPort("model"); private final OutputPort applyProcessExampleSource = getSubprocess(1).getInnerSources().createPort("unlabelled data"); private final InputPort applyProcessExampleInnerSink = getSubprocess(1).getInnerSinks().createPort("labelled data"); // output private final OutputPort exampleSetOutput = getOutputPorts().createPort("labelled data"); public XVPrediction(OperatorDescription description) { super(description, "Training", "Model Application"); exampleSetInput.addPrecondition(new CapabilityPrecondition(this, exampleSetInput)); throughExtender.start(); getTransformer().addRule(new ExampleSetPassThroughRule(exampleSetInput, trainingProcessExampleSource, SetRelation.EQUAL) { @Override public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData metaData) throws UndefinedParameterError { try { metaData.setNumberOfExamples(getTrainingSetSize(metaData.getNumberOfExamples())); } catch (UndefinedParameterError e) { } return super.modifyExampleSet(metaData); } }); getTransformer().addRule(new ExampleSetPassThroughRule(exampleSetInput, applyProcessExampleSource, SetRelation.EQUAL) { @Override public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData metaData) throws UndefinedParameterError { try { metaData.setNumberOfExamples(getTestSetSize(metaData.getNumberOfExamples())); } catch (UndefinedParameterError e) { } return super.modifyExampleSet(metaData); } }); getTransformer().addRule(new SubprocessTransformRule(getSubprocess(0))); getTransformer().addRule(new PassThroughRule(trainingProcessModelSink, applyProcessModelSource, false)); getTransformer().addRule(throughExtender.makePassThroughRule()); getTransformer().addRule(new SubprocessTransformRule(getSubprocess(1))); getTransformer().addPassThroughRule(applyProcessExampleInnerSink, exampleSetOutput); addValue(new ValueDouble("iteration", "The number of the current iteration.") { @Override public double getDoubleValue() { return iteration; } }); } @Override public void doWork() throws OperatorException { ExampleSet inputSet = exampleSetInput.getData(); // check capabilities and produce errors if they are not fulfilled CapabilityCheck check = new CapabilityCheck(this, false); check.checkLearnerCapabilities(this, inputSet); if (getParameterAsBoolean(PARAMETER_LEAVE_ONE_OUT)) { number = inputSet.size(); } else { number = getParameterAsInt(PARAMETER_NUMBER_OF_VALIDATIONS); } log("Starting " + number + "-fold cross validation prediction"); // creating predicted label ExampleSet resultSet = (ExampleSet) inputSet.clone(); Attribute predictedLabel = PredictionModel.createPredictedLabel(resultSet, inputSet.getAttributes().getLabel()); Collection<String> predictedLabelValues = null; if (predictedLabel.isNominal()) predictedLabelValues = predictedLabel.getMapping().getValues(); // Split training / test set int samplingType = getParameterAsInt(PARAMETER_SAMPLING_TYPE); SplittedExampleSet splittedSet = new SplittedExampleSet(inputSet, number, samplingType, getParameterAsBoolean(RandomGenerator.PARAMETER_USE_LOCAL_RANDOM_SEED), getParameterAsInt(RandomGenerator.PARAMETER_LOCAL_RANDOM_SEED), getCompatibilityLevel().isAtMost(SplittedExampleSet.VERSION_SAMPLING_CHANGED)); for (iteration = 0; iteration < number; iteration++) { splittedSet.selectAllSubsetsBut(iteration); trainingProcessExampleSource.deliver(splittedSet); getSubprocess(0).execute(); // IOContainer learnResult = getLearner().apply(new IOContainer(new IOObject[] { splittedSet })); splittedSet.selectSingleSubset(iteration); applyProcessExampleSource.deliver((IOObject) splittedSet); throughExtender.passDataThrough(); applyProcessModelSource.deliver(trainingProcessModelSink.getData()); getSubprocess(1).execute(); ExampleSet predictedSet = applyProcessExampleInnerSink.getData(); for (int i = 0; i < splittedSet.size(); i++) { Example predictedExample = predictedSet.getExample(i); // setting label in inputSet Example resultExample = resultSet.getExample(splittedSet.getActualParentIndex(i)); resultExample.setValue(predictedLabel, predictedExample.getPredictedLabel()); if (predictedLabel.isNominal()) { for (String s : predictedLabelValues) { resultExample.setConfidence(s, predictedExample.getConfidence(s)); } } } //PredictionModel.removePredictedLabel(predictedSet); inApplyLoop(); } exampleSetOutput.deliver(resultSet); } protected MDInteger getTestSetSize(MDInteger originalSize) throws UndefinedParameterError { if (getParameterAsBoolean(PARAMETER_LEAVE_ONE_OUT)) { return new MDInteger(1); } else { return originalSize.multiply(1d / getParameterAsDouble(PARAMETER_NUMBER_OF_VALIDATIONS)); } } protected MDInteger getTrainingSetSize(MDInteger originalSize) throws UndefinedParameterError { if (getParameterAsBoolean(PARAMETER_LEAVE_ONE_OUT)) { return originalSize.add(-1); } else { return originalSize.multiply(1d - 1d / getParameterAsDouble(PARAMETER_NUMBER_OF_VALIDATIONS)); } } @Override public List<ParameterType> getParameterTypes() { List<ParameterType> types = super.getParameterTypes(); types.add(new ParameterTypeBoolean(PARAMETER_LEAVE_ONE_OUT, "Set the number of validations to the number of examples. If set to true, number_of_validations is ignored.", false, false)); ParameterType type = new ParameterTypeInt(PARAMETER_NUMBER_OF_VALIDATIONS, "Number of subsets for the crossvalidation.", 2, Integer.MAX_VALUE, 10, false); type.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_LEAVE_ONE_OUT, false, false)); types.add(type); types.add(new ParameterTypeCategory(PARAMETER_SAMPLING_TYPE, "Defines the sampling type of the cross validation.", SplittedExampleSet.SAMPLING_NAMES, SplittedExampleSet.STRATIFIED_SAMPLING, false)); types.addAll(RandomGenerator.getRandomGeneratorParameters(this)); return types; } @Override public boolean supportsCapability(OperatorCapability capability) { switch (capability) { case NO_LABEL: return false; case NUMERICAL_LABEL: try { return getParameterAsInt(PARAMETER_SAMPLING_TYPE) != SplittedExampleSet.STRATIFIED_SAMPLING; } catch (UndefinedParameterError e) { return false; } default: return true; } } @Override public OperatorVersion[] getIncompatibleVersionChanges() { return new OperatorVersion[] { SplittedExampleSet.VERSION_SAMPLING_CHANGED }; } }