package io.airlift.airship.coordinator.auth;
import com.google.common.base.Charsets;
import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import io.airlift.units.Duration;
import org.apache.commons.codec.binary.Base64;
import javax.inject.Inject;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static java.lang.Math.abs;
import static java.lang.String.format;
import static java.lang.System.currentTimeMillis;
import static javax.ws.rs.core.Response.Status.BAD_REQUEST;
import static javax.ws.rs.core.Response.Status.FORBIDDEN;
import static org.apache.commons.codec.digest.DigestUtils.md5Hex;
public class AuthFilter
implements Filter
{
public static final String AUTHORIZED_KEY_ATTRIBUTE = "AuthorizedKey";
private static final Duration MAX_REQUEST_TIME_SKEW = new Duration(5, TimeUnit.MINUTES);
private final SignatureVerifier verifier;
private final boolean enabled;
@Inject
public AuthFilter(AuthConfig config, SignatureVerifier verifier)
{
this.verifier = verifier;
this.enabled = config.isEnabled();
}
@Override
public void init(FilterConfig filterConfig)
throws ServletException
{
}
/**
* Verify authorization header:
* <pre>
* Authorization: Airship fingerprint:signature
* fingerprint = hex md5 of private key
* signature = base64 signature of [ts, method, uri, bodyMd5]
* </pre>
*/
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain)
throws IOException, ServletException
{
if (!enabled) {
chain.doFilter(servletRequest, servletResponse);
return;
}
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
// get authorization headers
ArrayList<String> authorizations = Collections.list(request.getHeaders("Authorization"));
if (authorizations.isEmpty()) {
sendError(response, BAD_REQUEST, "Missing Authorization header");
return;
}
//
// Generate message
//
// get unix timestamp from request time
long millis;
try {
millis = request.getDateHeader("Date");
}
catch (IllegalArgumentException e) {
sendError(response, BAD_REQUEST, "Invalid Date header");
return;
}
if (millis == -1) {
sendError(response, BAD_REQUEST, "Missing Date header");
return;
}
long serverTime = currentTimeMillis();
if (abs(serverTime - millis) > MAX_REQUEST_TIME_SKEW.toMillis()) {
sendError(response, BAD_REQUEST, format("Request time too skewed (server time: %s)", serverTime / 1000));
return;
}
long timestamp = millis / 1000;
// get method and uri with query parameters
String method = request.getMethod();
String uri = getRequestUri(request);
// wrap request to allow reading body
RequestWrapper requestWrapper = new RequestWrapper(request);
String bodyMd5 = md5Hex(requestWrapper.getRequestBody());
// compute signature payload
String stringToSign = Joiner.on('\n').join(timestamp, method, uri, bodyMd5);
byte[] bytesToSign = stringToSign.getBytes(Charsets.UTF_8);
//
// try each authorization header
//
for (String authorization : authorizations) {
// parse authorization header
List<String> authTokens = ImmutableList.copyOf(Splitter.on(' ').omitEmptyStrings().split(authorization));
if ((authTokens.size() != 2) || (!authTokens.get(0).equals("Airship"))) {
sendError(response, BAD_REQUEST, "Invalid Authorization header");
return;
}
List<String> authParts = ImmutableList.copyOf(Splitter.on(':').split(authTokens.get(1)));
if (authParts.size() != 2) {
sendError(response, BAD_REQUEST, "Invalid Authorization token");
return;
}
// parse authorization token
String hexFingerprint = authParts.get(0);
String base64Signature = authParts.get(1);
Fingerprint fingerprint;
try {
fingerprint = Fingerprint.valueOf(hexFingerprint);
}
catch (IllegalArgumentException e) {
sendError(response, BAD_REQUEST, "Invalid Authorization fingerprint");
return;
}
byte[] signature;
try {
signature = Base64.decodeBase64(base64Signature);
}
catch (Exception e) {
sendError(response, BAD_REQUEST, "Invalid Authorization signature encoding");
return;
}
// verify signature
AuthorizedKey authorizedKey = verifier.verify(fingerprint, signature, bytesToSign);
if (authorizedKey == null) {
continue;
}
request.setAttribute(AUTHORIZED_KEY_ATTRIBUTE, authorizedKey);
chain.doFilter(requestWrapper, response);
return;
}
sendError(response, FORBIDDEN, "Signature verification failed");
}
@Override
public void destroy()
{
}
private static void sendError(HttpServletResponse response, Response.Status status, String error)
throws IOException
{
response.reset();
response.setStatus(status.getStatusCode());
response.setContentType(MediaType.TEXT_PLAIN);
PrintWriter writer = response.getWriter();
writer.println(error);
writer.close();
}
private static String getRequestUri(HttpServletRequest request)
{
String uri = request.getRequestURI();
String query = request.getQueryString();
return (query == null) ? uri : (uri + "?" + query);
}
}