/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.meta; import java.util.Iterator; import com.rapidminer.example.Attribute; import com.rapidminer.example.Attributes; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.table.AttributeFactory; import com.rapidminer.operator.Model; import com.rapidminer.operator.OperatorCapability; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.ProcessSetupError.Severity; import com.rapidminer.operator.ports.InputPort; import com.rapidminer.operator.ports.metadata.AttributeMetaData; import com.rapidminer.operator.ports.metadata.ExampleSetMetaData; import com.rapidminer.operator.ports.metadata.MDReal; import com.rapidminer.operator.ports.metadata.MetaData; import com.rapidminer.operator.ports.metadata.PredictionModelMetaData; import com.rapidminer.operator.ports.metadata.SetRelation; import com.rapidminer.operator.ports.metadata.SimpleMetaDataError; import com.rapidminer.tools.Ontology; import com.rapidminer.tools.math.container.Range; /** * For a classified dataset (with possibly more than two classes) builds a classifier using a regression method which is * specified by the inner operator. For each class {@rapidminer.math i} a regression model is trained after setting the * label to {@rapidminer.math +1} if the label equals {@rapidminer.math i} and to {@rapidminer.math -1} if it is not. * Then the regression models are combined into a classification model. In order to determine the prediction for an * unlabeled example, all models are applied and the class belonging to the regression model which predicts the greatest * value is chosen. * * @author Ingo Mierswa, Simon Fischer */ public class ClassificationByRegression extends AbstractMetaLearner { private int numberOfClasses; public ClassificationByRegression(OperatorDescription description) { super(description); } @Override protected MetaData modifyExampleSetMetaData(ExampleSetMetaData unmodifiedMetaData) { switch (unmodifiedMetaData.hasSpecial(Attributes.LABEL_NAME)) { case NO: getTrainingSetInputPort().addError(new SimpleMetaDataError(Severity.ERROR, getTrainingSetInputPort(), "special_missing", "label")); return unmodifiedMetaData; case UNKNOWN: getTrainingSetInputPort().addError(new SimpleMetaDataError(Severity.WARNING, getTrainingSetInputPort(), "special_unknown", "label")); return unmodifiedMetaData; case YES: AttributeMetaData labelMD = unmodifiedMetaData.getLabelMetaData(); unmodifiedMetaData.removeAttribute(labelMD); AttributeMetaData transformedMD = new AttributeMetaData("regression(" + labelMD.getName() + ")", Ontology.REAL, Attributes.LABEL_NAME); transformedMD.setValueRange(new Range(-1d, 1d), SetRelation.EQUAL); transformedMD.setValueSetRelation(SetRelation.EQUAL); transformedMD.setMean(new MDReal()); unmodifiedMetaData.addAttribute(transformedMD); return unmodifiedMetaData; default: return unmodifiedMetaData; } } /** Transforms the regression label back into a classification label. */ @Override protected MetaData modifyGeneratedModelMetaData(PredictionModelMetaData unmodifiedMetaData) { InputPort in = getTrainingSetInputPort(); MetaData esetIn = in.getMetaData(); if ((esetIn != null) && (esetIn instanceof ExampleSetMetaData)) { return new PredictionModelMetaData(MultiModelByRegression.class, (ExampleSetMetaData) esetIn); } return unmodifiedMetaData; } public Model learn(ExampleSet inputSet) throws OperatorException { Attribute classLabel = inputSet.getAttributes().getLabel(); numberOfClasses = classLabel.getMapping().getValues().size(); Model[] models = new Model[numberOfClasses]; ExampleSet eSet = (ExampleSet) inputSet.clone(); Attribute tempLabel = AttributeFactory.createAttribute("regression(" + classLabel.getName() + ")", Ontology.REAL); eSet.getExampleTable().addAttribute(tempLabel); eSet.getAttributes().setLabel(tempLabel); for (int i = 0; i < numberOfClasses; i++) { // 1. Set regression labels Iterator<Example> r = eSet.iterator(); while (r.hasNext()) { Example e = r.next(); if (e.getValue(classLabel) == i) { e.setValue(tempLabel, +1.0); } else { e.setValue(tempLabel, -1.0); } } // 2. Apply learner models[i] = applyInnerLearner(eSet); inApplyLoop(); } return new MultiModelByRegression(inputSet, models); } @Override public boolean supportsCapability(OperatorCapability lc) { switch (lc) { case NUMERICAL_LABEL: case NO_LABEL: case UPDATABLE: case FORMULA_PROVIDER: return false; default: return true; } } }