package com.tesora.dve.db.mysql.portal.protocol;
/*
* #%L
* Tesora Inc.
* Database Virtualization Engine
* %%
* Copyright (C) 2011 - 2014 Tesora Inc.
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License, version 3,
* as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
* #L%
*/
import com.tesora.dve.db.mysql.common.MysqlHandshake;
import com.tesora.dve.db.mysql.libmy.MyErrorResponse;
import com.tesora.dve.db.mysql.libmy.MyMessage;
import com.tesora.dve.db.mysql.libmy.MyOKResponse;
import com.tesora.dve.db.mysql.libmy.MyServerGreetingErrorResponse;
import com.tesora.dve.exceptions.PECodingException;
import com.tesora.dve.exceptions.PEException;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public abstract class InboundMysqlAuthenticationHandlerV10 extends ChannelInboundHandlerAdapter {
static final Logger logger = LoggerFactory.getLogger(InboundMysqlAuthenticationHandlerV10.class);
public enum AuthState { UNAUTHENTICATED, AUTHENTICATED, FAILED }
protected AuthState currentAuthState = AuthState.UNAUTHENTICATED;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
boolean forwarded = false;
try {
switch (currentAuthState){
case AUTHENTICATED:
//already authenticated, pass decoded message through to other handler.
forwarded = true;
ctx.fireChannelRead(msg);
break;
case FAILED:
ReferenceCountUtil.release(msg);
ctx.channel().close();
break;
case UNAUTHENTICATED: {
if (!(msg instanceof MSPAuthenticateV10MessageMessage))
throw new PECodingException("Expecting authentication message, received, " + msg);
MSPAuthenticateV10MessageMessage authMessage = (MSPAuthenticateV10MessageMessage)msg;
authenticateClient(ctx, authMessage);
}
break;
default:
throw new PECodingException("Unexpected authorization state, " + currentAuthState);
}
} finally {
if (!forwarded)
ReferenceCountUtil.release(msg);
}
}
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
try {
MysqlHandshake handshake = null;
handshake = doHandshake(ctx);
// if we don't have a ssCon available then we don't need to send the greeting
if (handshake != null) {
sendGreeting(ctx, handshake);
}
} catch (Exception e) {
ctx.channel().write(new MyServerGreetingErrorResponse(e));
}
}
protected void sendGreeting(ChannelHandlerContext ctx, MysqlHandshake handshake) {
int connectionId = handshake.getConnectionId();
String salt = handshake.getSalt();
int serverCapabilities = (int) handshake.getServerCapabilities();
String serverVersion = handshake.getServerVersion();
byte serverCharSet = handshake.getServerCharSet();
String pluginData = handshake.getPluginData();
MSPServerGreetingRequestMessage.write(ctx, connectionId, salt, serverCapabilities, serverVersion, serverCharSet, pluginData);
}
void authenticateClient(ChannelHandlerContext ctx, MSPAuthenticateV10MessageMessage authMessage) throws PEException {
// Login in the SSConnection
MyMessage mysqlResp;
try {
mysqlResp = doAuthenticate(ctx, authMessage);
} catch (PEException e) {
mysqlResp = new MyErrorResponse(e.rootCause());
} catch (Throwable t) {
mysqlResp = new MyErrorResponse(new Exception(t.getMessage()));
}
if (mysqlResp instanceof MyOKResponse)
currentAuthState = AuthState.AUTHENTICATED;
else
currentAuthState = AuthState.FAILED;
ctx.writeAndFlush(mysqlResp);
}
protected abstract MyMessage doAuthenticate(ChannelHandlerContext ctx, MSPAuthenticateV10MessageMessage authMessage) throws Throwable;
protected abstract MysqlHandshake doHandshake(ChannelHandlerContext ctx);
}