/**
* diqube: Distributed Query Base.
*
* Copyright (C) 2015 Bastian Gloeckle
*
* This file is part of diqube.
*
* diqube is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.diqube.connection.integrity;
import java.nio.ByteBuffer;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import org.apache.thrift.TException;
import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolDecorator;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TTransport;
import org.diqube.thrift.util.RememberingTransport;
import org.diqube.util.BouncyCastleUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A {@link TProtocol} which writes HMACs to all messages in order to validate them again when reading a message.
*
* <p>
* The transport of the protocol that this protocol is based on needs to be a {@link RememberingTransport}!
*
* <p>
* This protocol will throw a {@link IntegrityViolatedException} when trying to read a message whose integrity is
* violated.
*
* @author Bastian Gloeckle
*/
public class IntegrityCheckingProtocol extends TProtocolDecorator {
private static final Logger logger = LoggerFactory.getLogger(IntegrityCheckingProtocol.class);
private static ThreadLocal<Boolean> integrityCheckDisabled = new ThreadLocal<>();
private RememberingTransport transport;
private Mac[] mac;
private TProtocol protocol;
/**
* Create the integrity validating protocol.
*
* @param protocol
* The protocol this protocol should be based upon.
* @param macKeys
* Secret keys to use for the HMAC algorithm. There needs to be at least 1. The first key will be used to
* sign all messages the are written by this protocol. Whn messages are read, a signature of any of the given
* keys will count the message as "valid".
*/
public IntegrityCheckingProtocol(TProtocol protocol, byte[]... macKeys) {
super(protocol);
if (!(protocol.getTransport() instanceof RememberingTransport))
throw new IllegalArgumentException("The transport needs to be a " + RememberingTransport.class.getSimpleName());
transport = (RememberingTransport) protocol.getTransport();
this.protocol = protocol;
if (macKeys.length == 0)
throw new IllegalArgumentException("Need at least one macKey!");
try {
mac = new Mac[macKeys.length];
for (int i = 0; i < macKeys.length; i++) {
mac[i] = Mac.getInstance("HmacSHA256", BouncyCastleUtil.getProvider());
mac[i].init(new SecretKeySpec(macKeys[i], "HmacSHA256"));
}
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new IllegalStateException("Could not find HMAC algorithm implementation or could not initialize it.", e);
}
}
@Override
public void writeMessageBegin(TMessage tMessage) throws TException {
Boolean disabled = integrityCheckDisabled.get();
if (disabled == null || !disabled)
transport.startRemeberingWriteBytes();
super.writeMessageBegin(tMessage);
}
@Override
public void writeMessageEnd() throws TException {
super.writeMessageEnd();
Boolean disabled = integrityCheckDisabled.get();
if (disabled == null || !disabled) {
byte[] msgData = transport.stopRememberingWriteBytes();
// logger.trace("Calculating integrity for message {}", msgData);
byte[] integrityData = mac[0].doFinal(msgData);
protocol.writeBinary(ByteBuffer.wrap(integrityData));
}
}
@Override
public TMessage readMessageBegin() throws TException {
Boolean disabled = integrityCheckDisabled.get();
if (disabled == null || !disabled)
transport.startRemeberingReadBytes();
return super.readMessageBegin();
}
@Override
public void readMessageEnd() throws TException {
super.readMessageEnd();
Boolean disabled = integrityCheckDisabled.get();
if (disabled == null || !disabled) {
byte[] msgData = transport.stopRememberingReadBytes();
ByteBuffer integrityBuffer = protocol.readBinary();
byte[] integrityDataActual = new byte[integrityBuffer.remaining()];
integrityBuffer.get(integrityDataActual);
// logger.trace("Validating integrity of message: {}", msgData);
// logger.trace("Integrity data provided: {}", integrityDataActual);
for (int i = 0; i < mac.length; i++) {
byte[] integrityDataExpected = mac[i].doFinal(msgData);
// logger.trace("Calculated possible valid integrity data: {}", integrityDataExpected);
if (Arrays.equals(integrityDataActual, integrityDataExpected))
return;
}
logger.error("Received a message with violated integrity!");
throw new IntegrityViolatedException("Integrity of message violated.");
}
}
/**
* Message integrity violated.
*/
public static class IntegrityViolatedException extends TException {
private static final long serialVersionUID = 1L;
public IntegrityViolatedException(String msg) {
super(msg);
}
}
/**
* {@link TProtocolFactory} for {@link IntegrityCheckingProtocol}.
*/
public static class Factory implements TProtocolFactory {
private static final long serialVersionUID = 1L;
private TProtocolFactory delegateFactory;
private byte[][] macKeys;
/**
* Create the integrity validating protocol.
*
* @param delegateFactory
* The factory to be used to create the delegate protocol.
* @param macKeys
* Secret keys to use for the HMAC algorithm. There needs to be at least 1. The first key will be used to
* sign all messages the are written by this protocol. Whn messages are read, a signature of any of the
* given keys will count the message as "valid".
*/
public Factory(TProtocolFactory delegateFactory, byte[]... macKeys) {
this.delegateFactory = delegateFactory;
this.macKeys = macKeys;
}
@Override
public TProtocol getProtocol(TTransport trans) {
if (!(trans instanceof RememberingTransport))
throw new IllegalArgumentException("The transport needs to be a " + RememberingTransport.class.getSimpleName());
TProtocol delegateProtocol = delegateFactory.getProtocol(trans);
return new IntegrityCheckingProtocol(delegateProtocol, macKeys);
}
}
/**
* A {@link TProcessor} that disables integrity checks when reading &writing messages from a
* {@link IntegrityCheckingProtocol}.
*
* <p>
* This leads to the MAC not being calculated, written and read from the input.
*/
public static class IntegrityCheckDisablingProcessor implements TProcessor {
private TProcessor delegate;
public IntegrityCheckDisablingProcessor(TProcessor delegate) {
this.delegate = delegate;
}
@Override
public boolean process(TProtocol in, TProtocol out) throws TException {
integrityCheckDisabled.set(true);
try {
return delegate.process(in, out);
} finally {
integrityCheckDisabled.remove();
}
}
}
}