/* * Copyright 2015 Kevin Herron * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.digitalpetri.opcua.stack.client.handlers; import java.nio.ByteOrder; import java.security.KeyPair; import java.security.cert.X509Certificate; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; import com.digitalpetri.opcua.stack.client.UaTcpStackClient; import com.digitalpetri.opcua.stack.client.config.UaTcpStackClientConfig; import com.digitalpetri.opcua.stack.core.StatusCodes; import com.digitalpetri.opcua.stack.core.UaException; import com.digitalpetri.opcua.stack.core.channel.ChannelConfig; import com.digitalpetri.opcua.stack.core.channel.ChannelParameters; import com.digitalpetri.opcua.stack.core.channel.ClientSecureChannel; import com.digitalpetri.opcua.stack.core.channel.SerializationQueue; import com.digitalpetri.opcua.stack.core.channel.headers.HeaderDecoder; import com.digitalpetri.opcua.stack.core.channel.messages.AcknowledgeMessage; import com.digitalpetri.opcua.stack.core.channel.messages.ErrorMessage; import com.digitalpetri.opcua.stack.core.channel.messages.HelloMessage; import com.digitalpetri.opcua.stack.core.channel.messages.MessageType; import com.digitalpetri.opcua.stack.core.channel.messages.TcpMessageDecoder; import com.digitalpetri.opcua.stack.core.channel.messages.TcpMessageEncoder; import com.digitalpetri.opcua.stack.core.security.SecurityPolicy; import com.digitalpetri.opcua.stack.core.types.builtin.StatusCode; import com.digitalpetri.opcua.stack.core.types.enumerated.MessageSecurityMode; import com.digitalpetri.opcua.stack.core.types.structured.EndpointDescription; import com.digitalpetri.opcua.stack.core.util.CertificateUtil; import com.google.common.primitives.Ints; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageCodec; import io.netty.util.AttributeKey; import io.netty.util.Timeout; import org.jooq.lambda.tuple.Tuple1; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class UaTcpClientAcknowledgeHandler extends ByteToMessageCodec<UaRequestFuture> implements HeaderDecoder { public static final AttributeKey<List<UaRequestFuture>> KEY_AWAITING_HANDSHAKE = AttributeKey.valueOf("awaiting-handshake"); private final Logger logger = LoggerFactory.getLogger(getClass()); private final List<UaRequestFuture> awaitingHandshake = new CopyOnWriteArrayList<>(); private volatile Timeout helloTimeout; private final ClientSecureChannel secureChannel; private final UaTcpStackClient client; private final CompletableFuture<ClientSecureChannel> handshakeFuture; public UaTcpClientAcknowledgeHandler(UaTcpStackClient client, Optional<ClientSecureChannel> existingChannel, CompletableFuture<ClientSecureChannel> handshakeFuture) { this.client = client; this.handshakeFuture = handshakeFuture; UaTcpStackClientConfig config = client.getConfig(); if (existingChannel.isPresent()) { secureChannel = existingChannel.get(); } else { secureChannel = config.getEndpoint() .flatMap(e -> { SecurityPolicy securityPolicy = SecurityPolicy .fromUriSafe(e.getSecurityPolicyUri()) .orElse(SecurityPolicy.None); if (securityPolicy == SecurityPolicy.None) { return Optional.empty(); } else { return Optional.of(new Tuple1<>(e)); } }) .flatMap(t1 -> config.getKeyPair().map(t1::concat)) .flatMap(t2 -> config.getCertificate().map(t2::concat)) .flatMap(t3 -> { EndpointDescription endpoint = t3.v1(); KeyPair keyPair = t3.v2(); X509Certificate localCertificate = t3.v3(); try { X509Certificate remoteCertificate = CertificateUtil .decodeCertificate(endpoint.getServerCertificate().bytes()); List<X509Certificate> remoteCertificateChain = CertificateUtil .decodeCertificates(endpoint.getServerCertificate().bytes()); SecurityPolicy securityPolicy = SecurityPolicy.fromUri(endpoint.getSecurityPolicyUri()); ClientSecureChannel secureChannel = new ClientSecureChannel( keyPair, localCertificate, remoteCertificate, remoteCertificateChain, securityPolicy, endpoint.getSecurityMode() ); return Optional.of(secureChannel); } catch (Throwable t) { return Optional.empty(); } }) .orElse(new ClientSecureChannel(SecurityPolicy.None, MessageSecurityMode.None)); } } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { helloTimeout = startHelloTimeout(ctx); secureChannel.setChannel(ctx.channel()); HelloMessage hello = new HelloMessage( PROTOCOL_VERSION, client.getChannelConfig().getMaxChunkSize(), client.getChannelConfig().getMaxChunkSize(), client.getChannelConfig().getMaxMessageSize(), client.getChannelConfig().getMaxChunkCount(), client.getEndpointUrl()); ByteBuf messageBuffer = TcpMessageEncoder.encode(hello); ctx.writeAndFlush(messageBuffer); logger.debug("Sent Hello message on channel={}.", ctx.channel()); super.channelActive(ctx); } private Timeout startHelloTimeout(ChannelHandlerContext ctx) { return client.getConfig().getWheelTimer().newTimeout( timeout -> { if (!timeout.isCancelled()) { handshakeFuture.completeExceptionally( new UaException(StatusCodes.Bad_Timeout, "timed out waiting for acknowledge")); ctx.close(); } }, 5, TimeUnit.SECONDS); } @Override protected void encode(ChannelHandlerContext ctx, UaRequestFuture message, ByteBuf out) throws Exception { awaitingHandshake.add(message); } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) throws Exception { buffer = buffer.order(ByteOrder.LITTLE_ENDIAN); while (buffer.readableBytes() >= HEADER_LENGTH && buffer.readableBytes() >= getMessageLength(buffer)) { int messageLength = getMessageLength(buffer); MessageType messageType = MessageType.fromMediumInt(buffer.getMedium(buffer.readerIndex())); switch (messageType) { case Acknowledge: onAcknowledge(ctx, buffer.readSlice(messageLength)); break; case Error: onError(ctx, buffer.readSlice(messageLength)); break; default: out.add(buffer.readSlice(messageLength).retain()); } } } private void onAcknowledge(ChannelHandlerContext ctx, ByteBuf buffer) { if (helloTimeout != null && !helloTimeout.cancel()) { helloTimeout = null; handshakeFuture.completeExceptionally( new UaException(StatusCodes.Bad_Timeout, "timed out waiting for acknowledge")); ctx.close(); return; } logger.debug("Received Acknowledge message on channel={}.", ctx.channel()); buffer.skipBytes(3 + 1 + 4); // Skip messageType, chunkType, and messageSize AcknowledgeMessage acknowledge = AcknowledgeMessage.decode(buffer); long remoteProtocolVersion = acknowledge.getProtocolVersion(); long remoteReceiveBufferSize = acknowledge.getReceiveBufferSize(); long remoteSendBufferSize = acknowledge.getSendBufferSize(); long remoteMaxMessageSize = acknowledge.getMaxMessageSize(); long remoteMaxChunkCount = acknowledge.getMaxChunkCount(); if (PROTOCOL_VERSION > remoteProtocolVersion) { logger.warn("Client protocol version ({}) does not match server protocol version ({}).", PROTOCOL_VERSION, remoteProtocolVersion); } ChannelConfig config = client.getChannelConfig(); /* Our receive buffer size is determined by the remote send buffer size. */ long localReceiveBufferSize = Math.min(remoteSendBufferSize, config.getMaxChunkSize()); /* Our send buffer size is determined by the remote receive buffer size. */ long localSendBufferSize = Math.min(remoteReceiveBufferSize, config.getMaxChunkSize()); /* Max message size the remote can send us; not influenced by remote configuration. */ long localMaxMessageSize = config.getMaxMessageSize(); /* Max chunk count the remote can send us; not influenced by remote configuration. */ long localMaxChunkCount = config.getMaxChunkCount(); ChannelParameters parameters = new ChannelParameters( Ints.saturatedCast(localMaxMessageSize), Ints.saturatedCast(localReceiveBufferSize), Ints.saturatedCast(localSendBufferSize), Ints.saturatedCast(localMaxChunkCount), Ints.saturatedCast(remoteMaxMessageSize), Ints.saturatedCast(remoteReceiveBufferSize), Ints.saturatedCast(remoteSendBufferSize), Ints.saturatedCast(remoteMaxChunkCount) ); ctx.channel().attr(KEY_AWAITING_HANDSHAKE).set(awaitingHandshake); ctx.executor().execute(() -> { int maxArrayLength = client.getChannelConfig().getMaxArrayLength(); int maxStringLength = client.getChannelConfig().getMaxStringLength(); SerializationQueue serializationQueue = new SerializationQueue( client.getConfig().getExecutor(), parameters, maxArrayLength, maxStringLength ); UaTcpClientMessageHandler handler = new UaTcpClientMessageHandler( client, secureChannel, serializationQueue, handshakeFuture ); ctx.pipeline().addLast(handler); }); } private void onError(ChannelHandlerContext ctx, ByteBuf buffer) { try { ErrorMessage errorMessage = TcpMessageDecoder.decodeError(buffer); StatusCode statusCode = errorMessage.getError(); long errorCode = statusCode.getValue(); boolean secureChannelError = errorCode == StatusCodes.Bad_SecurityChecksFailed || errorCode == StatusCodes.Bad_TcpSecureChannelUnknown || errorCode == StatusCodes.Bad_SecureChannelIdInvalid; if (secureChannelError) { secureChannel.setChannelId(0); } logger.error("Received error message: " + errorMessage); handshakeFuture.completeExceptionally(new UaException(statusCode, errorMessage.getReason())); } catch (UaException e) { logger.error("An exception occurred while decoding an error message: {}", e.getMessage(), e); handshakeFuture.completeExceptionally(e); } finally { ctx.close(); } } }