package org.dcache.xrootd.door;
import com.google.common.net.HostAndPort;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import java.io.PrintWriter;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import dmg.cells.nucleus.CDC;
import dmg.cells.nucleus.CellCommandListener;
import dmg.cells.nucleus.CellInfo;
import dmg.cells.nucleus.CellInfoProvider;
import org.dcache.util.Args;
/**
* Channel handler that keeps track of connected channels. Provides
* administrative commands for listing and killing connections.
*/
@Sharable
public class ConnectionTracker
extends ChannelInboundHandlerAdapter
implements CellCommandListener, CellInfoProvider
{
private final Map<Channel,String> sessions = new ConcurrentHashMap<>();
private final Map<Channel, HostAndPort> addresses = new ConcurrentHashMap<>();
private final LongAdder counter = new LongAdder();
@Override
public void channelActive(ChannelHandlerContext ctx)
throws Exception
{
Channel channel = ctx.channel();
sessions.put(channel, CDC.getSession());
counter.increment();
super.channelActive(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx)
throws Exception
{
try {
super.channelInactive(ctx);
} finally {
sessions.remove(ctx.channel());
addresses.remove(ctx.channel());
}
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
{
if (msg instanceof HAProxyMessage) {
HAProxyMessage proxyMessage = (HAProxyMessage) msg;
if (proxyMessage.command() == HAProxyCommand.PROXY) {
addresses.put(ctx.channel(), HostAndPort.fromParts(proxyMessage.sourceAddress(), proxyMessage.sourcePort()));
}
}
ctx.fireChannelRead(msg);
}
@Override
public CellInfo getCellInfo(CellInfo info)
{
return info;
}
@Override
public void getInfo(PrintWriter pw)
{
pw.println(String.format("Active : %d", sessions.size()));
pw.println(String.format("Created: %d", counter.longValue()));
}
public String ac_connections(Args args)
{
StringBuilder s = new StringBuilder();
for (Map.Entry<Channel, String> e: sessions.entrySet()) {
s.append(e.getValue()).append(' ').append(e.getKey());
HostAndPort hostAndPort = addresses.get(e.getKey());
if (hostAndPort != null) {
s.append(' ').append(hostAndPort);
}
s.append("\n");
}
return s.toString();
}
public String ac_kill_$_1(Args args)
{
String session = args.argv(0);
Iterator<String> iterator = sessions.values().iterator();
while (iterator.hasNext()) {
if (iterator.next().equals(session)) {
iterator.remove();
return "";
}
}
return "No such connection";
}
}