/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.lazy;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorProgress;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.error.AttributeNotFoundError;
import com.rapidminer.operator.learner.PredictionModel;
/**
* This variant of the DefaultModel sets the prediction according to another attribute given during
* learn time.
*
* @author Sebastian Land
*
*/
public class AttributeDefaultModel extends PredictionModel {
private static final long serialVersionUID = 3987661566241516287L;
private static final int OPERATOR_PROGRESS_STEPS = 10_000;
private String sourceAttributeName;
protected AttributeDefaultModel(ExampleSet trainingExampleSet, String sourceAttribute) {
super(trainingExampleSet, null, null);
this.sourceAttributeName = sourceAttribute;
}
@Override
public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException {
Attribute label = getLabel();
Attribute exampleSetLabel = exampleSet.getAttributes().getLabel();
Attribute sourceAttribute = exampleSet.getAttributes().get(sourceAttributeName);
if (sourceAttribute != null) {
if (label.isNominal() && !exampleSetLabel.isNominal()) {
throw new UserError(null, 120, exampleSetLabel.getName(), "numerical", "nominal");
}
if (!label.isNominal() && exampleSetLabel.isNominal()) {
throw new UserError(null, 120, exampleSetLabel.getName(), "nominal", "numerical");
}
if (label.isNominal() && !sourceAttribute.isNominal()) {
throw new UserError(null, 120, sourceAttributeName, "numerical", "nominal");
}
if (!label.isNominal() && sourceAttribute.isNominal()) {
throw new UserError(null, 120, sourceAttributeName, "nominal", "numerical");
}
OperatorProgress progress = null;
if (getShowProgress() && getOperator() != null && getOperator().getProgress() != null) {
progress = getOperator().getProgress();
progress.setTotal(exampleSet.size());
}
int progressCounter = 0;
for (Example example : exampleSet) {
if (label.isNominal()) {
if (!exampleSetLabel.getMapping().equals(label.getMapping())) {
throw new UserError(null, 969);
}
if (!sourceAttribute.getMapping().equals(label.getMapping())) {
throw new UserError(null, 969);
}
String classValue = example.getValueAsString(sourceAttribute);
example.setValue(predictedLabel, classValue);
example.setConfidence(classValue, 1);
} else {
double classValue = example.getValue(sourceAttribute);
example.setValue(predictedLabel, classValue);
}
if (progress != null && ++progressCounter % OPERATOR_PROGRESS_STEPS == 0) {
progress.setCompleted(progressCounter);
}
}
} else {
throw new AttributeNotFoundError(null, null, sourceAttributeName);
}
return exampleSet;
}
}