/* * Copyright (C) 2014 Indeed Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either * express or implied. See the License for the specific language governing permissions and * limitations under the License. */ package com.indeed.imhotep.iql; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.indeed.util.core.Pair; import com.indeed.imhotep.ez.EZImhotepSession; import com.indeed.imhotep.ez.GroupKey; import com.indeed.imhotep.ez.StatReference; import org.apache.log4j.Logger; import java.text.DecimalFormat; import java.util.*; /** * @author jplaisance */ public final class TopKGroupingFTGSCallback extends EZImhotepSession.FTGSCallback { private static final Logger log = Logger.getLogger(TopKGroupingFTGSCallback.class); private final Map<Integer, PriorityQueue<Pair<Double, GroupStats>>> groupToTopK = Maps.newHashMap(); private final Comparator<Pair<Double, GroupStats>> comparator; private final int topK; private final boolean isBottom; private final StatReference countStat; private final List<StatReference> statRefs; private final Map<Integer, GroupKey> groupKeys; private int newGroupCount = 0; public TopKGroupingFTGSCallback(final int numStats, int topK, StatReference countStat, List<StatReference> statRefs, Map<Integer, GroupKey> groupKeys, boolean isBottom) { super(numStats); this.topK = topK; this.isBottom = isBottom; this.countStat = countStat; this.statRefs = statRefs; this.groupKeys = groupKeys; // do a custom comparator to ensure that real numbers are preferred to NaNs final Comparator<Pair<Double, GroupStats>> baseComparator = new Comparator<Pair<Double, GroupStats>>() { @Override public int compare(Pair<Double, GroupStats> o1, Pair<Double, GroupStats> o2) { Double a = o1.getFirst(); if(a.isNaN()) { a = Double.NEGATIVE_INFINITY; } Double b = o2.getFirst(); if(b.isNaN()) { b = Double.NEGATIVE_INFINITY; } return a.compareTo(b); } }; final Comparator<Pair<Double, GroupStats>> reverseComparator = new Comparator<Pair<Double, GroupStats>>() { @Override public int compare(Pair<Double, GroupStats> o1, Pair<Double, GroupStats> o2) { Double a = o1.getFirst(); if(a.isNaN()) { a = Double.POSITIVE_INFINITY; } Double b = o2.getFirst(); if(b.isNaN()) { b = Double.POSITIVE_INFINITY; } return b.compareTo(a); // reverse the result by swapping the sides } }; this.comparator = isBottom ? reverseComparator : baseComparator; } protected void intTermGroup(final String field, final long term, final int group) { termGroup(term, group); } protected void stringTermGroup(final String field, final String term, final int group) { termGroup(term, group); } private void termGroup(final Object term, final int group) { PriorityQueue<Pair<Double, GroupStats>> topTerms = groupToTopK.get(group); if (topTerms == null) { topTerms = new PriorityQueue<Pair<Double, GroupStats>>(topK, comparator); groupToTopK.put(group, topTerms); } final double count = getStat(countStat); if (topTerms.size() < topK) { topTerms.add(getStats(count, group, term)); if(++newGroupCount > EZImhotepSession.GROUP_LIMIT) { throw new IllegalArgumentException("Number of groups exceeds the limit " + new DecimalFormat("###,###").format(EZImhotepSession.GROUP_LIMIT) + ". Please simplify the query."); } } else { final Double headCount = topTerms.peek().getFirst(); if ((!isBottom && count > headCount) || (isBottom && count < headCount) || (headCount.isNaN() && !Double.isNaN(count))) { topTerms.remove(); topTerms.add(getStats(count, group, term)); } } } private Pair<Double, GroupStats> getStats(double count, int group, Object term) { final double[] stats = new double[statRefs.size()]; for (int i = 0; i < statRefs.size(); i++) { stats[i] = getStat(statRefs.get(i)); } return Pair.of(count, new GroupStats(groupKeys.get(group).add(term), stats)); } public List<GroupStats> getResults() { final List<GroupStats> ret = Lists.newArrayList(); final ArrayDeque<Pair<Double, GroupStats>> stack = new ArrayDeque<Pair<Double, GroupStats>>(); for (int group = 1; group <= groupKeys.size(); group++) { final PriorityQueue<Pair<Double, GroupStats>> pairs = groupToTopK.get(group); if (pairs != null) { while (!pairs.isEmpty()) { stack.push(pairs.remove()); } while (!stack.isEmpty()) { ret.add(stack.remove().getSecond()); } } else { // TODO: do we want these empty rows? ret.add(new GroupStats(groupKeys.get(group).add(""), new double[statRefs.size()])); } } return ret; } }