/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core; import java.util.ArrayList; import java.util.Collections; import java.util.List; public class MunroPatEstimator<T extends Comparable<T>> { private static final long MAX_TOT_ELEMS = 1024L * 1024L * 1024L * 1024L; private final List<List<T>> buffer = new ArrayList<List<T>>(); private final int numQuantiles; private final int maxElementsPerBuffer; private int totalElements; private T min; private T max; public MunroPatEstimator(int numQuantiles) { this.numQuantiles = numQuantiles; this.maxElementsPerBuffer = computeMaxElementsPerBuffer(); } private int computeMaxElementsPerBuffer() { double epsilon = 1.0 / (numQuantiles - 1.0); int b = 2; while((b - 2) * (0x1L << (b - 2)) + 0.5 <= epsilon * MAX_TOT_ELEMS) { ++b; } return (int) (MAX_TOT_ELEMS / (0x1L << (b - 1))); } private void ensureBuffer(int level) { while(buffer.size() < level + 1) { buffer.add(null); } if(buffer.get(level) == null) { buffer.set(level, new ArrayList<T>()); } } private void collapse(List<T> a, List<T> b, List<T> out) { int indexA = 0, indexB = 0, count = 0; T smaller = null; while(indexA < maxElementsPerBuffer || indexB < maxElementsPerBuffer) { if(indexA >= maxElementsPerBuffer || (indexB < maxElementsPerBuffer && a.get(indexA).compareTo(b.get(indexB)) >= 0)) { smaller = b.get(indexB++); } else { smaller = a.get(indexA++); } if(count++ % 2 == 0) { out.add(smaller); } } a.clear(); b.clear(); } private void recursiveCollapse(List<T> buf, int level) { ensureBuffer(level + 1); List<T> merged; if(buffer.get(level + 1).isEmpty()) { merged = buffer.get(level + 1); } else { merged = new ArrayList<T>(maxElementsPerBuffer); } collapse(buffer.get(level), buf, merged); if(buffer.get(level + 1) != merged) { recursiveCollapse(merged, level + 1); } } public void add(T elem) { if(totalElements == 0 || elem.compareTo(min) < 0) { min = elem; } if(totalElements == 0 || max.compareTo(elem) < 0) { max = elem; } if(totalElements > 0 && totalElements % (2 * maxElementsPerBuffer) == 0) { Collections.sort(buffer.get(0)); Collections.sort(buffer.get(1)); recursiveCollapse(buffer.get(0), 1); } ensureBuffer(0); ensureBuffer(1); int index = buffer.get(0).size() < maxElementsPerBuffer ? 0 : 1; buffer.get(index).add(elem); totalElements++; } public void clear() { buffer.clear(); totalElements = 0; } public int getTotalElements() { return totalElements; } public List<T> getQuantiles() { List<T> quantiles = new ArrayList<T>(); if(min == null || max == null || buffer == null || buffer.size() == 0) { return quantiles; } quantiles.add(min); if(buffer.get(0) != null) { Collections.sort(buffer.get(0)); } if(buffer.get(1) != null) { Collections.sort(buffer.get(1)); } int[] index = new int[buffer.size()]; long S = 0; for(int i = 1; i <= numQuantiles - 2; i++) { long targetS = (long) Math.ceil(i * (totalElements / (numQuantiles - 1.0))); while(true) { T smallest = max; int minBufferId = -1; for(int j = 0; j < buffer.size(); j++) { if(buffer.get(j) != null && index[j] < buffer.get(j).size()) { if(smallest.compareTo(buffer.get(j).get(index[j])) >= 0) { smallest = buffer.get(j).get(index[j]); minBufferId = j; } } } long incrementS = minBufferId <= 1 ? 1L : (0x1L << (minBufferId - 1)); if(S + incrementS >= targetS) { quantiles.add(smallest); break; } else { index[minBufferId]++; S += incrementS; } } } quantiles.add(max); return quantiles; } }