package quickml.supervised.inspection;
import quickml.supervised.tree.decisionTree.nodes.DTCatBranch;
import quickml.supervised.tree.decisionTree.nodes.DTNumBranch;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.nodes.Node;
import quickml.utlities.SerializationUtility;
import quickml.supervised.tree.decisionTree.DecisionTree;
import quickml.supervised.tree.nodes.NumBranch;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest;
import java.io.*;
import java.util.*;
public class RandomForestDumper {
public void summarizeForest(PrintStream out, RandomDecisionForest randomDecisionForest) {
summarizeModel(out, randomDecisionForest);
}
public void summarizeForest(PrintStream out, String file) {
SerializationUtility<RandomDecisionForest> serializationUtility = new SerializationUtility<>();
RandomDecisionForest randomDecisionForest = serializationUtility.loadObjectFromGZIPFile(file);
summarizeModel(out, randomDecisionForest);
}
public void summarizeModel(PrintStream out, RandomDecisionForest forest) {
List<TreeSummary> summaries = new ArrayList<>();
for (DecisionTree t : forest.decisionTrees) {
TreeSummary summary = new TreeSummary();
summary.summarizeNode(t.root, 0);
summaries.add(summary);
}
TreeSummary summary = new TreeSummary();
for (TreeSummary t : summaries) {
summary.splits += t.splits;
for (AttributeSummary as : t.attributes.values()) {
AttributeSummary fas = summary.attributes.get(as.name);
if (fas == null) {
fas = new AttributeSummary();
fas.name = as.name;
summary.attributes.put(as.name, fas);
}
fas.splitCount+= as.splitCount;
fas.weightedSplitCount +=as.weightedSplitCount;
fas.treeCount++;
for (int i = 0; i < as.depths.length; i++) {
fas.depths[i]+= as.depths[i];
}
}
}
// Output trees, total splits, distinct attributes
out.format("%d trees, %d total splits, %d distinct attributes\n", forest.decisionTrees.size(), summary.splits, summary.attributes.size());
// Get attributes, sort, emit:
// - name, # trees, # splits, depths
List<AttributeSummary> attributes = new ArrayList<>(summary.attributes.values());
Collections.sort(attributes);
for (AttributeSummary s : attributes) {
out.format("%s : %f weightedSplits, %d trees, %d splits\n", s.name, s.weightedSplitCount, s.treeCount, s.splitCount);
out.format(" depths = %s\n", Arrays.toString(s.depths));
}
}
public static class TreeSummary {
private int splits;
private Map<String, AttributeSummary> attributes = new HashMap<>();
private void summarizeNode(Node<ClassificationCounter> node, int currentDepth) {
if (node instanceof DTCatBranch) {
summarizeCategoricalNode((DTCatBranch)node, currentDepth);
}
else if (node instanceof NumBranch) {
summarizeNumericNode((DTNumBranch) node, currentDepth);
}
}
private void addAttribute(String name, int depth) {
AttributeSummary attrSummary = attributes.get(name);
if (attrSummary == null) {
attrSummary = new AttributeSummary();
attrSummary.name = name;
attributes.put(name, attrSummary);
}
attrSummary.splitCount++;
attrSummary.weightedSplitCount = attrSummary.weightedSplitCount + Math.max(0.00000001, 1.0/Math.pow(2, depth));
attrSummary.depths[depth]++;
}
private void summarizeCategoricalNode(DTCatBranch node, int currentDepth) {
splits++;
addAttribute(node.attribute, currentDepth);
summarizeNode(node.getTrueChild(), currentDepth+1);
summarizeNode(node.getFalseChild(), currentDepth+1);
}
private void summarizeNumericNode(DTNumBranch node, int currentDepth) {
splits++;
addAttribute(node.attribute, currentDepth);
summarizeNode(node.getTrueChild(), currentDepth+1);
summarizeNode(node.getFalseChild(), currentDepth + 1);
}
}
private static class AttributeSummary implements Comparable<AttributeSummary> {
private String name;
private int treeCount;
private int splitCount;
private double weightedSplitCount;
private int[] depths= new int[20];
public int compareTo(AttributeSummary other) {
int result = -Double.compare(weightedSplitCount, other.weightedSplitCount);
if (result == 0) {
result = -Integer.compare(treeCount, other.treeCount);
}
if (result == 0) {
result = -Integer.compare(splitCount, other.splitCount);
}
if (result == 0) {
result = name.compareTo(other.name);
}
return result;
}
}
}