/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.ga.watchmaker.cd;
import java.util.Random;
import com.google.common.base.Preconditions;
import org.uncommons.maths.binary.BitString;
/**
* Binary classification rule of the form:
*
* <pre>
* if (condition1 && condition2 && ... ) then
* class = 1
* else
* class = 0
* </pre>
*
* where conditioni = (wi): attributi oi vi <br>
* <ul>
* <li>wi is the weight of the condition: <br>
* {@code if (wi < a given threshold) then conditioni is not taken into consideration}
* </li>
* <li>oi is an operator ('<' or '>=')</li>
* </ul>
*/
public final class CDRule implements Rule {
private final double threshold;
private final int nbConditions;
private final double[] weights;
private final BitString operators;
private final double[] values;
/**
* @param threshold
* condition activation threshold
*/
public CDRule(double threshold) {
// crossover needs at least 2 attributes
Preconditions.checkArgument(threshold >= 0.0 && threshold <= 1.0, "Threshold must be in [0,1]");
this.threshold = threshold;
// the label is not included in the conditions
this.nbConditions = DataSet.getDataSet().getNbAttributes() - 1;
weights = new double[nbConditions];
operators = new BitString(nbConditions);
values = new double[nbConditions];
}
/**
* Random rule.
*/
public CDRule(double threshold, Random rng) {
this(threshold);
DataSet dataset = DataSet.getDataSet();
for (int condInd = 0; condInd < nbConditions; condInd++) {
int attrInd = attributeIndex(condInd);
setW(condInd, rng.nextDouble());
setO(condInd, rng.nextBoolean());
if (dataset.isNumerical(attrInd)) {
setV(condInd, randomNumerical(dataset, attrInd, rng));
} else {
setV(condInd, randomCategorical(dataset, attrInd, rng));
}
}
}
/**
* Copy Constructor
*/
public CDRule(CDRule ind) {
threshold = ind.threshold;
nbConditions = ind.nbConditions;
weights = ind.weights.clone();
operators = ind.operators.clone();
values = ind.values.clone();
}
private static double randomNumerical(DataSet dataset, int attrInd, Random rng) {
double max = dataset.getMax(attrInd);
double min = dataset.getMin(attrInd);
return rng.nextDouble() * (max - min) + min;
}
private static double randomCategorical(DataSet dataset, int attrInd, Random rng) {
int nbcategories = dataset.getNbValues(attrInd);
return rng.nextInt(nbcategories);
}
/**
* if all the active conditions are met returns 1, else returns 0.
*/
@Override
public int classify(DataLine dl) {
for (int condInd = 0; condInd < nbConditions; condInd++) {
if (!condition(condInd, dl)) {
return 0;
}
}
return 1;
}
/**
* Makes sure that the label is not handled by any condition.
*
* @param condInd
* condition index
* @return attribute index
*/
public static int attributeIndex(int condInd) {
int labelpos = DataSet.getDataSet().getLabelIndex();
return condInd < labelpos ? condInd : condInd + 1;
}
/**
* Returns the value of the condition.
*
* @param condInd
* index of the condition
*/
boolean condition(int condInd, DataLine dl) {
int attrInd = attributeIndex(condInd);
// is the condition active
if (getW(condInd) < threshold) {
return true; // no
}
return DataSet.getDataSet().isNumerical(attrInd)
? numericalCondition(condInd, dl)
: categoricalCondition(condInd, dl);
}
boolean numericalCondition(int condInd, DataLine dl) {
int attrInd = attributeIndex(condInd);
return getO(condInd) ? dl.getAttribute(attrInd) >= getV(condInd) : dl.getAttribute(attrInd) < getV(condInd);
}
boolean categoricalCondition(int condInd, DataLine dl) {
int attrInd = attributeIndex(condInd);
return getO(condInd) ? dl.getAttribute(attrInd) == getV(condInd) : dl.getAttribute(attrInd) != getV(condInd);
}
@Override
public String toString() {
StringBuilder buffer = new StringBuilder();
buffer.append("CDRule = [");
boolean empty = true;
for (int condInd = 0; condInd < nbConditions; condInd++) {
if (getW(condInd) >= threshold) {
if (!empty) {
buffer.append(" && ");
}
buffer.append("attr").append(attributeIndex(condInd)).append(' ').append(
getO(condInd) ? ">=" : "<");
buffer.append(' ').append(getV(condInd));
empty = false;
}
}
buffer.append(']');
return buffer.toString();
}
public int getNbConditions() {
return nbConditions;
}
public double getW(int index) {
return weights[index];
}
public void setW(int index, double w) {
weights[index] = w;
}
/**
* operator
*
* @return true if '>='; false if '<'
*/
public boolean getO(int index) {
return operators.getBit(index);
}
/**
* set the operator
*
* @param o true if '>='; false if '<'
*/
public void setO(int index, boolean o) {
operators.setBit(index, o);
}
public double getV(int index) {
return values[index];
}
public void setV(int index, double v) {
values[index] = v;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!(obj instanceof CDRule)) {
return false;
}
CDRule rule = (CDRule) obj;
for (int index = 0; index < nbConditions; index++) {
if (!areGenesEqual(this, rule, index)) {
return false;
}
}
return true;
}
@Override
public int hashCode() {
int value = 0;
for (int index = 0; index < nbConditions; index++) {
value *= 31;
value += Double.doubleToLongBits(getW(index)) + (getO(index) ? 1 : 0) + getV(index);
}
return value;
}
/**
* Compares a given gene between two rules
*
* @param index
* gene index
* @return true if the gene is the same
*/
public static boolean areGenesEqual(CDRule rule1, CDRule rule2, int index) {
return rule1.getW(index) == rule2.getW(index) && rule1.getO(index) == rule2.getO(index)
&& rule1.getV(index) == rule2.getV(index);
}
/**
* Compares two genes from this Rule
*
* @param index1
* first gene index
* @param index2
* second gene index
* @return if the genes are equal
*/
public boolean areGenesEqual(int index1, int index2) {
return getW(index1) == getW(index2) && getO(index1) == getO(index2) && getV(index1) == getV(index2);
}
}