/* * Copyright (c) 2008 - 2017, Hazelcast, Inc. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.hazelcast.nio.tcp; import com.hazelcast.nio.Address; import com.hazelcast.nio.Connection; import com.hazelcast.nio.ConnectionListener; import com.hazelcast.nio.ConnectionManager; import com.hazelcast.nio.Packet; import com.hazelcast.spi.impl.PacketHandler; import java.util.Collections; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import static java.util.concurrent.TimeUnit.MILLISECONDS; /** * A {@link ConnectionManager} wrapper which adds firewalling capabilities. * All methods delegate to the original ConnectionManager. */ public class FirewallingConnectionManager implements ConnectionManager, PacketHandler { private final ConnectionManager delegate; private final Set<Address> blockedAddresses = Collections.newSetFromMap(new ConcurrentHashMap<Address, Boolean>()); private final ScheduledExecutorService scheduledExecutor = Executors.newSingleThreadScheduledExecutor(); private final PacketHandler packetHandler; private volatile PacketFilter droppingPacketFilter; private volatile DelayingPacketFilterWrapper delayingPacketFilter; public FirewallingConnectionManager(ConnectionManager delegate, Set<Address> initiallyBlockedAddresses) { this.delegate = delegate; this.blockedAddresses.addAll(initiallyBlockedAddresses); packetHandler = delegate instanceof PacketHandler ? (PacketHandler) delegate : null; } @Override public synchronized Connection getOrConnect(Address address) { Connection connection = getConnection(address); if (connection != null && connection.isAlive()) { return connection; } if (blockedAddresses.contains(address)) { connection = new DroppingConnection(address, this); registerConnection(address, connection); return connection; } else { return delegate.getOrConnect(address); } } @Override public synchronized Connection getOrConnect(Address address, boolean silent) { return getOrConnect(address); } public synchronized void block(Address address) { blockedAddresses.add(address); Connection connection = getConnection(address); if (connection != null) { connection.close("Blocked by connection manager", null); } } public synchronized void unblock(Address address) { blockedAddresses.remove(address); Connection connection = getConnection(address); if (connection instanceof DroppingConnection) { connection.close(null, null); } } public void setDroppingPacketFilter(PacketFilter droppingPacketFilter) { assert droppingPacketFilter != null; this.droppingPacketFilter = droppingPacketFilter; } public void removeDroppingPacketFilter() { droppingPacketFilter = null; } public void setDelayingPacketFilter(PacketFilter delayingPacketFilter, long minDelayMs, long maxDelayMs) { assert delayingPacketFilter != null; this.delayingPacketFilter = new DelayingPacketFilterWrapper(delayingPacketFilter, minDelayMs, maxDelayMs); } public void removeDelayingPacketFilter() { delayingPacketFilter = null; } private boolean isAllowed(Packet packet, Address target) { boolean allowed = true; PacketFilter filter = droppingPacketFilter; if (filter != null) { allowed = filter.allow(packet, target); } return allowed; } private long getDelayMs(Packet packet, Address target) { DelayingPacketFilterWrapper delayingFilter = delayingPacketFilter; if (delayingFilter != null) { if (!delayingFilter.packetFilter.allow(packet, target)) { return getRandomBetween(delayingFilter.maxDelayMs, delayingFilter.minDelayMs); } } return 0; } private long getRandomBetween(long max, long min) { return (long) ((max - min) * Math.random() + min); } @Override public boolean transmit(Packet packet, Connection connection) { if (connection != null) { if (!isAllowed(packet, connection.getEndPoint())) { return false; } long delayMs; if ((delayMs = getDelayMs(packet, connection.getEndPoint())) > 0) { scheduledExecutor.schedule(new DelayedPacketTask(packet, connection), delayMs, MILLISECONDS); return true; } } return delegate.transmit(packet, connection); } @Override public boolean transmit(Packet packet, Address target) { if (!isAllowed(packet, target)) { return false; } long delayMs; if ((delayMs = getDelayMs(packet, target)) > 0) { scheduledExecutor.schedule(new DelayedPacketTask(packet, target), delayMs, MILLISECONDS); return true; } return delegate.transmit(packet, target); } @Override public int getCurrentClientConnections() {return delegate.getCurrentClientConnections();} @Override public void addConnectionListener(ConnectionListener listener) {delegate.addConnectionListener(listener);} @Override public int getAllTextConnections() {return delegate.getAllTextConnections();} @Override public int getConnectionCount() {return delegate.getConnectionCount();} @Override public int getActiveConnectionCount() {return delegate.getActiveConnectionCount();} @Override public Connection getConnection(Address address) {return delegate.getConnection(address);} @Override public boolean registerConnection(Address address, Connection connection) { return delegate.registerConnection(address, connection); } @Override public void onConnectionClose(Connection connection) { delegate.onConnectionClose(connection); } @Override public void start() {delegate.start();} @Override public void stop() {delegate.stop();} @Override public void shutdown() { delegate.shutdown(); scheduledExecutor.shutdown(); } @Override public void handle(Packet packet) throws Exception { if (packetHandler == null) { throw new UnsupportedOperationException(delegate + " is not instance of PacketHandler!"); } packetHandler.handle(packet); } private class DelayedPacketTask implements Runnable { Packet packet; Connection connection; Address target; DelayedPacketTask(Packet packet, Connection connection) { assert connection != null; this.packet = packet; this.connection = connection; } DelayedPacketTask(Packet packet, Address target) { assert target != null; this.packet = packet; this.target = target; } @Override public void run() { if (connection != null) { delegate.transmit(packet, connection); } else { delegate.transmit(packet, target); } } } private static class DelayingPacketFilterWrapper { final PacketFilter packetFilter; final long minDelayMs; final long maxDelayMs; private DelayingPacketFilterWrapper(PacketFilter packetFilter, long minDelayMs, long maxDelayMs) { this.packetFilter = packetFilter; this.minDelayMs = minDelayMs; this.maxDelayMs = maxDelayMs; } } }