/*
* Password Management Servlets (PWM)
* http://www.pwm-project.org
*
* Copyright (c) 2006-2009 Novell, Inc.
* Copyright (c) 2009-2017 The PWM Project
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/
package password.pwm.util.secure;
import org.apache.commons.io.IOUtils;
import password.pwm.PwmConstants;
import password.pwm.error.ErrorInformation;
import password.pwm.error.PwmError;
import password.pwm.error.PwmUnrecoverableException;
import password.pwm.util.java.JavaHelper;
import password.pwm.util.java.StringUtil;
import password.pwm.util.java.TimeDuration;
import password.pwm.util.logging.PwmLogger;
import javax.crypto.Cipher;
import javax.crypto.Mac;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Writer;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.security.GeneralSecurityException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Random;
/**
* Primary static security/crypto library for app.
*/
public class SecureEngine {
private static final PwmLogger LOGGER = PwmLogger.forClass(SecureEngine.class);
private static final int HASH_BUFFER_SIZE = 1024 * 4;
private static final NonceGenerator AES_GCM_NONCE_GENERATOR = new NonceGenerator(8,8);
private SecureEngine() {
}
public enum Flag {
URL_SAFE,
}
public static String encryptToString(
final String value,
final PwmSecurityKey key,
final PwmBlockAlgorithm blockAlgorithm,
final Flag... flags
)
throws PwmUnrecoverableException {
try {
final byte[] encrypted = encryptToBytes(value, key, blockAlgorithm);
return Arrays.asList(flags).contains(Flag.URL_SAFE)
? StringUtil.base64Encode(encrypted, StringUtil.Base64Options.URL_SAFE, StringUtil.Base64Options.GZIP)
: StringUtil.base64Encode(encrypted);
} catch (Exception e) {
final String errorMsg = "unexpected error b64 encoding crypto result: " + e.getMessage();
final ErrorInformation errorInformation = new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, errorMsg);
LOGGER.error(errorInformation.toDebugStr());
throw new PwmUnrecoverableException(errorInformation);
}
}
static final int GCM_TAG_LENGTH = 16; // in bytes
public static byte[] encryptToBytes(
final String value,
final PwmSecurityKey key,
final PwmBlockAlgorithm blockAlgorithm
)
throws PwmUnrecoverableException
{
try {
if (value == null || value.length() < 1) {
return null;
}
final SecretKey aesKey = key.getKey(blockAlgorithm.getBlockKey());
final byte[] nonce;
final Cipher cipher;
if (blockAlgorithm == PwmBlockAlgorithm.AES128_GCM) {
nonce = AES_GCM_NONCE_GENERATOR.nextValue();
final GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH * 8, nonce);
cipher = Cipher.getInstance(blockAlgorithm.getAlgName());
cipher.init(Cipher.ENCRYPT_MODE, aesKey, spec);
} else {
cipher = Cipher.getInstance(blockAlgorithm.getAlgName());
cipher.init(Cipher.ENCRYPT_MODE, aesKey, cipher.getParameters());
nonce = null;
}
final byte[] encryptedBytes = cipher.doFinal(value.getBytes(PwmConstants.DEFAULT_CHARSET));
final byte[] output;
if (blockAlgorithm.getHmacAlgorithm() != null) {
final byte[] hashChecksum = computeHmacToBytes(blockAlgorithm.getHmacAlgorithm(), key, encryptedBytes);
output = appendByteArrays(blockAlgorithm.getPrefix(), hashChecksum, encryptedBytes);
} else {
if (nonce == null) {
output = appendByteArrays(blockAlgorithm.getPrefix(), encryptedBytes);
} else {
final byte[] nonceLength = new byte[1];
nonceLength[0] = (byte)nonce.length;
output = appendByteArrays(blockAlgorithm.getPrefix(), nonceLength, nonce, encryptedBytes);
}
}
return output;
} catch (Exception e) {
final String errorMsg = "unexpected error performing simple crypt operation: " + e.getMessage();
final ErrorInformation errorInformation = new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, errorMsg);
LOGGER.error(errorInformation.toDebugStr());
throw new PwmUnrecoverableException(errorInformation);
}
}
public static String decryptStringValue(
final String value,
final PwmSecurityKey key,
final PwmBlockAlgorithm blockAlgorithm,
final Flag... flags
)
throws PwmUnrecoverableException {
try {
if (value == null || value.length() < 1) {
return "";
}
final byte[] decoded = Arrays.asList(flags).contains(Flag.URL_SAFE)
? StringUtil.base64Decode(value, StringUtil.Base64Options.URL_SAFE, StringUtil.Base64Options.GZIP)
: StringUtil.base64Decode(value);
return decryptBytes(decoded, key, blockAlgorithm);
} catch (Exception e) {
final String errorMsg = "unexpected error performing simple decrypt operation: " + e.getMessage();
final ErrorInformation errorInformation = new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, errorMsg);
throw new PwmUnrecoverableException(errorInformation);
}
}
public static String decryptBytes(
final byte[] value,
final PwmSecurityKey key,
final PwmBlockAlgorithm blockAlgorithm
)
throws PwmUnrecoverableException {
try {
if (value == null || value.length < 1) {
return null;
}
byte[] workingValue = verifyAndStripPrefix(blockAlgorithm, value);
final SecretKey aesKey = key.getKey(blockAlgorithm.getBlockKey());
if (blockAlgorithm.getHmacAlgorithm() != null) {
final HmacAlgorithm hmacAlgorithm = blockAlgorithm.getHmacAlgorithm();
final int CHECKSUM_SIZE = hmacAlgorithm.getLength();
if (workingValue.length <= CHECKSUM_SIZE) {
throw new PwmUnrecoverableException(new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, "incoming " + blockAlgorithm.toString() + " data is missing checksum"));
}
final byte[] inputChecksum = Arrays.copyOfRange(workingValue, 0, CHECKSUM_SIZE);
final byte[] inputPayload = Arrays.copyOfRange(workingValue, CHECKSUM_SIZE, workingValue.length);
final byte[] computedChecksum = computeHmacToBytes(hmacAlgorithm, key, inputPayload);
if (!Arrays.equals(inputChecksum, computedChecksum)) {
throw new PwmUnrecoverableException(new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, "incoming " + blockAlgorithm.toString() + " data has incorrect checksum"));
}
workingValue = inputPayload;
}
final Cipher cipher;
if (blockAlgorithm == PwmBlockAlgorithm.AES128_GCM) {
final int nonceLength = workingValue[0];
workingValue = Arrays.copyOfRange(workingValue,1,workingValue.length);
if (workingValue.length <= nonceLength) {
throw new PwmUnrecoverableException(new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, "incoming " + blockAlgorithm.toString() + " data is missing nonce"));
}
final byte[] nonce = Arrays.copyOfRange(workingValue, 0, nonceLength);
workingValue = Arrays.copyOfRange(workingValue, nonceLength, workingValue.length);
final GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH * 8, nonce);
cipher = Cipher.getInstance(blockAlgorithm.getAlgName());
cipher.init(Cipher.DECRYPT_MODE, aesKey, spec);
} else {
cipher = Cipher.getInstance(blockAlgorithm.getAlgName());
cipher.init(Cipher.DECRYPT_MODE, aesKey);
}
final byte[] decrypted = cipher.doFinal(workingValue);
return new String(decrypted, PwmConstants.DEFAULT_CHARSET);
} catch (Exception e) {
final String errorMsg = "unexpected error performing simple decrypt operation: " + e.getMessage();
final ErrorInformation errorInformation = new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, errorMsg);
throw new PwmUnrecoverableException(errorInformation);
}
}
public static String hash(
final byte[] input,
final PwmHashAlgorithm algorithm
)
throws PwmUnrecoverableException {
if (input == null || input.length < 1) {
return null;
}
return hash(new ByteArrayInputStream(input), algorithm);
}
public static String hash(
final File file,
final PwmHashAlgorithm hashAlgorithm
)
throws IOException, PwmUnrecoverableException
{
FileInputStream fileInputStream = null;
try {
final MessageDigest messageDigest = MessageDigest.getInstance(hashAlgorithm.getAlgName());
fileInputStream = new FileInputStream(file);
final FileChannel fileChannel = fileInputStream.getChannel();
final ByteBuffer byteBuffer = ByteBuffer.allocateDirect(1024 * 8);
while (fileChannel.read(byteBuffer) > 0) {
byteBuffer.flip();
messageDigest.update(byteBuffer);
byteBuffer.clear();
}
return JavaHelper.byteArrayToHexString(messageDigest.digest());
} catch (NoSuchAlgorithmException | IOException e) {
final String errorMsg = "unexpected error during file hash operation: " + e.getMessage();
final ErrorInformation errorInformation = new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, errorMsg);
throw new PwmUnrecoverableException(errorInformation);
} finally {
IOUtils.closeQuietly(fileInputStream);
}
}
public static String hash(
final String input,
final PwmHashAlgorithm algorithm
)
throws PwmUnrecoverableException {
if (input == null || input.length() < 1) {
return null;
}
return hash(new ByteArrayInputStream(input.getBytes(PwmConstants.DEFAULT_CHARSET)), algorithm);
}
public static String hash(
final InputStream is,
final PwmHashAlgorithm algorithm
)
throws PwmUnrecoverableException {
return JavaHelper.byteArrayToHexString(computeHashToBytes(is, algorithm));
}
private static byte[] computeHmacToBytes(
final HmacAlgorithm hmacAlgorithm,
final PwmSecurityKey pwmSecurityKey,
final byte[] input
)
throws PwmUnrecoverableException
{
try {
final Mac mac = Mac.getInstance(hmacAlgorithm.getAlgorithmName());
final SecretKey secret_key = pwmSecurityKey.getKey(hmacAlgorithm.getKeyType());
mac.init(secret_key);
return mac.doFinal(input);
} catch (GeneralSecurityException e) {
final String errorMsg = "error during hmac operation: " + e.getMessage();
final ErrorInformation errorInformation = new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, errorMsg);
throw new PwmUnrecoverableException(errorInformation);
}
}
public static byte[] computeHashToBytes(
final InputStream is,
final PwmHashAlgorithm algorithm
)
throws PwmUnrecoverableException {
final InputStream bis = is instanceof BufferedInputStream ? is : new BufferedInputStream(is);
final MessageDigest messageDigest;
try {
messageDigest = MessageDigest.getInstance(algorithm.getAlgName());
} catch (NoSuchAlgorithmException e) {
final String errorMsg = "missing hash algorithm: " + e.getMessage();
final ErrorInformation errorInformation = new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, errorMsg);
throw new PwmUnrecoverableException(errorInformation);
}
try {
final byte[] buffer = new byte[HASH_BUFFER_SIZE];
int length;
while (true) {
length = bis.read(buffer, 0, buffer.length);
if (length == -1) {
break;
}
messageDigest.update(buffer, 0, length);
}
bis.close();
return messageDigest.digest();
} catch (IOException e) {
final String errorMsg = "unexpected error during hash operation: " + e.getMessage();
final ErrorInformation errorInformation = new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, errorMsg);
throw new PwmUnrecoverableException(errorInformation);
}
}
private static byte[] appendByteArrays(final byte[]... input) {
if (input == null || input.length == 0) {
return new byte[0];
}
if (input.length == 1) {
return input[0];
}
int totalLength = 0;
for (final byte[] loopBa : input) {
totalLength += loopBa.length;
}
final byte[] output = new byte[totalLength];
int position = 0;
for (final byte[] loopBa : input) {
System.arraycopy(loopBa,0,output,position,loopBa.length);
position += loopBa.length;
}
return output;
}
static byte[] verifyAndStripPrefix(final PwmBlockAlgorithm blockAlgorithm, final byte[] input) throws PwmUnrecoverableException {
final byte[] definedPrefix = blockAlgorithm.getPrefix();
if (definedPrefix.length == 0) {
return input;
}
final byte[] inputPrefix = Arrays.copyOf(input, definedPrefix.length);
if (!Arrays.equals(definedPrefix, inputPrefix)) {
final String errorMsg = "value is missing valid prefix for decryption type";
final ErrorInformation errorInformation = new ErrorInformation(PwmError.ERROR_CRYPT_ERROR, errorMsg);
throw new PwmUnrecoverableException(errorInformation);
}
return Arrays.copyOfRange(input,definedPrefix.length,input.length);
}
static class NonceGenerator {
private final byte[] value;
private final int fixedComponentLength;
NonceGenerator(final int fixedComponentLength, final int counterComponentLength) {
this.fixedComponentLength = fixedComponentLength;
value = new byte[fixedComponentLength + counterComponentLength];
PwmRandom.getInstance().nextBytes(value);
}
public synchronized byte[] nextValue() {
increment(value.length - 1);
return Arrays.copyOf(value, value.length);
}
private void increment(final int index) {
if (value[index] == Byte.MAX_VALUE) {
value[index] = 0;
if(index > fixedComponentLength) {
increment(index - 1);
}
} else {
value[index]++;
}
}
}
public static void benchmark(final Writer outputData) throws PwmUnrecoverableException, IOException {
final int testIterations = 10 * 1000;
final Random random = new Random();
final byte[] noise = new byte[1024 * 10];
final PwmSecurityKey key = new PwmSecurityKey(PwmRandom.getInstance().newBytes(1024));
for (int i = 0; i < 10; i++) {
for (final PwmBlockAlgorithm alg : PwmBlockAlgorithm.values()) {
final Instant startTime = Instant.now();
for (int j = 0; j < testIterations; j++) {
random.nextBytes(noise);
SecureEngine.encryptToString(
JavaHelper.binaryArrayToHex(noise),
key,
alg
);
}
final TimeDuration executionDuration = TimeDuration.fromCurrent(startTime);
outputData.write("processed " + testIterations + " iterations using "
+ alg.toString() + " (" + alg.getLabel() + ") in "
+ executionDuration.getTotalMilliseconds() + "ms");
outputData.write("\n");
}
}
}
}