/*
* Copyright 2015 Kevin Herron
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.digitalpetri.opcua.stack.server.handlers;
import java.io.IOException;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import com.digitalpetri.opcua.stack.core.StatusCodes;
import com.digitalpetri.opcua.stack.core.UaException;
import com.digitalpetri.opcua.stack.core.application.services.ServiceRequest;
import com.digitalpetri.opcua.stack.core.application.services.ServiceResponse;
import com.digitalpetri.opcua.stack.core.channel.ChannelSecurity;
import com.digitalpetri.opcua.stack.core.channel.ExceptionHandler;
import com.digitalpetri.opcua.stack.core.channel.SerializationQueue;
import com.digitalpetri.opcua.stack.core.channel.ServerSecureChannel;
import com.digitalpetri.opcua.stack.core.channel.headers.HeaderDecoder;
import com.digitalpetri.opcua.stack.core.channel.headers.SymmetricSecurityHeader;
import com.digitalpetri.opcua.stack.core.channel.messages.ErrorMessage;
import com.digitalpetri.opcua.stack.core.channel.messages.MessageType;
import com.digitalpetri.opcua.stack.core.serialization.UaRequestMessage;
import com.digitalpetri.opcua.stack.core.serialization.UaResponseMessage;
import com.digitalpetri.opcua.stack.core.util.BufferUtil;
import com.digitalpetri.opcua.stack.server.tcp.UaTcpStackServer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class UaTcpServerSymmetricHandler extends ByteToMessageCodec<ServiceResponse> implements HeaderDecoder {
private final Logger logger = LoggerFactory.getLogger(getClass());
private List<ByteBuf> chunkBuffers;
private final int maxChunkCount;
private final int maxChunkSize;
private final UaTcpStackServer server;
private final SerializationQueue serializationQueue;
private final ServerSecureChannel secureChannel;
public UaTcpServerSymmetricHandler(UaTcpStackServer server,
SerializationQueue serializationQueue,
ServerSecureChannel secureChannel) {
this.server = server;
this.serializationQueue = serializationQueue;
this.secureChannel = secureChannel;
maxChunkCount = serializationQueue.getParameters().getLocalMaxChunkCount();
maxChunkSize = serializationQueue.getParameters().getLocalReceiveBufferSize();
chunkBuffers = new ArrayList<>(maxChunkCount);
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
if (secureChannel != null) {
secureChannel.attr(UaTcpStackServer.BoundChannelKey).set(ctx.channel());
}
super.channelActive(ctx);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (secureChannel != null) {
secureChannel.attr(UaTcpStackServer.BoundChannelKey).remove();
}
super.channelInactive(ctx);
}
@Override
protected void encode(ChannelHandlerContext ctx, ServiceResponse message, ByteBuf out) throws Exception {
serializationQueue.encode((binaryEncoder, chunkEncoder) -> {
ByteBuf messageBuffer = BufferUtil.buffer();
try {
binaryEncoder.setBuffer(messageBuffer);
binaryEncoder.encodeMessage(null, message.getResponse());
final List<ByteBuf> chunks = chunkEncoder.encodeSymmetric(
secureChannel,
MessageType.SecureMessage,
messageBuffer,
message.getRequestId()
);
ctx.executor().execute(() -> {
chunks.forEach(c -> ctx.write(c, ctx.voidPromise()));
ctx.flush();
});
} catch (UaException e) {
logger.error("Error encoding {}: {}", message.getResponse().getClass(), e.getMessage(), e);
ctx.close();
} finally {
messageBuffer.release();
}
});
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) throws Exception {
buffer = buffer.order(ByteOrder.LITTLE_ENDIAN);
while (buffer.readableBytes() >= HEADER_LENGTH &&
buffer.readableBytes() >= getMessageLength(buffer)) {
int messageLength = getMessageLength(buffer);
MessageType messageType = MessageType.fromMediumInt(buffer.getMedium(buffer.readerIndex()));
switch (messageType) {
case SecureMessage:
onSecureMessage(ctx, buffer.readSlice(messageLength), out);
break;
default:
out.add(buffer.readSlice(messageLength).retain());
}
}
}
private void onSecureMessage(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) throws UaException {
buffer.skipBytes(3); // Skip messageType
char chunkType = (char) buffer.readByte();
if (chunkType == 'A') {
chunkBuffers.forEach(ByteBuf::release);
chunkBuffers.clear();
} else {
buffer.skipBytes(4); // Skip messageSize
long secureChannelId = buffer.readUnsignedInt();
if (secureChannelId != secureChannel.getChannelId()) {
throw new UaException(StatusCodes.Bad_SecureChannelIdInvalid,
"invalid secure channel id: " + secureChannelId);
}
int chunkSize = buffer.readerIndex(0).readableBytes();
if (chunkSize > maxChunkSize) {
throw new UaException(StatusCodes.Bad_TcpMessageTooLarge,
String.format("max chunk size exceeded (%s)", maxChunkSize));
}
chunkBuffers.add(buffer.retain());
if (chunkBuffers.size() > maxChunkCount) {
throw new UaException(StatusCodes.Bad_TcpMessageTooLarge,
String.format("max chunk count exceeded (%s)", maxChunkCount));
}
if (chunkType == 'F') {
final List<ByteBuf> buffersToDecode = chunkBuffers;
chunkBuffers = new ArrayList<>(maxChunkCount);
serializationQueue.decode((binaryDecoder, chunkDecoder) -> {
try {
validateChunkHeaders(buffersToDecode);
ByteBuf messageBuffer = chunkDecoder.decodeSymmetric(secureChannel, buffersToDecode);
binaryDecoder.setBuffer(messageBuffer);
UaRequestMessage request = binaryDecoder.decodeMessage(null);
ServiceRequest<UaRequestMessage, UaResponseMessage> serviceRequest = new ServiceRequest<>(
request,
chunkDecoder.getLastRequestId(),
server,
secureChannel
);
server.getExecutorService().execute(() -> server.receiveRequest(serviceRequest));
messageBuffer.release();
buffersToDecode.clear();
} catch (UaException e) {
logger.error("Error decoding symmetric message: {}", e.getMessage(), e);
ctx.close();
}
});
}
}
}
private void validateChunkHeaders(List<ByteBuf> chunkBuffers) throws UaException {
ChannelSecurity channelSecurity = secureChannel.getChannelSecurity();
long currentTokenId = channelSecurity.getCurrentToken().getTokenId().longValue();
long previousTokenId = channelSecurity.getPreviousToken()
.map(t -> t.getTokenId().longValue())
.orElse(-1L);
for (ByteBuf chunkBuffer : chunkBuffers) {
chunkBuffer.skipBytes(3 + 1 + 4 + 4); // skip messageType, chunkType, messageSize, secureChannelId
SymmetricSecurityHeader securityHeader = SymmetricSecurityHeader.decode(chunkBuffer);
if (securityHeader.getTokenId() != currentTokenId) {
if (securityHeader.getTokenId() != previousTokenId) {
String message = String.format(
"received unknown secure channel token. " +
"tokenId=%s, currentTokenId=%s, previousTokenId=%s",
securityHeader.getTokenId(), currentTokenId, previousTokenId);
throw new UaException(StatusCodes.Bad_SecureChannelTokenUnknown, message);
}
}
chunkBuffer.readerIndex(0);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
chunkBuffers.forEach(ByteBuf::release);
chunkBuffers.clear();
if (cause instanceof IOException) {
ctx.close();
logger.debug("[remote={}] IOException caught; channel closed");
} else {
ErrorMessage errorMessage = ExceptionHandler.sendErrorMessage(ctx, cause);
if (cause instanceof UaException) {
logger.debug("[remote={}] UaException caught; sent {}",
ctx.channel().remoteAddress(), errorMessage, cause);
} else {
logger.error("[remote={}] Exception caught; sent {}",
ctx.channel().remoteAddress(), errorMessage, cause);
}
}
}
}