/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package gobblin.crypto; import java.io.FilterOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.util.HashMap; import java.util.Map; import java.util.Random; import org.apache.commons.codec.binary.Base64; import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.crypto.Cipher; import javax.crypto.CipherInputStream; import javax.crypto.CipherOutputStream; import javax.crypto.NoSuchPaddingException; import javax.crypto.SecretKey; import javax.crypto.spec.IvParameterSpec; import javax.crypto.spec.SecretKeySpec; import javax.xml.bind.DatatypeConverter; import gobblin.codec.Base64Codec; import gobblin.codec.StreamCodec; /** * Implementation of an encryption algorithm that works in the following way: * * 1. A credentialStore is provisioned with a set of AES keys * 2. When encodeOutputStream() is called, an AES key will be picked at random and a new initialization vector (IV) * will be generated. * 3. A header will be written [keyId][ivLength][base64 encoded iv] * 4. Ciphertext will be base64 encoded and written out. We do not insert linebreaks. */ public class RotatingAESCodec implements StreamCodec { private static final Logger log = LoggerFactory.getLogger(RotatingAESCodec.class); private static final int AES_KEY_LEN = 16; private static final String TAG = "aes_rotating"; private final Random random; private final CredentialStore credentialStore; /* * Cache valid keys in two forms: * A map for retrieving a key quickly (decode case) * An array for quickly selecting a random key (encode case) */ private volatile Map<Integer, KeyRecord> keyRecords_cache; private volatile KeyRecord[] keyRecords_cache_arr; /** * Create a new encryptor * @param credentialStore Credential store where keys can be found */ public RotatingAESCodec(CredentialStore credentialStore) { this.credentialStore = credentialStore; this.random = new Random(); } @Override public OutputStream encodeOutputStream(OutputStream origStream) throws IOException { return new EncodingStreamInstance(selectRandomKey(), origStream).wrapOutputStream(); } @Override public InputStream decodeInputStream(InputStream origStream) throws IOException { return new DecodingStreamInstance(origStream).wrapInputStream(); } private synchronized KeyRecord getKey(Integer key) { fillKeyRecords(); return keyRecords_cache.get(key); } private synchronized KeyRecord selectRandomKey() { KeyRecord[] keyRecords = getKeyRecords(); if (keyRecords.length == 0) { throw new IllegalStateException("Couldn't find any valid keys in store!"); } return keyRecords[random.nextInt(keyRecords.length)]; } private synchronized KeyRecord[] getKeyRecords() { fillKeyRecords(); return keyRecords_cache_arr; } private synchronized void fillKeyRecords() { if (keyRecords_cache == null) { keyRecords_cache = new HashMap<>(); for (Map.Entry<String, byte[]> entry : credentialStore.getAllEncodedKeys().entrySet()) { if (entry.getValue().length != AES_KEY_LEN) { log.debug("Skipping keyId {} because it is length {}; expected {}", entry.getKey(), entry.getValue().length, AES_KEY_LEN); continue; } try { Integer keyId = Integer.parseInt(entry.getKey()); SecretKey key = new SecretKeySpec(entry.getValue(), "AES"); keyRecords_cache.put(keyId, new KeyRecord(keyId, key)); } catch (NumberFormatException e) { log.debug("Skipping keyId {} because this algorithm can only use numeric key ids", entry.getKey()); } } keyRecords_cache_arr = keyRecords_cache.values().toArray(new KeyRecord[keyRecords_cache.size()]); } } @Override public String getTag() { return TAG; } /** * Represents a set of parsed AES keys that we can choose from when encrypting. */ static class KeyRecord { private final int keyId; private final SecretKey secretKey; KeyRecord(int keyId, SecretKey secretKey) { this.keyId = keyId; this.secretKey = secretKey; } int getKeyId() { return keyId; } SecretKey getSecretKey() { return secretKey; } } /** * Helper class that keeps state around for a wrapped output stream. Each stream will have a different * selected key, IV, and cipher state. */ static class EncodingStreamInstance { private final OutputStream origStream; private final KeyRecord secretKey; private Cipher cipher; private String base64Iv; private boolean headerWritten = false; EncodingStreamInstance(KeyRecord secretKey, OutputStream origStream) { this.secretKey = secretKey; this.origStream = origStream; } OutputStream wrapOutputStream() throws IOException { initCipher(); final OutputStream base64OutputStream = getBase64Stream(origStream); final CipherOutputStream encryptedStream = new CipherOutputStream(base64OutputStream, cipher); return new FilterOutputStream(origStream) { @Override public void write(int b) throws IOException { writeHeaderIfNecessary(); encryptedStream.write(b); } @Override public void write(byte[] b) throws IOException { writeHeaderIfNecessary(); encryptedStream.write(b); } @Override public void write(byte[] b, int off, int len) throws IOException { writeHeaderIfNecessary(); encryptedStream.write(b, off, len); } @Override public void close() throws IOException { encryptedStream.close(); } }; } private OutputStream getBase64Stream(OutputStream origStream) throws IOException { return new Base64Codec().encodeOutputStream(origStream); } private void initCipher() { if (origStream == null) { throw new IllegalStateException("Can't initCipher stream before encodeOutputStream() has been called!"); } try { cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); cipher.init(Cipher.ENCRYPT_MODE, secretKey.getSecretKey()); byte[] iv = cipher.getIV(); base64Iv = DatatypeConverter.printBase64Binary(iv); this.headerWritten = false; } catch (NoSuchAlgorithmException | NoSuchPaddingException e) { throw new IllegalStateException("Error creating AES algorithm? Should always exist in JRE"); } catch (InvalidKeyException e) { throw new IllegalStateException("Key " + secretKey.getKeyId() + " is illegal - please check credential store"); } } private void writeHeaderIfNecessary() throws IOException { if (!headerWritten) { String header = String.format("%04d%03d%s", secretKey.getKeyId(), base64Iv.length(), base64Iv); origStream.write(header.getBytes(StandardCharsets.UTF_8)); this.headerWritten = true; } } } private class DecodingStreamInstance { private final InputStream origStream; private final byte[] buffer = new byte[32]; private final Cipher cipher; DecodingStreamInstance(InputStream origStream) throws IOException { this.origStream = origStream; Integer keyId = readKey(); KeyRecord key = getKey(keyId); if (key == null) { throw new IOException("Cannot load key " + String.valueOf(keyId) + " which is specified in input stream"); } try { byte[] iv = readIv(); cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); if (iv != null) { IvParameterSpec ivParameterSpec = new IvParameterSpec(iv); cipher.init(Cipher.DECRYPT_MODE, key.getSecretKey(), ivParameterSpec); } else { cipher.init(Cipher.DECRYPT_MODE, key.getSecretKey()); } } catch (NoSuchAlgorithmException | NoSuchPaddingException e) { throw new IllegalStateException("Failed to load AES which should never happen", e); } catch (InvalidKeyException e) { throw new IllegalStateException("Failed to parse key from keystore", e); } catch (InvalidAlgorithmParameterException e) { throw new IllegalStateException("Failed to initialize IV", e); } } InputStream wrapInputStream() throws IOException { InputStream base64Decoder = new Base64Codec().decodeInputStream(origStream); return new CipherInputStream(base64Decoder, cipher); } // read and parse key from the bytestream private Integer readKey() throws IOException { IOUtils.readFully(origStream, buffer, 0, 4); try { return Integer.valueOf(new String(buffer, 0, 4, StandardCharsets.UTF_8)); } catch (NumberFormatException e) { throw new IOException("Expected to be able to parse first 4 bytes of stream as an ASCII keyId"); } } private byte[] readIv() throws IOException { IOUtils.readFully(origStream, buffer, 0, 3); Integer ivLen; try { ivLen = Integer.valueOf(new String(buffer, 0, 3, StandardCharsets.UTF_8)); } catch (NumberFormatException e) { throw new IOException("Expected to parse next 3 bytes of stream as an IV len"); } if (ivLen < 0 || ivLen > buffer.length) { throw new IOException( "Corrupted data suspected; expected IVLen to be between 0 and " + String.valueOf(buffer.length) + ", read " + String.valueOf(ivLen)); } if (ivLen == 0) { return null; } // IV is separately base64 encoded -- none of the standard base64 codec instances support decoding a slice of a // byte[] array so create a new buffer here byte[] ivBuffer = new byte[ivLen]; IOUtils.readFully(origStream, ivBuffer, 0, ivBuffer.length); return Base64.decodeBase64(ivBuffer); } } }