package com.tesora.dve.db.mysql.portal.protocol; /* * #%L * Tesora Inc. * Database Virtualization Engine * %% * Copyright (C) 2011 - 2014 Tesora Inc. * %% * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, version 3, * as published by the Free Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. * #L% */ import com.tesora.dve.db.mysql.MysqlNativeConstants; import com.tesora.dve.db.mysql.common.JavaCharsetCatalog; import com.tesora.dve.db.mysql.common.SimpleCredentials; import com.tesora.dve.db.mysql.libmy.MyErrorResponse; import com.tesora.dve.db.mysql.libmy.MyHandshakeErrorResponse; import com.tesora.dve.db.mysql.libmy.MyHandshakeV10; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.util.AttributeKey; import java.net.SocketAddress; import java.nio.ByteOrder; import java.nio.charset.Charset; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import com.tesora.dve.exceptions.PECodingException; import com.tesora.dve.exceptions.PESQLException; import io.netty.util.CharsetUtil; import org.apache.log4j.Logger; public class MysqlClientAuthenticationHandler extends ByteToMessageDecoder { static final Logger log = Logger.getLogger(MysqlClientAuthenticationHandler.class); private static final long MAXIMUM_WAITTIME_MINUTES=5; //abort operation if site doesn't at least authenticate us after 5 minutes. private static final int MESSAGE_HEADER_LEN = 4; enum AuthenticationState { AWAIT_GREETING, AWAIT_ACKNOWLEGEMENT, AUTHENTICATED, FAILURE }; public static final AttributeKey<MyHandshakeV10> HANDSHAKE_KEY = new AttributeKey<MyHandshakeV10>("ServerHandshake"); CountDownLatch finished = new CountDownLatch(1); volatile AuthenticationState state = AuthenticationState.AWAIT_GREETING; //must be volatile, it is read by multiple threads that haven't sync'ed on the same monitor. private SimpleCredentials userCredentials; private int serverThreadID; JavaCharsetCatalog javaCharsetCatalog; Charset serverCharset = CharsetUtil.UTF_8; AtomicReference<Charset> targetCharset; long clientCapabilities; public MysqlClientAuthenticationHandler(SimpleCredentials userCredentials, long clientCapabilities, JavaCharsetCatalog javaCharsetCatalog, AtomicReference<Charset> targetCharset) { this.userCredentials = userCredentials; this.javaCharsetCatalog = javaCharsetCatalog; this.clientCapabilities = clientCapabilities; this.targetCharset = targetCharset; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { decode(ctx, in, out, false); } @Override protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { decode(ctx, in, out, true); } protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out, boolean isLastBytes) throws Exception { ByteBuf leBuf = in.order(ByteOrder.LITTLE_ENDIAN); leBuf.markReaderIndex(); boolean messageProcessed = false; try { if (leBuf.readableBytes() > MESSAGE_HEADER_LEN) { int payloadLen = leBuf.readMedium(); leBuf.readByte(); // seq if (leBuf.readableBytes() >= payloadLen) { ByteBuf payload = leBuf.slice(leBuf.readerIndex(), payloadLen).order(ByteOrder.LITTLE_ENDIAN); Byte protocolVersion = leBuf.readByte(); leBuf.skipBytes(payloadLen-1); messageProcessed = true; if (state == AuthenticationState.AWAIT_GREETING) processGreeting(ctx, payload, protocolVersion); else processAcknowlegement(ctx, payload); } } } catch (Throwable t){ enterState(AuthenticationState.FAILURE); log.warn("Unexpected problem on outbound mysql connection.",t); throw t; } finally { if (!messageProcessed) leBuf.resetReaderIndex(); if (isLastBytes && (state == AuthenticationState.AWAIT_ACKNOWLEGEMENT || state == AuthenticationState.AWAIT_GREETING) ){ //we are waiting for handshake packets from mysql, but no more packets will ever arrive. release blocked callers. Channel channel = ctx.channel(); SocketAddress addr = (channel == null ? null : channel.remoteAddress() ); log.warn("Socket closed in middle of authentication handshake on socket "+addr); enterState(AuthenticationState.FAILURE); } } } private void processGreeting(ChannelHandlerContext ctx, ByteBuf payload, Byte protocolVersion) throws Exception { if (protocolVersion == MyErrorResponse.ERRORPKT_FIELD_COUNT) { processErrorPacket(ctx, payload); } else { MyHandshakeV10 handshake = new MyHandshakeV10(); handshake.unmarshallMessage(payload); ctx.channel().attr(HANDSHAKE_KEY).set(handshake); serverCharset = handshake.getServerCharset( javaCharsetCatalog ); targetCharset.set(serverCharset); serverThreadID = handshake.getThreadID(); ByteBuf out = Unpooled.buffer().order(ByteOrder.LITTLE_ENDIAN); try { String userName = userCredentials.getName(); String userPassword = userCredentials.getPassword(); String salt = handshake.getSalt(); Charset charset = javaCharsetCatalog.findJavaCharsetById(handshake.getServerCharsetId()); int mysqlCharsetID = MysqlNativeConstants.MYSQL_CHARSET_UTF8; int capabilitiesFlag = (int) clientCapabilities; handshake.setServerCharset((byte) mysqlCharsetID); MSPAuthenticateV10MessageMessage.write(out, userName, userPassword, salt, charset, mysqlCharsetID, capabilitiesFlag); ctx.writeAndFlush(out); } catch (Exception e) { out.release(); log.debug("Couldn't write auth handshake to socket",e); } enterState(AuthenticationState.AWAIT_ACKNOWLEGEMENT); } } private void processAcknowlegement(ChannelHandlerContext ctx, ByteBuf payload) throws Exception { byte fieldCount = payload.getByte(payload.readerIndex()); if (fieldCount == MyErrorResponse.ERRORPKT_FIELD_COUNT) { processErrorPacket(ctx, payload); } else { ctx.pipeline().remove(this); enterState(AuthenticationState.AUTHENTICATED); } } private void processErrorPacket(ChannelHandlerContext ctx, ByteBuf payload) throws Exception { enterState(AuthenticationState.FAILURE); MyHandshakeErrorResponse errorResp = new MyHandshakeErrorResponse(serverCharset); errorResp.unmarshallMessage(payload); throw errorResp.asException(); } private synchronized void enterState(AuthenticationState newState) { boolean alreadyFinished = true; try { alreadyFinished = finished.await(0, TimeUnit.NANOSECONDS); } catch (InterruptedException e) {}//shouldn't happen, we aren't waiting, just checking status. if (alreadyFinished){ log.warn("we already unlatched, current state is " + state + " , new state is "+newState); return; } this.state = newState; if (state == AuthenticationState.AUTHENTICATED || state == AuthenticationState.FAILURE) finished.countDown(); } public void assertAuthenticated() throws PESQLException { try { boolean receivedResponse = finished.await(MAXIMUM_WAITTIME_MINUTES, TimeUnit.MINUTES); if (!receivedResponse){ throw new PESQLException("Timeout trying to authenticate as user " + userCredentials.getName()); } AuthenticationState check = state; if (check == AuthenticationState.FAILURE) throw new PESQLException("Failed to authenticate as user " + userCredentials.getName()); if (check != AuthenticationState.AUTHENTICATED) throw new PECodingException("unexpected state after unlatch, "+check); } catch (InterruptedException e) { throw new PESQLException("Interrupted while authenticating as user "+userCredentials.getName()); } } public int getThreadID() { return serverThreadID; } }