package org.jgroups.protocols;
import org.jgroups.*;
import org.jgroups.annotations.ManagedAttribute;
import org.jgroups.annotations.Property;
import org.jgroups.stack.Protocol;
import org.jgroups.util.*;
import javax.crypto.Cipher;
import java.security.Key;
import java.security.MessageDigest;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.function.BiConsumer;
import java.util.zip.Adler32;
import java.util.zip.CRC32;
import java.util.zip.Checksum;
/**
* Super class of symmetric ({@link SYM_ENCRYPT}) and asymmetric ({@link ASYM_ENCRYPT}) encryption protocols.
* @author Bela Ban
*/
public abstract class Encrypt extends Protocol {
protected static final String DEFAULT_SYM_ALGO="AES";
/* ----------------------------------------- Properties -------------------------------------------------- */
@Property(description="Cryptographic Service Provider")
protected String provider;
@Property(description="Cipher engine transformation for asymmetric algorithm. Default is RSA")
protected String asym_algorithm="RSA";
@Property(description="Cipher engine transformation for symmetric algorithm. Default is AES")
protected String sym_algorithm=DEFAULT_SYM_ALGO;
@Property(description="Initial public/private key length. Default is 512")
protected int asym_keylength=512;
@Property(description="Initial key length for matching symmetric algorithm. Default is 128")
protected int sym_keylength=128;
@Property(description="Number of ciphers in the pool to parallelize encrypt and decrypt requests",writable=false)
protected int cipher_pool_size=8;
@Property(description="If true, the entire message (including payload and headers) is encrypted, else only the payload")
protected boolean encrypt_entire_message=true;
@Property(description="If true, all messages are digitally signed by adding an encrypted checksum of the encrypted " +
"message to the header. Ignored if encrypt_entire_message is false")
protected boolean sign_msgs=true;
@Property(description="When sign_msgs is true, by default CRC32 is used to create the checksum. If use_adler is " +
"true, Adler32 will be used")
protected boolean use_adler;
@Property(description="Max number of keys in key_map")
protected int key_map_max_size=20;
protected volatile Address local_addr;
protected volatile View view;
// Cipher pools used for encryption and decryption. Size is cipher_pool_size
protected BlockingQueue<Cipher> encoding_ciphers, decoding_ciphers;
// version filed for secret key
protected volatile byte[] sym_version;
// shared secret key to encrypt/decrypt messages
protected volatile Key secret_key;
// map to hold previous keys so we can decrypt some earlier messages if we need to
protected Map<AsciiString,Cipher> key_map;
public int asymKeylength() {return asym_keylength;}
public <T extends Encrypt> T asymKeylength(int len) {this.asym_keylength=len; return (T)this;}
public int symKeylength() {return sym_keylength;}
public <T extends Encrypt> T symKeylength(int len) {this.sym_keylength=len; return (T)this;}
public Key secretKey() {return secret_key;}
public <T extends Encrypt> T secretKey(Key key) {this.secret_key=key; return (T)this;}
public String symAlgorithm() {return sym_algorithm;}
public <T extends Encrypt> T symAlgorithm(String alg) {this.sym_algorithm=alg; return (T)this;}
public String asymAlgorithm() {return asym_algorithm;}
public <T extends Encrypt> T asymAlgorithm(String alg) {this.asym_algorithm=alg; return (T)this;}
public byte[] symVersion() {return sym_version;}
public <T extends Encrypt> T symVersion(byte[] v) {this.sym_version=Arrays.copyOf(v, v.length); return (T)this;}
public <T extends Encrypt> T localAddress(Address addr) {this.local_addr=addr; return (T)this;}
public boolean encryptEntireMessage() {return encrypt_entire_message;}
public <T extends Encrypt> T encryptEntireMessage(boolean b) {this.encrypt_entire_message=b; return (T)this;}
public boolean signMessages() {return this.sign_msgs;}
public <T extends Encrypt> T signMessages(boolean flag) {this.sign_msgs=flag; return (T)this;}
public boolean adler() {return use_adler;}
public <T extends Encrypt> T adler(boolean flag) {this.use_adler=flag; return (T)this;}
@ManagedAttribute public String version() {return Util.byteArrayToHexString(sym_version);}
public void init() throws Exception {
int tmp=Util.getNextHigherPowerOfTwo(cipher_pool_size);
if(tmp != cipher_pool_size) {
log.warn("%s: setting cipher_pool_size (%d) to %d (power of 2) for faster modulo operation", local_addr, cipher_pool_size, tmp);
cipher_pool_size=tmp;
}
key_map=new BoundedHashMap<>(key_map_max_size);
encoding_ciphers=new ArrayBlockingQueue<>(cipher_pool_size);
decoding_ciphers=new ArrayBlockingQueue<>(cipher_pool_size);
initSymCiphers(sym_algorithm, secret_key);
}
public Object down(Event evt) {
switch(evt.getType()) {
case Event.VIEW_CHANGE:
handleView(evt.getArg());
break;
case Event.SET_LOCAL_ADDRESS:
local_addr=evt.getArg();
break;
}
return down_prot.down(evt);
}
public Object down(Message msg) {
try {
if(secret_key == null) {
log.trace("%s: discarded %s message to %s as secret key is null, hdrs: %s",
local_addr, msg.dest() == null? "mcast" : "unicast", msg.dest(), msg.printHeaders());
return null;
}
encryptAndSend(msg);
}
catch(Exception e) {
log.warn("%s: unable to send message down", local_addr, e);
}
return null;
}
public Object up(Event evt) {
switch(evt.getType()) {
case Event.VIEW_CHANGE:
handleView(evt.getArg());
break;
}
return up_prot.up(evt);
}
public Object up(Message msg) {
try {
return handleUpMessage(msg);
}
catch(Exception e) {
log.warn("%s: exception occurred decrypting message", local_addr, e);
}
return null;
}
public void up(MessageBatch batch) {
Cipher cipher=null;
try {
if(secret_key == null) {
log.trace("%s: discarded %s batch from %s as secret key is null",
local_addr, batch.dest() == null? "mcast" : "unicast", batch.sender());
return;
}
BiConsumer<Message,MessageBatch> decrypter=new Decrypter(cipher=decoding_ciphers.take());
batch.forEach(decrypter);
}
catch(InterruptedException e) {
log.error("%s: failed processing batch; discarding batch", local_addr, e);
// we need to drop the batch if we for example have a failure fetching a cipher, or else other messages
// in the batch might make it up the stack, bypassing decryption! This is not an issue because encryption
// is below NAKACK2 or UNICAST3, so messages will get retransmitted
return;
}
finally {
if(cipher != null)
decoding_ciphers.offer(cipher);
}
if(!batch.isEmpty())
up_prot.up(batch);
}
/** Initialises the ciphers for both encryption and decryption using the generated or supplied secret key */
protected synchronized void initSymCiphers(String algorithm, Key secret) throws Exception {
if(secret == null)
return;
encoding_ciphers.clear();
decoding_ciphers.clear();
for(int i=0; i < cipher_pool_size; i++ ) {
encoding_ciphers.offer(createCipher(Cipher.ENCRYPT_MODE, secret, algorithm));
decoding_ciphers.offer(createCipher(Cipher.DECRYPT_MODE, secret, algorithm));
};
//set the version
MessageDigest digest=MessageDigest.getInstance("MD5");
digest.reset();
digest.update(secret.getEncoded());
byte[] tmp=digest.digest();
sym_version=Arrays.copyOf(tmp, tmp.length);
log.debug("%s: created %d symmetric ciphers with secret key (%d bytes)", local_addr, cipher_pool_size, sym_version.length);
}
protected Cipher createCipher(int mode, Key secret_key, String algorithm) throws Exception {
Cipher cipher=provider != null && !provider.trim().isEmpty()?
Cipher.getInstance(algorithm, provider) : Cipher.getInstance(algorithm);
cipher.init(mode, secret_key);
return cipher;
}
protected Object handleUpMessage(Message msg) throws Exception {
EncryptHeader hdr=msg.getHeader(this.id);
if(hdr == null) {
log.error("%s: received message without encrypt header from %s; dropping it", local_addr, msg.src());
return null;
}
switch(hdr.type()) {
case EncryptHeader.ENCRYPT:
return handleEncryptedMessage(msg);
default:
return handleUpEvent(msg,hdr);
}
}
protected Object handleEncryptedMessage(Message msg) throws Exception {
if(!process(msg))
return null;
// try and decrypt the message - we need to copy msg as we modify its
// buffer (http://jira.jboss.com/jira/browse/JGRP-538)
Message tmpMsg=decryptMessage(null, msg.copy()); // need to copy for possible xmits
if(tmpMsg != null)
return up_prot.up(tmpMsg);
log.warn("%s: unrecognized cipher; discarding message from %s", local_addr, msg.src());
return null;
}
protected Object handleUpEvent(Message msg, EncryptHeader hdr) {
return null;
}
/** Whether or not to process this received message */
protected boolean process(Message msg) {return true;}
protected void handleView(View view) {
this.view=view;
}
protected boolean inView(Address sender, String error_msg) {
View curr_view=this.view;
if(curr_view == null || curr_view.containsMember(sender))
return true;
log.error(error_msg, sender, curr_view);
return false;
}
protected Checksum createChecksummer() {return use_adler? new Adler32() : new CRC32();}
/** Does the actual work for decrypting - if version does not match current cipher then tries the previous cipher */
protected Message decryptMessage(Cipher cipher, Message msg) throws Exception {
EncryptHeader hdr=msg.getHeader(this.id);
if(!Arrays.equals(hdr.version(), sym_version)) {
cipher=key_map.get(new AsciiString(hdr.version()));
if(cipher == null) {
handleUnknownVersion();
return null;
}
log.trace("%s: decrypting msg from %s using previous cipher version", local_addr, msg.src());
return _decrypt(cipher, msg, hdr);
}
return _decrypt(cipher, msg, hdr);
}
protected Message _decrypt(final Cipher cipher, Message msg, EncryptHeader hdr) throws Exception {
byte[] decrypted_msg;
if(!encrypt_entire_message && msg.getLength() == 0)
return msg;
if(encrypt_entire_message && sign_msgs) {
byte[] signature=hdr.signature();
if(signature == null) {
log.error("%s: dropped message from %s as the header did not have a checksum", local_addr, msg.src());
return null;
}
long msg_checksum=decryptChecksum(cipher, signature, 0, signature.length);
long actual_checksum=computeChecksum(msg.getRawBuffer(), msg.getOffset(), msg.getLength());
if(actual_checksum != msg_checksum) {
log.error("%s: dropped message from %s as the message's checksum (%d) did not match the computed checksum (%d)",
local_addr, msg.src(), msg_checksum, actual_checksum);
return null;
}
}
if(cipher == null)
decrypted_msg=code(msg.getRawBuffer(), msg.getOffset(), msg.getLength(), true);
else
decrypted_msg=cipher.doFinal(msg.getRawBuffer(), msg.getOffset(), msg.getLength());
if(!encrypt_entire_message) {
msg.setBuffer(decrypted_msg);
return msg;
}
Message ret=Util.streamableFromBuffer(Message.class,decrypted_msg,0,decrypted_msg.length);
if(ret.getDest() == null)
ret.setDest(msg.getDest());
if(ret.getSrc() == null)
ret.setSrc(msg.getSrc());
return ret;
}
protected void encryptAndSend(Message msg) throws Exception {
EncryptHeader hdr=new EncryptHeader(EncryptHeader.ENCRYPT, symVersion());
if(encrypt_entire_message) {
if(msg.getSrc() == null)
msg.setSrc(local_addr);
Buffer serialized_msg=Util.streamableToBuffer(msg);
byte[] encrypted_msg=code(serialized_msg.getBuf(),serialized_msg.getOffset(),serialized_msg.getLength(),false);
if(sign_msgs) {
long checksum=computeChecksum(encrypted_msg, 0, encrypted_msg.length);
byte[] checksum_array=encryptChecksum(checksum);
hdr.signature(checksum_array);
}
// exclude existing headers, they will be seen again when we decrypt and unmarshal the msg at the receiver
Message tmp=msg.copy(false, false).setBuffer(encrypted_msg).putHeader(this.id,hdr);
down_prot.down(tmp);
return;
}
// copy neeeded because same message (object) may be retransmitted -> prevent double encryption
Message msgEncrypted=msg.copy(false).putHeader(this.id, hdr);
if(msg.getLength() > 0)
msgEncrypted.setBuffer(code(msg.getRawBuffer(),msg.getOffset(),msg.getLength(),false));
else { // length is 0
byte[] payload=msg.getRawBuffer();
if(payload != null) // we don't encrypt empty buffers (https://issues.jboss.org/browse/JGRP-2153)
msgEncrypted.setBuffer(payload, msg.getOffset(), msg.getLength());
}
down_prot.down(msgEncrypted);
}
protected byte[] code(byte[] buf, int offset, int length, boolean decode) throws Exception {
BlockingQueue<Cipher> queue=decode? decoding_ciphers : encoding_ciphers;
Cipher cipher=queue.take();
try {
return cipher.doFinal(buf, offset, length);
}
finally {
queue.offer(cipher);
}
}
protected long computeChecksum(byte[] input, int offset, int length) {
Checksum checksummer=createChecksummer();
checksummer.update(input, offset, length);
return checksummer.getValue();
}
protected byte[] encryptChecksum(long checksum) throws Exception {
byte[] checksum_array=new byte[Global.LONG_SIZE];
Bits.writeLong(checksum, checksum_array, 0);
return code(checksum_array, 0, checksum_array.length, false);
}
protected long decryptChecksum(final Cipher cipher, byte[] input, int offset, int length) throws Exception {
byte[] decrypted_checksum;
if(cipher == null)
decrypted_checksum=code(input, offset, length, true);
else
decrypted_checksum=cipher.doFinal(input, offset, length);
return Bits.readLong(decrypted_checksum, 0);
}
/* Get the algorithm name from "algorithm/mode/padding" taken from original ENCRYPT */
protected static String getAlgorithm(String s) {
int index=s.indexOf('/');
return index == -1? s : s.substring(0, index);
}
/** Called when the version shipped in the header can't be found */
protected void handleUnknownVersion() {}
/** Decrypts all messages in a batch, replacing encrypted messages in-place with their decrypted versions */
protected class Decrypter implements BiConsumer<Message,MessageBatch> {
protected final Cipher cipher;
public Decrypter(Cipher cipher) {
this.cipher=cipher;
}
public void accept(Message msg, MessageBatch batch) {
EncryptHeader hdr;
if((hdr=msg.getHeader(id)) == null) {
log.error("%s: received message without encrypt header from %s; dropping it", local_addr, batch.sender());
batch.remove(msg); // remove from batch to prevent passing the message further up as part of the batch
return;
}
if(hdr.type() == EncryptHeader.ENCRYPT) {
try {
if(!process(msg)) {
batch.remove(msg);
return;
}
Message tmpMsg=decryptMessage(cipher, msg.copy()); // need to copy for possible xmits
if(tmpMsg != null)
batch.replace(msg, tmpMsg);
else
batch.remove(msg);
}
catch(Exception e) {
log.error("%s: failed decrypting message from %s (offset=%d, length=%d, buf.length=%d): %s, headers are %s",
local_addr, msg.getSrc(), msg.getOffset(), msg.getLength(), msg.getRawBuffer().length, e, msg.printHeaders());
batch.remove(msg);
}
}
else {
batch.remove(msg); // a control message will get handled by ENCRYPT and should not be passed up
handleUpEvent(msg, hdr);
}
}
}
}