/** * Copyright (c) 2015 Lemur Consulting Ltd. * <p/> * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * <p/> * http://www.apache.org/licenses/LICENSE-2.0 * <p/> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package uk.co.flax.biosolr.pruning; import java.util.Collection; import java.util.Comparator; import java.util.Iterator; import java.util.LinkedList; import java.util.SortedSet; import java.util.TreeSet; import java.util.stream.Collectors; import uk.co.flax.biosolr.TreeFacetField; /** * Prune a facet hierarchy tree into its most significant data points, * with all other points grouped into "other". * * @author mlp */ public class DatapointPruner implements Pruner { public static final String DEFAULT_MORE_LABEL = "Others"; private final int datapoints; private final String moreLabel; public DatapointPruner(int datapoints, String moreLabel) { this.datapoints = datapoints; this.moreLabel = moreLabel; } @Override public Collection<TreeFacetField> prune(Collection<TreeFacetField> unprunedTrees) { Collection<TreeFacetField> prunedTrees = new TreeSet<>(Comparator.comparingLong(TreeFacetField::getCount) .thenComparing(TreeFacetField::getValue).reversed()); // Clone the unpruned collection - we need it again later Collection<TreeFacetField> incoming = unprunedTrees.stream().map(TreeFacetField::clone).collect(Collectors.toList()); long total = getNodeTotal(incoming); int itCount = 1; int prevCount = Integer.MAX_VALUE; while (prunedTrees.size() < datapoints && !incoming.isEmpty()) { int minCount = getThreshold(itCount, prevCount, total); if (minCount <= 0) { break; } prunedTrees.addAll(getNodesWithCount(incoming, minCount)); itCount ++; prevCount = minCount; } /* Trim the pruned trees list to the number of datapoints. * This leaves the incoming list copy potentially missing nodes which * should be in the "other" node. Since they could be anywhere, we * have to rebuild it from scratch. */ if (prunedTrees.size() > datapoints) { prunedTrees = prunedTrees.stream().limit(datapoints).collect(Collectors.toList()); } // Rebuild the incoming node set - no need to clone... incoming = new LinkedList<>(unprunedTrees); // ...and strip the nodes already extracted to the pruned list trimIncomingNodes(incoming, prunedTrees, 0); // Build the "other" node TreeFacetField otherNode = buildOtherNode(incoming); if (otherNode.getTotal() > 0) { prunedTrees.add(otherNode); } return prunedTrees; } private int getThreshold(int iteration, int previous, long total) { int min = Math.min(Math.round((total / datapoints) / iteration), previous - 1); if (min == 0 && iteration == 1) { // First iteration - set minCount to 1 min = 1; } return min; } /** * Extract all nodes in a collection with a hit count greater or equal * to a given threshold. This has the side effect of modifying the * incoming node collection. * @param incoming the incoming nodes. Matching nodes will be removed * during the processing. * @param threshold the minimum hit count required to be returned. * @return the collection of nodes whose hit count is greater than or * equal to the threshold. */ private Collection<TreeFacetField> getNodesWithCount(Collection<TreeFacetField> incoming, long threshold) { Collection<TreeFacetField> retList = new LinkedList<>(); for (Iterator<TreeFacetField> iter = incoming.iterator(); iter.hasNext(); ) { TreeFacetField tff = iter.next(); if (tff.getTotal() >= threshold) { if (tff.getChildCount() >= threshold) { // Recurse, finding the nodes with enough hits retList.addAll(getNodesWithCount(tff.getHierarchy(), threshold)); // Recalculate the child count throughout the tree tff.recalculateChildCount(); } if (tff.getCount() >= threshold) { // This node has enough hits - store, and remove from the // incoming nodes so it's not picked again later. retList.add(tff); iter.remove(); } } } return retList; } /** * Get the total node count for all trees. * @param trees the trees whose total count is required. * @return the count. */ private long getNodeTotal(Collection<TreeFacetField> trees) { return trees.stream().mapToLong(TreeFacetField::getTotal).sum(); } /** * Remove a collection of pruned nodes from the original incoming set. * @param incoming the set containing all nodes in the tree. * @param pruned the nodes to check for duplicates. * @param level the current level in the tree, starting from 0. */ private void trimIncomingNodes(Collection<TreeFacetField> incoming, Collection<TreeFacetField> pruned, int level) { for (Iterator<TreeFacetField> it = incoming.iterator(); it.hasNext(); ) { TreeFacetField tff = it.next(); if (isFacetInChildren(tff, pruned)) { it.remove(); } else { if (tff.hasChildren()) { trimIncomingNodes(tff.getHierarchy(), pruned, level + 1); } if (level == 0) { // Update the child counts in the node and its children tff.recalculateChildCount(); } } } } /** * Check whether a particular facet exists in the children of any other facets * in a collection. * @param facet the facet to check for. * @param trees the collection of trees to check through. * @return <code>true</code> if the facet is found in the child lists. */ private boolean isFacetInChildren(TreeFacetField facet, Collection<TreeFacetField> trees) { boolean retVal = false; if (trees != null) { for (TreeFacetField tree : trees) { if (tree.equals(facet) || isFacetInChildren(facet, tree.getHierarchy())) { retVal = true; break; } } } return retVal; } private TreeFacetField buildOtherNode(Collection<TreeFacetField> otherNodes) { // Prune the other nodes - use the SimplePruner SortedSet<TreeFacetField> pruned = new TreeSet<>(Comparator.reverseOrder()); pruned.addAll(new SimplePruner(SimplePruner.MIN_CHILD_COUNT).prune(otherNodes)); TreeFacetField other = new TreeFacetField(moreLabel, "", 0, 0, pruned); other.recalculateChildCount(); return other; } }