/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.tinkerpop.gremlin.process.traversal.util; import org.apache.tinkerpop.gremlin.process.traversal.Operator; import org.apache.tinkerpop.gremlin.process.traversal.Step; import org.apache.tinkerpop.gremlin.process.traversal.Traversal; import org.apache.tinkerpop.gremlin.process.traversal.TraversalSideEffects; import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.DefaultGraphTraversal; import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.__; import org.apache.tinkerpop.gremlin.process.traversal.step.TraversalParent; import org.apache.tinkerpop.gremlin.util.function.ConstantSupplier; import org.apache.tinkerpop.gremlin.util.function.HashSetSupplier; import org.hamcrest.CoreMatchers; import org.junit.Test; import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; import static org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.__.is; import static org.hamcrest.number.OrderingComparison.greaterThan; import static org.hamcrest.number.OrderingComparison.lessThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** * @author Marko A. Rodriguez (http://markorodriguez.com) */ public class DefaultTraversalTest { @Test public void shouldRespectThreadInterruption() throws Exception { final AtomicBoolean exceptionThrown = new AtomicBoolean(false); final AtomicInteger counter = new AtomicInteger(0); final CountDownLatch startedIterating = new CountDownLatch(100); final List<Integer> l = IntStream.range(0, 1000000).boxed().collect(Collectors.toList()); final Thread t = new Thread(() -> { try { __.inject(l).unfold().sideEffect(i -> { startedIterating.countDown(); counter.incrementAndGet(); }).iterate(); } catch (Exception ex) { exceptionThrown.set(ex instanceof TraversalInterruptedException); } }); t.start(); startedIterating.await(); t.interrupt(); t.join(); // ensure that some but not all of the traversal was iterated and that the right exception was tossed assertThat(counter.get(), greaterThan(0)); assertThat(counter.get(), lessThan(1000000)); assertThat(exceptionThrown.get(), CoreMatchers.is(true)); } @Test public void shouldCloneTraversalCorrectly() { final DefaultGraphTraversal<?, ?> original = new DefaultGraphTraversal<>(); original.out().groupCount("m").values("name").count(); final DefaultTraversal<?, ?> clone = (DefaultTraversal) original.clone(); assertEquals(original.hashCode(), clone.hashCode()); assertEquals(original.getSteps().size(), clone.getSteps().size()); for (int i = 0; i < original.steps.size(); i++) { assertEquals(original.getSteps().get(i), clone.getSteps().get(i)); } assertNotEquals(original.sideEffects, clone.sideEffects); original.getSideEffects().set("m", 1); assertEquals(1, original.getSideEffects().<Integer>get("m").intValue()); clone.getSideEffects().set("m", 2); assertEquals(1, original.getSideEffects().<Integer>get("m").intValue()); assertEquals(2, clone.getSideEffects().<Integer>get("m").intValue()); } @Test public void shouldBeTheSameSideEffectsThroughoutAllChildTraversals() { final DefaultTraversal.Admin<?, ?> traversal = (DefaultTraversal.Admin) __.out().repeat(__.in().groupCount("a").by(__.select("a"))).in(); final TraversalSideEffects sideEffects = traversal.getSideEffects(); sideEffects.register("a", (Supplier) HashSetSupplier.instance(), Operator.addAll); sideEffects.register("b", new ConstantSupplier<>(1), Operator.sum); assertEquals(1, sideEffects.<Integer>get("b").intValue()); assertFalse(traversal.isLocked()); traversal.applyStrategies(); assertTrue(traversal.isLocked()); sideEffects.add("b", 7); assertEquals(0, sideEffects.<Set>get("a").size()); assertEquals(8, sideEffects.<Integer>get("b").intValue()); recursiveTestTraversals(traversal, sideEffects, sideEffects.get("a"), 8); sideEffects.add("a", new HashSet<>(Arrays.asList("marko", "bob"))); sideEffects.set("b", 3); recursiveTestTraversals(traversal, sideEffects, new HashSet<>(Arrays.asList("marko", "bob")), 3); sideEffects.add("a", new HashSet<>(Arrays.asList("marko", "x", "x", "bob"))); sideEffects.add("b", 10); recursiveTestTraversals(traversal, sideEffects, new HashSet<>(Arrays.asList("marko", "bob", "x")), 13); } private void recursiveTestTraversals(final Traversal.Admin<?, ?> traversal, final TraversalSideEffects sideEffects, final Set aValue, final int bValue) { assertTrue(traversal.getSideEffects() == sideEffects); assertEquals(sideEffects.keys().size(), traversal.getSideEffects().keys().size()); assertEquals(bValue, traversal.getSideEffects().<Integer>get("b").intValue()); assertEquals(aValue.size(), traversal.getSideEffects().<Set>get("a").size()); assertFalse(aValue.stream().filter(k -> !traversal.getSideEffects().<Set>get("a").contains(k)).findAny().isPresent()); assertFalse(traversal.getSideEffects().exists("c")); for (final Step<?, ?> step : traversal.getSteps()) { assertTrue(step.getTraversal().getSideEffects() == sideEffects); assertEquals(sideEffects.keys().size(), step.getTraversal().getSideEffects().keys().size()); assertEquals(bValue, step.getTraversal().getSideEffects().<Integer>get("b").intValue()); assertEquals(aValue.size(), step.getTraversal().getSideEffects().<Set>get("a").size()); assertFalse(aValue.stream().filter(k -> !step.getTraversal().getSideEffects().<Set>get("a").contains(k)).findAny().isPresent()); assertFalse(step.getTraversal().getSideEffects().exists("c")); if (step instanceof TraversalParent) { ((TraversalParent) step).getGlobalChildren().forEach(t -> this.recursiveTestTraversals(t, sideEffects, aValue, bValue)); ((TraversalParent) step).getLocalChildren().forEach(t -> this.recursiveTestTraversals(t, sideEffects, aValue, bValue)); } } } }