/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.kafka.common.security.scram; import org.junit.Before; import org.junit.Test; import java.nio.charset.StandardCharsets; import javax.security.sasl.SaslException; import javax.xml.bind.DatatypeConverter; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; import org.apache.kafka.common.security.scram.ScramMessages.AbstractScramMessage; import org.apache.kafka.common.security.scram.ScramMessages.ClientFinalMessage; import org.apache.kafka.common.security.scram.ScramMessages.ClientFirstMessage; import org.apache.kafka.common.security.scram.ScramMessages.ServerFinalMessage; import org.apache.kafka.common.security.scram.ScramMessages.ServerFirstMessage; public class ScramMessagesTest { private static final String[] VALID_EXTENSIONS = { "ext=val1", "anotherext=name1=value1 name2=another test value \"\'!$[]()", "first=val1,second=name1 = value ,third=123" }; private static final String[] INVALID_EXTENSIONS = { "ext1=value", "ext", "ext=value1,value2", "ext=,", "ext =value" }; private static final String[] VALID_RESERVED = { "m=reserved-value", "m=name1=value1 name2=another test value \"\'!$[]()" }; private static final String[] INVALID_RESERVED = { "m", "m=name,value", "m=," }; private ScramFormatter formatter; @Before public void setUp() throws Exception { formatter = new ScramFormatter(ScramMechanism.SCRAM_SHA_256); } @Test public void validClientFirstMessage() throws SaslException { String nonce = formatter.secureRandomString(); ClientFirstMessage m = new ClientFirstMessage("someuser", nonce); checkClientFirstMessage(m, "someuser", nonce, ""); // Default format used by Kafka client: only user and nonce are specified String str = String.format("n,,n=testuser,r=%s", nonce); m = createScramMessage(ClientFirstMessage.class, str); checkClientFirstMessage(m, "testuser", nonce, ""); m = new ClientFirstMessage(m.toBytes()); checkClientFirstMessage(m, "testuser", nonce, ""); // Username containing comma, encoded as =2C str = String.format("n,,n=test=2Cuser,r=%s", nonce); m = createScramMessage(ClientFirstMessage.class, str); checkClientFirstMessage(m, "test=2Cuser", nonce, ""); assertEquals("test,user", formatter.username(m.saslName())); // Username containing equals, encoded as =3D str = String.format("n,,n=test=3Duser,r=%s", nonce); m = createScramMessage(ClientFirstMessage.class, str); checkClientFirstMessage(m, "test=3Duser", nonce, ""); assertEquals("test=user", formatter.username(m.saslName())); // Optional authorization id specified str = String.format("n,a=testauthzid,n=testuser,r=%s", nonce); checkClientFirstMessage(createScramMessage(ClientFirstMessage.class, str), "testuser", nonce, "testauthzid"); // Optional reserved value specified for (String reserved : VALID_RESERVED) { str = String.format("n,,%s,n=testuser,r=%s", reserved, nonce); checkClientFirstMessage(createScramMessage(ClientFirstMessage.class, str), "testuser", nonce, ""); } // Optional extension specified for (String extension : VALID_EXTENSIONS) { str = String.format("n,,n=testuser,r=%s,%s", nonce, extension); checkClientFirstMessage(createScramMessage(ClientFirstMessage.class, str), "testuser", nonce, ""); } } @Test public void invalidClientFirstMessage() throws SaslException { String nonce = formatter.secureRandomString(); // Invalid entry in gs2-header String invalid = String.format("n,x=something,n=testuser,r=%s", nonce); checkInvalidScramMessage(ClientFirstMessage.class, invalid); // Invalid reserved entry for (String reserved : INVALID_RESERVED) { invalid = String.format("n,,%s,n=testuser,r=%s", reserved, nonce); checkInvalidScramMessage(ClientFirstMessage.class, invalid); } // Invalid extension for (String extension : INVALID_EXTENSIONS) { invalid = String.format("n,,n=testuser,r=%s,%s", nonce, extension); checkInvalidScramMessage(ClientFirstMessage.class, invalid); } } @Test public void validServerFirstMessage() throws SaslException { String clientNonce = formatter.secureRandomString(); String serverNonce = formatter.secureRandomString(); String nonce = clientNonce + serverNonce; String salt = randomBytesAsString(); ServerFirstMessage m = new ServerFirstMessage(clientNonce, serverNonce, toBytes(salt), 8192); checkServerFirstMessage(m, nonce, salt, 8192); // Default format used by Kafka clients, only nonce, salt and iterations are specified String str = String.format("r=%s,s=%s,i=4096", nonce, salt); m = createScramMessage(ServerFirstMessage.class, str); checkServerFirstMessage(m, nonce, salt, 4096); m = new ServerFirstMessage(m.toBytes()); checkServerFirstMessage(m, nonce, salt, 4096); // Optional reserved value for (String reserved : VALID_RESERVED) { str = String.format("%s,r=%s,s=%s,i=4096", reserved, nonce, salt); checkServerFirstMessage(createScramMessage(ServerFirstMessage.class, str), nonce, salt, 4096); } // Optional extension for (String extension : VALID_EXTENSIONS) { str = String.format("r=%s,s=%s,i=4096,%s", nonce, salt, extension); checkServerFirstMessage(createScramMessage(ServerFirstMessage.class, str), nonce, salt, 4096); } } @Test public void invalidServerFirstMessage() throws SaslException { String nonce = formatter.secureRandomString(); String salt = randomBytesAsString(); // Invalid iterations String invalid = String.format("r=%s,s=%s,i=0", nonce, salt); checkInvalidScramMessage(ServerFirstMessage.class, invalid); // Invalid salt invalid = String.format("r=%s,s=%s,i=4096", nonce, "=123"); checkInvalidScramMessage(ServerFirstMessage.class, invalid); // Invalid format invalid = String.format("r=%s,invalid,s=%s,i=4096", nonce, salt); checkInvalidScramMessage(ServerFirstMessage.class, invalid); // Invalid reserved entry for (String reserved : INVALID_RESERVED) { invalid = String.format("%s,r=%s,s=%s,i=4096", reserved, nonce, salt); checkInvalidScramMessage(ServerFirstMessage.class, invalid); } // Invalid extension for (String extension : INVALID_EXTENSIONS) { invalid = String.format("r=%s,s=%s,i=4096,%s", nonce, salt, extension); checkInvalidScramMessage(ServerFirstMessage.class, invalid); } } @Test public void validClientFinalMessage() throws SaslException { String nonce = formatter.secureRandomString(); String channelBinding = randomBytesAsString(); String proof = randomBytesAsString(); ClientFinalMessage m = new ClientFinalMessage(toBytes(channelBinding), nonce); assertNull("Invalid proof", m.proof()); m.proof(toBytes(proof)); checkClientFinalMessage(m, channelBinding, nonce, proof); // Default format used by Kafka client: channel-binding, nonce and proof are specified String str = String.format("c=%s,r=%s,p=%s", channelBinding, nonce, proof); m = createScramMessage(ClientFinalMessage.class, str); checkClientFinalMessage(m, channelBinding, nonce, proof); m = new ClientFinalMessage(m.toBytes()); checkClientFinalMessage(m, channelBinding, nonce, proof); // Optional extension specified for (String extension : VALID_EXTENSIONS) { str = String.format("c=%s,r=%s,%s,p=%s", channelBinding, nonce, extension, proof); checkClientFinalMessage(createScramMessage(ClientFinalMessage.class, str), channelBinding, nonce, proof); } } @Test public void invalidClientFinalMessage() throws SaslException { String nonce = formatter.secureRandomString(); String channelBinding = randomBytesAsString(); String proof = randomBytesAsString(); // Invalid channel binding String invalid = String.format("c=ab,r=%s,p=%s", nonce, proof); checkInvalidScramMessage(ClientFirstMessage.class, invalid); // Invalid proof invalid = String.format("c=%s,r=%s,p=123", channelBinding, nonce); checkInvalidScramMessage(ClientFirstMessage.class, invalid); // Invalid extensions for (String extension : INVALID_EXTENSIONS) { invalid = String.format("c=%s,r=%s,%s,p=%s", channelBinding, nonce, extension, proof); checkInvalidScramMessage(ClientFinalMessage.class, invalid); } } @Test public void validServerFinalMessage() throws SaslException { String serverSignature = randomBytesAsString(); ServerFinalMessage m = new ServerFinalMessage("unknown-user", null); checkServerFinalMessage(m, "unknown-user", null); m = new ServerFinalMessage(null, toBytes(serverSignature)); checkServerFinalMessage(m, null, serverSignature); // Default format used by Kafka clients for successful final message String str = String.format("v=%s", serverSignature); m = createScramMessage(ServerFinalMessage.class, str); checkServerFinalMessage(m, null, serverSignature); m = new ServerFinalMessage(m.toBytes()); checkServerFinalMessage(m, null, serverSignature); // Default format used by Kafka clients for final message with error str = "e=other-error"; m = createScramMessage(ServerFinalMessage.class, str); checkServerFinalMessage(m, "other-error", null); m = new ServerFinalMessage(m.toBytes()); checkServerFinalMessage(m, "other-error", null); // Optional extension for (String extension : VALID_EXTENSIONS) { str = String.format("v=%s,%s", serverSignature, extension); checkServerFinalMessage(createScramMessage(ServerFinalMessage.class, str), null, serverSignature); } } @Test public void invalidServerFinalMessage() throws SaslException { String serverSignature = randomBytesAsString(); // Invalid error String invalid = "e=error1,error2"; checkInvalidScramMessage(ServerFinalMessage.class, invalid); // Invalid server signature invalid = String.format("v=1=23"); checkInvalidScramMessage(ServerFinalMessage.class, invalid); // Invalid extensions for (String extension : INVALID_EXTENSIONS) { invalid = String.format("v=%s,%s", serverSignature, extension); checkInvalidScramMessage(ServerFinalMessage.class, invalid); invalid = String.format("e=unknown-user,%s", extension); checkInvalidScramMessage(ServerFinalMessage.class, invalid); } } private String randomBytesAsString() { return DatatypeConverter.printBase64Binary(formatter.secureRandomBytes()); } private byte[] toBytes(String base64Str) { return DatatypeConverter.parseBase64Binary(base64Str); }; private void checkClientFirstMessage(ClientFirstMessage message, String saslName, String nonce, String authzid) { assertEquals(saslName, message.saslName()); assertEquals(nonce, message.nonce()); assertEquals(authzid, message.authorizationId()); } private void checkServerFirstMessage(ServerFirstMessage message, String nonce, String salt, int iterations) { assertEquals(nonce, message.nonce()); assertArrayEquals(DatatypeConverter.parseBase64Binary(salt), message.salt()); assertEquals(iterations, message.iterations()); } private void checkClientFinalMessage(ClientFinalMessage message, String channelBinding, String nonce, String proof) { assertArrayEquals(DatatypeConverter.parseBase64Binary(channelBinding), message.channelBinding()); assertEquals(nonce, message.nonce()); assertArrayEquals(DatatypeConverter.parseBase64Binary(proof), message.proof()); } private void checkServerFinalMessage(ServerFinalMessage message, String error, String serverSignature) { assertEquals(error, message.error()); if (serverSignature == null) assertNull("Unexpected server signature", message.serverSignature()); else assertArrayEquals(DatatypeConverter.parseBase64Binary(serverSignature), message.serverSignature()); } @SuppressWarnings("unchecked") private <T extends AbstractScramMessage> T createScramMessage(Class<T> clazz, String message) throws SaslException { byte[] bytes = message.getBytes(StandardCharsets.UTF_8); if (clazz == ClientFirstMessage.class) return (T) new ClientFirstMessage(bytes); else if (clazz == ServerFirstMessage.class) return (T) new ServerFirstMessage(bytes); else if (clazz == ClientFinalMessage.class) return (T) new ClientFinalMessage(bytes); else if (clazz == ServerFinalMessage.class) return (T) new ServerFinalMessage(bytes); else throw new IllegalArgumentException("Unknown message type: " + clazz); } private <T extends AbstractScramMessage> void checkInvalidScramMessage(Class<T> clazz, String message) { try { createScramMessage(clazz, message); fail("Exception not throws for invalid message of type " + clazz + " : " + message); } catch (SaslException e) { // Expected exception } } }