/* * Copyright Terracotta, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.ehcache.clustered.client.internal.store; import org.ehcache.clustered.common.internal.messages.ClusterTierReconnectMessage; import org.ehcache.clustered.common.internal.messages.EhcacheEntityResponse; import org.ehcache.clustered.common.internal.messages.ServerStoreMessageFactory; import org.ehcache.clustered.common.internal.store.Chain; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.nio.ByteBuffer; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; public class StrongServerStoreProxy implements ServerStoreProxy { private static final Logger LOGGER = LoggerFactory.getLogger(StrongServerStoreProxy.class); private final CommonServerStoreProxy delegate; private final ConcurrentMap<Long, CountDownLatch> hashInvalidationsInProgress = new ConcurrentHashMap<Long, CountDownLatch>(); private final Lock invalidateAllLock = new ReentrantLock(); private volatile CountDownLatch invalidateAllLatch; private final ClusterTierClientEntity entity; private final ClusterTierClientEntity.ReconnectListener reconnectListener; private final ClusterTierClientEntity.DisconnectionListener disconnectionListener; public StrongServerStoreProxy(final String cacheId, final ServerStoreMessageFactory messageFactory, final ClusterTierClientEntity entity) { this.delegate = new CommonServerStoreProxy(cacheId, messageFactory, entity); this.entity = entity; this.reconnectListener = new SimpleClusterTierClientEntity.ReconnectListener() { @Override public void onHandleReconnect(ClusterTierReconnectMessage reconnectMessage) { Set<Long> inflightInvalidations = hashInvalidationsInProgress.keySet(); reconnectMessage.addInvalidationsInProgress(inflightInvalidations); if (invalidateAllLatch != null) { reconnectMessage.clearInProgress(); } } }; entity.setReconnectListener(reconnectListener); delegate.addResponseListeners(EhcacheEntityResponse.HashInvalidationDone.class, new SimpleClusterTierClientEntity.ResponseListener<EhcacheEntityResponse.HashInvalidationDone>() { @Override public void onResponse(EhcacheEntityResponse.HashInvalidationDone response) { long key = response.getKey(); LOGGER.debug("CLIENT: on cache {}, server notified that clients invalidated hash {}", cacheId, key); CountDownLatch countDownLatch = hashInvalidationsInProgress.remove(key); if (countDownLatch != null) { countDownLatch.countDown(); } } }); delegate.addResponseListeners(EhcacheEntityResponse.AllInvalidationDone.class, new SimpleClusterTierClientEntity.ResponseListener<EhcacheEntityResponse.AllInvalidationDone>() { @Override public void onResponse(EhcacheEntityResponse.AllInvalidationDone response) { LOGGER.debug("CLIENT: on cache {}, server notified that clients invalidated all", cacheId); CountDownLatch countDownLatch; invalidateAllLock.lock(); try { countDownLatch = invalidateAllLatch; invalidateAllLatch = null; } finally { invalidateAllLock.unlock(); } if (countDownLatch != null) { LOGGER.debug("CLIENT: on cache {}, count down", cacheId); countDownLatch.countDown(); } } }); this.disconnectionListener = new SimpleClusterTierClientEntity.DisconnectionListener() { @Override public void onDisconnection() { for (Map.Entry<Long, CountDownLatch> entry : hashInvalidationsInProgress.entrySet()) { entry.getValue().countDown(); } hashInvalidationsInProgress.clear(); invalidateAllLock.lock(); try { if (invalidateAllLatch != null) { invalidateAllLatch.countDown(); } } finally { invalidateAllLock.unlock(); } } }; entity.setDisconnectionListener(disconnectionListener); } private <T> T performWaitingForHashInvalidation(long key, NullaryFunction<T> c) throws InterruptedException, TimeoutException { CountDownLatch latch = new CountDownLatch(1); while (true) { if (!entity.isConnected()) { throw new IllegalStateException("Cluster tier manager disconnected"); } CountDownLatch countDownLatch = hashInvalidationsInProgress.putIfAbsent(key, latch); if (countDownLatch == null) { break; } awaitOnLatch(countDownLatch); } try { T result = c.apply(); LOGGER.debug("CLIENT: Waiting for invalidations on key {}", key); awaitOnLatch(latch); LOGGER.debug("CLIENT: key {} invalidated on all clients, unblocking call", key); return result; } catch (Exception ex) { hashInvalidationsInProgress.remove(key); latch.countDown(); if (ex instanceof TimeoutException) { throw (TimeoutException)ex; } throw new RuntimeException(ex); } } private <T> T performWaitingForAllInvalidation(NullaryFunction<T> c) throws InterruptedException, TimeoutException { CountDownLatch newLatch = new CountDownLatch(1); while (true) { if (!entity.isConnected()) { throw new IllegalStateException("Cluster tier manager disconnected"); } CountDownLatch existingLatch; invalidateAllLock.lock(); try { existingLatch = invalidateAllLatch; if (existingLatch == null) { invalidateAllLatch = newLatch; break; } } finally { invalidateAllLock.unlock(); } awaitOnLatch(existingLatch); } try { T result = c.apply(); awaitOnLatch(newLatch); LOGGER.debug("CLIENT: all invalidated on all clients, unblocking call"); return result; } catch (Exception ex) { invalidateAllLock.lock(); try { invalidateAllLatch = null; } finally { invalidateAllLock.unlock(); } newLatch.countDown(); if (ex instanceof TimeoutException) { throw (TimeoutException)ex; } throw new RuntimeException(ex); } } private void awaitOnLatch(CountDownLatch countDownLatch) throws InterruptedException { int totalAwaitTime = 0; int backoff = 1; while (!countDownLatch.await(backoff, TimeUnit.SECONDS)) { totalAwaitTime += backoff; backoff = (backoff >= 10) ? 10 : backoff * 2; LOGGER.debug("Waiting for the server's InvalidationDone message for {}s, backing off {}s...", totalAwaitTime, backoff); } if (!entity.isConnected()) { throw new IllegalStateException("Cluster tier manager disconnected"); } } @Override public String getCacheId() { return delegate.getCacheId(); } @Override public void addInvalidationListener(InvalidationListener listener) { delegate.addInvalidationListener(listener); } @Override public boolean removeInvalidationListener(InvalidationListener listener) { return delegate.removeInvalidationListener(listener); } @Override public void close() { delegate.close(); } @Override public Chain get(long key) throws TimeoutException { return delegate.get(key); } @Override public void append(final long key, final ByteBuffer payLoad) throws TimeoutException { try { performWaitingForHashInvalidation(key, new NullaryFunction<Void>() { @Override public Void apply() throws TimeoutException { delegate.append(key, payLoad); return null; } }); } catch (InterruptedException ie) { throw new RuntimeException(ie); } } @Override public Chain getAndAppend(final long key, final ByteBuffer payLoad) throws TimeoutException { try { return performWaitingForHashInvalidation(key, new NullaryFunction<Chain>() { @Override public Chain apply() throws TimeoutException { return delegate.getAndAppend(key, payLoad); } }); } catch (InterruptedException ie) { throw new RuntimeException(ie); } } @Override public void replaceAtHead(long key, Chain expect, Chain update) { delegate.replaceAtHead(key, expect, update); } @Override public void clear() throws TimeoutException { try { performWaitingForAllInvalidation(new NullaryFunction<Object>() { @Override public Object apply() throws TimeoutException { delegate.clear(); return null; } }); } catch (InterruptedException ie) { throw new RuntimeException(ie); } } private interface NullaryFunction<T> { T apply() throws Exception; } }