/* * avenir: Predictive analytic based on Hadoop Map Reduce * Author: Pranab Ghosh * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package org.avenir.model; import org.chombo.util.FeatureSchema; import org.chombo.util.Pair; /** * abstract model predictor * @author pranab * */ public abstract class PredictiveModel { protected FeatureSchema schema; protected boolean errorCountingEnabled; protected int classAttributeOrd; protected String posClass; protected String negClass; protected String predClass; protected boolean costBasedPredictionEnabled; protected double falsePosCost; protected double falseNegCost; protected String[] items; protected Pair<String, Double> predClassProb; private int totalCount; private int errorCount; private int falsePosErrorCount; private int falseNegErrorCount; /** * */ public PredictiveModel() { } /** * @param schema */ public PredictiveModel(FeatureSchema schema) { this.schema = schema; } /** * @param classAttributeOrd * @param posClass * @param negClass */ public PredictiveModel enableErrorCounting(int classAttributeOrd, String posClass, String negClass) { errorCountingEnabled = true; this.classAttributeOrd = classAttributeOrd; withClassValues(posClass, negClass); return this; } /** * @param falsePosCost * @param falseNegCost */ public PredictiveModel enableCostBasedPrediction(String posClass, String negClass, double falsePosCost, double falseNegCost) { costBasedPredictionEnabled = true; withClassValues(posClass, negClass); this.falsePosCost = falsePosCost; this.falseNegCost = falseNegCost; return this; } /** * @param posClass * @param negClass */ private void withClassValues(String posClass, String negClass) { if (null == this.posClass) { this.posClass = posClass; this.negClass = negClass; } } /** * */ protected void countError() { ++totalCount; String actualClass = items[classAttributeOrd]; if (!actualClass.equals(predClass)) { if (predClass.equals(posClass)) { ++falsePosErrorCount; } else { ++falseNegErrorCount; } ++errorCount; } } /** * @param items * @return */ public abstract String predict(String[] items); /** * @param items * @return */ protected abstract Pair<String, Double> predictClassProb(String[] items); /** * @return */ public double getError() { double error = 0; if (errorCountingEnabled) { error = ((double)errorCount) / totalCount; } else { throw new IllegalStateException("error counting is not enabled"); } return error; } /** * @return */ public double getFalsePosError() { double error = 0; if (errorCountingEnabled) { error = ((double)falsePosErrorCount) / totalCount; } else { throw new IllegalStateException("error counting is not enabled"); } return error; } /** * @return */ public double getFalseNegError() { double error = 0; if (errorCountingEnabled) { error = ((double)falseNegErrorCount) / totalCount; } else { throw new IllegalStateException("error counting is not enabled"); } return error; } }