/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.fpm.pfpgrowth.fpgrowth;
import java.util.HashSet;
import java.util.PriorityQueue;
import java.util.Set;
import org.apache.mahout.math.map.OpenLongObjectHashMap;
/** keeps top K Attributes in a TreeSet */
public final class FrequentPatternMaxHeap {
private int count;
private Pattern least;
private final int maxSize;
private final boolean subPatternCheck;
private final OpenLongObjectHashMap<Set<Pattern>> patternIndex;
private final PriorityQueue<Pattern> queue;
public FrequentPatternMaxHeap(int numResults, boolean subPatternCheck) {
maxSize = numResults;
queue = new PriorityQueue<Pattern>(maxSize);
this.subPatternCheck = subPatternCheck;
patternIndex = new OpenLongObjectHashMap<Set<Pattern>>();
for (Pattern p : queue) {
Long index = p.support();
Set<Pattern> patternList;
if (!patternIndex.containsKey(index)) {
patternList = new HashSet<Pattern>();
patternIndex.put(index, patternList);
}
patternList = patternIndex.get(index);
patternList.add(p);
}
}
public boolean addable(long support) {
return count < maxSize || least.support() <= support;
}
public PriorityQueue<Pattern> getHeap() {
if (subPatternCheck) {
PriorityQueue<Pattern> ret = new PriorityQueue<Pattern>(maxSize);
for (Pattern p : queue) {
if (patternIndex.get(p.support()).contains(p)) {
ret.add(p);
}
}
return ret;
}
return queue;
}
public void addAll(FrequentPatternMaxHeap patterns,
int attribute,
long attributeSupport) {
for (Pattern pattern : patterns.getHeap()) {
long support = Math.min(attributeSupport, pattern.support());
if (this.addable(support)) {
pattern.add(attribute, support);
this.insert(pattern);
}
}
}
public void insert(Pattern frequentPattern) {
if (frequentPattern.length() == 0) {
return;
}
if (count == maxSize) {
if (frequentPattern.compareTo(least) > 0 && addPattern(frequentPattern)) {
Pattern evictedItem = queue.poll();
least = queue.peek();
if (subPatternCheck) {
patternIndex.get(evictedItem.support()).remove(evictedItem);
}
}
} else {
if (addPattern(frequentPattern)) {
count++;
if (least == null) {
least = frequentPattern;
} else {
if (least.compareTo(frequentPattern) < 0) {
least = frequentPattern;
}
}
}
}
}
public int count() {
return count;
}
public boolean isFull() {
return count == maxSize;
}
public long leastSupport() {
if (least == null) {
return 0;
}
return least.support();
}
private boolean addPattern(Pattern frequentPattern) {
if (subPatternCheck) {
Long index = frequentPattern.support();
if (patternIndex.containsKey(index)) {
Set<Pattern> indexSet = patternIndex.get(index);
boolean replace = false;
Pattern replacablePattern = null;
for (Pattern p : indexSet) {
if (frequentPattern.isSubPatternOf(p)) {
return false;
} else if (p.isSubPatternOf(frequentPattern)) {
replace = true;
replacablePattern = p;
break;
}
}
if (replace) {
indexSet.remove(replacablePattern);
if (!indexSet.contains(frequentPattern) && queue.add(frequentPattern)) {
indexSet.add(frequentPattern);
}
return false;
}
queue.add(frequentPattern);
indexSet.add(frequentPattern);
} else {
queue.add(frequentPattern);
Set<Pattern> patternList;
if (!patternIndex.containsKey(index)) {
patternList = new HashSet<Pattern>();
patternIndex.put(index, patternList);
}
patternList = patternIndex.get(index);
patternList.add(frequentPattern);
}
} else {
queue.add(frequentPattern);
}
return true;
}
}