/* * Copyright 2013 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.yarn.integration.ip.mind; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.StringReader; import java.util.HashMap; import java.util.Map; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.core.serializer.Deserializer; import org.springframework.core.serializer.Serializer; import org.springframework.integration.ip.tcp.serializer.SoftEndOfStreamException; /** * Spring {@link Serializer} and {@link Deserializer} interfaces * for mind protocol. * * @author Janne Valkealahti * */ public class MindRpcSerializer implements Serializer<MindRpcMessageHolder>, Deserializer<MindRpcMessageHolder> { private final static Log log = LogFactory.getLog(MindRpcSerializer.class); protected int maxMessageSize = 20000; /** * @see org.springframework.core.serializer.Deserializer#deserialize(java.io.InputStream) */ @Override public MindRpcMessageHolder deserialize(InputStream inputStream) throws IOException { int lenghts[] = readHeader(inputStream); if(log.isDebugEnabled()) { log.debug("rpc lenghts: " + lenghts[0] + "/" + lenghts[1]); } Map<String, String> headers = readHeaders(inputStream, lenghts[0]); byte[] content = readBytes(inputStream, lenghts[1]); if(log.isDebugEnabled()) { log.debug("deserialize: " + content); } return new MindRpcMessageHolder(headers, content); } /** * @see org.springframework.core.serializer.Serializer#serialize(java.lang.Object, java.io.OutputStream) */ @Override public void serialize(MindRpcMessageHolder object, OutputStream outputStream) throws IOException { if(log.isDebugEnabled()) { log.debug("serialize length=" + object.toBytes().length + " :" + new String(object.toBytes())); } outputStream.write(object.toBytes()); outputStream.flush(); } /** * Sets the max message size for transport. * * @param maxMessageSize the length of max message */ public void setMaxMessageSize(int maxMessageSize) { this.maxMessageSize = maxMessageSize; } /** * Reads the first line as a protocol header and parses sizes of * underlying headers and content. * * @param inputStream the input stream * @return int array containing sizes of headers and content * @throws IOException if read error occured */ protected int[] readHeader(InputStream inputStream) throws IOException { int[] ret = new int[]{0,0}; byte[] buffer = new byte[20]; int n = 0; int bite; while (true) { bite = inputStream.read(); if (bite < 0 && n == 0) { throw new SoftEndOfStreamException("Stream closed between payloads"); } checkClosure(bite); if (n > 0 && bite == '\n' && buffer[n-1] == '\r') { break; } buffer[n++] = (byte) bite; if (n >= 20) { throw new IOException("CRLF not found before max message length: " + this.maxMessageSize); } } byte[] assembledData = new byte[n-1]; System.arraycopy(buffer, 0, assembledData, 0, n-1); String header = new String(assembledData); if(log.isDebugEnabled()) { log.debug("Mind rpc header:" + header); } String[] respBytes = header.split(" "); ret[0] = Integer.parseInt(respBytes[1]); ret[1] = Integer.parseInt(respBytes[2]); if(log.isDebugEnabled()) { log.debug("Mind rpc parsed sizes: head=" + ret[0] + " content=" + ret[1]); } return ret; } /** * Reads a message headers from inputstream with a given length. * * @param inputStream the input stream * @param length how much to read from a stream * @return Map of headers * @throws IOException if error occured */ protected Map<String, String> readHeaders(InputStream inputStream, int length) throws IOException { Map<String, String> map = new HashMap<String, String>(); byte[] bytes = readBytes(inputStream, length); BufferedReader reader = new BufferedReader(new StringReader(new String(bytes))); String line; while ((line = reader.readLine()) != null) { if(log.isDebugEnabled()) { log.debug("deserialize header: " + line); } String[] split = line.split(":"); if(split != null & split.length == 2) { map.put(split[0], split[1]); } } return map; } /** * Helper method to read bytes from a stream. * * @param inputStream the input stream * @param length how much to read * @return bytes to read * @throws IOException if error occured */ protected byte[] readBytes(InputStream inputStream, int length) throws IOException { byte[] buffer = new byte[length]; int lengthRead = 0; while (lengthRead < length) { int len; len = inputStream.read(buffer, lengthRead, length - lengthRead); if (len < 0) { throw new IOException("Stream closed after " + lengthRead + " of " + length); } lengthRead += len; } return buffer; } protected void checkClosure(int bite) throws IOException { if (bite < 0) { if(log.isDebugEnabled()) { log.debug("Socket closed during message assembly"); } throw new IOException("Socket closed during message assembly"); } } }