package im.actor.core.modules.encryption; import java.util.ArrayList; import im.actor.core.entity.encryption.PeerSession; import im.actor.core.modules.ModuleContext; import im.actor.core.modules.encryption.session.EncryptedSessionChain; import im.actor.core.modules.ModuleActor; import im.actor.runtime.*; import im.actor.runtime.actors.ask.AskMessage; import im.actor.runtime.actors.ask.AskResult; import im.actor.runtime.function.Consumer; import im.actor.runtime.function.Function; import im.actor.runtime.promise.Promise; import im.actor.runtime.crypto.Curve25519; import im.actor.runtime.crypto.IntegrityException; import im.actor.runtime.crypto.primitives.util.ByteStrings; import static im.actor.runtime.promise.Promise.success; /** * Axolotl Ratchet encryption session * Session is identified by: * 1) Destination User's Id * 2) Own Key Group Id * 3) Own Pre Key Id * 4) Their Key Group Id * 5) Their Pre Key Id * <p/> * During actor starting it downloads all required key from Key Manager. * To encrypt/decrypt messages this actor spawns encryption chains. */ public class EncryptedSessionActor extends ModuleActor { private final String TAG; // No need to keep too much decryption chains as all messages are sequenced. Newer messages // never intentionally use old keys, but there are cases when some messages can be sent with // old encryption keys right after messages with new one. Even when we will kill sequence // new actors can be easily started again with same keys. // TODO: Check if this can cause race condition private final int MAX_DECRYPT_CHAINS = 2; // // Key References // private final int uid; private final PeerSession session; // // Key Manager reference // private KeyManagerInt keyManager; // // Temp encryption chains // private byte[] latestTheirEphemeralKey; private ArrayList<EncryptedSessionChain> encryptionChains = new ArrayList<>(); private ArrayList<EncryptedSessionChain> decryptionChains = new ArrayList<>(); // // Constructors and Methods // public EncryptedSessionActor(ModuleContext context, PeerSession session) { super(context); this.TAG = "EncryptionSessionActor#" + session.getUid() + "_" + session.getTheirKeyGroupId(); this.uid = session.getUid(); this.session = session; } @Override public void preStart() { super.preStart(); keyManager = context().getEncryption().getKeyManagerInt(); } private Promise<EncryptedPackageRes> onEncrypt(final byte[] data) { // // Stage 1: Pick Their Ephemeral key. Use already received or pick random pre key. // Stage 2: Pick Encryption Chain // Stage 3: Decrypt // return success(latestTheirEphemeralKey) .mapIfNullPromise(keyManager.supplyUserPreKey(uid, session.getTheirKeyGroupId())) .map(new Function<byte[], EncryptedSessionChain>() { @Override public EncryptedSessionChain apply(byte[] publicKey) { return pickEncryptChain(publicKey); } }) .map(new Function<EncryptedSessionChain, EncryptedPackageRes>() { @Override public EncryptedPackageRes apply(EncryptedSessionChain encryptedSessionChain) { return encrypt(encryptedSessionChain, data); } }); } private Promise<DecryptedPackage> onDecrypt(final byte[] data) { // // Stage 1: Parsing message header // Stage 2: Picking decryption chain // Stage 3: Decryption of message // Stage 4: Saving their ephemeral key // // final int ownKeyGroupId = ByteStrings.bytesToInt(data, 0); // final long ownEphemeralKey0Id = ByteStrings.bytesToLong(data, 4); // final long theirEphemeralKey0Id = ByteStrings.bytesToLong(data, 12); final byte[] senderEphemeralKey = ByteStrings.substring(data, 20, 32); final byte[] receiverEphemeralKey = ByteStrings.substring(data, 52, 32); Log.d(TAG, "Sender Ephemeral " + Crypto.keyHash(senderEphemeralKey)); Log.d(TAG, "Receiver Ephemeral " + Crypto.keyHash(receiverEphemeralKey)); return pickDecryptChain(senderEphemeralKey, receiverEphemeralKey) .map(new Function<EncryptedSessionChain, DecryptedPackage>() { @Override public DecryptedPackage apply(EncryptedSessionChain encryptedSessionChain) { return decrypt(encryptedSessionChain, data); } }) .then(new Consumer<DecryptedPackage>() { @Override public void apply(DecryptedPackage decryptedPackage) { Log.d(TAG, "onDecrypted"); latestTheirEphemeralKey = senderEphemeralKey; } }) .failure(new Consumer<Exception>() { @Override public void apply(Exception e) { Log.d(TAG, "onError"); } }); } private EncryptedSessionChain pickEncryptChain(byte[] ephemeralKey) { if (latestTheirEphemeralKey == null) { latestTheirEphemeralKey = ephemeralKey; } if (encryptionChains.size() > 0) { return encryptionChains.get(0); } EncryptedSessionChain chain = new EncryptedSessionChain(session, Curve25519.keyGenPrivate(Crypto.randomBytes(32)), ephemeralKey); encryptionChains.add(0, chain); return chain; } private EncryptedPackageRes encrypt(EncryptedSessionChain chain, byte[] data) { byte[] encrypted; try { encrypted = chain.encrypt(data); } catch (IntegrityException e) { e.printStackTrace(); throw new RuntimeException(e); } Log.d(TAG, "!Sender Ephemeral " + Crypto.keyHash(Curve25519.keyGenPublic(chain.getOwnPrivateKey()))); Log.d(TAG, "!Receiver Ephemeral " + Crypto.keyHash(chain.getTheirPublicKey())); return new EncryptedPackageRes(encrypted, session.getTheirKeyGroupId()); } private Promise<EncryptedSessionChain> pickDecryptChain(final byte[] theirEphemeralKey, final byte[] ephemeralKey) { EncryptedSessionChain pickedChain = null; for (EncryptedSessionChain c : decryptionChains) { if (ByteStrings.isEquals(Curve25519.keyGenPublic(c.getOwnPrivateKey()), ephemeralKey)) { pickedChain = c; break; } } return success(pickedChain) .flatMap(new Function<EncryptedSessionChain, Promise<EncryptedSessionChain>>() { @Override public Promise<EncryptedSessionChain> apply(EncryptedSessionChain src) { if (src != null) { return success(src); } // TODO: Implement! return null; // return ask(context().getEncryption().getKeyManager(), new FetchOwnPreKeyByPublic(ephemeralKey)) // .map(new Function<PrivateKey, EncryptedSessionChain>() { // @Override // public EncryptedSessionChain apply(PrivateKey src) { // EncryptedSessionChain chain = new EncryptedSessionChain(session, src.getKey(), theirEphemeralKey); // decryptionChains.add(0, chain); // if (decryptionChains.size() > MAX_DECRYPT_CHAINS) { // decryptionChains.remove(MAX_DECRYPT_CHAINS) // .safeErase(); // } // return chain; // } // }); } }); } private DecryptedPackage decrypt(EncryptedSessionChain chain, byte[] data) { byte[] decrypted; try { decrypted = chain.decrypt(data); } catch (IntegrityException e) { e.printStackTrace(); throw new RuntimeException(e); } return new DecryptedPackage(decrypted); } // // Actor Messages // @Override public Promise onAsk(Object message) throws Exception { if (message instanceof EncryptPackage) { return onEncrypt(((EncryptPackage) message).getData()); } else if (message instanceof DecryptPackage) { DecryptPackage decryptPackage = (DecryptPackage) message; return onDecrypt(decryptPackage.getData()); } else { return super.onAsk(message); } } public static class EncryptPackage implements AskMessage<EncryptedPackageRes> { private byte[] data; public EncryptPackage(byte[] data) { this.data = data; } public byte[] getData() { return data; } } public static class EncryptedPackageRes extends AskResult { private byte[] data; private int keyGroupId; public EncryptedPackageRes(byte[] data, int keyGroupId) { this.data = data; this.keyGroupId = keyGroupId; } public byte[] getData() { return data; } public int getKeyGroupId() { return keyGroupId; } } public static class DecryptPackage implements AskMessage<DecryptedPackage> { private byte[] data; public DecryptPackage(byte[] data) { this.data = data; } public byte[] getData() { return data; } } public static class DecryptedPackage extends AskResult { private byte[] data; public DecryptedPackage(byte[] data) { this.data = data; } public byte[] getData() { return data; } } }