package chapter3.recipe2; import static io.netty.handler.codec.http.HttpHeaders.Names.CONTENT_TYPE; import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.Cookie; import io.netty.handler.codec.http.CookieDecoder; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.multipart.Attribute; import io.netty.handler.codec.http.multipart.DefaultHttpDataFactory; import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder; import io.netty.handler.codec.http.multipart.InterfaceHttpData; import io.netty.handler.codec.http.multipart.InterfaceHttpData.HttpDataType; import io.netty.handler.stream.ChunkedStream; import io.netty.util.CharsetUtil; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.UnsupportedEncodingException; import java.nio.charset.Charset; import java.util.List; import java.util.Map.Entry; import java.util.Set; import javax.servlet.Servlet; import javax.servlet.ServletContext; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriUtils; @Sharable public class ServletNettyChannelHandler extends SimpleChannelInboundHandler<FullHttpRequest> { private static final String UTF_8 = "UTF-8"; private final Servlet servlet; private final ServletContext servletContext; public ServletNettyChannelHandler(Servlet servlet) { this.servlet = servlet; this.servletContext = servlet.getServletConfig().getServletContext(); } private MockHttpServletRequest createHttpServletRequest(FullHttpRequest fullHttpReq) { UriComponents uriComponents = UriComponentsBuilder.fromUriString(fullHttpReq.getUri()).build(); MockHttpServletRequest servletRequest = new MockHttpServletRequest(this.servletContext); servletRequest.setRequestURI(uriComponents.getPath()); servletRequest.setPathInfo(uriComponents.getPath()); servletRequest.setMethod(fullHttpReq.getMethod().name()); servletRequest.setCharacterEncoding(UTF_8); if (uriComponents.getScheme() != null) { servletRequest.setScheme(uriComponents.getScheme()); } if (uriComponents.getHost() != null) { servletRequest.setServerName(uriComponents.getHost()); } if (uriComponents.getPort() != -1) { servletRequest.setServerPort(uriComponents.getPort()); } copyHttpHeaders(fullHttpReq, servletRequest); copyHttpBodyData(fullHttpReq, servletRequest); copyQueryParams(uriComponents, servletRequest); copyToServletCookie(fullHttpReq, servletRequest); return servletRequest; } void copyToServletCookie(FullHttpRequest fullHttpReq, MockHttpServletRequest servletRequest){ String cookieString = fullHttpReq.headers().get(HttpHeaders.Names.COOKIE); if (cookieString != null) { Set<Cookie> cookies = CookieDecoder.decode(cookieString); if (!cookies.isEmpty()) { // Reset the cookies if necessary. javax.servlet.http.Cookie[] sCookies = new javax.servlet.http.Cookie[cookies.size()]; int i = 0; for (Cookie cookie: cookies) { javax.servlet.http.Cookie sCookie = new javax.servlet.http.Cookie(cookie.getName(), cookie.getValue()); sCookie.setPath(cookie.getPath()); sCookie.setMaxAge((int) cookie.getMaxAge()); sCookies[i++] = sCookie; } servletRequest.setCookies(sCookies); } } else { servletRequest.setCookies( new javax.servlet.http.Cookie[0]); } } void copyQueryParams(UriComponents uriComponents, MockHttpServletRequest servletRequest){ try { if (uriComponents.getQuery() != null) { String query = UriUtils.decode(uriComponents.getQuery(), UTF_8); servletRequest.setQueryString(query); } for (Entry<String, List<String>> entry : uriComponents.getQueryParams().entrySet()) { for (String value : entry.getValue()) { servletRequest.addParameter( UriUtils.decode(entry.getKey(), UTF_8), UriUtils.decode(value, UTF_8)); } } } catch (UnsupportedEncodingException ex) { // shouldn't happen } } void copyHttpHeaders(FullHttpRequest fullHttpReq, MockHttpServletRequest servletRequest){ HttpHeaders headers = fullHttpReq.headers(); for (String name : headers.names()) { servletRequest.addHeader(name, headers.get(name)); } servletRequest.setContentType(headers.get(HttpHeaders.Names.CONTENT_TYPE)); } void copyHttpBodyData(FullHttpRequest fullHttpReq, MockHttpServletRequest servletRequest){ ByteBuf bbContent = fullHttpReq.content(); if(bbContent.hasArray()) { servletRequest.setContent(bbContent.array()); } else { if(fullHttpReq.getMethod().equals(HttpMethod.POST)){ HttpPostRequestDecoder decoderPostData = new HttpPostRequestDecoder(new DefaultHttpDataFactory(false), fullHttpReq); String bbContentStr = bbContent.toString(Charset.forName(UTF_8)); servletRequest.setContent(bbContentStr.getBytes()); if( ! decoderPostData.isMultipart() ){ List<InterfaceHttpData> postDatas = decoderPostData.getBodyHttpDatas(); for (InterfaceHttpData postData : postDatas) { if (postData.getHttpDataType() == HttpDataType.Attribute) { Attribute attribute = (Attribute) postData; try { servletRequest.addParameter(attribute.getName(),attribute.getValue()); } catch (IOException e) { e.printStackTrace(); } } } } } } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { cause.printStackTrace(); System.err.println(cause.getMessage()); if (ctx.channel().isActive()) { sendError(ctx, INTERNAL_SERVER_ERROR); } } private static void sendError(ChannelHandlerContext ctx, HttpResponseStatus status) { ByteBuf content = Unpooled.copiedBuffer( "Failure: " + status.toString() + "\r\n", CharsetUtil.UTF_8); FullHttpResponse fullHttpResponse = new DefaultFullHttpResponse( HTTP_1_1, status, content ); fullHttpResponse.headers().add(CONTENT_TYPE, "text/plain; charset=UTF-8"); // Close the connection as soon as the error message is sent. ctx.write(fullHttpResponse).addListener(ChannelFutureListener.CLOSE); } @Override protected void channelRead0(ChannelHandlerContext channelHandlerContext, FullHttpRequest fullHttpRequest) throws Exception { if (!fullHttpRequest.getDecoderResult().isSuccess()) { sendError(channelHandlerContext, BAD_REQUEST); return; } MockHttpServletRequest servletRequest = createHttpServletRequest(fullHttpRequest); MockHttpServletResponse servletResponse = new MockHttpServletResponse(); this.servlet.service(servletRequest, servletResponse); HttpResponseStatus status = HttpResponseStatus.valueOf(servletResponse.getStatus()); HttpResponse response = new DefaultHttpResponse(HTTP_1_1, status); for (String name : servletResponse.getHeaderNames()) { for (String value : servletResponse.getHeaders(name)) { response.headers().add(name, value); } } // Write the initial line and the header. channelHandlerContext.write(response); InputStream contentStream = new ByteArrayInputStream(servletResponse.getContentAsByteArray()); // Write the content and flush it. ChannelFuture writeFuture = channelHandlerContext.writeAndFlush(new ChunkedStream(contentStream)); writeFuture.addListener(ChannelFutureListener.CLOSE); } }