package com.brightgenerous.commons.crypto;
import static com.brightgenerous.commons.ObjectUtils.*;
import java.io.Serializable;
import java.lang.ref.SoftReference;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.DESKeySpec;
import javax.crypto.spec.DESedeKeySpec;
import javax.crypto.spec.SecretKeySpec;
import com.brightgenerous.commons.EqualsUtils;
import com.brightgenerous.commons.HashCodeUtils;
import com.brightgenerous.commons.ToStringUtils;
import com.brightgenerous.lang.Args;
public class CryptoUtils implements Serializable {
private static final long serialVersionUID = 1285931778321061231L;
static class InstanceKey implements Serializable {
private static final long serialVersionUID = -5571606798438371038L;
private final CryptoAlgorithm algorithm;
private final HashAlgorithm keyAlgorithm;
private final byte[] key;
public InstanceKey(CryptoAlgorithm algorithm, HashAlgorithm keyAlgorithm, byte[] key) {
this.algorithm = algorithm;
this.keyAlgorithm = keyAlgorithm;
this.key = key;
}
@Override
public int hashCode() {
final int multiplier = 37;
int result = 17;
result = (multiplier * result) + hashCodeEscapeNull(algorithm);
result = (multiplier * result) + hashCodeEscapeNull(keyAlgorithm);
result = (multiplier * result) + hashCodeEscapeNull(key);
return result;
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (!(obj instanceof InstanceKey)) {
return false;
}
InstanceKey other = (InstanceKey) obj;
if (!equalsEscapeNull(algorithm, other.algorithm)) {
return false;
}
if (!equalsEscapeNull(keyAlgorithm, other.keyAlgorithm)) {
return false;
}
if (!equalsEscapeNull(key, other.key)) {
return false;
}
return true;
}
}
private final CryptoAlgorithm algorithm;
private final HashAlgorithm keyAlgorithm;
private final byte[] key;
protected CryptoUtils(CryptoAlgorithm algorithm, byte[] key) {
this(algorithm, null, key);
}
protected CryptoUtils(CryptoAlgorithm algorithm, HashAlgorithm keyAlgorithm, byte[] key) {
Args.notNull(algorithm, "algorithm");
Args.notNull(key, "key");
this.algorithm = algorithm;
this.keyAlgorithm = keyAlgorithm;
this.key = key;
}
public static CryptoUtils get(CryptoAlgorithm algorithm, byte[] key) {
return getInstance(algorithm, null, key);
}
public static CryptoUtils get(CryptoAlgorithm algorithm, HashAlgorithm keyAlgorithm, byte[] key) {
return getInstance(algorithm, keyAlgorithm, key);
}
private static volatile Map<InstanceKey, SoftReference<CryptoUtils>> cache;
protected static CryptoUtils getInstance(CryptoAlgorithm algorithm, HashAlgorithm keyAlgorithm,
byte[] key) {
if (cache == null) {
synchronized (CryptoUtils.class) {
if (cache == null) {
cache = new ConcurrentHashMap<>();
}
}
}
InstanceKey ik = new InstanceKey(algorithm, keyAlgorithm, key);
SoftReference<CryptoUtils> sr = cache.get(ik);
CryptoUtils ret;
if (sr != null) {
ret = sr.get();
if (ret != null) {
return ret;
}
Set<InstanceKey> dels = new HashSet<>();
for (Entry<InstanceKey, SoftReference<CryptoUtils>> entry : cache.entrySet()) {
if (entry.getValue().get() == null) {
dels.add(entry.getKey());
}
}
for (InstanceKey del : dels) {
cache.remove(del);
}
}
ret = new CryptoUtils(algorithm, keyAlgorithm, key);
cache.put(ik, new SoftReference<>(ret));
return ret;
}
public byte[] encrypt(byte[] bytes) throws NoSuchAlgorithmException, InvalidKeyException,
InvalidKeySpecException, NoSuchPaddingException, IllegalBlockSizeException,
BadPaddingException {
return encrypt(algorithm, keyAlgorithm, key, bytes);
}
public byte[] decrypt(byte[] bytes) throws NoSuchAlgorithmException, InvalidKeyException,
InvalidKeySpecException, NoSuchPaddingException, IllegalBlockSizeException,
BadPaddingException {
return decrypt(algorithm, keyAlgorithm, key, bytes);
}
public static byte[] encrypt(CryptoAlgorithm algorithm, byte[] key, byte[] bytes)
throws NoSuchAlgorithmException, InvalidKeyException, InvalidKeySpecException,
NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException {
return encrypt(algorithm, null, key, bytes);
}
public static byte[] encrypt(CryptoAlgorithm algorithm, HashAlgorithm keyAlgorithm, byte[] key,
byte[] bytes) throws NoSuchAlgorithmException, InvalidKeyException,
InvalidKeySpecException, NoSuchPaddingException, IllegalBlockSizeException,
BadPaddingException {
Args.notNull(algorithm, "algorithm");
Args.notNull(key, "key");
Args.notNull(bytes, "bytes");
SecretKey secretKey;
switch (algorithm) {
case DES:
secretKey = SecretKeyFactory.getInstance(algorithm.value()).generateSecret(
new DESKeySpec(getPssKey(algorithm, keyAlgorithm, key)));
break;
case DESEDE:
secretKey = SecretKeyFactory.getInstance(algorithm.value()).generateSecret(
new DESedeKeySpec(getPssKey(algorithm, keyAlgorithm, key)));
break;
case AES:
secretKey = new SecretKeySpec(getPssKey(algorithm, keyAlgorithm, key),
algorithm.value());
break;
default:
throw new IllegalStateException();
}
Cipher cipher = Cipher.getInstance(algorithm.value());
cipher.init(Cipher.ENCRYPT_MODE, secretKey);
return cipher.doFinal(bytes);
}
public static byte[] decrypt(CryptoAlgorithm algorithm, byte[] key, byte[] bytes)
throws NoSuchAlgorithmException, InvalidKeyException, InvalidKeySpecException,
NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException {
return decrypt(algorithm, null, key, bytes);
}
public static byte[] decrypt(CryptoAlgorithm algorithm, HashAlgorithm keyAlgorithm, byte[] key,
byte[] bytes) throws NoSuchAlgorithmException, InvalidKeyException,
InvalidKeySpecException, NoSuchPaddingException, IllegalBlockSizeException,
BadPaddingException {
Args.notNull(algorithm, "algorithm");
Args.notNull(key, "key");
Args.notNull(bytes, "bytes");
SecretKey secretKey;
switch (algorithm) {
case DES:
secretKey = SecretKeyFactory.getInstance(algorithm.value()).generateSecret(
new DESKeySpec(getPssKey(algorithm, keyAlgorithm, key)));
break;
case DESEDE:
secretKey = SecretKeyFactory.getInstance(algorithm.value()).generateSecret(
new DESedeKeySpec(getPssKey(algorithm, keyAlgorithm, key)));
break;
case AES:
secretKey = new SecretKeySpec(getPssKey(algorithm, keyAlgorithm, key),
algorithm.value());
break;
default:
throw new IllegalStateException();
}
Cipher cipher = Cipher.getInstance(algorithm.value());
cipher.init(Cipher.DECRYPT_MODE, secretKey);
return cipher.doFinal(bytes);
}
private static byte[] getPssKey(CryptoAlgorithm algorithm, HashAlgorithm keyAlgorithm,
byte[] key) throws NoSuchAlgorithmException {
byte[] bytes = (keyAlgorithm != null) ? getHash(keyAlgorithm, key) : key;
switch (algorithm) {
case DES:
return Arrays.copyOf(bytes, 8);
case DESEDE:
return Arrays.copyOf(bytes, 24);
case AES:
return Arrays.copyOf(bytes, 16);
default:
throw new IllegalStateException();
}
}
public static byte[] getHash(HashAlgorithm algorithm, byte[] bytes)
throws NoSuchAlgorithmException {
Args.notNull(algorithm, "algorithm");
Args.notNull(bytes, "bytes");
byte[] ret;
switch (algorithm) {
case MD2:
case MD5:
case SHA:
case SHA256:
case SHA384:
case SHA512:
MessageDigest md = MessageDigest.getInstance(algorithm.value());
md.update(bytes);
ret = md.digest();
break;
default:
throw new IllegalStateException();
}
return ret;
}
@Override
public int hashCode() {
if (HashCodeUtils.resolved()) {
return HashCodeUtils.hashCodeAlt(null, this);
}
return super.hashCode();
}
@Override
public boolean equals(Object obj) {
if (EqualsUtils.resolved()) {
return EqualsUtils.equalsAlt(null, this, obj);
}
return super.equals(obj);
}
@Override
public String toString() {
if (ToStringUtils.resolved()) {
return ToStringUtils.toStringAlt(this);
}
return super.toString();
}
}