/* * Copyright 2015, 2016 Tagir Valeev * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package one.util.streamex; import org.junit.ComparisonFailure; import java.util.*; import java.util.Map.Entry; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.IntConsumer; import java.util.function.Supplier; import java.util.stream.Collector; import java.util.stream.Collectors; import java.util.stream.Stream; import static one.util.streamex.StreamExInternals.TailSpliterator; import static one.util.streamex.StreamExInternals.finished; import static org.junit.Assert.*; /** * @author Tagir Valeev */ public class TestHelpers { enum Mode { NORMAL, SPLITERATOR, PARALLEL, APPEND, PREPEND, RANDOM } static class StreamSupplier<T> { final Mode mode; final Supplier<Stream<T>> base; public StreamSupplier(Supplier<Stream<T>> base, Mode mode) { this.base = base; this.mode = mode; } public Stream<T> get() { Stream<T> res = base.get(); switch (mode) { case NORMAL: case SPLITERATOR: return res.sequential(); case PARALLEL: return res.parallel(); case APPEND: // using Stream.empty() or Arrays.asList() here is optimized out // in append/prepend which is undesired return StreamEx.of(res.parallel()).append(new ConcurrentLinkedQueue<>()); case PREPEND: return StreamEx.of(res.parallel()).prepend(new ConcurrentLinkedQueue<>()); case RANDOM: return StreamEx.of(new EmptyingSpliterator<>(res.parallel().spliterator())).parallel(); default: throw new InternalError("Unsupported mode: " + mode); } } @Override public String toString() { return mode.toString(); } } static class StreamExSupplier<T> extends StreamSupplier<T> { public StreamExSupplier(Supplier<Stream<T>> base, Mode mode) { super(base, mode); } @Override public StreamEx<T> get() { if(mode == Mode.SPLITERATOR) return StreamEx.of(base.get().spliterator()); return StreamEx.of(super.get()); } } static class EntryStreamSupplier<K, V> extends StreamSupplier<Map.Entry<K, V>> { public EntryStreamSupplier(Supplier<Stream<Map.Entry<K, V>>> base, Mode mode) { super(base, mode); } @Override public EntryStream<K, V> get() { return EntryStream.of(super.get()); } } static <T> List<StreamExSupplier<T>> streamEx(Supplier<Stream<T>> base) { return StreamEx.of(Mode.values()).map(mode -> new StreamExSupplier<>(base, mode)).toList(); } /** * Run the consumer once feeding it with RNG initialized with auto-generated seed * adding the seed value to every failed assertion message * * @param cons consumer to run */ static void withRandom(Consumer<Random> cons) { long seed = ThreadLocalRandom.current().nextLong(); withRandom(seed, cons); } /** * Run the consumer once feeding it with RNG initialized with given seed * adding the seed value to every failed assertion message * * @param seed random seed to use * @param cons consumer to run */ static void withRandom(long seed, Consumer<Random> cons) { Random random = new Random(seed); withMessage("Using new Random("+seed+")", () -> cons.accept(random)); } /** * Run the runnable automatically adding given message to every failed assertion * * @param message message to prepend * @param r runnable to run */ private static void withMessage(String message, Runnable r) { try { r.run(); } catch (ComparisonFailure cmp) { ComparisonFailure ex = new ComparisonFailure(message + ": " + cmp.getMessage(), cmp.getExpected(), cmp .getActual()); ex.setStackTrace(cmp.getStackTrace()); throw ex; } catch (AssertionError err) { AssertionError ex = new AssertionError(message + ": " + err.getMessage(), err.getCause()); ex.setStackTrace(err.getStackTrace()); throw ex; } catch (RuntimeException | Error err) { throw new RuntimeException(message + ": " + err.getMessage(), err); } } static void repeat(int times, IntConsumer consumer) { for (int i = 1; i <= times; i++) { int finalI = i; withMessage("#" + i, () -> consumer.accept(finalI)); } } static <T> void streamEx(Supplier<Stream<T>> base, Consumer<StreamExSupplier<T>> consumer) { for (StreamExSupplier<T> supplier : StreamEx.of(Mode.values()).map(mode -> new StreamExSupplier<>(base, mode))) { withMessage(supplier.toString(), () -> consumer.accept(supplier)); } } static <T> void emptyStreamEx(Class<T> clazz, Consumer<StreamExSupplier<T>> consumer) { streamEx(Stream::<T>empty, consumer); } static <K, V> void entryStream(Supplier<Stream<Map.Entry<K, V>>> base, Consumer<EntryStreamSupplier<K, V>> consumer) { for (EntryStreamSupplier<K, V> supplier : StreamEx.of(Mode.values()).map( mode -> new EntryStreamSupplier<>(base, mode))) { withMessage(supplier.toString(), () -> consumer.accept(supplier)); } } /** * Spliterator which randomly inserts empty spliterators on splitting * * @author Tagir Valeev * * @param <T> type of the elements */ private static class EmptyingSpliterator<T> implements Spliterator<T> { private Spliterator<T> source; public EmptyingSpliterator(Spliterator<T> source) { this.source = Objects.requireNonNull(source); } @Override public boolean tryAdvance(Consumer<? super T> action) { return source.tryAdvance(action); } @Override public void forEachRemaining(Consumer<? super T> action) { source.forEachRemaining(action); } @Override public Comparator<? super T> getComparator() { return source.getComparator(); } @Override public Spliterator<T> trySplit() { Spliterator<T> source = this.source; switch (ThreadLocalRandom.current().nextInt(3)) { case 0: return Spliterators.emptySpliterator(); case 1: this.source = Spliterators.emptySpliterator(); return source; default: Spliterator<T> split = source.trySplit(); return split == null ? null : new EmptyingSpliterator<>(split); } } @Override public long estimateSize() { return source.estimateSize(); } @Override public int characteristics() { return source.characteristics(); } } static <T, R> void checkCollectorEmpty(String message, R expected, Collector<T, ?, R> collector) { if (finished(collector) != null) checkShortCircuitCollector(message, expected, 0, Stream::empty, collector); else checkCollector(message, expected, Stream::empty, collector); } static <T, TT extends T, R> void checkShortCircuitCollector(String message, R expected, int expectedConsumedElements, Supplier<Stream<TT>> base, Collector<T, ?, R> collector) { checkShortCircuitCollector(message, expected, expectedConsumedElements, base, collector, false); } static <T, TT extends T, R> void checkShortCircuitCollector(String message, R expected, int expectedConsumedElements, Supplier<Stream<TT>> base, Collector<T, ?, R> collector, boolean skipIdentity) { assertNotNull(message, finished(collector)); Collector<T, ?, R> withIdentity = Collectors.collectingAndThen(collector, Function.identity()); for (StreamExSupplier<TT> supplier : streamEx(base)) { AtomicInteger counter = new AtomicInteger(); assertEquals(message + ": " + supplier, expected, supplier.get().peek(t -> counter.incrementAndGet()) .collect(collector)); if (!supplier.get().isParallel()) assertEquals(message + ": " + supplier + ": consumed: ", expectedConsumedElements, counter.get()); if (!skipIdentity) assertEquals(message + ": " + supplier, expected, supplier.get().collect(withIdentity)); } } static <T, TT extends T, R> void checkCollector(String message, R expected, Supplier<Stream<TT>> base, Collector<T, ?, R> collector) { // use checkShortCircuitCollector for CancellableCollector assertNull(message, finished(collector)); for (StreamExSupplier<TT> supplier : streamEx(base)) { assertEquals(message + ": " + supplier, expected, supplier.get().collect(collector)); } } static <T> void checkSpliterator(String msg, Supplier<Spliterator<T>> supplier) { List<T> expected = new ArrayList<>(); supplier.get().forEachRemaining(expected::add); checkSpliterator(msg, expected, supplier); } /* * Tests whether spliterators produced by given supplier produce the * expected result under various splittings * * This test is single-threaded. Its behavior is randomized, but random seed * will be printed in case of failure, so the results could be reproduced */ static <T> void checkSpliterator(String msg, List<T> expected, Supplier<Spliterator<T>> supplier) { List<T> seq = new ArrayList<>(); // Test forEachRemaining Spliterator<T> sequential = supplier.get(); sequential.forEachRemaining(seq::add); assertFalse(msg, sequential.tryAdvance(t -> fail(msg + ": Advance called with " + t))); sequential.forEachRemaining(t -> fail(msg + ": Advance called with " + t)); assertEquals(msg, expected, seq); // Test tryAdvance seq.clear(); sequential = supplier.get(); while (true) { AtomicBoolean called = new AtomicBoolean(); boolean res = sequential.tryAdvance(t -> { seq.add(t); called.set(true); }); if (res != called.get()) { fail(msg + (res ? ": Consumer not called, but spliterator returned true" : ": Consumer called, but spliterator returned false")); } if (!res) break; } assertFalse(msg, sequential.tryAdvance(t -> fail(msg + ": Advance called with " + t))); assertEquals(msg, expected, seq); // Test TailSpliterator if(sequential instanceof TailSpliterator) { seq.clear(); TailSpliterator.forEachWithTail(supplier.get(), seq::add); assertEquals(msg, expected, seq); seq.clear(); sequential = supplier.get(); while(sequential != null) { sequential = TailSpliterator.tryAdvanceWithTail(sequential, seq::add); } } assertEquals(msg, expected, seq); // Test advance+remaining for (int i = 1; i < Math.min(4, expected.size() - 1); i++) { seq.clear(); sequential = supplier.get(); for(int j=0; j<i; j++) assertTrue(msg, sequential.tryAdvance(seq::add)); sequential.forEachRemaining(seq::add); assertEquals(msg, expected, seq); } // Test trySplit withRandom(r -> { repeat(500, n -> { Spliterator<T> spliterator = supplier.get(); List<Spliterator<T>> spliterators = new ArrayList<>(); spliterators.add(spliterator); int p = r.nextInt(10) + 2; for (int i = 0; i < p; i++) { int idx = r.nextInt(spliterators.size()); Spliterator<T> split = spliterators.get(idx).trySplit(); if (split != null) spliterators.add(idx, split); } List<Integer> order = IntStreamEx.ofIndices(spliterators).boxed().toList(); Collections.shuffle(order, r); List<T> list = StreamEx.of(order).mapToEntry(idx -> { Spliterator<T> s = spliterators.get(idx); Stream.Builder<T> builder = Stream.builder(); s.forEachRemaining(builder); assertFalse(msg, s.tryAdvance(t -> fail(msg + ": Advance called with " + t))); s.forEachRemaining(t -> fail(msg + ": Advance called with " + t)); return builder.build(); }).sortedBy(Entry::getKey).values().flatMap(Function.identity()).toList(); assertEquals(msg, expected, list); }); repeat(500, n -> { Spliterator<T> spliterator = supplier.get(); List<Spliterator<T>> spliterators = new ArrayList<>(); spliterators.add(spliterator); int p = r.nextInt(30) + 2; for (int i = 0; i < p; i++) { int idx = r.nextInt(spliterators.size()); Spliterator<T> split = spliterators.get(idx).trySplit(); if (split != null) spliterators.add(idx, split); } List<List<T>> results = StreamEx.<List<T>> generate(ArrayList::new).limit(spliterators.size()) .toList(); int count = spliterators.size(); while (count > 0) { int i; do { i = r.nextInt(spliterators.size()); spliterator = spliterators.get(i); } while (spliterator == null); if (!spliterator.tryAdvance(results.get(i)::add)) { spliterators.set(i, null); count--; } } List<T> list = StreamEx.of(results).flatMap(List::stream).toList(); assertEquals(msg, expected, list); }); }); } static void checkIllegalStateException(Runnable r, String key, String value1, String value2) { try { r.run(); fail("no exception"); } catch (IllegalStateException ex) { String exmsg = ex.getMessage(); if (!exmsg.equals("Duplicate entry for key '" + key + "' (attempt to merge values '" + value1 + "' and '" + value2 + "')") && !exmsg.equals("Duplicate entry for key '" + key + "' (attempt to merge values '" + value2 + "' and '" + value1 + "')") && !exmsg.equals("java.lang.IllegalStateException: Duplicate entry for key '" + key + "' (attempt to merge values '" + value1 + "' and '" + value2 + "')") && !exmsg.equals("java.lang.IllegalStateException: Duplicate entry for key '" + key + "' (attempt to merge values '" + value2 + "' and '" + value1 + "')")) fail("wrong exception message: " + exmsg); } } }