package org.limewire.rest.oauth;
import java.io.UnsupportedEncodingException;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import javax.crypto.Mac;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.codec.binary.Base64;
import org.limewire.rest.RestUtils;
import org.limewire.util.StringUtils;
import com.google.inject.Inject;
import com.google.inject.assistedinject.Assisted;
/**
* Implementation of OAuthValidator used to validate requests using the OAuth
* protocol. At present, only the HMAC-SHA1 signature method is supported.
*/
public class OAuthValidatorImpl implements OAuthValidator {
private static final String VERSION = "1.0";
private static final String SIG_METHOD = "HMAC-SHA1";
private static final String MAC_NAME = "HmacSHA1";
/** Maximum nonce age is 2 minutes. */
private static final long NONCE_AGE = 2 * 60 * 1000L;
private final String baseUrl;
private final String consumerSecret;
private final String tokenSecret;
private final Map<String, Long> timestamps;
private final NonceTracker nonceTracker;
/**
* Constructs an OAuthValidator with the specified base URL, port number,
* and consumer secret. By default, the token secret is an empty string
* for use with two-legged OAuth.
*/
@Inject
public OAuthValidatorImpl(
@Assisted("baseUrl") String baseUrl,
@Assisted int port,
@Assisted("secret") String secret) {
this.baseUrl = createBaseUrl(baseUrl, port);
this.consumerSecret = secret;
this.tokenSecret = "";
this.timestamps = new ConcurrentHashMap<String, Long>();
this.nonceTracker = new NonceTracker();
}
/**
* Creates the base URL using the specified URL and port number. The
* default port numbers 80 (http) or 443 (https) are ignored because
* OAuth specifies that these must be excluded from the signature base
* string.
*/
private String createBaseUrl(String baseUrl, int port) {
// Split protocol and domain in URL string.
int pos = baseUrl.indexOf("//");
String protocol = (pos < 0) ? "" : baseUrl.substring(0, pos + 2);
String domain = (pos < 0) ? baseUrl : baseUrl.substring(pos + 2);
// Split uri from domain.
int uriPos = domain.indexOf('/');
String uri = (uriPos < 0) ? "" : domain.substring(uriPos);
domain = (uriPos < 0) ? domain : domain.substring(0, uriPos);
// Remove old port number.
int portPos = domain.indexOf(':');
domain = (portPos < 0) ? domain : domain.substring(0, portPos);
// Add port number to domain.
if ((port != 80) && (port != 443)) {
domain = domain + ':' + port;
}
// Recreate url string.
return protocol + domain + uri;
}
@Override
public void validateRequest(OAuthRequest request) throws OAuthException {
long currentMsec = System.currentTimeMillis();
validateParameters(request);
validateVersion(request);
validateTimestamp(request, currentMsec);
validateNonce(request, currentMsec);
validateSignatureMethod(request);
validateSignature(request);
}
/**
* Validates the required OAuth parameters in the specified request.
*/
private void validateParameters(OAuthRequest request) throws OAuthException {
if (request.getParameter(OAuthRequest.OAUTH_CONSUMER_KEY) == null) {
throw new OAuthException("Missing " + OAuthRequest.OAUTH_CONSUMER_KEY);
}
if (request.getParameter(OAuthRequest.OAUTH_SIGNATURE_METHOD) == null) {
throw new OAuthException("Missing " + OAuthRequest.OAUTH_SIGNATURE_METHOD);
}
if (request.getParameter(OAuthRequest.OAUTH_SIGNATURE) == null) {
throw new OAuthException("Missing " + OAuthRequest.OAUTH_SIGNATURE);
}
if (request.getParameter(OAuthRequest.OAUTH_TIMESTAMP) == null) {
throw new OAuthException("Missing " + OAuthRequest.OAUTH_TIMESTAMP);
}
if (request.getParameter(OAuthRequest.OAUTH_NONCE) == null) {
throw new OAuthException("Missing " + OAuthRequest.OAUTH_NONCE);
}
}
/**
* Validates the OAuth version in the specified request. The version is
* an optional parameter.
*/
private void validateVersion(OAuthRequest request) throws OAuthException {
String version = request.getParameter(OAuthRequest.OAUTH_VERSION);
if ((version != null) && !VERSION.equalsIgnoreCase(version)) {
throw new OAuthException("Invalid OAuth version");
}
}
/**
* Validates the timestamp in the specified request. According to OAuth
* Core 1.0 Revision A, Section 8, the timestamp is in seconds, and must
* be equal to or greater than the timestamp in previous requests.
*/
private void validateTimestamp(OAuthRequest request, long currentMsec) throws OAuthException {
// Get timestamp in seconds.
long timestamp = request.getParameter(OAuthRequest.OAUTH_TIMESTAMP, 0);
if (timestamp <= 0) {
throw new OAuthException("Invalid OAuth timestamp");
}
// Get previous timestamp.
String consumerKey = request.getParameter(OAuthRequest.OAUTH_CONSUMER_KEY);
Long prevTime = timestamps.get(consumerKey);
// Timestamp cannot be earlier than last request.
if ((prevTime != null) && (prevTime.longValue() > timestamp)) {
throw new OAuthException("OAuth timestamp earlier than previous");
}
timestamps.put(consumerKey, timestamp);
}
/**
* Validates the nonce in the specified request. According to OAuth
* Core 1.0 Revision A, Section 8, the nonce must be unique for all
* requests with the same timestamp.
*/
private void validateNonce(OAuthRequest request, long currentMsec) throws OAuthException {
// Get request parameters.
long timestamp = request.getParameter(OAuthRequest.OAUTH_TIMESTAMP, 0);
String consumerKey = request.getParameter(OAuthRequest.OAUTH_CONSUMER_KEY);
String nonceStr = request.getParameter(OAuthRequest.OAUTH_NONCE);
// Nonce must be unique.
Nonce nonce = new Nonce(timestamp, consumerKey, nonceStr);
boolean valid = nonceTracker.add(nonce);
if (!valid) {
throw new OAuthException("OAuth nonce already used");
}
// Remove old nonces.
nonceTracker.removeOldNonces(currentMsec);
}
/**
* Validates the OAuth signature method in the specified request. Only
* HMAC-SHA1 is supported.
*/
private void validateSignatureMethod(OAuthRequest request) throws OAuthException {
String sigMethod = request.getParameter(OAuthRequest.OAUTH_SIGNATURE_METHOD);
if (!SIG_METHOD.equalsIgnoreCase(sigMethod)) {
throw new OAuthException("Unsupported OAuth signature method");
}
}
/**
* Validates the OAuth signature in the specified request.
*/
private void validateSignature(OAuthRequest request) throws OAuthException {
// Retrieve request signature.
String oauthSignature = request.getParameter(OAuthRequest.OAUTH_SIGNATURE);
byte[] oauthBytes = Base64.decodeBase64(oauthSignature.getBytes());
try {
// Create base string and compute signature.
String baseString = OAuthUtils.createSignatureBaseString(request, baseUrl);
byte[] signatureBytes = computeSignature(baseString);
// Compare signatures.
if (!Arrays.equals(oauthBytes, signatureBytes)) {
throw new OAuthException("Invalid OAuth signature");
}
} catch (GeneralSecurityException ex) {
throw new OAuthException(ex);
} catch (UnsupportedEncodingException ex) {
throw new OAuthException(ex);
}
}
/**
* Computes the signature for the specified base string. The HMAC-SHA1
* signature method is used.
*/
private byte[] computeSignature(String baseString)
throws GeneralSecurityException, UnsupportedEncodingException {
// Create key.
String keyString = RestUtils.percentEncode(consumerSecret) + '&' + RestUtils.percentEncode(tokenSecret);
byte[] keyBytes = StringUtils.toUTF8Bytes(keyString);
SecretKey key = new SecretKeySpec(keyBytes, MAC_NAME);
// Compute signature using HmacSHA1.
Mac mac = Mac.getInstance(MAC_NAME);
mac.init(key);
byte[] text = StringUtils.toUTF8Bytes(baseString);
return mac.doFinal(text);
}
/**
* Representation of a nonce. Each timestamp/consumer key must use a
* unique nonce string.
*/
private static class Nonce {
private final long creationTime;
private final long timestamp;
private final String consumerKey;
private final String nonce;
public Nonce(long timestamp, String consumerKey, String nonce) {
this.creationTime = System.currentTimeMillis();
this.timestamp = timestamp;
this.consumerKey = consumerKey;
this.nonce = nonce;
}
public long getCreationTime() {
return creationTime;
}
@Override
public boolean equals(Object obj) {
if (obj instanceof Nonce) {
Nonce n2 = (Nonce) obj;
return (timestamp == n2.timestamp) &&
consumerKey.equals(n2.consumerKey) &&
nonce.equals(n2.nonce);
}
return false;
}
@Override
public int hashCode() {
int result = 17;
result = 31 * result + (int) (timestamp ^ (timestamp >>> 32));
result = 31 * result + consumerKey.hashCode();
result = 31 * result + nonce.hashCode();
return result;
}
}
/**
* Tracker to maintain Nonce values. Each Nonce must be unique. We also
* order nonces by their creation time so we can easily remove old values.
*/
private static class NonceTracker {
private final Set<Nonce> nonces;
private final Set<Nonce> orderedNonces;
public NonceTracker() {
nonces = new HashSet<Nonce>();
orderedNonces = new TreeSet<Nonce>(new Comparator<Nonce>() {
@Override
public int compare(Nonce o1, Nonce o2) {
long time1 = o1.getCreationTime();
long time2 = o2.getCreationTime();
int result = (time1 < time2) ? -1 : ((time1 > time2) ? 1 : 0);
return result;
}
});
}
public boolean add(Nonce nonce) {
synchronized(nonces) {
boolean valid = nonces.add(nonce);
if (valid) {
orderedNonces.add(nonce);
}
return valid;
}
}
public void removeOldNonces(long currentMsec) {
// Calculate oldest creation time.
long minMsec = currentMsec - NONCE_AGE;
synchronized(nonces) {
// Remove old nonces from cache. Nonces are stored in order of
// creation time so we can easily remove items that are too old.
for (Iterator<Nonce> iter = orderedNonces.iterator(); iter.hasNext();) {
Nonce nonce = iter.next();
if (minMsec < nonce.getCreationTime()) {
break;
}
nonces.remove(nonce);
iter.remove();
}
}
}
}
}