/* * (C) Copyright 2016 Pantheon Technologies, s.r.o. and others. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.opendaylight.yangtools.triemap; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class TestMultiThreadMapIterator { private static final Logger LOG = LoggerFactory.getLogger(TestMultiThreadMapIterator.class); private static final int NTHREADS = 7; @Test public void testMultiThreadMapIterator() throws InterruptedException { final Map<Object, Object> bt = TrieMap.create(); for (int j = 0; j < 50 * 1000; j++) { for (final Object o : getObjects(j)) { bt.put(o, o); } } LOG.debug("Size of initialized map is {}", bt.size()); int count = 0; { final ExecutorService es = Executors.newFixedThreadPool(NTHREADS); for (int i = 0; i < NTHREADS; i++) { final int threadNo = i; es.execute(() -> { for (Entry<Object, Object> e : bt.entrySet()) { if (accepts(threadNo, NTHREADS, e.getKey())) { String newValue = "TEST:" + threadNo; e.setValue(newValue); } } }); } es.shutdown(); es.awaitTermination(5, TimeUnit.MINUTES); } count = 0; for (final Map.Entry<Object, Object> kv : bt.entrySet()) { assertTrue(kv.getValue() instanceof String); count++; } assertEquals(50000 + 2000 + 1000 + 100, count); final ConcurrentHashMap<Object, Object> removed = new ConcurrentHashMap<>(); { final ExecutorService es = Executors.newFixedThreadPool(NTHREADS); for (int i = 0; i < NTHREADS; i++) { final int threadNo = i; es.execute(() -> { for (final Iterator<Map.Entry<Object, Object>> it = bt.entrySet().iterator(); it.hasNext();) { final Entry<Object, Object> e = it.next(); Object key = e.getKey(); if (accepts(threadNo, NTHREADS, key)) { if (null == bt.get(key)) { LOG.error("Key {} is not present", key); } it.remove(); if (null != bt.get(key)) { LOG.error("Key {} is still present", key); } removed.put(key, key); } } }); } es.shutdown(); es.awaitTermination(5, TimeUnit.MINUTES); } count = 0; for (final Object value : bt.keySet()) { value.toString(); count++; } for (final Object o : bt.keySet()) { if (!removed.contains(bt.get(o))) { LOG.error("Not removed: {}", o); } } assertEquals(0, count); assertEquals(0, bt.size()); assertTrue(bt.isEmpty()); } protected static boolean accepts(final int threadNo, final int nrThreads, final Object key) { final int val = getKeyValue(key); return val >= 0 ? val % nrThreads == threadNo : false; } private static int getKeyValue(final Object key) { if (key instanceof Integer) { return ((Integer) key).intValue(); } else if (key instanceof Character) { return Math.abs(Character.getNumericValue((Character) key) + 1); } else if (key instanceof Short) { return ((Short) key).intValue() + 2; } else if (key instanceof Byte) { return ((Byte) key).intValue() + 3; } else { return -1; } } static Collection<Object> getObjects(final int j) { final Collection<Object> results = new ArrayList<>(4); results.add(Integer.valueOf(j)); if (j < 2000) { results.add(Character.valueOf((char) j)); } if (j < 1000) { results.add(Short.valueOf((short) j)); } if (j < 100) { results.add(Byte.valueOf((byte) j)); } return results; } }