/*
* Copyright (c) 2012-2015 The original author or authors
* ------------------------------------------------------
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Apache License v2.0 which accompanies this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* The Apache License v2.0 is available at
* http://www.opensource.org/licenses/apache2.0.php
*
* You may elect to redistribute this code under either of these licenses.
*/
package org.red5.server.mqtt.codec.parser;
import static org.red5.server.mqtt.codec.MQTTProtocol.VERSION_3_1;
import static org.red5.server.mqtt.codec.MQTTProtocol.VERSION_3_1_1;
import java.io.UnsupportedEncodingException;
import org.apache.mina.core.buffer.IoBuffer;
import org.apache.mina.core.session.IoSession;
import org.apache.mina.filter.codec.ProtocolDecoderOutput;
import org.eclipse.moquette.proto.messages.AbstractMessage;
import org.eclipse.moquette.proto.messages.ConnectMessage;
import org.red5.server.mqtt.codec.MQTTDecoder;
import org.red5.server.mqtt.codec.MQTTProtocol;
import org.red5.server.mqtt.codec.exception.CorruptedFrameException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Connect decoder.
*
* @author andrea
* @author Paul Gregoire
*/
public class ConnectDecoder extends DemuxDecoder {
private static final Logger log = LoggerFactory.getLogger(ConnectDecoder.class);
public static final String CONNECT_STATUS = "connected";
@Override
public void decode(IoSession session, IoBuffer in, ProtocolDecoderOutput out) throws UnsupportedEncodingException, CorruptedFrameException {
in.reset();
//Common decoding part
ConnectMessage message = new ConnectMessage();
if (!decodeCommonHeader(message, 0x00, in)) {
in.reset();
return;
}
int remainingLength = message.getRemainingLength();
log.trace("remainingLength: {}", remainingLength);
int start = in.markValue();
int protocolNameLen = in.getUnsignedShort();
log.trace("protocolNameLen: {}", protocolNameLen);
byte[] encProtoName;
String protoName;
switch (protocolNameLen) {
case 6:
//MQTT version 3.1 "MQIsdp"
//ProtocolName 8 bytes or 6 bytes
if (in.remaining() < 10) {
in.reset();
return;
}
encProtoName = new byte[6];
in.get(encProtoName);
protoName = new String(encProtoName, "UTF-8");
if (!"MQIsdp".equals(protoName)) {
in.remaining();
throw new CorruptedFrameException("Invalid protoName: " + protoName);
}
message.setProtocolName(protoName);
session.setAttribute(MQTTDecoder.PROTOCOL_VERSION, (int) VERSION_3_1);
break;
case 4:
//MQTT version 3.1.1 "MQTT"
//ProtocolName 6 bytes
if (in.remaining() < 8) {
in.reset();
return;
}
encProtoName = new byte[4];
in.get(encProtoName);
protoName = new String(encProtoName, "UTF-8");
if (!"MQTT".equals(protoName)) {
in.reset();
throw new CorruptedFrameException("Invalid protoName: " + protoName);
}
message.setProtocolName(protoName);
session.setAttribute(MQTTDecoder.PROTOCOL_VERSION, (int) VERSION_3_1_1);
break;
default:
//protocol broken
throw new CorruptedFrameException("Invalid protoName size: " + protocolNameLen);
}
log.trace("protoName: {}", protoName);
//ProtocolVersion 1 byte (value 0x03 for 3.1, 0x04 for 3.1.1)
message.setProcotolVersion(in.get());
if (message.getProcotolVersion() == VERSION_3_1_1) {
//if 3.1.1, check the flags (dup, retain and qos == 0)
if (message.isDupFlag() || message.isRetainFlag() || message.getQos() != AbstractMessage.QOSType.MOST_ONE) {
throw new CorruptedFrameException("Received a CONNECT with fixed header flags != 0");
}
//check if this is another connect from the same client on the same session
Boolean alreadyConnected = (Boolean) session.getAttribute(ConnectDecoder.CONNECT_STATUS);
if (alreadyConnected == null) {
//never set
session.setAttribute(ConnectDecoder.CONNECT_STATUS, Boolean.TRUE);
} else if (alreadyConnected) {
throw new CorruptedFrameException("Received a second CONNECT on the same network connection");
}
}
//Connection flag
byte connFlags = in.get();
if (message.getProcotolVersion() == VERSION_3_1_1) {
if ((connFlags & 0x01) != 0) { //bit(0) of connection flags is != 0
throw new CorruptedFrameException("Received a CONNECT with connectionFlags[0(bit)] != 0");
}
}
boolean cleanSession = ((connFlags & 0x02) >> 1) == 1;
boolean willFlag = ((connFlags & 0x04) >> 2) == 1;
byte willQos = (byte) ((connFlags & 0x18) >> 3);
if (willQos > 2) {
in.reset();
throw new CorruptedFrameException("Expected will QoS in range 0..2 but found: " + willQos);
}
boolean willRetain = ((connFlags & 0x20) >> 5) == 1;
boolean passwordFlag = ((connFlags & 0x40) >> 6) == 1;
boolean userFlag = ((connFlags & 0x80) >> 7) == 1;
//a password is true iff user is true
if (!userFlag && passwordFlag) {
in.reset();
throw new CorruptedFrameException("Expected password flag to true if the user flag is true but was: " + passwordFlag);
}
message.setCleanSession(cleanSession);
message.setWillFlag(willFlag);
message.setWillQos(willQos);
message.setWillRetain(willRetain);
message.setPasswordFlag(passwordFlag);
message.setUserFlag(userFlag);
//Keep Alive timer 2 bytes
//int keepAlive = Utils.readWord(in);
int keepAlive = in.getUnsignedShort();
message.setKeepAlive(keepAlive);
if ((remainingLength == 12 && message.getProcotolVersion() == VERSION_3_1) || (remainingLength == 10 && message.getProcotolVersion() == VERSION_3_1_1)) {
out.write(message);
return;
}
//Decode the ClientID
String clientID = MQTTProtocol.decodeString(in);
if (clientID == null) {
in.reset();
return;
}
message.setClientID(clientID);
//Decode willTopic
if (willFlag) {
String willTopic = MQTTProtocol.decodeString(in);
if (willTopic == null) {
in.reset();
return;
}
message.setWillTopic(willTopic);
}
//Decode willMessage
if (willFlag) {
String willMessage = MQTTProtocol.decodeString(in);
if (willMessage == null) {
in.reset();
return;
}
message.setWillMessage(willMessage);
}
//Compatibility check with v3.0, remaining length has precedence over
//the user and password flags
int readed = in.markValue() - start;
if (readed == remainingLength) {
out.write(message);
return;
}
//Decode username
if (userFlag) {
String userName = MQTTProtocol.decodeString(in);
if (userName == null) {
in.reset();
return;
}
message.setUsername(userName);
}
readed = in.position() - start;
if (readed == remainingLength) {
out.write(message);
return;
}
//Decode password
if (passwordFlag) {
String password = MQTTProtocol.decodeString(in);
if (password == null) {
in.reset();
return;
}
message.setPassword(password);
}
out.write(message);
}
}