package org.xdi.oxauth.model.authorize;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.binary.BaseNCodec;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang.RandomStringUtils;
/**
* @author Yuriy Zabrovarnyy
* @version 0.9, 21/03/2016
*/
public class CodeVerifier {
private static final int MAX_CODE_VERIFIER_LENGTH = 128;
private static final int MIN_CODE_VERIFIER_LENGTH = 43;
private static final byte[] EMPTY_BYTE_ARRAY = new byte[0];
public enum CodeChallengeMethod {
PLAIN("plain", ""),
S256("s256", "SHA-256");
private String pkceString;
private String messageDigestString;
private CodeChallengeMethod(String pkceString, String messageDigestString) {
this.pkceString = pkceString;
this.messageDigestString = messageDigestString;
}
public String getMessageDigestString() {
return messageDigestString;
}
public String getPkceString() {
return pkceString;
}
public static CodeChallengeMethod fromString(String value) {
for (CodeChallengeMethod type : values()) {
if (type.getPkceString().equalsIgnoreCase(value)) {
return type;
}
}
return null;
}
}
private String codeVerifier;
private String codeChallenge;
private CodeChallengeMethod transformationType;
public CodeVerifier() {
this(CodeChallengeMethod.S256);
}
public CodeVerifier(CodeChallengeMethod transformationType) {
this.codeVerifier = generateCodeVerifier();
this.transformationType = transformationType;
this.codeChallenge = generateCodeChallenge(transformationType, codeVerifier);
}
public static String generateCodeChallenge(CodeChallengeMethod codeChallengeMethod, String codeVerifier) {
Preconditions.checkNotNull(codeChallengeMethod);
Preconditions.checkNotNull(codeVerifier);
switch (codeChallengeMethod) {
case PLAIN:
return codeVerifier;
case S256:
return s256(codeVerifier);
}
throw new RuntimeException("Unsupported code challenge method: " + codeChallengeMethod);
}
public static boolean matched(String codeChallenge, String codeChallengeMethod, String codeVerifier) {
return matched(codeChallenge, CodeChallengeMethod.fromString(codeChallengeMethod), codeVerifier);
}
public static boolean matched(String codeChallenge, CodeChallengeMethod codeChallengeMethod, String codeVerifier) {
if (Strings.isNullOrEmpty(codeChallenge) || codeChallengeMethod == null || Strings.isNullOrEmpty(codeVerifier)) {
return false;
}
return generateCodeChallenge(codeChallengeMethod, codeVerifier).equals(codeChallenge);
}
public static String s256(String codeVerifier) {
byte[] sha256 = DigestUtils.sha256(codeVerifier);
return base64UrlEncode(sha256);
}
public static String base64UrlEncode(byte[] input) {
Base64 base64 = new Base64(BaseNCodec.MIME_CHUNK_SIZE, EMPTY_BYTE_ARRAY, true);
return base64.encodeAsString(input);
}
public static String generateCodeVerifier() {
String alphabetic = "abcdefghijklmnopqrstuvwxyz";
String chars = alphabetic + alphabetic.toUpperCase()
+ "1234567890" + "-._~";
String code = RandomStringUtils.random(MAX_CODE_VERIFIER_LENGTH, chars);
Preconditions.checkState(isCodeVerifierValid(code));
return code;
}
public static boolean isCodeVerifierValid(String codeVerifier) {
if (codeVerifier == null) {
return false;
}
int length = codeVerifier.length();
if (length > MAX_CODE_VERIFIER_LENGTH || length < MIN_CODE_VERIFIER_LENGTH) {
return false;
}
return true;
}
public String getCodeChallenge() {
return codeChallenge;
}
public String getCodeVerifier() {
return codeVerifier;
}
public CodeChallengeMethod getTransformationType() {
return transformationType;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append("CodeVerifier");
sb.append("{codeVerifier='").append(codeVerifier).append('\'');
sb.append(", codeChallenge='").append(codeChallenge).append('\'');
sb.append(", transformationType=").append(transformationType);
sb.append('}');
return sb.toString();
}
}