package jwt4j; import com.google.gson.Gson; import com.google.gson.JsonObject; import jwt4j.exceptions.ExpiredTokenException; import jwt4j.exceptions.InvalidSignatureException; import jwt4j.exceptions.InvalidTokenException; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import java.util.Arrays; import java.util.Base64; import java.util.HashMap; import java.util.Map; import java.util.StringJoiner; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Matchers.any; import static org.mockito.Mockito.doThrow; @RunWith(MockitoJUnitRunner.class) public class JWTDecoderTest { private static final String SUBJECT = "subject"; private static final String SECRET = "secret"; @Rule public ExpectedException expectedException = ExpectedException.none(); private final Gson gson = new Gson(); private final Base64.Encoder encoder = Base64.getEncoder(); @Mock private TokenChecker tokenChecker; private JWTDecoder defaultJwtDecoder; private JWTDecoder noneAlgorithmJwtDecoder; @Before public void setUp() { defaultJwtDecoder = new JWTDecoder( Algorithm.HS256, SECRET.getBytes(), gson, Arrays.asList(tokenChecker)); noneAlgorithmJwtDecoder = new JWTDecoder( Algorithm.none, "".getBytes(), gson, Arrays.asList(tokenChecker)); } @Test public void shouldFailForNullToken() { //expect expectedException.expect(InvalidTokenException.class); expectedException.expectMessage("No token"); //when defaultJwtDecoder.decode(null); } @Test public void shouldFailForEmptyToken() { //expect expectedException.expect(InvalidTokenException.class); expectedException.expectMessage("No token"); //when defaultJwtDecoder.decode(""); } @Test public void shouldFailForMalformedToken() { //expect expectedException.expect(InvalidTokenException.class); expectedException.expectMessage("Invalid token structure"); //when defaultJwtDecoder.decode("x.x"); } @Test public void shouldFailForUnsupportedAlgorithm() { //expect expectedException.expect(IllegalStateException.class); expectedException.expectMessage("not supported"); //given final String header = getTokenPart(new HashMap<String, String>() { { put(JWTConstants.ALGORITHM, Algorithm.HS512.name()); } }); final String payload = getTokenPart(new HashMap<>()); //when defaultJwtDecoder.decode(new StringJoiner(".") .add(header) .add(payload) .add(encoder.encodeToString("x".getBytes())) .toString()); } @Test public void shouldRecognizeInvalidSignature() { //expect expectedException.expect(InvalidSignatureException.class); expectedException.expectMessage("compromised"); //given final String header = getTokenPart(new HashMap<String, String>() { { put(JWTConstants.ALGORITHM, Algorithm.HS256.name()); } }); final String payload = getTokenPart(new HashMap<>()); //when defaultJwtDecoder.decode(new StringJoiner(".") .add(header) .add(payload) .add(encoder.encodeToString("x".getBytes())) .toString()); } @Test public void shouldFailAtChecker() throws Exception { //expect expectedException.expect(ExpiredTokenException.class); //given doThrow(ExpiredTokenException.class).when(tokenChecker).check(any()); final String header = getTokenPart(new HashMap<String, String>() { { put(JWTConstants.ALGORITHM, Algorithm.HS256.name()); } }); final String payload = getTokenPart(new HashMap<>()); //when defaultJwtDecoder.decode(new StringJoiner(".") .add(header) .add(payload) .add(sign(header, payload)) .toString()); } @Test public void shouldReturnRegisteredClaims() throws Exception { //given final String header = getTokenPart(new HashMap<String, String>() { { put(JWTConstants.ALGORITHM, Algorithm.HS256.name()); } }); final String payload = getTokenPart(new HashMap<String, String>() { { put(JWTConstants.SUBJECT, SUBJECT); put("rubbish", "Lorem ipsum"); } }); //when //when final Map<String, String> result = defaultJwtDecoder.decode(new StringJoiner(".") .add(header) .add(payload) .add(sign(header, payload)) .toString()); //then assertThat(result).isNotNull().hasSize(1); assertThat(result.get(JWTConstants.SUBJECT)).isEqualTo(SUBJECT); } @Test public void shouldNotVerifySignatureForNoneAlgorithm() { //given final String header = getTokenPart(new HashMap<String, String>() { { put(JWTConstants.ALGORITHM, Algorithm.none.name()); } }); final String payload = getTokenPart(new HashMap<String, String>() { { put(JWTConstants.SUBJECT, SUBJECT); put("rubbish", "Lorem ipsum"); } }); //when //when final Map<String, String> result = noneAlgorithmJwtDecoder.decode(new StringJoiner(".") .add(header) .add(payload) .add("") .toString()); //then assertThat(result).isNotNull().hasSize(1); assertThat(result.get(JWTConstants.SUBJECT)).isEqualTo(SUBJECT); } private String sign(String header, String payload) throws Exception { final Mac mac = Mac.getInstance(Algorithm.HS256.name); mac.init(new SecretKeySpec(SECRET.getBytes(), Algorithm.HS256.name)); return new String(encoder.encodeToString(mac.doFinal( new StringJoiner(".").add(header).add(payload).toString().getBytes()))); } private String getTokenPart(Map<String, String> parameters) { final JsonObject jsonObject = new JsonObject(); parameters.forEach((key, value) -> jsonObject.addProperty(key, value)); return encoder.encodeToString(gson.toJson(jsonObject).getBytes()); } }