/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.toolkit.tls.util; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.Reader; import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.security.GeneralSecurityException; import java.security.InvalidKeyException; import java.security.KeyPair; import java.security.KeyPairGenerator; import java.security.KeyStore; import java.security.KeyStoreSpi; import java.security.NoSuchAlgorithmException; import java.security.NoSuchProviderException; import java.security.Provider; import java.security.SignatureException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Arrays; import java.util.Date; import java.util.List; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; import org.apache.nifi.security.util.CertificateUtils; import org.apache.nifi.toolkit.tls.configuration.TlsConfig; import org.bouncycastle.asn1.pkcs.Attribute; import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers; import org.bouncycastle.asn1.x509.Extension; import org.bouncycastle.asn1.x509.Extensions; import org.bouncycastle.asn1.x509.GeneralName; import org.bouncycastle.asn1.x509.GeneralNames; import org.bouncycastle.cert.X509CertificateHolder; import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.bouncycastle.openssl.PEMKeyPair; import org.bouncycastle.openssl.PEMParser; import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter; import org.bouncycastle.operator.OperatorCreationException; import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequest; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.AdditionalMatchers; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @RunWith(MockitoJUnitRunner.class) public class TlsHelperTest { public static final Logger logger = LoggerFactory.getLogger(TlsHelperTest.class); private static final boolean originalUnlimitedCrypto = TlsHelper.isUnlimitedStrengthCryptographyEnabled(); private int days; private int keySize; private String keyPairAlgorithm; private String signingAlgorithm; private KeyPairGenerator keyPairGenerator; private KeyStore keyStore; @Mock KeyStoreSpi keyStoreSpi; @Mock Provider keyStoreProvider; @Mock OutputStreamFactory outputStreamFactory; private ByteArrayOutputStream tmpFileOutputStream; private File file; private static void setUnlimitedCrypto(boolean value) { try { Field isUnlimitedStrengthCryptographyEnabled = TlsHelper.class.getDeclaredField("isUnlimitedStrengthCryptographyEnabled"); isUnlimitedStrengthCryptographyEnabled.setAccessible(true); isUnlimitedStrengthCryptographyEnabled.set(null, value); } catch (Exception e) { throw new RuntimeException(e); } } public static KeyPair loadKeyPair(Reader reader) throws IOException { try (PEMParser pemParser = new PEMParser(reader)) { Object object = pemParser.readObject(); assertEquals(PEMKeyPair.class, object.getClass()); return new JcaPEMKeyConverter().getKeyPair((PEMKeyPair) object); } } public static KeyPair loadKeyPair(File file) throws IOException { return loadKeyPair(new FileReader(file)); } public static X509Certificate loadCertificate(Reader reader) throws IOException, CertificateException { try (PEMParser pemParser = new PEMParser(reader)) { Object object = pemParser.readObject(); assertEquals(X509CertificateHolder.class, object.getClass()); return new JcaX509CertificateConverter().setProvider(BouncyCastleProvider.PROVIDER_NAME).getCertificate((X509CertificateHolder) object); } } public static X509Certificate loadCertificate(File file) throws IOException, CertificateException { return loadCertificate(new FileReader(file)); } @Before public void setup() throws Exception { days = 360; keySize = 2048; keyPairAlgorithm = "RSA"; signingAlgorithm = "SHA1WITHRSA"; keyPairGenerator = KeyPairGenerator.getInstance(keyPairAlgorithm); keyPairGenerator.initialize(keySize); Constructor<KeyStore> keyStoreConstructor = KeyStore.class.getDeclaredConstructor(KeyStoreSpi.class, Provider.class, String.class); keyStoreConstructor.setAccessible(true); keyStore = keyStoreConstructor.newInstance(keyStoreSpi, keyStoreProvider, "faketype"); keyStore.load(null, null); file = File.createTempFile("keystore", "file"); when(outputStreamFactory.create(file)).thenReturn(tmpFileOutputStream); } @After public void tearDown() { setUnlimitedCrypto(originalUnlimitedCrypto); file.delete(); } private Date inFuture(int days) { return new Date(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(days)); } @Test public void testGenerateSelfSignedCert() throws GeneralSecurityException, IOException, OperatorCreationException { String dn = "CN=testDN,O=testOrg"; X509Certificate x509Certificate = CertificateUtils.generateSelfSignedX509Certificate(TlsHelper.generateKeyPair(keyPairAlgorithm, keySize), dn, signingAlgorithm, days); Date notAfter = x509Certificate.getNotAfter(); assertTrue(notAfter.after(inFuture(days - 1))); assertTrue(notAfter.before(inFuture(days + 1))); Date notBefore = x509Certificate.getNotBefore(); assertTrue(notBefore.after(inFuture(-1))); assertTrue(notBefore.before(inFuture(1))); assertEquals(dn, x509Certificate.getIssuerX500Principal().getName()); assertEquals(signingAlgorithm, x509Certificate.getSigAlgName()); assertEquals(keyPairAlgorithm, x509Certificate.getPublicKey().getAlgorithm()); x509Certificate.checkValidity(); } @Test public void testIssueCert() throws IOException, CertificateException, NoSuchAlgorithmException, OperatorCreationException, NoSuchProviderException, InvalidKeyException, SignatureException { X509Certificate issuer = loadCertificate(new InputStreamReader(getClass().getClassLoader().getResourceAsStream("rootCert.crt"))); KeyPair issuerKeyPair = loadKeyPair(new InputStreamReader(getClass().getClassLoader().getResourceAsStream("rootCert.key"))); String dn = "CN=testIssued, O=testOrg"; KeyPair keyPair = TlsHelper.generateKeyPair(keyPairAlgorithm, keySize); X509Certificate x509Certificate = CertificateUtils.generateIssuedCertificate(dn, keyPair.getPublic(), issuer, issuerKeyPair, signingAlgorithm, days); assertEquals(dn, x509Certificate.getSubjectX500Principal().toString()); assertEquals(issuer.getSubjectX500Principal().toString(), x509Certificate.getIssuerX500Principal().toString()); assertEquals(keyPair.getPublic(), x509Certificate.getPublicKey()); Date notAfter = x509Certificate.getNotAfter(); assertTrue(notAfter.after(inFuture(days - 1))); assertTrue(notAfter.before(inFuture(days + 1))); Date notBefore = x509Certificate.getNotBefore(); assertTrue(notBefore.after(inFuture(-1))); assertTrue(notBefore.before(inFuture(1))); assertEquals(signingAlgorithm, x509Certificate.getSigAlgName()); assertEquals(keyPairAlgorithm, x509Certificate.getPublicKey().getAlgorithm()); x509Certificate.verify(issuerKeyPair.getPublic()); } @Test public void testWriteKeyStoreSuccess() throws IOException, GeneralSecurityException { setUnlimitedCrypto(false); String testPassword = "testPassword"; assertEquals(testPassword, TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, false)); verify(keyStoreSpi, times(1)).engineStore(eq(tmpFileOutputStream), AdditionalMatchers.aryEq(testPassword.toCharArray())); } @Test public void testWriteKeyStoreFailure() throws IOException, GeneralSecurityException { setUnlimitedCrypto(false); String testPassword = "testPassword"; IOException ioException = new IOException("Fail"); doThrow(ioException).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream), AdditionalMatchers.aryEq(testPassword.toCharArray())); try { TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, true); fail("Expected " + ioException); } catch (IOException e) { assertEquals(ioException, e); } } @Test public void testWriteKeyStoreTruncate() throws IOException, GeneralSecurityException { setUnlimitedCrypto(false); String testPassword = "testPassword"; String truncatedPassword = testPassword.substring(0, 7); IOException ioException = new IOException(TlsHelper.ILLEGAL_KEY_SIZE); doThrow(ioException).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream), AdditionalMatchers.aryEq(testPassword.toCharArray())); assertEquals(truncatedPassword, TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, true)); verify(keyStoreSpi, times(1)).engineStore(eq(tmpFileOutputStream), AdditionalMatchers.aryEq(testPassword.toCharArray())); verify(keyStoreSpi, times(1)).engineStore(eq(tmpFileOutputStream), AdditionalMatchers.aryEq(truncatedPassword.toCharArray())); } @Test public void testWriteKeyStoreUnlimitedWontTruncate() throws GeneralSecurityException, IOException { setUnlimitedCrypto(true); String testPassword = "testPassword"; IOException ioException = new IOException(TlsHelper.ILLEGAL_KEY_SIZE); doThrow(ioException).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream), AdditionalMatchers.aryEq(testPassword.toCharArray())); try { TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, true); fail("Expected " + ioException); } catch (IOException e) { assertEquals(ioException, e); } } @Test public void testWriteKeyStoreNoTruncate() throws IOException, GeneralSecurityException { setUnlimitedCrypto(false); String testPassword = "testPassword"; IOException ioException = new IOException(TlsHelper.ILLEGAL_KEY_SIZE); doThrow(ioException).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream), AdditionalMatchers.aryEq(testPassword.toCharArray())); try { TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, false); fail("Expected " + GeneralSecurityException.class); } catch (GeneralSecurityException e) { assertTrue("Expected exception to contain " + TlsHelper.JCE_URL, e.getMessage().contains(TlsHelper.JCE_URL)); } } @Test public void testWriteKeyStoreTruncateFailure() throws IOException, GeneralSecurityException { setUnlimitedCrypto(false); String testPassword = "testPassword"; String truncatedPassword = testPassword.substring(0, 7); IOException ioException = new IOException(TlsHelper.ILLEGAL_KEY_SIZE); IOException ioException2 = new IOException(TlsHelper.ILLEGAL_KEY_SIZE); doThrow(ioException).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream), AdditionalMatchers.aryEq(testPassword.toCharArray())); doThrow(ioException2).when(keyStoreSpi).engineStore(eq(tmpFileOutputStream), AdditionalMatchers.aryEq(truncatedPassword.toCharArray())); try { TlsHelper.writeKeyStore(keyStore, outputStreamFactory, file, testPassword, true); fail("Expected " + ioException2); } catch (IOException e) { assertEquals(ioException2, e); } } @Test public void testShouldIncludeSANFromCSR() throws Exception { // Arrange final List<String> SAN_ENTRIES = Arrays.asList("127.0.0.1", "nifi.nifi.apache.org"); final String SAN = StringUtils.join(SAN_ENTRIES, ","); final int SAN_COUNT = SAN_ENTRIES.size(); final String DN = "CN=localhost"; KeyPair keyPair = keyPairGenerator.generateKeyPair(); logger.info("Generating CSR with DN: " + DN); // Act JcaPKCS10CertificationRequest csrWithSan = TlsHelper.generateCertificationRequest(DN, SAN, keyPair, TlsConfig.DEFAULT_SIGNING_ALGORITHM); logger.info("Created CSR with SAN: " + SAN); String testCsrPem = TlsHelper.pemEncodeJcaObject(csrWithSan); logger.info("Encoded CSR as PEM: " + testCsrPem); // Assert String subjectName = csrWithSan.getSubject().toString(); logger.info("CSR Subject Name: " + subjectName); assert subjectName.equals(DN); List<String> extractedSans = extractSanFromCsr(csrWithSan); assert extractedSans.size() == SAN_COUNT; List<String> formattedSans = SAN_ENTRIES.stream().map(s -> "DNS: " + s).collect(Collectors.toList()); assert extractedSans.containsAll(formattedSans); } private List<String> extractSanFromCsr(JcaPKCS10CertificationRequest csr) { List<String> sans = new ArrayList<>(); Attribute[] certAttributes = csr.getAttributes(); for (Attribute attribute : certAttributes) { if (attribute.getAttrType().equals(PKCSObjectIdentifiers.pkcs_9_at_extensionRequest)) { Extensions extensions = Extensions.getInstance(attribute.getAttrValues().getObjectAt(0)); GeneralNames gns = GeneralNames.fromExtensions(extensions, Extension.subjectAlternativeName); GeneralName[] names = gns.getNames(); for (GeneralName name : names) { logger.info("Type: " + name.getTagNo() + " | Name: " + name.getName()); String title = ""; if (name.getTagNo() == GeneralName.dNSName) { title = "DNS"; } else if (name.getTagNo() == GeneralName.iPAddress) { title = "IP Address"; // name.toASN1Primitive(); } else if (name.getTagNo() == GeneralName.otherName) { title = "Other Name"; } sans.add(title + ": " + name.getName()); } } } return sans; } }