/* * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ package com.oracle.truffle.api.dsl.test; import static com.oracle.truffle.api.dsl.test.TestHelper.createRoot; import static com.oracle.truffle.api.dsl.test.TestHelper.executeWith; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; import java.lang.reflect.Field; import java.util.Arrays; import java.util.concurrent.CountDownLatch; import org.junit.Test; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.NodeChild; import com.oracle.truffle.api.dsl.NodeFactory; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.TypeSystemReference; import com.oracle.truffle.api.dsl.test.MergeSpecializationsTestFactory.TestCachedNodeFactory; import com.oracle.truffle.api.dsl.test.MergeSpecializationsTestFactory.TestNodeFactory; import com.oracle.truffle.api.dsl.test.TypeBoxingTest.TypeBoxingTypeSystem; import com.oracle.truffle.api.dsl.test.TypeSystemTest.TestRootNode; import com.oracle.truffle.api.dsl.test.TypeSystemTest.ValueNode; import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.test.ReflectionUtils; public class MergeSpecializationsTest { private static final int THREADS = 50; @NodeChild @SuppressWarnings("unused") @TypeSystemReference(TypeBoxingTypeSystem.class) abstract static class TestNode extends ValueNode { @Specialization int s1(int a) { return 1; } @Specialization int s2(long a) { return 2; } @Specialization int s3(double a) { return 3; } } @NodeChild @SuppressWarnings("unused") @TypeSystemReference(TypeBoxingTypeSystem.class) abstract static class TestCachedNode extends ValueNode { @Specialization(guards = "a == cachedA", limit = "3") int s1(int a, @Cached("a") int cachedA) { return 1; } @Specialization int s2(long a) { return 2; } @Specialization int s3(double a) { return 3; } } @Test public void testMultithreadedMergeInOrder() throws Exception { for (int i = 0; i < 100; i++) { multithreadedMerge(TestNodeFactory.getInstance(), new Executions(1, 1L << 32, 1.0), 1, 2, 3); } } @Test public void testMultithreadedMergeReverse() throws Exception { for (int i = 0; i < 100; i++) { multithreadedMerge(TestNodeFactory.getInstance(), new Executions(1.0, 1L << 32, 1), 3, 2, 1); } } @Test public void testMultithreadedMergeCachedInOrder() throws Exception { for (int i = 0; i < 100; i++) { multithreadedMerge(TestCachedNodeFactory.getInstance(), new Executions(1, 1L << 32, 1.0), 1, 2, 3); } } @Test public void testMultithreadedMergeCachedTwoEntries() throws Exception { for (int i = 0; i < 100; i++) { multithreadedMerge(TestCachedNodeFactory.getInstance(), new Executions(1, 2, 1.0), 1, 1, 3); } } @Test public void testMultithreadedMergeCachedThreeEntries() throws Exception { for (int i = 0; i < 100; i++) { multithreadedMerge(TestCachedNodeFactory.getInstance(), new Executions(1, 2, 3), 1, 1, 1); } } @SuppressWarnings("deprecation") private static <T extends ValueNode> void multithreadedMerge(NodeFactory<T> factory, final Executions executions, int... order) throws Exception { assertEquals(3, order.length); final TestRootNode<T> node = createRoot(factory); final CountDownLatch threadsStarted = new CountDownLatch(THREADS); final CountDownLatch beforeFirst = new CountDownLatch(1); final CountDownLatch executedFirst = new CountDownLatch(THREADS); final CountDownLatch beforeSecond = new CountDownLatch(1); final CountDownLatch executedSecond = new CountDownLatch(THREADS); final CountDownLatch beforeThird = new CountDownLatch(1); final CountDownLatch executedThird = new CountDownLatch(THREADS); Thread[] threads = new Thread[THREADS]; for (int i = 0; i < threads.length; i++) { threads[i] = new Thread(new Runnable() { public void run() { threadsStarted.countDown(); MergeSpecializationsTest.await(beforeFirst); executeWith(node, executions.firstValue); executedFirst.countDown(); MergeSpecializationsTest.await(beforeSecond); executeWith(node, executions.secondValue); executedSecond.countDown(); MergeSpecializationsTest.await(beforeThird); executeWith(node, executions.thirdValue); executedThird.countDown(); } }); threads[i].start(); } T checkedNode = node.getNode(); if (node instanceof com.oracle.truffle.api.dsl.internal.SpecializedNode) { final com.oracle.truffle.api.dsl.internal.SpecializedNode gen = (com.oracle.truffle.api.dsl.internal.SpecializedNode) checkedNode; final com.oracle.truffle.api.dsl.internal.SpecializationNode start0 = gen.getSpecializationNode(); assertEquals("UninitializedNode_", start0.getClass().getSimpleName()); await(threadsStarted); beforeFirst.countDown(); await(executedFirst); final com.oracle.truffle.api.dsl.internal.SpecializationNode start1 = gen.getSpecializationNode(); assertEquals("S" + order[0] + "Node_", start1.getClass().getSimpleName()); assertEquals("UninitializedNode_", nthChild(1, start1).getClass().getSimpleName()); beforeSecond.countDown(); await(executedSecond); final com.oracle.truffle.api.dsl.internal.SpecializationNode start2 = gen.getSpecializationNode(); Arrays.sort(order, 0, 2); assertEquals("PolymorphicNode_", start2.getClass().getSimpleName()); assertEquals("S" + order[0] + "Node_", nthChild(1, start2).getClass().getSimpleName()); assertEquals("S" + order[1] + "Node_", nthChild(2, start2).getClass().getSimpleName()); assertEquals("UninitializedNode_", nthChild(3, start2).getClass().getSimpleName()); beforeThird.countDown(); await(executedThird); final com.oracle.truffle.api.dsl.internal.SpecializationNode start3 = gen.getSpecializationNode(); Arrays.sort(order); assertEquals("PolymorphicNode_", start3.getClass().getSimpleName()); assertEquals("S" + order[0] + "Node_", nthChild(1, start3).getClass().getSimpleName()); assertEquals("S" + order[1] + "Node_", nthChild(2, start3).getClass().getSimpleName()); assertEquals("S" + order[2] + "Node_", nthChild(3, start3).getClass().getSimpleName()); assertEquals("UninitializedNode_", nthChild(4, start3).getClass().getSimpleName()); } else { assertState(checkedNode, order, 0); await(threadsStarted); beforeFirst.countDown(); await(executedFirst); assertState(checkedNode, order, 1); beforeSecond.countDown(); await(executedSecond); assertState(checkedNode, order, 2); beforeThird.countDown(); await(executedThird); assertState(checkedNode, order, 3); } for (Thread thread : threads) { try { thread.join(); } catch (InterruptedException e) { fail("interrupted"); } } } private static void assertState(Node node, int[] expectedOrder, int checkedIndices) throws IllegalArgumentException, IllegalAccessException, NoSuchFieldException, SecurityException { Field stateField = node.getClass().getDeclaredField("state_"); ReflectionUtils.setAccessible(stateField, true); int state = ((((Number) stateField.get(node))).intValue() & ~0x1) >> 1; // exclude // uninitialized Arrays.sort(expectedOrder, 0, checkedIndices); int mask = 0; for (int i = 0; i < checkedIndices; i++) { mask |= 0b1 << expectedOrder[i] - 1; } assertEquals(mask, state & 0b111); } private static class Executions { public final Object firstValue; public final Object secondValue; public final Object thirdValue; Executions(Object firstValue, Object secondValue, Object thirdValue) { this.firstValue = firstValue; this.secondValue = secondValue; this.thirdValue = thirdValue; } } private static void await(final CountDownLatch latch) { try { latch.await(); } catch (InterruptedException e) { fail("interrupted"); } } private static Node firstChild(Node node) { return node.getChildren().iterator().next(); } private static Node nthChild(int n, Node node) { if (n == 0) { return node; } else { return nthChild(n - 1, firstChild(node)); } } }