package splar.core.heuristics; import java.util.ArrayList; import java.util.Comparator; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.Map.Entry; import splar.core.constraints.CNFFormula; import splar.core.fm.FeatureGroup; import splar.core.fm.FeatureModel; import splar.core.fm.FeatureTreeNode; public class FTAverageOrderTraversalHeuristic extends FTTraversalHeuristic { public FTAverageOrderTraversalHeuristic(String name, FeatureModel featureModel) { super(name, featureModel); } protected String[] runHeuristic(CNFFormula cnf) { // runPreProcessing(cnf); List<String> variableOrder = recursiveRun(featureModel.getRoot()); // remove feature groups for( Iterator<String> it = variableOrder.iterator() ; it.hasNext() ; ) { String nodeID = it.next(); if ( nodeID.startsWith("*") ) { it.remove(); } } // runPostProcessing(cnf); return variableOrder.toArray(new String[0]); } protected List<String> recursiveRun(FeatureTreeNode node) { int childCount = node.getChildCount(); List<String> nodeList = new ArrayList<String>(); if ( childCount == 0 ) { nodeList.add(node.getID()); return nodeList; } // for each child node recursively call recursiveRun Map<String,List<String>> nodeChildrenMap = new LinkedHashMap<String,List<String>>(); for( int i = 0 ; i < node.getChildCount() ; i++ ) { FeatureTreeNode childNode = (FeatureTreeNode)node.getChildAt(i); // if ( childNode instanceof FeatureGroup ) { // for( int j = 0 ; j < childNode.getChildCount() ; j++ ) { // FeatureTreeNode groupedNode = (FeatureTreeNode)childNode.getChildAt(j); // List<String> childList = recursiveRun(groupedNode); // // add each child node list to the map // nodeChildrenMap.put(groupedNode.getID(), childList); // } // } // else { List<String> childList = recursiveRun(childNode); // add each child node list to the map String nodeID = childNode.getID(); if ( childNode instanceof FeatureGroup ) { nodeID = "*" + nodeID; } nodeChildrenMap.put(nodeID, childList); // } } childCount = nodeChildrenMap.size(); // if node has two or less children add their children list to the node list if ( childCount <= 2 ) { for( List<String> childNodeList : nodeChildrenMap.values() ) { nodeList.addAll(childNodeList); } } // if node has more than two children add child nodes to node list according to their (children) weight else { // order child nodes based on their weights List<String> childNodesBalancedList = balanceWeights(nodeChildrenMap); // merge child node lists into one (nodeList) according to their weights order for( String childNodeName : childNodesBalancedList ) { nodeList.addAll(nodeChildrenMap.get(childNodeName)); } } // IF NOT FEATURE GROUP... // insert 'node' into the child nodes list // the position of 'node' is the average of the sum of child nodes positions int childNodePosSum = 0; for( String childNodeName : nodeChildrenMap.keySet() ) { childNodePosSum += nodeList.indexOf(childNodeName); } int nodePos = Math.round(childNodePosSum/(1.0f*childCount)); String nodeID = node.getID(); if ( node instanceof FeatureGroup ) { nodeID = "*" + nodeID; } nodeList.add(nodePos,nodeID); return nodeList; } public List<String> balanceWeights(Map<String,List<String>> weightsMap) { Set<Entry<String,List<String>>> weights = weightsMap.entrySet(); // sort entries in non-descending order according to their weights Comparator<Object> c = new Comparator<Object>() { public int compare(Object entry1, Object entry2) { Entry<String,List<String>> nEntry1 = (Entry<String,List<String>>)entry1; Entry<String,List<String>> nEntry2 = (Entry<String,List<String>>)entry2; if ( nEntry1.getValue().size() > nEntry2.getValue().size() ) { return 1; } if ( nEntry1.getValue().size() < nEntry2.getValue().size() ) { return -1; } return 0; } }; Object entries[] = weights.toArray(); java.util.Arrays.sort(entries,c); // split the the entries in two groups of balanced weights List<String> leftList = new ArrayList<String>(); List<String> rightList = new ArrayList<String>(); List<Entry<String,List<String>>> entriesList = new ArrayList<Entry<String,List<String>>>(); for( Object entry : entries ) { entriesList.add((Entry<String,List<String>>)entry); } int sumLeft = 0; int sumRight = 0; while( entriesList.size() > 0 ) { // get maximum weight from list and add to right list Entry<String,List<String>> maxEntry = entriesList.get(entriesList.size()-1); entriesList.remove(entriesList.size()-1); rightList.add(maxEntry.getKey()); sumRight += maxEntry.getValue().size(); // search a combination of values that approximates the max weight int curWeight = 0; int index = entriesList.size()-1; while( entriesList.size() > 0 && curWeight < maxEntry.getValue().size() ) { Entry<String,List<String>> curEntry = entriesList.get(index); curWeight += curEntry.getValue().size(); sumLeft += curEntry.getValue().size(); entriesList.remove(index); leftList.add(curEntry.getKey()); index--; } } // System.out.print("\nLEFT: (" + sumLeft +") - "); // for( String key : leftList ) { // System.out.print(key + ","); // } // // System.out.print("\nRIGHT: (" + sumRight +") - "); // for( String key : rightList ) { // System.out.print(key + ","); // } leftList.addAll(rightList); return leftList; } }