/* * Copyright 2003-2016 JetBrains s.r.o. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package jetbrains.mps.smodel; import jetbrains.mps.smodel.ModelListenerTest.AccessCountListener1; import jetbrains.mps.smodel.ModelListenerTest.AccessCountListener2; import jetbrains.mps.smodel.ModelListenerTest.AccessCountListener3; import jetbrains.mps.smodel.ModelUndoTest.TestUndoHandler; import jetbrains.mps.smodel.TestModelFactory.TestModelAccess; import jetbrains.mps.smodel.TestModelFactory.TestRepository; import jetbrains.mps.testbench.PerformanceMessenger; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; import org.jetbrains.mps.openapi.module.SRepository; import org.junit.Assert; import org.junit.Before; import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ErrorCollector; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.TimeUnit; import static org.hamcrest.CoreMatchers.equalTo; /** * Test to ensure model performance doesn't degrade * * @author Artem Tikhomirov */ public class ModelPerformanceTest { @Rule public ErrorCollector myErrors = new ErrorCollector(); @ClassRule public static PerformanceMessenger ourStats = new PerformanceMessenger("ModelReadPerformance."); private final TestModelAccess myTestModelAccess = new TestModelAccess(); private final SRepository myTestRepo = new TestRepository(myTestModelAccess); @Before public void setUp() { TestUndoHandler uh = new TestUndoHandler(); uh.needsUndo(false); // undo is not our focus here, we merely need to avoid NPE from ModelAccess.instance().isInsideCommand() UndoHelper.getInstance().setUndoHandler(uh); } /** * Ensure parallel reads are viable. * <p> * Execution time, justification for baseline value * Detached model, no listeners: 1 thread = 200 ms; 4 threads = ~265 ms per thread * Attached model, no listeners: 1 thread = 270 ms; 4 threads = ~330 ms (300 - 420) * Attached model, 3 listeners: 1 thread = 340 ms; 4 threads = 510 ms (500-540) * </p> * Note, though average time in testWalkTime for slightly smaller model is 50ms, it's rather 200, 75, 25, 20, 20, * i.e. each thread in parallel mode is executed as a 'fresh' run, so it looks like JIT optimizes per thread? */ @Test public void testParallelRead() throws Exception { final TestModelFactory m1f = new TestModelFactory(); m1f.createModel(20, 100, 10, 5); // ~120k nodes final int initialNodeCount = m1f.countModelNodes(); myTestModelAccess.enableRead(); m1f.attachTo(myTestRepo); final long baselineMillis = 500 * 2; // Use twice as much time to account for slow build agents ourStats.report("multiThreadBaselineMillis", baselineMillis); final int parallelThreads = 4; final CountDownLatch stopLatch = new CountDownLatch(3); // 1 for thread start sync, 1 for results ready sync, 1 for thread stop sync CyclicBarrier b = new CyclicBarrier(parallelThreads, new Runnable() { @Override public void run() { stopLatch.countDown(); } }); ModelReadThread[] threads = new ModelReadThread[parallelThreads]; for (int i = 0; i < parallelThreads; i++) { threads[i] = new ModelReadThread(b, m1f); threads[i].start(); } boolean finishOk = stopLatch.await(10, TimeUnit.SECONDS); if (finishOk) { final int expectedNodeCount = 3 * initialNodeCount; // average between different threads, to compare with baseline long averageElapsedMillis = 0; // min and max to see how far from baseline we could go long minElapsedMillis = Long.MAX_VALUE, maxElapsedMillis = 0; for (int i = 0; i < parallelThreads; i++) { myErrors.checkThat(threads[i].getName(), threads[i].getAllThreadListenerCount(), equalTo(expectedNodeCount * parallelThreads)); myErrors.checkThat(threads[i].getName(), threads[i].getThisThreadCount1(), equalTo(expectedNodeCount)); myErrors.checkThat(threads[i].getName(), threads[i].getThisThreadCount2(), equalTo(expectedNodeCount)); myErrors.checkThat(threads[i].getName(), threads[i].getElapsedMillis(), new BaseMatcher<Long>() { @Override public boolean matches(Object item) { if (item instanceof Long) { return ((Long) item) < baselineMillis; } return false; } @Override public void describeTo(Description description) {description.appendText(String.format("less than %d", baselineMillis)); } }); myErrors.checkThat(threads[i].getName(), threads[i].getElapsedMillis(), new BaseMatcher<Long>() { @Override public boolean matches(Object item) { if (item instanceof Long) { return ((Long) item) > baselineMillis / 4; } return false; } @Override public void describeTo(Description description) { description.appendText(String.format("greater than %d", baselineMillis/4)); } }); averageElapsedMillis += threads[i].getElapsedMillis(); if (threads[i].getElapsedMillis() < minElapsedMillis) { minElapsedMillis = threads[i].getElapsedMillis(); } if (threads[i].getElapsedMillis() > maxElapsedMillis) { maxElapsedMillis = threads[i].getElapsedMillis(); } } averageElapsedMillis /= parallelThreads; ourStats.report("multiThreadAverageMillis", averageElapsedMillis); ourStats.report("multiThreadMaxMillis", maxElapsedMillis); ourStats.report("multiThreadMinMillis", minElapsedMillis); return; } for (int i = 0; i < parallelThreads; i++) { if (threads[i].isAlive()) { Throwable th = new Throwable("Hanging thread " + threads[i].getName()); th.setStackTrace(threads[i].getStackTrace()); myErrors.addError(th); threads[i].interrupt(); } } } /** * Just a quick check iteration time over a model doesn't deviate significantly due to * changes in SModel/SNode implementation. */ @Test public void testWalkTime() { final TestModelFactory m1f = new TestModelFactory(); org.jetbrains.mps.openapi.model.SModel m1 = m1f.createModel(10, 25, 15, 5, 4); final int actualNodes = m1f.countModelNodes(); // 10, 25, 15, 5, 4 == 97760 nodes. It takes about 50 ms to walk this model in avg. I use twice as much time to account for slower build agents final long baselineMillis = 50*2; ourStats.report("singleThreadBaselineMillis", baselineMillis); final int testRuns = 10; long elapsed = 0; for (int i = 0; i < testRuns; i++) { final long start = System.nanoTime(); ModelListenerTest.readTreeNodes(m1.getRootNodes()); elapsed += System.nanoTime() - start; if (i == 0) { ourStats.report("singleThreadFirstRunMillis", elapsed / 1000000); } } long averageMillis = elapsed / 1000000 / testRuns; ourStats.report("singleThreadAvgMillis", averageMillis); if (averageMillis > baselineMillis) { final String fmt = "Walking model of %d nodes was expected to take less than %d ms. Actual average time for %d runs was %d ms"; Assert.fail(String.format(fmt, actualNodes, baselineMillis, testRuns, averageMillis)); } // guard if it's too fast if (averageMillis < baselineMillis / 5) { final String fmt = "Walking model of %d nodes took less than 20%% of baseline. Actual average time for %d runs was %d ms, while baseline is %d ms. Re-consider baseline value"; Assert.fail(String.format(fmt, actualNodes, testRuns, averageMillis, baselineMillis)); } } private static class ModelReadThread extends Thread { private final CyclicBarrier myBarrier; private final TestModelFactory myModel; private int myCountL1, myCountL2, myCountL3; private long myElapsedMillis; public ModelReadThread(CyclicBarrier barrier, TestModelFactory mf) { myBarrier = barrier; myModel = mf; } @Override public void run() { AccessCountListener1 cl1 = new AccessCountListener1(); AccessCountListener2 cl2 = new AccessCountListener2(); AccessCountListener3 cl3 = new AccessCountListener3(); myModel.attachAccessListeners(cl1, cl2, cl3); try { myBarrier.await(); final long s = System.nanoTime(); ModelListenerTest.readTreeNodes(myModel.getModel().getRootNodes()); final long e = System.nanoTime(); myBarrier.await(); myCountL1 = cl1.myVisitedNodes; myCountL2 = cl2.myVisitedNodes; myCountL3 = cl3.myVisitedNodes; myElapsedMillis = (e - s) / 1000000; myBarrier.await(); } catch (InterruptedException e) { e.printStackTrace(); throw new RuntimeException(e); } catch (BrokenBarrierException e) { e.printStackTrace(); throw new RuntimeException(e); } finally { myModel.detachAccessListeners(cl1, cl2, cl3); } } public int getAllThreadListenerCount() { return myCountL1; } public int getThisThreadCount1() { return myCountL2; } public int getThisThreadCount2() { return myCountL3; } public long getElapsedMillis() { return myElapsedMillis; } } }