/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.toolkit.tls.service; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.nifi.security.util.KeystoreType; import org.apache.nifi.security.util.KeyStoreUtils; import org.apache.nifi.toolkit.tls.configuration.TlsClientConfig; import org.apache.nifi.toolkit.tls.configuration.TlsConfig; import org.apache.nifi.toolkit.tls.service.client.TlsCertificateAuthorityClient; import org.apache.nifi.toolkit.tls.service.client.TlsCertificateAuthorityClientCommandLine; import org.apache.nifi.toolkit.tls.service.server.TlsCertificateAuthorityService; import org.apache.nifi.toolkit.tls.standalone.TlsToolkitStandalone; import org.apache.nifi.toolkit.tls.util.InputStreamFactory; import org.apache.nifi.toolkit.tls.util.OutputStreamFactory; import org.junit.Before; import org.junit.Test; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.io.OutputStream; import java.net.ServerSocket; import java.nio.charset.StandardCharsets; import java.security.InvalidKeyException; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.NoSuchProviderException; import java.security.PrivateKey; import java.security.PublicKey; import java.security.Signature; import java.security.SignatureException; import java.security.UnrecoverableEntryException; import java.security.cert.Certificate; import java.security.cert.CertificateException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalMatchers.or; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; public class TlsCertificateAuthorityTest { private File serverConfigFile; private File clientConfigFile; private OutputStreamFactory outputStreamFactory; private InputStreamFactory inputStreamFactory; private TlsConfig serverConfig; private TlsClientConfig clientConfig; private ObjectMapper objectMapper; private ByteArrayOutputStream serverKeyStoreOutputStream; private ByteArrayOutputStream clientKeyStoreOutputStream; private ByteArrayOutputStream clientTrustStoreOutputStream; private ByteArrayOutputStream serverConfigFileOutputStream; private ByteArrayOutputStream clientConfigFileOutputStream; @Before public void setup() throws FileNotFoundException { objectMapper = new ObjectMapper(); serverConfigFile = new File("fake.server.config"); clientConfigFile = new File("fake.client.config"); String serverKeyStore = "serverKeyStore"; String clientKeyStore = "clientKeyStore"; String clientTrustStore = "clientTrustStore"; serverKeyStoreOutputStream = new ByteArrayOutputStream(); clientKeyStoreOutputStream = new ByteArrayOutputStream(); clientTrustStoreOutputStream = new ByteArrayOutputStream(); serverConfigFileOutputStream = new ByteArrayOutputStream(); clientConfigFileOutputStream = new ByteArrayOutputStream(); String myTestTokenUseSomethingStronger = "myTestTokenUseSomethingStronger"; int port = availablePort(); serverConfig = new TlsConfig(); serverConfig.setCaHostname("localhost"); serverConfig.setToken(myTestTokenUseSomethingStronger); serverConfig.setKeyStore(serverKeyStore); serverConfig.setPort(port); serverConfig.setDays(5); serverConfig.setKeySize(2048); serverConfig.initDefaults(); clientConfig = new TlsClientConfig(); clientConfig.setCaHostname("localhost"); clientConfig.setDn("OU=NIFI,CN=otherHostname"); clientConfig.setKeyStore(clientKeyStore); clientConfig.setTrustStore(clientTrustStore); clientConfig.setToken(myTestTokenUseSomethingStronger); clientConfig.setPort(port); clientConfig.setKeySize(2048); clientConfig.initDefaults(); outputStreamFactory = mock(OutputStreamFactory.class); mockReturnOutputStream(outputStreamFactory, new File(serverKeyStore), serverKeyStoreOutputStream); mockReturnOutputStream(outputStreamFactory, new File(clientKeyStore), clientKeyStoreOutputStream); mockReturnOutputStream(outputStreamFactory, new File(clientTrustStore), clientTrustStoreOutputStream); mockReturnOutputStream(outputStreamFactory, serverConfigFile, serverConfigFileOutputStream); mockReturnOutputStream(outputStreamFactory, clientConfigFile, clientConfigFileOutputStream); inputStreamFactory = mock(InputStreamFactory.class); mockReturnProperties(inputStreamFactory, serverConfigFile, serverConfig); mockReturnProperties(inputStreamFactory, clientConfigFile, clientConfig); } private void mockReturnProperties(InputStreamFactory inputStreamFactory, File file, TlsConfig tlsConfig) throws FileNotFoundException { when(inputStreamFactory.create(eq(file))).thenAnswer(invocation -> { ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); objectMapper.writeValue(byteArrayOutputStream, tlsConfig); return new ByteArrayInputStream(byteArrayOutputStream.toByteArray()); }); } private void mockReturnOutputStream(OutputStreamFactory outputStreamFactory, File file, OutputStream outputStream) throws FileNotFoundException { when(outputStreamFactory.create(or(eq(file), eq(new File(file.getAbsolutePath()))))).thenReturn(outputStream); } @Test public void testClientGetCertDifferentPasswordsForKeyAndKeyStore() throws Exception { TlsCertificateAuthorityService tlsCertificateAuthorityService = null; try { tlsCertificateAuthorityService = new TlsCertificateAuthorityService(outputStreamFactory); tlsCertificateAuthorityService.start(serverConfig, serverConfigFile.getAbsolutePath(), true); TlsCertificateAuthorityClient tlsCertificateAuthorityClient = new TlsCertificateAuthorityClient(outputStreamFactory); tlsCertificateAuthorityClient.generateCertificateAndGetItSigned(clientConfig, null, clientConfigFile.getAbsolutePath(), true); validate(); } finally { if (tlsCertificateAuthorityService != null) { tlsCertificateAuthorityService.shutdown(); } } } @Test public void testClientGetCertSamePasswordsForKeyAndKeyStore() throws Exception { TlsCertificateAuthorityService tlsCertificateAuthorityService = null; try { tlsCertificateAuthorityService = new TlsCertificateAuthorityService(outputStreamFactory); tlsCertificateAuthorityService.start(serverConfig, serverConfigFile.getAbsolutePath(), false); TlsCertificateAuthorityClient tlsCertificateAuthorityClient = new TlsCertificateAuthorityClient(outputStreamFactory); tlsCertificateAuthorityClient.generateCertificateAndGetItSigned(clientConfig, null, clientConfigFile.getAbsolutePath(), false); validate(); } finally { if (tlsCertificateAuthorityService != null) { tlsCertificateAuthorityService.shutdown(); } } } @Test public void testClientPkcs12() throws Exception { serverConfig.setKeyStoreType(KeystoreType.PKCS12.toString()); clientConfig.setKeyStoreType(KeystoreType.PKCS12.toString()); TlsCertificateAuthorityService tlsCertificateAuthorityService = null; try { tlsCertificateAuthorityService = new TlsCertificateAuthorityService(outputStreamFactory); tlsCertificateAuthorityService.start(serverConfig, serverConfigFile.getAbsolutePath(), false); TlsCertificateAuthorityClient tlsCertificateAuthorityClient = new TlsCertificateAuthorityClient(outputStreamFactory); new TlsCertificateAuthorityClientCommandLine(inputStreamFactory); tlsCertificateAuthorityClient.generateCertificateAndGetItSigned(clientConfig, null, clientConfigFile.getAbsolutePath(), true); validate(); } finally { if (tlsCertificateAuthorityService != null) { tlsCertificateAuthorityService.shutdown(); } } } @Test public void testTokenMismatch() throws Exception { serverConfig.setToken("a different token..."); try { testClientGetCertSamePasswordsForKeyAndKeyStore(); fail("Expected error with mismatching token"); } catch (IOException e) { assertTrue(e.getMessage().contains("forbidden")); } } private void validate() throws CertificateException, InvalidKeyException, NoSuchAlgorithmException, KeyStoreException, SignatureException, NoSuchProviderException, UnrecoverableEntryException, IOException { Certificate caCertificate = validateServerKeyStore(); validateClient(caCertificate); } private Certificate validateServerKeyStore() throws KeyStoreException, CertificateException, NoSuchAlgorithmException, IOException, UnrecoverableEntryException, InvalidKeyException, NoSuchProviderException, SignatureException { serverConfig = objectMapper.readValue(new ByteArrayInputStream(serverConfigFileOutputStream.toByteArray()), TlsConfig.class); KeyStore serverKeyStore = KeyStoreUtils.getKeyStore(serverConfig.getKeyStoreType()); serverKeyStore.load(new ByteArrayInputStream(serverKeyStoreOutputStream.toByteArray()), serverConfig.getKeyStorePassword().toCharArray()); String keyPassword = serverConfig.getKeyPassword(); KeyStore.Entry serverKeyEntry = serverKeyStore.getEntry(TlsToolkitStandalone.NIFI_KEY, new KeyStore.PasswordProtection(keyPassword == null ? serverConfig.getKeyStorePassword().toCharArray() : keyPassword.toCharArray())); assertTrue(serverKeyEntry instanceof KeyStore.PrivateKeyEntry); KeyStore.PrivateKeyEntry privateKeyEntry = (KeyStore.PrivateKeyEntry) serverKeyEntry; Certificate[] certificateChain = privateKeyEntry.getCertificateChain(); assertEquals(1, certificateChain.length); Certificate caCertificate = certificateChain[0]; caCertificate.verify(caCertificate.getPublicKey()); assertPrivateAndPublicKeyMatch(privateKeyEntry.getPrivateKey(), caCertificate.getPublicKey()); return caCertificate; } private void validateClient(Certificate caCertificate) throws IOException, KeyStoreException, CertificateException, NoSuchAlgorithmException, UnrecoverableEntryException, InvalidKeyException, NoSuchProviderException, SignatureException { clientConfig = objectMapper.readValue(new ByteArrayInputStream(clientConfigFileOutputStream.toByteArray()), TlsClientConfig.class); KeyStore clientKeyStore = KeyStoreUtils.getKeyStore(clientConfig.getKeyStoreType()); clientKeyStore.load(new ByteArrayInputStream(clientKeyStoreOutputStream.toByteArray()), clientConfig.getKeyStorePassword().toCharArray()); String keyPassword = clientConfig.getKeyPassword(); KeyStore.Entry clientKeyStoreEntry = clientKeyStore.getEntry(TlsToolkitStandalone.NIFI_KEY, new KeyStore.PasswordProtection(keyPassword == null ? clientConfig.getKeyStorePassword().toCharArray() : keyPassword.toCharArray())); assertTrue(clientKeyStoreEntry instanceof KeyStore.PrivateKeyEntry); KeyStore.PrivateKeyEntry clientPrivateKeyEntry = (KeyStore.PrivateKeyEntry) clientKeyStoreEntry; Certificate[] certificateChain = clientPrivateKeyEntry.getCertificateChain(); assertEquals(2, certificateChain.length); assertEquals(caCertificate, certificateChain[1]); certificateChain[0].verify(caCertificate.getPublicKey()); assertPrivateAndPublicKeyMatch(clientPrivateKeyEntry.getPrivateKey(), certificateChain[0].getPublicKey()); KeyStore clientTrustStore = KeyStoreUtils.getTrustStore(KeystoreType.JKS.toString()); clientTrustStore.load(new ByteArrayInputStream(clientTrustStoreOutputStream.toByteArray()), clientConfig.getTrustStorePassword().toCharArray()); assertEquals(caCertificate, clientTrustStore.getCertificate(TlsToolkitStandalone.NIFI_CERT)); } public static void assertPrivateAndPublicKeyMatch(PrivateKey privateKey, PublicKey publicKey) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { Signature signature = Signature.getInstance(TlsConfig.DEFAULT_SIGNING_ALGORITHM); signature.initSign(privateKey); byte[] bytes = "test string".getBytes(StandardCharsets.UTF_8); signature.update(bytes); Signature verify = Signature.getInstance(TlsConfig.DEFAULT_SIGNING_ALGORITHM); verify.initVerify(publicKey); verify.update(bytes); verify.verify(signature.sign()); } /** * Will determine the available port used by ca server */ private int availablePort() { ServerSocket s = null; try { s = new ServerSocket(0); s.setReuseAddress(true); return s.getLocalPort(); } catch (Exception e) { throw new IllegalStateException("Failed to discover available port.", e); } finally { try { s.close(); } catch (IOException e) { // ignore } } } }