package be.bagofwords.db; import be.bagofwords.db.combinator.LongCombinator; import be.bagofwords.iterator.CloseableIterator; import be.bagofwords.util.HashUtils; import be.bagofwords.util.KeyValue; import be.bagofwords.util.SafeThread; import be.bagofwords.util.Utils; import org.apache.commons.lang3.mutable.MutableLong; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import java.util.Random; @RunWith(Parameterized.class) public class TestDataInterfaceMultiThreaded extends BaseTestDataInterface { public TestDataInterfaceMultiThreaded(DatabaseCachingType type, DatabaseBackendType backendType) throws Exception { super(type, backendType); } @Test public void testStrings() { final int numOfThreads = 10; final int numOfExamples = 200; final int numOfIterations = 10000; String nameOfSubset = "testMultiThreaded_" + type; final DataInterface<Long> db = dataInterfaceFactory.createDataInterface(type, nameOfSubset, Long.class, new LongCombinator()); db.dropAllData(); final long[] numAdded = new long[numOfExamples]; final MutableLong threadsFinished = new MutableLong(); //Add multithreaded for (int i = 0; i < numOfThreads; i++) { Thread t = new Thread() { @Override public void run() { Random r = new Random(); for (int i = 0; i < numOfIterations; i++) { int nextInt = r.nextInt(numOfExamples); synchronized (numAdded) { numAdded[nextInt]++; } db.increaseCount(Integer.toString(nextInt), 1l); } threadsFinished.increment(); } }; t.start(); } while (threadsFinished.intValue() != numOfThreads) { Utils.threadSleep(100); } db.flush(); for (int i = 0; i < numOfExamples; i++) { Assert.assertEquals("Not equal on position " + i, numAdded[i], db.readCount(Integer.toString(i))); } //Delete multithreaded: for (int i = 0; i < numOfThreads; i++) { Thread t = new Thread() { @Override public void run() { Random r = new Random(); while (!allRemoved(numAdded)) { int pos = r.nextInt(numAdded.length); if (numAdded[pos] > 0) { db.write(Integer.toString(pos), null); numAdded[pos] = 0; } } } }; t.start(); } while (!allRemoved(numAdded)) { Utils.threadSleep(100); } db.flush(); for (int i = 0; i < numOfExamples; i++) { Assert.assertEquals(0, db.readCount(Integer.toString(i))); } } @Test public void testCounts() { int numOfThreads = 10; final int numOfWritesPerThread = 500; final DataInterface<Long> db = createCountDataInterface("testCounts"); db.dropAllData(); SafeThread[] threads = new SafeThread[numOfThreads]; for (int i = 0; i < threads.length; i++) { threads[i] = new SafeThread("testThread_" + i, false) { @Override protected void runInt() throws Exception { for (int i = 0; i < numOfWritesPerThread; i++) { for (int j = 0; j < 1000; j++) { db.increaseCount(HashUtils.randomDistributeHash(j)); } } } }; threads[i].start(); } SafeThread flushThread = new SafeThread("flushThread", false) { @Override protected void runInt() throws Exception { for (int j = 0; j < 50; j++) { db.flush(); Utils.threadSleep(100); } } }; flushThread.start(); for (SafeThread thread : threads) { thread.waitForFinish(); } flushThread.waitForFinish(); db.flush(); for (int j = 0; j < 1000; j++) { db.readCount(HashUtils.randomDistributeHash(j)); } CloseableIterator<KeyValue<Long>> iterator = db.iterator(); while (iterator.hasNext()) { KeyValue<Long> curr = iterator.next(); Assert.assertEquals("Incorrect total for " + curr.getKey() + " " + curr.getValue(), numOfThreads * numOfWritesPerThread, curr.getValue().longValue()); } iterator.close(); } private boolean allRemoved(long[] numAdded) { boolean allRemoved = true; for (long aNumAdded : numAdded) { allRemoved &= aNumAdded == 0; } return allRemoved; } }