/* * Copyright 2014 EMC Corporation. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. * A copy of the License is located at * * http://www.apache.org/licenses/LICENSE-2.0.txt * * or in the "license" file accompanying this file. This file is distributed * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either * express or implied. See the License for the specific language governing * permissions and limitations under the License. */ package com.emc.vipr.ribbon; import com.emc.vipr.ribbon.bean.ListDataNode; import com.netflix.client.config.IClientConfig; import com.netflix.loadbalancer.AbstractServerList; import com.netflix.loadbalancer.Server; import org.apache.commons.codec.binary.Base64; import org.apache.http.HttpResponse; import org.apache.http.client.HttpClient; import org.apache.http.client.methods.HttpGet; import org.apache.http.impl.client.DefaultHttpClient; import org.apache.http.impl.conn.PoolingClientConnectionManager; import org.apache.http.params.HttpConnectionParams; import org.apache.http.util.EntityUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import javax.xml.bind.JAXBContext; import javax.xml.bind.JAXBException; import javax.xml.bind.Unmarshaller; import java.io.IOException; import java.io.InputStream; import java.text.SimpleDateFormat; import java.util.*; public class ViPRDataServicesServerList extends AbstractServerList<Server> { private static final Logger logger = LoggerFactory.getLogger(ViPRDataServicesServerList.class); protected final SimpleDateFormat rfc822DateFormat = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US); protected Unmarshaller unmarshaller; private String protocol; private List<Server> nodeList; private int port; private String user; private String secret; private HttpClient httpClient; private int requestCounter = 0; public ViPRDataServicesServerList() { rfc822DateFormat.setTimeZone(new SimpleTimeZone(0, "GMT")); try { unmarshaller = JAXBContext.newInstance(ListDataNode.class).createUnmarshaller(); } catch (JAXBException e) { throw new RuntimeException("can't create unmarshaller", e); } PoolingClientConnectionManager cm = new PoolingClientConnectionManager(); cm.setDefaultMaxPerRoute(10); httpClient = new DefaultHttpClient(cm); } @Override public void initWithNiwsConfig(IClientConfig clientConfig) { protocol = clientConfig.getPropertyAsString(SmartClientConfigKey.ViPRDataServicesProtocol, "").toLowerCase(); if (!Arrays.asList("http", "https").contains(protocol)) throw new IllegalArgumentException("Invalid protocol: " + protocol); String nodeStr = clientConfig.getPropertyAsString(SmartClientConfigKey.ViPRDataServicesInitialNodes, ""); if (nodeStr.trim().length() == 0) throw new IllegalStateException("No servers configured in smartConfig or NIWS config"); setNodeList(SmartClientConfig.parseServerList(nodeStr)); // pull the port from the initial node list (it will not be returned by the list-data-nodes call) port = getNodeList().get(0).getPort(); user = clientConfig.getPropertyAsString(SmartClientConfigKey.ViPRDataServicesUser, null); secret = clientConfig.getPropertyAsString(SmartClientConfigKey.ViPRDataServicesUserSecret, null); int timeout = clientConfig.getPropertyAsInteger(SmartClientConfigKey.ViPRDataServicesTimeout, SmartClientConfig.DEFAULT_TIMEOUT); HttpConnectionParams.setConnectionTimeout(httpClient.getParams(), timeout); HttpConnectionParams.setSoTimeout(httpClient.getParams(), timeout); if (logger.isDebugEnabled()) { logger.debug("Configured node enumeration"); logger.debug("--- protocol: " + protocol); logger.debug("--- nodeList: " + getNodeList()); logger.debug("--- user: " + user); logger.debug("--- secret: " + secret); logger.debug("--- timeout: " + timeout); logger.debug("--- httpClient: " + (httpClient != null)); } } @Override public List<Server> getInitialListOfServers() { return getNodeList(); } @Override public List<Server> getUpdatedListOfServers() { return pollForServers(); } protected List<Server> pollForServers() { try { List<Server> activeNodeList = getNodeList(); int activeNodeCount = activeNodeList.size(); Server server = null; List<String> hosts = null; String path = "/?endpoint"; // we want to try a different node on failure until we try every active node (HttpClient will auto-retry // 500s and some IOEs), but we don't want to start with the same node each time. for (int i = 0; i < activeNodeCount; i++) { try { // get next server in the list (trying to distribute this call among active nodes) // note: the extra modulus logic is there just in case requestCounter wraps around to a negative value server = activeNodeList.get((requestCounter++ % activeNodeCount + activeNodeCount) % activeNodeCount); HttpGet request = new HttpGet(protocol + "://" + server + path); logger.debug("endpoint query attempt #" + (i + 1) + ": trying " + server); // format date String rfcDate; synchronized (rfc822DateFormat) { rfcDate = rfc822DateFormat.format(new Date()); } // generate signature String canonicalString = "GET\n\n\n" + rfcDate + "\n" + path; String signature = getSignature(canonicalString, secret); // add date and auth headers request.addHeader("Date", rfcDate); request.addHeader("Authorization", "AWS " + user + ":" + signature); // send request HttpResponse response = httpClient.execute(request); if (response.getStatusLine().getStatusCode() > 299) { EntityUtils.consumeQuietly(response.getEntity()); throw new RuntimeException("received error response: " + response.getStatusLine()); } logger.debug("received success response: " + response.getStatusLine()); hosts = parseResponse(response); break; } catch (Exception e) { logger.warn("error polling for endpoints on " + server, e); } } if (hosts == null) throw new RuntimeException("Exhausted all nodes; no response available"); List<Server> updatedNodeList = new ArrayList<Server>(); for (String host : hosts) { updatedNodeList.add(new Server(host, port)); } setNodeList(updatedNodeList); } catch (Exception e) { logger.warn("Unable to poll for servers", e); } return getNodeList(); } protected String getSignature(String canonicalString, String secret) throws Exception { Mac mac = Mac.getInstance("HmacSHA1"); mac.init(new SecretKeySpec(secret.getBytes("UTF-8"), "HmacSHA1")); String signature = new String(Base64.encodeBase64(mac.doFinal(canonicalString.getBytes("UTF-8")))); logger.debug("canonicalString:\n" + canonicalString); logger.debug("signature:\n" + signature); return signature; } @SuppressWarnings("unchecked") protected List<String> parseResponse(HttpResponse response) throws IOException, JAXBException { InputStream contentStream = response.getEntity().getContent(); try { ListDataNode listDataNode = (ListDataNode) unmarshaller.unmarshal(contentStream); List<String> hosts = new ArrayList<String>(); for (String host : listDataNode.getDataNodes()) { hosts.add(host.trim()); } return hosts; } finally { try { contentStream.close(); } catch (RuntimeException e) { logger.warn("error closing HTTP content stream", e); } } } protected List<Server> getNodeList() { return nodeList; } protected synchronized void setNodeList(List<Server> nodeList) { this.nodeList = Collections.unmodifiableList(nodeList); } }