/** * Copyright 2016 Nikita Koksharov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.redisson.client.handler; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; import org.redisson.client.RedisAskException; import org.redisson.client.RedisException; import org.redisson.client.RedisLoadingException; import org.redisson.client.RedisMovedException; import org.redisson.client.RedisOutOfMemoryException; import org.redisson.client.RedisPubSubConnection; import org.redisson.client.RedisTimeoutException; import org.redisson.client.RedisTryAgainException; import org.redisson.client.codec.StringCodec; import org.redisson.client.protocol.CommandData; import org.redisson.client.protocol.CommandsData; import org.redisson.client.protocol.Decoder; import org.redisson.client.protocol.QueueCommand; import org.redisson.client.protocol.RedisCommands; import org.redisson.client.protocol.RedisCommand.ValueType; import org.redisson.client.protocol.decoder.ListMultiDecoder; import org.redisson.client.protocol.decoder.MultiDecoder; import org.redisson.client.protocol.decoder.NestedMultiDecoder; import org.redisson.client.protocol.decoder.SlotsDecoder; import org.redisson.client.protocol.pubsub.Message; import org.redisson.client.protocol.pubsub.PubSubMessage; import org.redisson.client.protocol.pubsub.PubSubPatternMessage; import org.redisson.client.protocol.pubsub.PubSubStatusMessage; import org.redisson.misc.LogHelper; import org.redisson.misc.RPromise; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ReplayingDecoder; import io.netty.util.CharsetUtil; import io.netty.util.internal.PlatformDependent; /** * Redis protocol command decoder * * Code parts from Sam Pullara * * @author Nikita Koksharov * */ public class CommandDecoder extends ReplayingDecoder<State> { private final Logger log = LoggerFactory.getLogger(getClass()); public static final char CR = '\r'; public static final char LF = '\n'; private static final char ZERO = '0'; // It is not needed to use concurrent map because responses are coming consecutive private final Map<String, MultiDecoder<Object>> pubSubMessageDecoders = new HashMap<String, MultiDecoder<Object>>(); private final Map<PubSubKey, CommandData<Object, Object>> pubSubChannels = PlatformDependent.newConcurrentHashMap(); private final ExecutorService executor; public CommandDecoder(ExecutorService executor) { this.executor = executor; } public void addPubSubCommand(String channel, CommandData<Object, Object> data) { String operation = data.getCommand().getName().toLowerCase(); pubSubChannels.put(new PubSubKey(channel, operation), data); } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { QueueCommand data = ctx.channel().attr(CommandsQueue.CURRENT_COMMAND).get(); if (log.isTraceEnabled()) { log.trace("channel: {} message: {}", ctx.channel(), in.toString(0, in.writerIndex(), CharsetUtil.UTF_8)); } if (state() == null) { boolean makeCheckpoint = data != null; if (data != null) { if (data instanceof CommandsData) { makeCheckpoint = false; } else { CommandData<Object, Object> cmd = (CommandData<Object, Object>)data; if (cmd.getCommand().getReplayMultiDecoder() != null && (NestedMultiDecoder.class.isAssignableFrom(cmd.getCommand().getReplayMultiDecoder().getClass()) || SlotsDecoder.class.isAssignableFrom(cmd.getCommand().getReplayMultiDecoder().getClass()) || ListMultiDecoder.class.isAssignableFrom(cmd.getCommand().getReplayMultiDecoder().getClass()))) { makeCheckpoint = false; } } } state(new State(makeCheckpoint)); } state().setDecoderState(null); if (data == null) { decode(in, null, null, ctx.channel()); } else if (data instanceof CommandData) { CommandData<Object, Object> cmd = (CommandData<Object, Object>)data; try { if (state().getLevels().size() > 0) { decodeFromCheckpoint(ctx, in, data, cmd); } else { decode(in, cmd, null, ctx.channel()); } } catch (Exception e) { cmd.tryFailure(e); } } else if (data instanceof CommandsData) { CommandsData commands = (CommandsData)data; try { decodeCommandBatch(ctx, in, data, commands); } catch (Exception e) { commands.getPromise().tryFailure(e); } return; } ctx.pipeline().get(CommandsQueue.class).sendNextCommand(ctx.channel()); state(null); } private void decodeFromCheckpoint(ChannelHandlerContext ctx, ByteBuf in, QueueCommand data, CommandData<Object, Object> cmd) throws IOException { if (state().getLevels().size() == 2) { StateLevel secondLevel = state().getLevels().get(1); if (secondLevel.getParts().isEmpty()) { state().getLevels().remove(1); } } if (state().getLevels().size() == 2) { StateLevel firstLevel = state().getLevels().get(0); StateLevel secondLevel = state().getLevels().get(1); decodeList(in, cmd, firstLevel.getParts(), ctx.channel(), secondLevel.getSize(), secondLevel.getParts()); Channel channel = ctx.channel(); MultiDecoder<Object> decoder = messageDecoder(cmd, firstLevel.getParts(), channel); if (decoder != null) { Object result = decoder.decode(firstLevel.getParts(), state()); if (data != null) { handleResult(cmd, null, result, true, channel); } } } if (state().getLevels().size() == 1) { StateLevel firstLevel = state().getLevels().get(0); if (firstLevel.getParts().isEmpty()) { state().resetLevel(); decode(in, cmd, null, ctx.channel()); } else { decodeList(in, cmd, null, ctx.channel(), firstLevel.getSize(), firstLevel.getParts()); } } } private void decodeCommandBatch(ChannelHandlerContext ctx, ByteBuf in, QueueCommand data, CommandsData commandBatch) { int i = state().getBatchIndex(); Throwable error = null; while (in.writerIndex() > in.readerIndex()) { CommandData<Object, Object> cmd = null; try { checkpoint(); state().setBatchIndex(i); cmd = (CommandData<Object, Object>) commandBatch.getCommands().get(i); decode(in, cmd, null, ctx.channel()); } catch (Exception e) { cmd.tryFailure(e); } i++; if (!cmd.isSuccess()) { error = cmd.cause(); } } if (commandBatch.isNoResult() || i == commandBatch.getCommands().size()) { RPromise<Void> promise = commandBatch.getPromise(); if (error != null) { if (!promise.tryFailure(error) && promise.cause() instanceof RedisTimeoutException) { log.warn("response has been skipped due to timeout! channel: {}, command: {}",ctx.channel(), LogHelper.toString(data)); } } else { if (!promise.trySuccess(null) && promise.cause() instanceof RedisTimeoutException) { log.warn("response has been skipped due to timeout! channel: {}, command: {}", ctx.channel(), LogHelper.toString(data)); } } ctx.pipeline().get(CommandsQueue.class).sendNextCommand(ctx.channel()); state(null); } else { checkpoint(); state().setBatchIndex(i); } } private void decode(ByteBuf in, CommandData<Object, Object> data, List<Object> parts, Channel channel) throws IOException { int code = in.readByte(); if (code == '+') { ByteBuf rb = in.readBytes(in.bytesBefore((byte) '\r')); try { String result = rb.toString(CharsetUtil.UTF_8); in.skipBytes(2); handleResult(data, parts, result, false, channel); } finally { rb.release(); } } else if (code == '-') { ByteBuf rb = in.readBytes(in.bytesBefore((byte) '\r')); try { String error = rb.toString(CharsetUtil.UTF_8); in.skipBytes(2); if (error.startsWith("MOVED")) { String[] errorParts = error.split(" "); int slot = Integer.valueOf(errorParts[1]); String addr = errorParts[2]; data.tryFailure(new RedisMovedException(slot, addr)); } else if (error.startsWith("ASK")) { String[] errorParts = error.split(" "); int slot = Integer.valueOf(errorParts[1]); String addr = errorParts[2]; data.tryFailure(new RedisAskException(slot, addr)); } else if (error.startsWith("TRYAGAIN")) { data.tryFailure(new RedisTryAgainException(error + ". channel: " + channel + " data: " + data)); } else if (error.startsWith("LOADING")) { data.tryFailure(new RedisLoadingException(error + ". channel: " + channel + " data: " + data)); } else if (error.startsWith("OOM")) { data.tryFailure(new RedisOutOfMemoryException(error.split("OOM ")[1] + ". channel: " + channel + " data: " + data)); } else if (error.contains("-OOM ")) { data.tryFailure(new RedisOutOfMemoryException(error.split("-OOM ")[1] + ". channel: " + channel + " data: " + data)); } else { if (data != null) { data.tryFailure(new RedisException(error + ". channel: " + channel + " command: " + data)); } else { log.error("Error: {} channel: {} data: {}", error, channel, data); } } } finally { rb.release(); } } else if (code == ':') { Long result = readLong(in); handleResult(data, parts, result, false, channel); } else if (code == '$') { ByteBuf buf = readBytes(in); Object result = null; if (buf != null) { Decoder<Object> decoder = selectDecoder(data, parts); result = decoder.decode(buf, state()); } handleResult(data, parts, result, false, channel); } else if (code == '*') { int level = state().incLevel(); long size = readLong(in); List<Object> respParts; if (state().getLevels().size()-1 >= level) { StateLevel stateLevel = state().getLevels().get(level); respParts = stateLevel.getParts(); size = stateLevel.getSize(); } else { respParts = new ArrayList<Object>(); if (state().isMakeCheckpoint()) { state().addLevel(new StateLevel(size, respParts)); } } decodeList(in, data, parts, channel, size, respParts); } else { String dataStr = in.toString(0, in.writerIndex(), CharsetUtil.UTF_8); throw new IllegalStateException("Can't decode replay: " + dataStr); } } private void decodeList(ByteBuf in, CommandData<Object, Object> data, List<Object> parts, Channel channel, long size, List<Object> respParts) throws IOException { for (int i = respParts.size(); i < size; i++) { decode(in, data, respParts, channel); if (state().isMakeCheckpoint()) { checkpoint(); } } MultiDecoder<Object> decoder = messageDecoder(data, respParts, channel); if (decoder == null) { return; } Object result = decoder.decode(respParts, state()); if (data != null) { handleResult(data, parts, result, true, channel); return; } if (result instanceof Message) { // store current message index checkpoint(); handlePublishSubscribe(data, null, channel, result); // has next messages? if (in.writerIndex() > in.readerIndex()) { decode(in, data, null, channel); } } } private void handlePublishSubscribe(CommandData<Object, Object> data, List<Object> parts, Channel channel, final Object result) { if (result instanceof PubSubStatusMessage) { String channelName = ((PubSubStatusMessage) result).getChannel(); String operation = ((PubSubStatusMessage) result).getType().name().toLowerCase(); PubSubKey key = new PubSubKey(channelName, operation); CommandData<Object, Object> d = pubSubChannels.get(key); if (Arrays.asList(RedisCommands.PSUBSCRIBE.getName(), RedisCommands.SUBSCRIBE.getName()).contains(d.getCommand().getName())) { pubSubChannels.remove(key); pubSubMessageDecoders.put(channelName, d.getMessageDecoder()); } if (Arrays.asList(RedisCommands.PUNSUBSCRIBE.getName(), RedisCommands.UNSUBSCRIBE.getName()).contains(d.getCommand().getName())) { pubSubChannels.remove(key); pubSubMessageDecoders.remove(channelName); } } final RedisPubSubConnection pubSubConnection = RedisPubSubConnection.getFrom(channel); executor.execute(new Runnable() { @Override public void run() { if (result instanceof PubSubStatusMessage) { pubSubConnection.onMessage((PubSubStatusMessage) result); } else if (result instanceof PubSubMessage) { pubSubConnection.onMessage((PubSubMessage) result); } else { pubSubConnection.onMessage((PubSubPatternMessage) result); } } }); } private void handleResult(CommandData<Object, Object> data, List<Object> parts, Object result, boolean multiResult, Channel channel) { if (data != null) { if (multiResult) { result = data.getCommand().getConvertor().convertMulti(result); } else { result = data.getCommand().getConvertor().convert(result); } } if (parts != null) { parts.add(result); } else { if (data != null && !data.getPromise().trySuccess(result) && data.cause() instanceof RedisTimeoutException) { log.warn("response has been skipped due to timeout! channel: {}, command: {}, result: {}", channel, data, result); } } } private MultiDecoder<Object> messageDecoder(CommandData<Object, Object> data, List<Object> parts, Channel channel) { if (data == null) { if (parts.isEmpty()) { return null; } String command = parts.get(0).toString(); if (Arrays.asList("subscribe", "psubscribe", "punsubscribe", "unsubscribe").contains(command)) { String channelName = parts.get(1).toString(); PubSubKey key = new PubSubKey(channelName, command); CommandData<Object, Object> commandData = pubSubChannels.get(key); if (commandData == null) { return null; } return commandData.getCommand().getReplayMultiDecoder(); } else if (parts.get(0).equals("message")) { String channelName = (String) parts.get(1); return pubSubMessageDecoders.get(channelName); } else if (parts.get(0).equals("pmessage")) { String patternName = (String) parts.get(1); return pubSubMessageDecoders.get(patternName); } } return data.getCommand().getReplayMultiDecoder(); } private Decoder<Object> selectDecoder(CommandData<Object, Object> data, List<Object> parts) { if (data == null) { if (parts != null) { if (parts.size() == 2 && "message".equals(parts.get(0))) { String channelName = (String) parts.get(1); return pubSubMessageDecoders.get(channelName); } if (parts.size() == 3 && "pmessage".equals(parts.get(0))) { String patternName = (String) parts.get(1); return pubSubMessageDecoders.get(patternName); } } return StringCodec.INSTANCE.getValueDecoder(); } Decoder<Object> decoder = data.getCommand().getReplayDecoder(); if (parts != null) { MultiDecoder<Object> multiDecoder = data.getCommand().getReplayMultiDecoder(); if (multiDecoder.isApplicable(parts.size(), state())) { decoder = multiDecoder; } } if (decoder == null) { if (data.getCommand().getOutParamType() == ValueType.MAP) { if (parts.size() % 2 != 0) { decoder = data.getCodec().getMapValueDecoder(); } else { decoder = data.getCodec().getMapKeyDecoder(); } } else if (data.getCommand().getOutParamType() == ValueType.MAP_KEY) { decoder = data.getCodec().getMapKeyDecoder(); } else if (data.getCommand().getOutParamType() == ValueType.MAP_VALUE) { decoder = data.getCodec().getMapValueDecoder(); } else { decoder = data.getCodec().getValueDecoder(); } } return decoder; } public ByteBuf readBytes(ByteBuf is) throws IOException { long l = readLong(is); if (l > Integer.MAX_VALUE) { throw new IllegalArgumentException( "Java only supports arrays up to " + Integer.MAX_VALUE + " in size"); } int size = (int) l; if (size == -1) { return null; } ByteBuf buffer = is.readSlice(size); int cr = is.readByte(); int lf = is.readByte(); if (cr != CR || lf != LF) { throw new IOException("Improper line ending: " + cr + ", " + lf); } return buffer; } public static long readLong(ByteBuf is) throws IOException { long size = 0; int sign = 1; int read = is.readByte(); if (read == '-') { read = is.readByte(); sign = -1; } do { if (read == CR) { if (is.readByte() == LF) { break; } } int value = read - ZERO; if (value >= 0 && value < 10) { size *= 10; size += value; } else { throw new IOException("Invalid character in integer"); } read = is.readByte(); } while (true); return size * sign; } }