/** * Copyright 2009 DFKI GmbH. * All Rights Reserved. Use is subject to license terms. * * This file is part of MARY TTS. * * MARY TTS is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation, version 3 of the License. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. * */ package marytts.tools.voiceimport.traintrees; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Future; import marytts.cart.CART; import marytts.cart.DecisionNode; import marytts.cart.DirectedGraph; import marytts.cart.DirectedGraphNode; import marytts.cart.FeatureVectorCART; import marytts.cart.LeafNode; import marytts.cart.Node; import marytts.cart.LeafNode.FeatureVectorLeafNode; import marytts.cart.impose.FeatureArrayIndexer; import marytts.features.FeatureDefinition; import marytts.features.FeatureVector; /** * @author marc * */ public class AgglomerativeClusterer { private static final float SINGLE_ITEM_IMPURITY = 0; private FeatureVector[] trainingFeatures; private FeatureVector[] testFeatures; private Map<LeafNode, Double> impurities = new HashMap<LeafNode, Double>(); private FeatureDefinition featureDefinition; private int numByteFeatures; private int[] availableFeatures; // private double globalMean; private double globalStddev; private DistanceMeasure dist; private double minFSGI, minCriterion; private int iBestFeature; private float[][] squaredDistances; private DirectedGraph graph; private int[] prevFeatureList; private double prevFSGI; private double prevTestDataDistance; private boolean canClusterMore = true; public AgglomerativeClusterer(FeatureVector[] features, FeatureDefinition featureDefinition, List<String> featuresToUse, DistanceMeasure dist) { this(features, featureDefinition, featuresToUse, dist, 0.1f); } public AgglomerativeClusterer(FeatureVector[] features, FeatureDefinition featureDefinition, List<String> featuresToUse, DistanceMeasure dist, float proportionTestData) { // Now replace all feature vectors with feature vectors whose unit index // corresponds to the distance matrix in squaredDistance: for (int i = 0; i < features.length; i++) { features[i] = new FeatureVector(features[i].getByteValuedDiscreteFeatures(), features[i].getShortValuedDiscreteFeatures(), features[i].getContinuousFeatures(), i); } this.dist = dist; this.globalStddev = Math.sqrt(((F0ContourPolynomialDistanceMeasure) dist).computeVariance(features)); System.out.println("Global stddev: " + globalStddev); /* * // Get an estimate of the global mean by sampling: estimateGlobalMean(features, dist); * * // Precompute distances and set unit index features accordingly System.out.println("Precomputing distances..."); long * startTime = System.currentTimeMillis(); squaredDistances = new float[features.length-1][]; for (int i=0; * i<features.length-1; i++) { squaredDistances[i] = new float[features.length - i -1]; for (int j=i+1; j<features.length; * j++) { squaredDistances[i][j-i-1] = dist.squaredDistance(features[i], features[j]); } } * * long endTime = System.currentTimeMillis(); * System.out.println("Computed distances between "+features.length+" items in "+(endTime-startTime)+" ms"); */ int nSkip = (int) (1 / proportionTestData); // we use every nSkip'th feature vector as test data int numTestFeatures = features.length / nSkip; if (numTestFeatures * nSkip < features.length) numTestFeatures++; this.testFeatures = new FeatureVector[numTestFeatures]; this.trainingFeatures = new FeatureVector[features.length - testFeatures.length]; int iTest = 0, iTrain = 0; for (int i = 0; i < features.length; i++) { if (i % nSkip == 0) { testFeatures[iTest++] = features[i]; } else { trainingFeatures[iTrain++] = features[i]; } } this.featureDefinition = featureDefinition; this.numByteFeatures = featureDefinition.getNumberOfByteFeatures(); if (featuresToUse != null) { availableFeatures = new int[featuresToUse.size()]; for (int i = 0; i < availableFeatures.length; i++) { availableFeatures[i] = featureDefinition.getFeatureIndex(featuresToUse.get(i)); } } else { // no features given, use all byte-valued features availableFeatures = new int[numByteFeatures]; for (int i = 0; i < numByteFeatures; i++) { availableFeatures[i] = i; } } graph = new DirectedGraph(featureDefinition); graph.setRootNode(new DirectedGraphNode(null, null)); prevFeatureList = new int[0]; prevFSGI = Double.POSITIVE_INFINITY; prevTestDataDistance = Double.POSITIVE_INFINITY; canClusterMore = true; } public DirectedGraph getGraph() { return graph; } public boolean canClusterMore() { return canClusterMore; } public DirectedGraph cluster() { if (!canClusterMore) return null; long startTime = System.currentTimeMillis(); int[] newFeatureList = new int[prevFeatureList.length + 1]; System.arraycopy(prevFeatureList, 0, newFeatureList, 0, prevFeatureList.length); // Step 1: Feature selection // We look for the feature that yields the best (=lowest) global impurity. // Stop criterion: when the best feature does not substantially add new leaves. FeatureArrayIndexer fai = new FeatureArrayIndexer(trainingFeatures, featureDefinition); // Count previous number of leaves: fai.deepSort(prevFeatureList); CART prevCART = new FeatureVectorCART(fai.getTree(), fai); int prevNLeaves = 0; for (LeafNode leaf : prevCART.getLeafNodes()) { if (leaf != null && leaf.getNumberOfData() > 0) prevNLeaves++; } iBestFeature = -1; minFSGI = Double.POSITIVE_INFINITY; minCriterion = Double.POSITIVE_INFINITY; Set<Future<?>> openJobs = new HashSet<Future<?>>(); // Loop over all unused discrete features, and compute their Global Impurity for (int f = 0; f < availableFeatures.length; f++) { int fi = availableFeatures[f]; boolean featureAlreadyUsed = false; for (int i = 0; i < prevFeatureList.length; i++) { if (prevFeatureList[i] == fi) { featureAlreadyUsed = true; break; } } if (featureAlreadyUsed) continue; newFeatureList[newFeatureList.length - 1] = fi; fai.deepSort(newFeatureList); CART testCART = new FeatureVectorCART(fai.getTree(), fai); assert testCART.getRootNode().getNumberOfData() == trainingFeatures.length; verifyFeatureQuality(fi, testCART, prevNLeaves); } newFeatureList[newFeatureList.length - 1] = iBestFeature; fai.deepSort(newFeatureList); CART bestFeatureCart = new FeatureVectorCART(fai.getTree(), fai); int nLeaves = 0; for (LeafNode leaf : bestFeatureCart.getLeafNodes()) { if (leaf != null && leaf.getNumberOfData() > 0) nLeaves++; } long featSelectedTime = System.currentTimeMillis(); // Now walk through graphSoFar and bestFeatureCart in parallel, // and add the leaves of bestFeatureCart into graphSoFar in order // to enable clustering: Node fNode = bestFeatureCart.getRootNode(); Node gNode = graph.getRootNode(); List<DirectedGraphNode> newLeavesList = new ArrayList<DirectedGraphNode>(); updateGraphFromTree((DecisionNode) fNode, (DirectedGraphNode) gNode, newLeavesList); DirectedGraphNode[] newLeaves = newLeavesList.toArray(new DirectedGraphNode[0]); System.out.printf("Level %2d: %25s (%5d leaves, gi=%7.3f -->", newFeatureList.length, featureDefinition.getFeatureName(iBestFeature), newLeaves.length, minFSGI); float[][] deltaGI = new float[newLeaves.length - 1][]; for (int i = 0; i < newLeaves.length - 1; i++) { deltaGI[i] = new float[newLeaves.length - i - 1]; for (int j = i + 1; j < newLeaves.length; j++) { deltaGI[i][j - i - 1] = (float) computeDeltaGI(newLeaves[i], newLeaves[j]); } } int numLeavesLeft = newLeaves.length; // Now cluster the leaves float minDeltaGI, threshold; int bestPair1, bestPair2; do { // threshold = 100*(float)(Math.log(numLeavesLeft)-Math.log(numLeavesLeft-1)); // threshold = (float)(Math.log(numLeavesLeft)-Math.log(numLeavesLeft-1)); threshold = 0; // threshold = 0.01f; minDeltaGI = threshold; // if we cannot find any that is better, stop. bestPair1 = bestPair2 = -1; for (int i = 0; i < newLeaves.length - 1; i++) { if (newLeaves[i] == null) continue; for (int j = i + 1; j < newLeaves.length; j++) { if (newLeaves[j] == null) continue; if (deltaGI[i][j - i - 1] < minDeltaGI) { bestPair1 = i; bestPair2 = j; minDeltaGI = deltaGI[i][j - i - 1]; } } } // System.out.printf("NumLeavesLeft=%4d, threshold=%f, minDeltaGI=%f\n", numLeavesLeft, threshold, minDeltaGI); if (minDeltaGI < threshold) { // found something to merge mergeLeaves(newLeaves[bestPair1], newLeaves[bestPair2]); numLeavesLeft--; // System.out.println("Merged leaves "+bestPair1+" and "+bestPair2+" (deltaGI: "+minDeltaGI+")"); newLeaves[bestPair2] = null; // Update deltaGI table: for (int i = 0; i < bestPair2; i++) { deltaGI[i][bestPair2 - i - 1] = Float.NaN; } for (int j = bestPair2 + 1; j < newLeaves.length; j++) { deltaGI[bestPair2][j - bestPair2 - 1] = Float.NaN; } for (int i = 0; i < bestPair1; i++) { if (newLeaves[i] != null) deltaGI[i][bestPair1 - i - 1] = (float) computeDeltaGI(newLeaves[i], newLeaves[bestPair1]); } for (int j = bestPair1 + 1; j < newLeaves.length; j++) { if (newLeaves[j] != null) deltaGI[bestPair1][j - bestPair1 - 1] = (float) computeDeltaGI(newLeaves[bestPair1], newLeaves[j]); } } } while (minDeltaGI < threshold); int nLeavesLeft = 0; List<LeafNode> survivors = new ArrayList<LeafNode>(); for (int i = 0; i < newLeaves.length; i++) { if (newLeaves[i] != null) { nLeavesLeft++; survivors.add((LeafNode) ((DirectedGraphNode) newLeaves[i]).getLeafNode()); } } long clusteredTime = System.currentTimeMillis(); System.out.printf("%5d leaves, gi=%7.3f).", nLeavesLeft, computeGlobalImpurity(survivors)); deltaGI = null; impurities.clear(); float testDist = rmsDistanceTestData(graph); System.out.printf(" Distance test data: %5.3f", testDist); System.out.printf(" | fs %5dms, cl %5dms", (featSelectedTime - startTime), (clusteredTime - featSelectedTime)); System.out.println(); // Stop criterion: stop if feature selection does not succeed in reducing global impurity further, // and at the same time, the test data approximation is getting worse. if (minFSGI > prevFSGI && testDist > prevTestDataDistance) { canClusterMore = false; } // Iteration step: prevFeatureList = newFeatureList; prevFSGI = minFSGI; prevTestDataDistance = testDist; return graph; } private void verifyFeatureQuality(int fi, CART testCART, int prevNLeaves) { List<LeafNode> leaves = new ArrayList<LeafNode>(); int nLeaves = 0; for (LeafNode leaf : testCART.getLeafNodes()) { if (leaf.isEmpty()) continue; leaves.add(leaf); nLeaves++; } if (nLeaves <= prevNLeaves) { // this feature adds no leaf return; // will not consider this further } double gi = computeGlobalImpurity(leaves, minCriterion); // More leaves cost a bit: double sizeBias = Math.log((float) nLeaves / prevNLeaves); assert sizeBias > 0; // double sizeBias = (float)nLeaves/prevNLeaves; // assert sizeBias > 1; // System.out.printf("%30s: GI=%.3f bias=%.7f\n", featureDefinition.getFeatureName(fi),gi,sizeBias); double criterion = gi; /* * if (gi > globalMean) { // The best one is the one that can reach a small gi with a small increase in number of leaves * criterion = globalMean + (gi-globalMean) * (1+sizeBias); } else { // leave as is, no size bias } */ if (criterion < minCriterion) { setMinCriterion(criterion); setMinFSGI(gi); setBestFeature(fi); } } /** * Estimate the mean of all *distances* in the training set. * * @param leaves * leaves * @return computeglobalimpurity(leaves, double.Positive_infinity) */ /* * private void estimateGlobalMean(FeatureVector[] data, DistanceMeasure dist) { int sampleSize = 100000; * System.out.println("Estimating global mean by random sampling "+sampleSize+" distances"); long startTime = * System.currentTimeMillis(); // Compute mean and stddev using recurrence relation, attributed by Donald Knuth // (The Art of * Computer Programming, Volume 2: Seminumerical Algorithms, Section 4.2.2) // to B.P. Welford, Technometrics, 4, (1962), * 419-420. // M(1) = x(1), M(k) = M(k-1) + (x(k) - M(k-1))/k // S(1) = 0, S(k) = S(k-1) + (x(k) - M(k-1))*(x(k)-M(k)) // for * 2 <= k <= n, then sigma = sqrt(S(n)/(n-1)) // globalMean = 0; Random random = new Random(); for (int k=1; k<sampleSize; * k++) { int i = random.nextInt(data.length); int j = random.nextInt(data.length); double xk = dist.distance(data[i], * data[j]); double mk = globalMean + (xk - globalMean) / k; globalMean = mk; } //globalMean = Math.sqrt(globalMean); long * endTime = System.currentTimeMillis(); * System.out.println("Computation of "+sampleSize+" distances took "+(endTime-startTime)+" ms"); * System.out.println("Global mean distance = "+globalMean); } */ private double computeGlobalImpurity(List<LeafNode> leaves) { return computeGlobalImpurity(leaves, Double.POSITIVE_INFINITY); } /** * Compute global impurity as the weighted sum of leaf impurities. stop when cutoff value is reached or surpassed. * * @param leaves * leaves * @param cutoff * cutoff * @return gi */ private double computeGlobalImpurity(List<LeafNode> leaves, double cutoff) { cutoff *= trainingFeatures.length; double gi = 0; // Global Impurity measures the average distance of an instance // to the other instances in the same leaf. // Global Impurity is computed as follows: // GI = 1/N * sum(|l| * I(l)), where // N = total number of instances (feature vectors); // |l| = the number of instances in a leaf; // I(l) = the impurity of the leaf. int numLeaves = 0; for (LeafNode leaf : leaves) { if (leaf.isEmpty()) continue; gi += leaf.getNumberOfData() * computeImpurity(leaf); numLeaves++; if (gi >= cutoff) { // too high, stop it // System.out.println("Cutoff exceeded, breaking"); break; } } gi /= trainingFeatures.length; return gi; } private double computeImpurity(LeafNode leaf) { /* * impurities.remove(leaf); double i1 = computeMutualDistanceImpurity(leaf); impurities.remove(leaf); double i2 = * computeVarianceImpurity(leaf); System.out.printf("mdi=%.3f, vi=%.3f\n", i1, i2); return i2; */ // return computeMutualDistanceImpurity(leaf); return computeVarianceImpurity(leaf); } /** * The impurity of a leaf node is computed as follows: I(l) = sqrt( 2/(|l|*(|l|-1)) * sum over all pairs(distance of pair) ), * where |l| = the number of instances in the leaf. * * @param leaf * leaf * @return impurity */ /* * private double computeMutualDistanceImpurity(LeafNode leaf) { if (!(leaf instanceof FeatureVectorLeafNode)) throw new * IllegalArgumentException("Currently only feature vector leaf nodes are supported"); if (impurities.containsKey(leaf)) * return impurities.get(leaf); FeatureVectorLeafNode l = (FeatureVectorLeafNode) leaf; FeatureVector[] fvs = * l.getFeatureVectors(); int[] leafIndices = new int[fvs.length]; for (int i=0; i<fvs.length; i++) { leafIndices[i] = * fvs[i].getUnitIndex(); } int len = fvs.length; double impurity = globalMean * Math.exp(-(len-1)); if (len < 2) return * impurity; double rmsDistance = 0; //System.out.println("Leaf has "+n+" items, computing "+(n*(n-1)/2)+" distances"); for * (int i=0; i<len; i++) { int li = leafIndices[i]; for (int j=i+1; j<len; j++) { int lj = leafIndices[j]; if (li < lj) { * rmsDistance += squaredDistances[li][lj-li-1]; } else if (lj < li) { rmsDistance += squaredDistances[lj][li-lj-1]; } } } * rmsDistance *= 2./(len*(len-1)); rmsDistance = Math.sqrt(rmsDistance); * * //System.out.println("len="+len+", baseI="+impurity+", rmsDist="+rmsDistance); impurity += rmsDistance; * * // Normalise impurity: //impurity -= globalMean; //impurity /= globalStddev; * * * // Penalty for small leaves: //impurity += (float)SINGLE_ITEM_IMPURITY/(len*len); * * impurities.put(leaf, impurity); return impurity; } */ private double computeVarianceImpurity(LeafNode leaf) { if (!(leaf instanceof FeatureVectorLeafNode)) throw new IllegalArgumentException("Currently only feature vector leaf nodes are supported"); if (impurities.containsKey(leaf)) return impurities.get(leaf); FeatureVectorLeafNode l = (FeatureVectorLeafNode) leaf; FeatureVector[] fvs = l.getFeatureVectors(); int[] leafIndices = new int[fvs.length]; for (int i = 0; i < fvs.length; i++) { leafIndices[i] = fvs[i].getUnitIndex(); } int len = fvs.length; double impurity = globalStddev * Math.exp(-(len - 1)); if (len < 2) return impurity; double variance = ((F0ContourPolynomialDistanceMeasure) dist).computeVariance(fvs); impurity += Math.sqrt(variance); impurities.put(leaf, impurity); return impurity; } private double computeDeltaGI(DirectedGraphNode dgn1, DirectedGraphNode dgn2) { // return computeMutualDistanceDeltaGI(dgn1, dgn2); return computeVarianceDeltaGI(dgn1, dgn2); } /** * The delta in global impurity that would be caused by merging the two given leaves is computed as follows. Delta GI = * (|l1|+|l2|) * I(l1 united with l2) - |l1| * I(l1) - |l2| * I(l2) = 1/N*(|l1|+|l2|-1) * (sum of all distances between items * in l1 and items in l2 - |l2| * I(l1) - |l1| * I(l2) ) where N = sum of all |l| = total number of instances in the tree, * |l| = number of instances in the leaf l * * @param dgn1 * dgn1 * @param dgn2 * dgn2 * @return deltaGI */ private double computeMutualDistanceDeltaGI(DirectedGraphNode dgn1, DirectedGraphNode dgn2) { FeatureVectorLeafNode l1 = (FeatureVectorLeafNode) dgn1.getLeafNode(); FeatureVectorLeafNode l2 = (FeatureVectorLeafNode) dgn2.getLeafNode(); FeatureVector[] fv1 = l1.getFeatureVectors(); int[] indices1 = new int[fv1.length]; for (int i = 0; i < fv1.length; i++) { indices1[i] = fv1[i].getUnitIndex(); } FeatureVector[] fv2 = l2.getFeatureVectors(); int[] indices2 = new int[fv2.length]; for (int j = 0; j < fv2.length; j++) { indices2[j] = fv2[j].getUnitIndex(); } double deltaGI = 0; int len1 = l1.getNumberOfData(); int len2 = l2.getNumberOfData(); int len12 = len1 + len2; double imp1 = computeImpurity(l1); double imp2 = computeImpurity(l2); double imp12 = len1 * (len1 - 1) / 2 * imp1 * imp1 + len2 * (len2 - 1) / 2 * imp2 * imp2; // Sum of all distances across leaf boundaries: for (int i = 0; i < fv1.length; i++) { int li = indices1[i]; for (int j = 0; j < fv2.length; j++) { int lj = indices2[j]; if (li < lj) { imp12 += squaredDistances[li][lj - li - 1]; } else if (lj < li) { imp12 += squaredDistances[lj][li - lj - 1]; } } } imp12 *= 2. / (len12 * (len12 - 1)); imp12 = Math.sqrt(imp12); deltaGI = 1. / trainingFeatures.length * (len12 * imp12 - len1 * imp1 - len2 * imp2); // Encourage small leaves to merge: // double sizeEffect = 0.01 * (1./((len1+len2)*(len1+len2)) - 1./(len1*len1) - 1./(len2*len2)); // System.out.println("len1="+len1+", len2="+len2+", sizeEffect="+sizeEffect+", deltaGI="+deltaGI); // deltaGI += sizeEffect; return deltaGI; } private double computeVarianceDeltaGI(DirectedGraphNode dgn1, DirectedGraphNode dgn2) { FeatureVectorLeafNode l1 = (FeatureVectorLeafNode) dgn1.getLeafNode(); FeatureVectorLeafNode l2 = (FeatureVectorLeafNode) dgn2.getLeafNode(); FeatureVector[] fv1 = l1.getFeatureVectors(); int[] indices1 = new int[fv1.length]; for (int i = 0; i < fv1.length; i++) { indices1[i] = fv1[i].getUnitIndex(); } FeatureVector[] fv2 = l2.getFeatureVectors(); int[] indices2 = new int[fv2.length]; for (int j = 0; j < fv2.length; j++) { indices2[j] = fv2[j].getUnitIndex(); } double deltaGI = 0; int len1 = fv1.length; int len2 = fv2.length; double imp1 = computeImpurity(l1); double imp2 = computeImpurity(l2); FeatureVector[] fv12 = new FeatureVector[fv1.length + fv2.length]; System.arraycopy(fv1, 0, fv12, 0, fv1.length); System.arraycopy(fv2, 0, fv12, fv1.length, fv2.length); int len12 = fv12.length; double imp12 = globalStddev * Math.exp(-(len12 - 1)); double variance = ((F0ContourPolynomialDistanceMeasure) dist).computeVariance(fv12); imp12 += Math.sqrt(variance); deltaGI = 1. / trainingFeatures.length * (len12 * imp12 - len1 * imp1 - len2 * imp2); // System.out.printf("deltaGI=%.3f -- I(%d)=%.3f, I(%d)=%.3f => I(%d)=%.3f\n", deltaGI, len1, imp1, len2, imp2, len12, // imp12); return deltaGI; } private void mergeLeaves(DirectedGraphNode dgn1, DirectedGraphNode dgn2) { // Copy all data from dgn2 into dgn1 FeatureVectorLeafNode l1 = (FeatureVectorLeafNode) dgn1.getLeafNode(); FeatureVectorLeafNode l2 = (FeatureVectorLeafNode) dgn2.getLeafNode(); FeatureVector[] fv1 = l1.getFeatureVectors(); FeatureVector[] fv2 = l2.getFeatureVectors(); FeatureVector[] newFV = new FeatureVector[fv1.length + fv2.length]; System.arraycopy(fv1, 0, newFV, 0, fv1.length); System.arraycopy(fv2, 0, newFV, fv1.length, fv2.length); l1.setFeatureVectors(newFV); // then update all mother/daughter relationships Set<Node> dgn2Mothers = new HashSet<Node>(dgn2.getMothers()); for (Node mother : dgn2Mothers) { if (mother instanceof DecisionNode) { DecisionNode dm = (DecisionNode) mother; dm.replaceDaughter(dgn1, dgn2.getNodeIndex(mother)); } else if (mother instanceof DirectedGraphNode) { DirectedGraphNode gm = (DirectedGraphNode) mother; gm.setLeafNode(dgn1); } dgn2.removeMother(mother); } dgn2.setLeafNode(null); l2.setMother(null, 0); // and remove impurity entries: try { impurities.remove(l1); impurities.remove(l2); } catch (NullPointerException e) { e.printStackTrace(); System.err.println("Impurities: " + impurities + ", l1:" + l1 + ", l2:" + l2); } } private void updateGraphFromTree(DecisionNode treeNode, DirectedGraphNode graphNode, List<DirectedGraphNode> newLeaves) { int treeFeatureIndex = treeNode.getFeatureIndex(); int treeNumDaughters = treeNode.getNumberOfDaugthers(); DecisionNode graphDecisionNode = graphNode.getDecisionNode(); if (graphDecisionNode != null) { // Sanity check: the two must be aligned: same feature, same number of children int graphFeatureIndex = graphDecisionNode.getFeatureIndex(); assert treeFeatureIndex == graphFeatureIndex : "Tree indices out of sync!"; assert treeNumDaughters == graphDecisionNode.getNumberOfDaugthers() : "Tree structure out of sync!"; // OK, now recursively call ourselves for all daughters for (int i = 0; i < treeNumDaughters; i++) { // We expect the next tree node to be a decision node (unless it is an empty node), // because the level just above the leaves does not exist in graph yet. Node nextTreeNode = treeNode.getDaughter(i); if (nextTreeNode == null) continue; else if (nextTreeNode instanceof LeafNode) { assert ((LeafNode) nextTreeNode).getNumberOfData() == 0; continue; } assert nextTreeNode instanceof DecisionNode; DirectedGraphNode nextGraphNode = (DirectedGraphNode) graphDecisionNode.getDaughter(i); updateGraphFromTree((DecisionNode) nextTreeNode, nextGraphNode, newLeaves); } } else { // No structure in graph yet which corresponds to tree. // This is what we actually want to do. if (featureDefinition.isByteFeature(treeFeatureIndex)) { graphDecisionNode = new DecisionNode.ByteDecisionNode(treeFeatureIndex, treeNumDaughters, featureDefinition); } else { assert featureDefinition.isShortFeature(treeFeatureIndex) : "Only support byte and short features"; graphDecisionNode = new DecisionNode.ShortDecisionNode(treeFeatureIndex, treeNumDaughters, featureDefinition); } assert treeNumDaughters == graphDecisionNode.getNumberOfDaugthers(); graphNode.setDecisionNode(graphDecisionNode); for (int i = 0; i < treeNumDaughters; i++) { // we expect the next tree node to be a leaf node LeafNode nextTreeNode = (LeafNode) treeNode.getDaughter(i); // Now create the new daughter number i of graphDecisionNode. // It is a DirectedGraphNode containing no decision tree but // a leaf node, which is itself a DirectedGraphNode with no // decision node but a leaf node: if (nextTreeNode != null && nextTreeNode.getNumberOfData() > 0) { DirectedGraphNode daughterLeafNode = new DirectedGraphNode(null, nextTreeNode); DirectedGraphNode daughterNode = new DirectedGraphNode(null, daughterLeafNode); graphDecisionNode.addDaughter(daughterNode); newLeaves.add(daughterLeafNode); } else { graphDecisionNode.addDaughter(null); } } } } private float rmsDistanceTestData(DirectedGraph graph) { // return rmsMutualDistanceTestData(graph); return rmsMeanDistanceTestData(graph); } private float rmsMeanDistanceTestData(DirectedGraph graph) { float avgDist = 0; for (int i = 0; i < testFeatures.length; i++) { int ti = testFeatures[i].getUnitIndex(); FeatureVector[] leafData = (FeatureVector[]) graph.interpret(testFeatures[i]); float[] mean = ((F0ContourPolynomialDistanceMeasure) dist).computeMean(leafData); float oneDist = ((F0ContourPolynomialDistanceMeasure) dist).squaredDistance(testFeatures[i], mean); oneDist = (float) Math.sqrt(oneDist); avgDist += oneDist; } avgDist /= testFeatures.length; return avgDist; } private float rmsMutualDistanceTestData(DirectedGraph graph) { float avgDist = 0; for (int i = 0; i < testFeatures.length; i++) { int ti = testFeatures[i].getUnitIndex(); FeatureVector[] leafData = (FeatureVector[]) graph.interpret(testFeatures[i]); float oneDist = 0; for (int j = 0; j < leafData.length; j++) { int lj = leafData[j].getUnitIndex(); if (ti < lj) { oneDist += squaredDistances[ti][lj - ti - 1]; } else if (lj < ti) { oneDist += squaredDistances[lj][ti - lj - 1]; } } oneDist /= leafData.length; oneDist = (float) Math.sqrt(oneDist); avgDist += oneDist; } avgDist /= testFeatures.length; return avgDist; } private void setMinCriterion(double value) { minCriterion = value; } private void setMinFSGI(double value) { minFSGI = value; } private void setBestFeature(int featureIndex) { iBestFeature = featureIndex; } private void debugOut(DirectedGraph graph) { for (Iterator<Node> it = graph.getNodeIterator(); it.hasNext();) { Node next = it.next(); debugOut(next); } } private void debugOut(CART graph) { Node root = graph.getRootNode(); debugOut(root); } private void debugOut(Node node) { if (node instanceof DirectedGraphNode) debugOut((DirectedGraphNode) node); else if (node instanceof LeafNode) debugOut((LeafNode) node); else debugOut((DecisionNode) node); } private void debugOut(DirectedGraphNode node) { System.out.println("DGN"); if (node.getLeafNode() != null) debugOut(node.getLeafNode()); if (node.getDecisionNode() != null) debugOut(node.getDecisionNode()); } private void debugOut(LeafNode node) { System.out.println("Leaf: " + node.getDecisionPath()); } private void debugOut(DecisionNode node) { System.out.println("DN with " + node.getNumberOfDaugthers() + " daughters: " + node.toString()); for (int i = 0; i < node.getNumberOfDaugthers(); i++) { Node daughter = node.getDaughter(i); if (daughter == null) System.out.println("null"); else debugOut(daughter); } } }