/* * JBoss, Home of Professional Open Source * Copyright 2011, Red Hat, Inc. and individual contributors * by the @authors tag. See the copyright.txt in the distribution for a * full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.mobicents.tools.sip.balancer; import gov.nist.javax.sip.header.SIPHeader; import gov.nist.javax.sip.header.Via; import java.util.Collections; import java.util.HashMap; import java.util.SortedSet; import java.util.TreeSet; import org.apache.commons.validator.routines.InetAddressValidator; import org.apache.log4j.Logger; import javax.sip.ListeningPoint; import javax.sip.address.SipURI; import javax.sip.header.FromHeader; import javax.sip.header.ToHeader; import javax.sip.message.Message; import javax.sip.message.Request; import javax.sip.message.Response; import org.jboss.netty.handler.codec.http.HttpRequest; import org.mobicents.tools.heartbeat.api.Node; public class HeaderConsistentHashBalancerAlgorithm extends DefaultBalancerAlgorithm { private static Logger logger = Logger.getLogger(HeaderConsistentHashBalancerAlgorithm.class.getName()); protected String sipHeaderAffinityKey; protected String httpAffinityKey; // We will maintain a sorted list of the nodes so all SIP LBs will see them in the same order // no matter at what order the events arrived private SortedSet<Node> nodesV4 = (SortedSet<Node>) Collections.synchronizedSortedSet(new TreeSet<Node>()); private SortedSet<Node> nodesV6 = (SortedSet<Node>) Collections.synchronizedSortedSet(new TreeSet<Node>()); // And we also keep a copy in the array because it is faster to query by index protected Object[] nodesArrayV4; protected Object[] nodesArrayV6; protected boolean nodesAreDirty = true; public HeaderConsistentHashBalancerAlgorithm() { } private SortedSet<Node> nodes(Boolean isIpV6) { if(isIpV6) return nodesV6; else return nodesV4; } protected Object[] nodesArray(Boolean isIpV6) { if(isIpV6) return nodesArrayV6; else return nodesArrayV4; } public HeaderConsistentHashBalancerAlgorithm(String headerName) { if(headerName == null) { this.sipHeaderAffinityKey = "Call-ID"; } else { this.sipHeaderAffinityKey = headerName; } } public Node processExternalRequest(Request request,Boolean isIpV6) { if(nodesAreDirty) { // for testing only where nodes are not removed, just start advertising new version while alive synchronized(this) { syncNodes(isIpV6); } } Integer nodeIndex = hashHeader(request,isIpV6); if(nodeIndex<0) { return null; } else { try { Node node = (Node) nodesArray(isIpV6)[nodeIndex]; // if(!invocationContext.gracefulShutdownSipNodeMap(isIpV6).containsKey(new KeySip(node))) if(!invocationContext.sipNodeMap(isIpV6).get(new KeySip(node,isIpV6)).isGracefulShutdown()) return node; else return null; } catch (Exception e) { return null; } } } @Override public synchronized void nodeAdded(Node node) { Boolean isIpV6=LbUtils.isValidInet6Address(node.getIp()); nodes(isIpV6).add(node); if(isIpV6) nodesArrayV6 = nodes(true).toArray(new Object[]{}); else { nodesArrayV4 = new Object[nodes(false).size()]; nodesArrayV4 = nodes(false).toArray(nodesArrayV4); } nodesAreDirty = false; } @Override public synchronized void nodeRemoved(Node node) { Boolean isIpV6=LbUtils.isValidInet6Address(node.getIp()); nodes(isIpV6).remove(node); if(isIpV6) nodesArrayV6 = nodes(true).toArray(new Object[]{}); else { nodesArrayV4 = new Object[nodes(false).size()]; nodesArrayV4 = nodes(false).toArray(nodesArrayV4); } nodesAreDirty = false; } protected Integer hashHeader(Message message,Boolean isIpV6) { String headerValue = null; if(sipHeaderAffinityKey.equals("From")) { headerValue = ((SipURI)((FromHeader) message.getHeader(FromHeader.NAME)) .getAddress().getURI()).getUser(); } else if(sipHeaderAffinityKey.equals("To")) { headerValue = ((SipURI)((ToHeader) message.getHeader(ToHeader.NAME)) .getAddress().getURI()).getUser(); } else { headerValue = ((SIPHeader) message.getHeader(sipHeaderAffinityKey)) .getValue(); } if(nodesArray(isIpV6).length == 0) { throw new RuntimeException("No Application Servers registered. All servers are dead."); } int nodeIndex = hashAffinityKeyword(headerValue,isIpV6); if(isAlive((Node)nodesArray(isIpV6)[nodeIndex])) { return nodeIndex; } else { return -1; } } protected boolean isAlive(Node node) { //if(invocationContext.nodes.contains(node)) return true; Boolean isIpV6=LbUtils.isValidInet6Address(node.getIp()); if(invocationContext.sipNodeMap(isIpV6).containsValue(node)) return true; return false; } public Node processHttpRequest(HttpRequest request, InvocationContext context) { String affinityKeyword = getUrlParameters(request.getUri()).get(this.httpAffinityKey); if(affinityKeyword == null) { return super.processHttpRequest(request); } return (Node) nodesArrayV4[hashAffinityKeyword(affinityKeyword,false)]; } protected int hashAffinityKeyword(String keyword,Boolean isIpV6) { int nodeIndex = Math.abs(keyword.hashCode()) % nodesArray(isIpV6).length; Node computedNode = (Node) nodesArray(isIpV6)[nodeIndex]; if(!isAlive(computedNode)) { // If the computed node is dead, find a new one for(int q = 0; q<nodesArray(isIpV6).length; q++) { nodeIndex = (nodeIndex + 1) % nodesArray(isIpV6).length; if(isAlive(((Node)nodesArray(isIpV6)[nodeIndex]))) { break; } } } return nodeIndex; } HashMap<String,String> getUrlParameters(String url) { HashMap<String,String> parameters = new HashMap<String, String>(); int start = url.lastIndexOf('?'); if(start>0 && url.length() > start +1) { url = url.substring(start + 1); } else { return parameters; } String[] tokens = url.split("&"); for(String token : tokens) { String[] params = token.split("="); if(params.length<2) { parameters.put(token, ""); } else { parameters.put(params[0], params[1]); } } return parameters; } public void init() { this.httpAffinityKey = getConfiguration().getSipConfiguration().getAlgorithmConfiguration().getHttpAffinityKey(); this.sipHeaderAffinityKey = getConfiguration().getSipConfiguration().getAlgorithmConfiguration().getSipHeaderAffinityKey(); logger.info("SIP affinity key = " + sipHeaderAffinityKey + " HTTP key = " + httpAffinityKey); } public void configurationChanged() { logger.info("Configuration changed"); this.httpAffinityKey = getConfiguration().getSipConfiguration().getAlgorithmConfiguration().getHttpAffinityKey(); this.sipHeaderAffinityKey = getConfiguration().getSipConfiguration().getAlgorithmConfiguration().getSipHeaderAffinityKey(); } @Override public void processExternalResponse(Response response,Boolean isIpV6){ this.processExternalResponse(response, this.invocationContext,isIpV6); } public void processExternalResponse(Response response, InvocationContext context,Boolean isIpV6) { Via via = (Via) response.getHeader(Via.NAME); String transport = via.getTransport().toLowerCase(); Integer nodeIndex = hashHeader(response,isIpV6); String host = via.getHost(); Integer port = via.getPort(); Boolean found = false; // for(Node node : context.nodes) { // if(node.getIp().equals(host)) { // if(port.equals(node.getProperties().get(transport+"Port"))) { // found = true; // } // } // } if(context.sipNodeMap(isIpV6).containsKey(new KeySip(host, port,isIpV6))) found = true; if(logger.isDebugEnabled()) { logger.debug("external response node found ? " + found); } if(!found) { if(nodesAreDirty) { synchronized(this) { syncNodes(isIpV6); } } try { Node node = (Node) nodesArray(isIpV6)[nodeIndex]; //if(node == null || !context.nodes.contains(node)) { if(node == null || !context.sipNodeMap(isIpV6).containsValue(node)) { if(logger.isDebugEnabled()) { logger.debug("No node to handle " + via); } } else { String transportProperty = transport + "Port"; port = Integer.parseInt(node.getProperties().get(transportProperty)); if(via.getHost().equalsIgnoreCase(node.getIp()) || via.getPort() != port) { if(logger.isDebugEnabled()) { logger.debug("changing retransmission via " + via + "setting new values " + node.getIp() + ":" + port); } try { via.setHost(node.getIp()); via.setPort(port); } catch (Exception e) { throw new RuntimeException("Error setting new values " + node.getIp() + ":" + port + " on via " + via, e); } // need to reset the rport for reliable transports if(!ListeningPoint.UDP.equalsIgnoreCase(transport)) { via.setRPort(); } } } } catch (Exception e) { } } } protected void syncNodes(Boolean isIpV6) { nodes(isIpV6).clear(); nodes(isIpV6).addAll(invocationContext.sipNodeMap(isIpV6).values()); if(isIpV6) nodesArrayV6 = nodes(true).toArray(new Object[]{}); else nodesArrayV4 = nodes(false).toArray(new Object[]{}); nodesAreDirty = false; } }