package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree;
import com.google.common.base.Predicate;
import quickml.data.AttributesMap;
import quickml.data.instances.Instance;
import java.io.IOException;
import java.io.Serializable;
import java.util.Map;
import java.util.Set;
public abstract class OldBranch extends OldNode {
private static final long serialVersionUID = 8290012786245422175L;
public final String attribute;
public OldNode trueChild, falseChild;
private double probabilityOfTrueChild;
public OldBranch(OldNode parent, final String attribute, double probabilityOfTrueChild) {
super(parent);
this.probabilityOfTrueChild = probabilityOfTrueChild;
this.attribute = attribute;
}
public abstract boolean decide(Map<String, Serializable> attributes);
@Override
public int size() {
return 1 + trueChild.size() + falseChild.size();
}
public Predicate<Instance<AttributesMap, Serializable>> getInPredicate() {
return new Predicate<Instance<AttributesMap, Serializable>>() {
@Override
public boolean apply(final Instance<AttributesMap, Serializable> input) {
return decide(input.getAttributes());
}
};
}
public Predicate<Instance<AttributesMap, Serializable>> getOutPredicate() {
return new Predicate<Instance<AttributesMap, Serializable>>() {
@Override
public boolean apply(final Instance<AttributesMap, Serializable> input) {
return !decide(input.getAttributes());
}
};
}
@Override
public OldLeaf getLeaf(final AttributesMap attributes) {
if (decide(attributes))
return trueChild.getLeaf(attributes);
else
return falseChild.getLeaf(attributes);
}
@Override
public double getProbabilityWithoutAttributes(AttributesMap attributes, Serializable classification, Set<String> attributesToIgnore) {
//TODO[mk] - check with Alex
if (attributesToIgnore.contains(this.attribute)) {
return probabilityOfTrueChild * trueChild.getProbabilityWithoutAttributes(attributes, classification, attributesToIgnore) +
(1 - probabilityOfTrueChild) * falseChild.getProbabilityWithoutAttributes(attributes, classification, attributesToIgnore);
} else {
if (decide(attributes)) {
return trueChild.getProbabilityWithoutAttributes(attributes, classification, attributesToIgnore);
}
else {
return falseChild.getProbabilityWithoutAttributes(attributes, classification, attributesToIgnore);
}
}
}
@Override
public void dump(final int indent, final Appendable ap) {
try {
for (int x = 0; x < indent; x++) {
ap.append(' ');
}
ap.append(this+"\n");
trueChild.dump(indent + 2, ap);
for (int x = 0; x < indent; x++) {
ap.append(' ');
}
ap.append(toNotString() +"\n");
falseChild.dump(indent + 2, ap);
}
catch (IOException e) {
throw new RuntimeException();
}
}
public abstract String toNotString();
@Override
protected void calcLeafDepthStats(final LeafDepthStats stats) {
trueChild.calcLeafDepthStats(stats);
falseChild.calcLeafDepthStats(stats);
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final OldBranch oldBranch = (OldBranch) o;
if (!attribute.equals(oldBranch.attribute)) return false;
if (!falseChild.equals(oldBranch.falseChild)) return false;
if (!trueChild.equals(oldBranch.trueChild)) return false;
return true;
}
@Override
public int hashCode() {
int result = attribute.hashCode();
result = 31 * result + trueChild.hashCode();
result = 31 * result + falseChild.hashCode();
return result;
}
}