/*
* Copyright 2014-2016 CyberVision, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.kaaproject.kaa.server.bootstrap.service.transport;
import org.kaaproject.kaa.common.endpoint.security.MessageEncoderDecoder;
import org.kaaproject.kaa.server.bootstrap.service.OperationsServerListService;
import org.kaaproject.kaa.server.bootstrap.service.security.KeyStoreService;
import org.kaaproject.kaa.server.sync.ClientSync;
import org.kaaproject.kaa.server.sync.ServerSync;
import org.kaaproject.kaa.server.sync.SyncStatus;
import org.kaaproject.kaa.server.sync.bootstrap.BootstrapClientSync;
import org.kaaproject.kaa.server.sync.bootstrap.BootstrapServerSync;
import org.kaaproject.kaa.server.sync.bootstrap.ProtocolConnectionData;
import org.kaaproject.kaa.server.sync.platform.PlatformEncDec;
import org.kaaproject.kaa.server.sync.platform.PlatformEncDecException;
import org.kaaproject.kaa.server.sync.platform.PlatformLookup;
import org.kaaproject.kaa.server.transport.AbstractTransportService;
import org.kaaproject.kaa.server.transport.TransportService;
import org.kaaproject.kaa.server.transport.channel.ChannelContext;
import org.kaaproject.kaa.server.transport.message.ErrorBuilder;
import org.kaaproject.kaa.server.transport.message.MessageBuilder;
import org.kaaproject.kaa.server.transport.message.MessageHandler;
import org.kaaproject.kaa.server.transport.message.SessionInitMessage;
import org.kaaproject.kaa.server.transport.session.SessionAware;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.security.GeneralSecurityException;
import java.security.KeyPair;
import java.security.PublicKey;
import java.text.MessageFormat;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* Responsible for initialization and management of transport instances.
*
* @author Andrew Shvayka
*/
@Service
public class BootstrapTransportService
extends AbstractTransportService
implements TransportService {
/**
* Constant LOG.
*/
private static final Logger LOG = LoggerFactory.getLogger(BootstrapTransportService.class);
private static final int DEFAULT_THREAD_POOL_SIZE = 1;
private static final String TRANSPORT_CONFIG_PREFIX = "bootstrap";
@Value("#{properties[worker_thread_pool]}")
private int threadPoolSize = DEFAULT_THREAD_POOL_SIZE;
@Value("#{properties[support_unencrypted_connection]}")
private boolean supportUnencryptedConnection;
@Autowired
private OperationsServerListService operationsServerListService;
@Autowired
private KeyStoreService bootstrapKeyStoreService;
@Autowired
private Properties properties;
private BootstrapMessageHandler handler;
public BootstrapTransportService() {
super();
}
@Override
protected String getTransportConfigPrefix() {
return TRANSPORT_CONFIG_PREFIX;
}
@Override
protected Properties getServiceProperties() {
return properties;
}
@Override
public void lookupAndInit() {
LOG.info("Lookup platform protocols");
Set<String> platformProtocols = PlatformLookup.lookupPlatformProtocols(
PlatformLookup.DEFAULT_PROTOCOL_LOOKUP_PACKAGE_NAME);
LOG.info("Initializing message handler with {} worker threads", threadPoolSize);
handler = new BootstrapMessageHandler(
operationsServerListService,
Executors.newFixedThreadPool(threadPoolSize),
platformProtocols,
new KeyPair(
bootstrapKeyStoreService.getPublicKey(),
bootstrapKeyStoreService.getPrivateKey()),
supportUnencryptedConnection);
super.lookupAndInit();
}
@Override
protected MessageHandler getMessageHandler() {
return handler;
}
@Override
protected PublicKey getPublicKey() {
return bootstrapKeyStoreService.getPublicKey();
}
@Override
public void stop() {
super.stop();
handler.stop();
}
public static class BootstrapMessageHandler implements MessageHandler {
private static final ThreadLocal<Map<Integer, PlatformEncDec>> platformEncDecMap =
new ThreadLocal<>(); //NOSONAR
private static final ThreadLocal<MessageEncoderDecoder> crypt = new ThreadLocal<>(); //NOSONAR
private final ExecutorService executor;
private final Set<String> platformProtocols;
private final KeyPair keyPair;
private final boolean supportUnencryptedConnection;
private final OperationsServerListService opsListService;
/**
* Create new instance of <code>BootstrapMessageHandler</code>.
*
* @param opsListService the ops list service
* @param executor the executor
* @param platformProtocols the platform protocols
* @param keyPair the key pair
* @param supportUnencryptedConnection the support unencrypted connection
*/
public BootstrapMessageHandler(OperationsServerListService opsListService,
ExecutorService executor,
Set<String> platformProtocols,
KeyPair keyPair,
boolean supportUnencryptedConnection) {
super();
this.opsListService = opsListService;
this.executor = executor;
this.platformProtocols = platformProtocols;
this.keyPair = keyPair;
this.supportUnencryptedConnection = supportUnencryptedConnection;
}
@Override
public void process(SessionAware message) {
// Session messages are not processed
}
@Override
public void process(final SessionInitMessage message) {
executor.execute(new Runnable() {
@Override
public void run() {
MessageEncoderDecoder crypt = getOrInitCrypt();
Map<Integer, PlatformEncDec> platformEncDecMap = getOrInitMap(platformProtocols);
try {
ClientSync request = decodeRequest(message, crypt, platformEncDecMap);
LOG.trace("Processing request {}", request);
BootstrapClientSync bsRequest = request.getBootstrapSync();
Set<ProtocolConnectionData> transports = opsListService.filter(bsRequest.getKeys());
BootstrapServerSync bsResponse = new BootstrapServerSync(
bsRequest.getRequestId(), transports);
ServerSync response = new ServerSync();
response.setRequestId(request.getRequestId());
response.setStatus(SyncStatus.SUCCESS);
response.setBootstrapSync(bsResponse);
LOG.trace("Response {}", response);
encodeAndForward(message, crypt, platformEncDecMap, response);
LOG.trace("Response forwarded to specific transport {}", response);
} catch (Exception ex) {
processErrors(message.getChannelContext(), message.getErrorBuilder(), ex);
}
}
private void encodeAndForward(final SessionInitMessage message,
MessageEncoderDecoder crypt,
Map<Integer, PlatformEncDec> platformEncDecMap,
ServerSync response)
throws PlatformEncDecException, GeneralSecurityException {
MessageBuilder converter = message.getMessageBuilder();
byte[] responseData = encodePlatformLevelData(
platformEncDecMap, message.getPlatformId(), response);
Object[] objects;
if (message.isEncrypted()) {
byte[] responseSignature = crypt.sign(responseData);
responseData = crypt.encodeData(responseData);
LOG.trace("Response signature {}", responseSignature);
LOG.trace("Response data {}", responseData);
objects = converter.build(responseData, responseSignature, message.isEncrypted());
} else {
LOG.trace("Response data {}", responseData);
objects = converter.build(responseData, message.isEncrypted());
}
ChannelContext context = message.getChannelContext();
if (objects != null && objects.length > 0) {
for (Object object : objects) {
context.write(object);
}
context.flush();
}
}
private void processErrors(ChannelContext ctx, ErrorBuilder converter, Exception ex) {
LOG.trace("Message processing failed", ex);
Object[] responses = converter.build(ex);
if (responses != null && responses.length > 0) {
for (Object response : responses) {
ctx.writeAndFlush(response);
}
} else {
ctx.fireExceptionCaught(ex);
}
}
private ClientSync decodeRequest(SessionInitMessage message,
MessageEncoderDecoder crypt,
Map<Integer, PlatformEncDec> platformEncDecMap)
throws GeneralSecurityException, PlatformEncDecException {
ClientSync syncRequest = null;
if (message.isEncrypted()) {
syncRequest = decodeEncryptedRequest(message, crypt, platformEncDecMap);
} else if (supportUnencryptedConnection) {
syncRequest = decodeUnencryptedRequest(message, platformEncDecMap);
} else {
LOG.warn("Received unencrypted init message, but unencrypted connection "
+ "forbidden by configuration.");
throw new GeneralSecurityException(
"Unencrypted connection forbidden by configuration.");
}
if (syncRequest.getBootstrapSync() == null) {
throw new IllegalArgumentException("Bootstrap sync message is missing");
}
return syncRequest;
}
private ClientSync decodeEncryptedRequest(SessionInitMessage message,
MessageEncoderDecoder crypt,
Map<Integer, PlatformEncDec> platformEncDecMap)
throws GeneralSecurityException, PlatformEncDecException {
byte[] requestRaw = crypt.decodeData(
message.getEncodedMessageData(), message.getEncodedSessionKey());
LOG.trace("Request data decrypted");
ClientSync request = decodePlatformLevelData(
platformEncDecMap, message.getPlatformId(), requestRaw);
LOG.trace("Request data deserialized");
return request;
}
private ClientSync decodeUnencryptedRequest(SessionInitMessage message,
Map<Integer, PlatformEncDec> platformEncDecMap)
throws GeneralSecurityException, PlatformEncDecException {
byte[] requestRaw = message.getEncodedMessageData();
LOG.trace("Try to convert raw data to SynRequest object");
ClientSync request = decodePlatformLevelData(
platformEncDecMap, message.getPlatformId(), requestRaw);
LOG.trace("Request data deserialized");
return request;
}
private byte[] encodePlatformLevelData(Map<Integer, PlatformEncDec> platformEncDecMap,
int platformId,
ServerSync sync)
throws PlatformEncDecException {
PlatformEncDec encDec = platformEncDecMap.get(platformId);
if (encDec != null) {
return platformEncDecMap.get(platformId).encode(sync);
} else {
throw new PlatformEncDecException(
MessageFormat.format("Encoder for platform protocol [{0}] is not defined",
platformId));
}
}
private ClientSync decodePlatformLevelData(Map<Integer, PlatformEncDec> platformEncDecMap,
Integer platformId,
byte[] requestRaw)
throws PlatformEncDecException {
PlatformEncDec encDec = platformEncDecMap.get(platformId);
if (encDec != null) {
return platformEncDecMap.get(platformId).decode(requestRaw);
} else {
throw new PlatformEncDecException(MessageFormat.format(
"Decoder for platform protocol [{0}] is not defined", platformId));
}
}
private MessageEncoderDecoder getOrInitCrypt() {
if (crypt.get() == null) {
crypt.set(new MessageEncoderDecoder(keyPair.getPrivate(), keyPair.getPublic()));
}
return crypt.get();
}
private Map<Integer, PlatformEncDec> getOrInitMap(Set<String> platformProtocols) {
if (platformEncDecMap.get() == null) {
platformEncDecMap.set(PlatformLookup.initPlatformProtocolMap(platformProtocols));
}
return platformEncDecMap.get();
}
});
}
public void stop() {
executor.shutdown();
}
}
}