/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.math.random; import com.google.common.collect.HashMultiset; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Multiset; import org.apache.mahout.math.DenseMatrix; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.MahoutTestCase; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.QRDecomposition; import org.junit.Test; import java.util.Collections; import java.util.List; import java.util.Set; public final class ChineseRestaurantTest extends MahoutTestCase { @Test public void testDepth() { List<Integer> totals = Lists.newArrayList(); for (int i = 0; i < 1000; i++) { ChineseRestaurant x = new ChineseRestaurant(10); Multiset<Integer> counts = HashMultiset.create(); for (int j = 0; j < 100; j++) { counts.add(x.sample()); } List<Integer> tmp = Lists.newArrayList(); for (Integer k : counts.elementSet()) { tmp.add(counts.count(k)); } Collections.sort(tmp, Collections.reverseOrder()); while (totals.size() < tmp.size()) { totals.add(0); } int j = 0; for (Integer k : tmp) { totals.set(j, totals.get(j) + k); j++; } } // these are empirically derived values, not principled ones assertEquals(25000.0, (double) totals.get(0), 1000); assertEquals(24000.0, (double) totals.get(1), 1000); assertEquals(8000.0, (double) totals.get(2), 200); assertEquals(1000.0, (double) totals.get(15), 50); assertEquals(1000.0, (double) totals.get(20), 40); } @Test public void testExtremeDiscount() { ChineseRestaurant x = new ChineseRestaurant(100, 1); Multiset<Integer> counts = HashMultiset.create(); for (int i = 0; i < 10000; i++) { counts.add(x.sample()); } assertEquals(10000, x.size()); for (int i = 0; i < 10000; i++) { assertEquals(1, x.count(i)); } } @Test public void testGrowth() { ChineseRestaurant s0 = new ChineseRestaurant(10, 0.0); ChineseRestaurant s5 = new ChineseRestaurant(10, 0.5); ChineseRestaurant s9 = new ChineseRestaurant(10, 0.9); Set<Double> splits = ImmutableSet.of(1.0, 1.5, 2.0, 3.0, 5.0, 8.0); double offset0 = 0; int k = 0; int i = 0; Matrix m5 = new DenseMatrix(20, 3); Matrix m9 = new DenseMatrix(20, 3); while (i <= 200000) { double n = i / Math.pow(10, Math.floor(Math.log10(i))); if (splits.contains(n)) { //System.out.printf("%d\t%d\t%d\t%d\n", i, s0.size(), s5.size(), s9.size()); if (i > 900) { double predict5 = predictSize(m5.viewPart(0, k, 0, 3), i, 0.5); assertEquals(predict5, Math.log(s5.size()), 1); double predict9 = predictSize(m9.viewPart(0, k, 0, 3), i, 0.9); assertEquals(predict9, Math.log(s9.size()), 1); //assertEquals(10.5 * Math.log(i) - offset0, s0.size(), 10); } else if (i > 50) { double x = 10.5 * Math.log(i) - s0.size(); m5.viewRow(k).assign(new double[]{Math.log(s5.size()), Math.log(i), 1}); m9.viewRow(k).assign(new double[]{Math.log(s9.size()), Math.log(i), 1}); k++; offset0 += (x - offset0) / k; } if (i > 10000) { assertEquals(0.0, (double) hapaxCount(s0) / s0.size(), 0.25); assertEquals(0.5, (double) hapaxCount(s5) / s5.size(), 0.1); assertEquals(0.9, (double) hapaxCount(s9) / s9.size(), 0.05); } } s0.sample(); s5.sample(); s9.sample(); i++; } } /** * Predict the power law growth in number of unique samples from the first few data points. * Also check that the fitted growth coefficient is about right. * * @param m * @param currentIndex Total data points seen so far. Unique values should be log(currentIndex)*expectedCoefficient + offset. * @param expectedCoefficient What slope do we expect. * @return The predicted value for log(currentIndex) */ private static double predictSize(Matrix m, int currentIndex, double expectedCoefficient) { int rows = m.rowSize(); Matrix a = m.viewPart(0, rows, 1, 2); Matrix b = m.viewPart(0, rows, 0, 1); Matrix ata = a.transpose().times(a); Matrix atb = a.transpose().times(b); QRDecomposition s = new QRDecomposition(ata); Matrix r = s.solve(atb).transpose(); assertEquals(expectedCoefficient, r.get(0, 0), 0.2); return r.times(new DenseVector(new double[]{Math.log(currentIndex), 1})).get(0); } private static int hapaxCount(ChineseRestaurant s) { int r = 0; for (int i = 0; i < s.size(); i++) { if (s.count(i) == 1) { r++; } } return r; } }