/*
* This file is part of ADDIS (Aggregate Data Drug Information System).
* ADDIS is distributed from http://drugis.org/.
* Copyright © 2009 Gert van Valkenhoef, Tommi Tervonen.
* Copyright © 2010 Gert van Valkenhoef, Tommi Tervonen, Tijs Zwinkels,
* Maarten Jacobs, Hanno Koeslag, Florin Schimbinschi, Ahmad Kamal, Daniel
* Reid.
* Copyright © 2011 Gert van Valkenhoef, Ahmad Kamal, Daniel Reid, Florin
* Schimbinschi.
* Copyright © 2012 Gert van Valkenhoef, Daniel Reid, Joël Kuiper, Wouter
* Reckman.
* Copyright © 2013 Gert van Valkenhoef, Joël Kuiper.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.drugis.addis.entities.treatment;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import org.drugis.common.EqualsUtil;
import edu.uci.ics.jung.graph.DelegateTree;
import edu.uci.ics.jung.graph.DirectedGraph;
import edu.uci.ics.jung.graph.DirectedSparseGraph;
import edu.uci.ics.jung.graph.FixedObservableGraph;
import edu.uci.ics.jung.graph.Graph;
import edu.uci.ics.jung.graph.ObservableGraph;
import edu.uci.ics.jung.graph.util.Pair;
public class DecisionTree extends DelegateTree<DecisionTreeNode, DecisionTreeEdge> {
private static final long serialVersionUID = -2669529780972041770L;
public static class ObservableDirectedGraph<V, E> extends FixedObservableGraph<V, E> implements DirectedGraph<V, E> {
private static final long serialVersionUID = 442135818546886998L;
public ObservableDirectedGraph(final Graph<V, E> delegate) {
super(delegate);
}
@Override
public Collection<V> getSuccessors(V vertex) {
return new ArrayList<V>(super.getSuccessors(vertex));
}
}
public DecisionTree(final DecisionTreeNode rootNode) {
super(new ObservableDirectedGraph<DecisionTreeNode, DecisionTreeEdge>(new DirectedSparseGraph<DecisionTreeNode, DecisionTreeEdge>()));
setRoot(rootNode);
}
public ObservableGraph<DecisionTreeNode, DecisionTreeEdge> getObservableGraph() {
return (ObservableGraph<DecisionTreeNode, DecisionTreeEdge>) delegate;
}
/**
* Classify the given object.
* @throws IllegalStateException If the tree is incomplete, or the obj is of an incompatible type.
*/
public LeafNode decide(final Object obj) {
return decide(obj, getRoot());
}
private LeafNode decide(final Object obj, final DecisionTreeNode parent) {
if (parent instanceof ChoiceNode) {
final ChoiceNode choice = (ChoiceNode) parent;
final Object value = choice.getValue(obj);
final DecisionTreeEdge e = findMatchingEdge(parent, value);
if (e != null) {
return decide(obj, getEdgeTarget(e));
} else {
throw new IllegalStateException("Object " + obj + " could not be classified");
}
}
return (LeafNode) parent;
}
public DecisionTreeNode getEdgeTarget(final DecisionTreeEdge e) {
return containsEdge(e) ? new Pair<DecisionTreeNode>(getIncidentVertices(e)).getSecond() : null;
}
public DecisionTreeNode getEdgeSource(final DecisionTreeEdge e) {
return new Pair<DecisionTreeNode>(getIncidentVertices(e)).getFirst();
}
public DecisionTreeEdge findMatchingEdge(final DecisionTreeNode parent, final Object value) {
for (final DecisionTreeEdge e : getOutEdges(parent)) {
if (e.decide(value)) {
return e;
}
}
return null;
}
public void replaceChild(final DecisionTreeEdge edge, final DecisionTreeNode newChild) {
final DecisionTreeNode parent = getEdgeSource(edge);
removeChild(getEdgeTarget(edge));
addChild(edge, parent, newChild);
}
public boolean equivalent(DecisionTree obj) {
return equivalent(getRoot(), obj.getRoot(), this, obj);
}
private static boolean equivalent(DecisionTreeNode n1, DecisionTreeNode n2, DecisionTree t1, DecisionTree t2) {
boolean equivalent = n1.equivalent(n2);
Collection<DecisionTreeEdge> n1Edges = t1.getOutEdges(n1);
Collection<DecisionTreeEdge> n2Edges = t2.getOutEdges(n2);
if(equivalent && n1Edges.size() == n2Edges.size()) {
for (DecisionTreeEdge e1 : n1Edges) {
DecisionTreeEdge e2 = containsEquivalent(n2Edges, e1);
if (e2 != null) {
equivalent = equivalent(t1.getEdgeTarget(e1), t2.getEdgeTarget(e2), t1, t2);
} else {
equivalent = false;
}
if (!equivalent) {
break;
}
}
}
return equivalent;
}
private static DecisionTreeEdge containsEquivalent(Collection<DecisionTreeEdge> list, DecisionTreeEdge edge) {
for (DecisionTreeEdge e2 : list) {
if (e2.equivalent(edge)) {
return e2;
}
}
return null;
}
public String getLabel(Category category) {
List<LeafNode> leafs = findLeafNodes(category);
List<String> labels = new ArrayList<String>();
for (LeafNode leaf : leafs) {
labels.add(getLabel(leaf));
}
Collections.sort(labels);
if(labels.size() < 2) {
return StringUtils.join(labels, ") OR (");
} else {
return "(" + StringUtils.join(labels, ") OR (") + ")";
}
}
private String getLabel(DecisionTreeNode leaf) {
List<String> labels = new ArrayList<String>();
DecisionTreeEdge parentEdge = getParentEdge(leaf);
DecisionTreeNode parent = getParent(leaf);
labels.add(parent + " " + parentEdge);
if(!isRoot(parent)) {
labels.add(getLabel(parent));
}
Collections.reverse(labels);
return StringUtils.join(labels, " AND ");
}
private List<LeafNode> findLeafNodes(Category category) {
List<LeafNode> leafs = new ArrayList<LeafNode>();
for (DecisionTreeNode node : getVertices()) {
if (node instanceof LeafNode) {
LeafNode leafNode = (LeafNode) node;
if (EqualsUtil.equal(leafNode.getCategory(), category)) leafs.add(leafNode);
}
}
return leafs;
}
}