package com.github.projectflink.generators.tpch.generators.core; import com.google.common.base.Charsets; import com.google.common.base.Preconditions; import com.google.common.io.Resources; import io.airlift.tpch.Distribution; import io.airlift.tpch.Distributions; import io.airlift.tpch.NationGenerator; import io.airlift.tpch.PartSupplierGenerator; import io.airlift.tpch.RegionGenerator; import io.airlift.tpch.TextPool; import org.apache.flink.util.SplittableIterator; import java.lang.reflect.Constructor; import java.net.URL; import java.util.HashSet; import java.util.Iterator; import java.util.Set; import static com.google.common.base.Preconditions.checkState; import static io.airlift.tpch.DistributionLoader.loadDistribution; public class TPCHGeneratorSplittableIterator<T> extends SplittableIterator<T> { private double scale; private int degreeOfParallelism; private Class<? extends Iterable<T>> generatorClass; public TPCHGeneratorSplittableIterator(double scale, int degreeOfParallelism, Class<? extends Iterable<T>> generatorClass) { Preconditions.checkArgument(scale > 0, "Scale must be > 0"); Preconditions.checkArgument(degreeOfParallelism > 0, "Parallelism must be > 0"); this.scale = scale; this.degreeOfParallelism = degreeOfParallelism; this.generatorClass = generatorClass; } public Iterator<T>[] split(int numPartitions) { if(numPartitions > this.degreeOfParallelism) { throw new IllegalArgumentException("Too many partitions requested"); } Iterator<T>[] iters = new Iterator[numPartitions]; for(int i = 1; i <= numPartitions; i++) { iters[i - 1] = new TPCHGeneratorSplittableIterator(i, numPartitions, scale, generatorClass); } return iters; } public int getMaximumNumberOfSplits() { return this.degreeOfParallelism; } //------------------------ Iterator ----------------------------------- private static Set<Class<? extends Iterable>> fixedGenerators; private static Distributions distributions; private static TextPool smallTextPool; static { fixedGenerators = new HashSet<Class<? extends Iterable>>(); fixedGenerators.add(RegionGenerator.class); fixedGenerators.add(NationGenerator.class); try { URL resource = Resources.getResource(Distribution.class, "dists.dss"); checkState(resource != null, "Distribution file 'dists.dss' not found"); distributions = new Distributions(loadDistribution(Resources.asCharSource(resource, Charsets.UTF_8))); smallTextPool = new TextPool(1 * 1024 * 1024, distributions); // 1 MB txt pool } catch(Throwable t) { throw new RuntimeException("Unable to load distributions", t); } } private Iterator<T> iter; public TPCHGeneratorSplittableIterator(int partNo, int totalParts, double scale, Class<? extends Iterable<T>> generatorClass) { try { Constructor<? extends Iterable<T>> generatorCtor; Iterable<T> generator = null; if(fixedGenerators.contains(generatorClass)) { // use short constructor: generatorCtor = generatorClass.getConstructor(Distributions.class, TextPool.class); generator = generatorCtor.newInstance(distributions, smallTextPool); } else if(generatorClass.equals(PartSupplierGenerator.class)) { generatorCtor = generatorClass.getConstructor(double.class, int.class, int.class); generator = generatorCtor.newInstance(scale, partNo, totalParts); } else { // use full constructor generatorCtor = generatorClass.getConstructor(double.class, int.class, int.class, Distributions.class, TextPool.class); generator = generatorCtor.newInstance(scale, partNo, totalParts, distributions, smallTextPool); } iter = generator.iterator(); } catch (Throwable e) { throw new RuntimeException("Unable to create generator "+generatorClass, e); } } @Override public boolean hasNext() { return iter.hasNext(); } @Override public T next() { return iter.next(); } @Override public void remove() { throw new UnsupportedOperationException("Remove not supported on this iterator"); } }