// // ======================================================================== // Copyright (c) 1995-2017 Mort Bay Consulting Pty. Ltd. // ------------------------------------------------------------------------ // All rights reserved. This program and the accompanying materials // are made available under the terms of the Eclipse Public License v1.0 // and Apache License v2.0 which accompanies this distribution. // // The Eclipse Public License is available at // http://www.eclipse.org/legal/epl-v10.html // // The Apache License v2.0 is available at // http://www.opensource.org/licenses/apache2.0.php // // You may elect to redistribute this code under either of these licenses. // ======================================================================== // package org.eclipse.jetty.websocket.common.test; import static org.hamcrest.Matchers.*; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.net.Socket; import java.net.SocketException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.eclipse.jetty.io.ByteBufferPool; import org.eclipse.jetty.io.MappedByteBufferPool; import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.IO; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.websocket.api.BatchMode; import org.eclipse.jetty.websocket.api.WebSocketPolicy; import org.eclipse.jetty.websocket.api.WriteCallback; import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig; import org.eclipse.jetty.websocket.api.extensions.Frame; import org.eclipse.jetty.websocket.api.extensions.IncomingFrames; import org.eclipse.jetty.websocket.api.extensions.OutgoingFrames; import org.eclipse.jetty.websocket.api.extensions.Frame.Type; import org.eclipse.jetty.websocket.common.AcceptHash; import org.eclipse.jetty.websocket.common.CloseInfo; import org.eclipse.jetty.websocket.common.Generator; import org.eclipse.jetty.websocket.common.OpCode; import org.eclipse.jetty.websocket.common.Parser; import org.eclipse.jetty.websocket.common.WebSocketFrame; import org.eclipse.jetty.websocket.common.extensions.ExtensionStack; import org.eclipse.jetty.websocket.common.extensions.WebSocketExtensionFactory; import org.eclipse.jetty.websocket.common.frames.CloseFrame; import org.eclipse.jetty.websocket.common.scopes.SimpleContainerScope; import org.junit.Assert; public class BlockheadServerConnection implements IncomingFrames, OutgoingFrames, Runnable, IBlockheadServerConnection { private static final Logger LOG = Log.getLogger(BlockheadServerConnection.class); private final int BUFFER_SIZE = 8192; private final Socket socket; private final ByteBufferPool bufferPool; private final WebSocketPolicy policy; private final IncomingFramesCapture incomingFrames; private final Parser parser; private final Generator generator; private final AtomicInteger parseCount; private final WebSocketExtensionFactory extensionRegistry; private final AtomicBoolean echoing = new AtomicBoolean(false); private Thread echoThread; /** Set to true to disable timeouts (for debugging reasons) */ private boolean debug = false; private OutputStream out; private InputStream in; private Map<String, String> extraResponseHeaders = new HashMap<>(); private OutgoingFrames outgoing = this; public BlockheadServerConnection(Socket socket) { this.socket = socket; this.incomingFrames = new IncomingFramesCapture(); this.policy = WebSocketPolicy.newServerPolicy(); this.policy.setMaxBinaryMessageSize(100000); this.policy.setMaxTextMessageSize(100000); // This is a blockhead server connection, no point tracking leaks on this object. this.bufferPool = new MappedByteBufferPool(BUFFER_SIZE); this.parser = new Parser(policy,bufferPool); this.parseCount = new AtomicInteger(0); this.generator = new Generator(policy,bufferPool,false); this.extensionRegistry = new WebSocketExtensionFactory(new SimpleContainerScope(policy,bufferPool)); } /** * Add an extra header for the upgrade response (from the server). No extra work is done to ensure the key and value are sane for http. * @param rawkey the raw key * @param rawvalue the raw value */ public void addResponseHeader(String rawkey, String rawvalue) { extraResponseHeaders.put(rawkey,rawvalue); } /* (non-Javadoc) * @see org.eclipse.jetty.websocket.common.test.IBlockheadServerConnection#close() */ @Override public void close() throws IOException { write(new CloseFrame()); flush(); } /* (non-Javadoc) * @see org.eclipse.jetty.websocket.common.test.IBlockheadServerConnection#close(int) */ @Override public void close(int statusCode) throws IOException { CloseInfo close = new CloseInfo(statusCode); write(close.asFrame()); flush(); } public void disconnect() { LOG.debug("disconnect"); IO.close(in); IO.close(out); if (socket != null) { try { socket.close(); } catch (IOException ignore) { /* ignore */ } } } public void echoMessage(int expectedFrames, int timeoutDuration, TimeUnit timeoutUnit) throws IOException, TimeoutException { LOG.debug("Echo Frames [expecting {}]",expectedFrames); IncomingFramesCapture cap = readFrames(expectedFrames,timeoutDuration,timeoutUnit); // now echo them back. for (Frame frame : cap.getFrames()) { write(WebSocketFrame.copy(frame).setMasked(false)); } } public void flush() throws IOException { getOutputStream().flush(); } public ByteBufferPool getBufferPool() { return bufferPool; } public IncomingFramesCapture getIncomingFrames() { return incomingFrames; } public InputStream getInputStream() throws IOException { if (in == null) { in = socket.getInputStream(); } return in; } private OutputStream getOutputStream() throws IOException { if (out == null) { out = socket.getOutputStream(); } return out; } public Parser getParser() { return parser; } public WebSocketPolicy getPolicy() { return policy; } @Override public void incomingError(Throwable e) { incomingFrames.incomingError(e); } @Override public void incomingFrame(Frame frame) { LOG.debug("incoming({})",frame); int count = parseCount.incrementAndGet(); if ((count % 10) == 0) { LOG.info("Server parsed {} frames",count); } incomingFrames.incomingFrame(WebSocketFrame.copy(frame)); if (frame.getOpCode() == OpCode.CLOSE) { CloseInfo close = new CloseInfo(frame); LOG.debug("Close frame: {}",close); } Type type = frame.getType(); if (echoing.get() && (type.isData() || type.isContinuation())) { try { write(WebSocketFrame.copy(frame).setMasked(false)); } catch (IOException e) { LOG.warn(e); } } } @Override public void outgoingFrame(Frame frame, WriteCallback callback, BatchMode batchMode) { ByteBuffer headerBuf = generator.generateHeaderBytes(frame); if (LOG.isDebugEnabled()) { LOG.debug("writing out: {}",BufferUtil.toDetailString(headerBuf)); } try { BufferUtil.writeTo(headerBuf,out); if (frame.hasPayload()) BufferUtil.writeTo(frame.getPayload(),out); out.flush(); if (callback != null) { callback.writeSuccess(); } if (frame.getOpCode() == OpCode.CLOSE) { disconnect(); } } catch (Throwable t) { if (callback != null) { callback.writeFailed(t); } } } public List<ExtensionConfig> parseExtensions(List<String> requestLines) { List<ExtensionConfig> extensionConfigs = new ArrayList<>(); List<String> hits = regexFind(requestLines, "^Sec-WebSocket-Extensions: (.*)$"); for (String econf : hits) { // found extensions ExtensionConfig config = ExtensionConfig.parse(econf); extensionConfigs.add(config); } return extensionConfigs; } public String parseWebSocketKey(List<String> requestLines) { List<String> hits = regexFind(requestLines,"^Sec-WebSocket-Key: (.*)$"); if (hits.size() <= 0) { return null; } Assert.assertThat("Number of Sec-WebSocket-Key headers", hits.size(), is(1)); String key = hits.get(0); return key; } public int read(ByteBuffer buf) throws IOException { int len = 0; while ((in.available() > 0) && (buf.remaining() > 0)) { buf.put((byte)in.read()); len++; } return len; } public IncomingFramesCapture readFrames(int expectedCount, int timeoutDuration, TimeUnit timeoutUnit) throws IOException, TimeoutException { LOG.debug("Read: waiting for {} frame(s) from client",expectedCount); int startCount = incomingFrames.size(); ByteBuffer buf = bufferPool.acquire(BUFFER_SIZE,false); BufferUtil.clearToFill(buf); try { long msDur = TimeUnit.MILLISECONDS.convert(timeoutDuration,timeoutUnit); long now = System.currentTimeMillis(); long expireOn = now + msDur; LOG.debug("Now: {} - expireOn: {} ({} ms)",now,expireOn,msDur); int len = 0; while (incomingFrames.size() < (startCount + expectedCount)) { BufferUtil.clearToFill(buf); len = read(buf); if (len > 0) { LOG.debug("Read {} bytes",len); BufferUtil.flipToFlush(buf,0); parser.parse(buf); } try { TimeUnit.MILLISECONDS.sleep(20); } catch (InterruptedException gnore) { /* ignore */ } if (!debug && (System.currentTimeMillis() > expireOn)) { incomingFrames.dump(); throw new TimeoutException(String.format("Timeout reading all %d expected frames. (managed to only read %d frame(s))",expectedCount, incomingFrames.size())); } } } finally { bufferPool.release(buf); } return incomingFrames; } public String readRequest() throws IOException { LOG.debug("Reading client request"); StringBuilder request = new StringBuilder(); BufferedReader in = new BufferedReader(new InputStreamReader(getInputStream())); for (String line = in.readLine(); line != null; line = in.readLine()) { if (line.length() == 0) { break; } request.append(line).append("\r\n"); LOG.debug("read line: {}",line); } LOG.debug("Client Request:{}{}","\n",request); return request.toString(); } public List<String> readRequestLines() throws IOException { LOG.debug("Reading client request header"); List<String> lines = new ArrayList<>(); BufferedReader in = new BufferedReader(new InputStreamReader(getInputStream())); for (String line = in.readLine(); line != null; line = in.readLine()) { if (line.length() == 0) { break; } lines.add(line); } return lines; } public List<String> regexFind(List<String> lines, String pattern) { List<String> hits = new ArrayList<>(); Pattern patKey = Pattern.compile(pattern,Pattern.CASE_INSENSITIVE); Matcher mat; for (String line : lines) { mat = patKey.matcher(line); if (mat.matches()) { if (mat.groupCount() >= 1) { hits.add(mat.group(1)); } else { hits.add(mat.group(0)); } } } return hits; } public void respond(String rawstr) throws IOException { LOG.debug("respond(){}{}","\n",rawstr); getOutputStream().write(rawstr.getBytes()); flush(); } @Override public void run() { LOG.debug("Entering echo thread"); ByteBuffer buf = bufferPool.acquire(BUFFER_SIZE,false); BufferUtil.clearToFill(buf); long readBytes = 0; try { while (echoing.get()) { BufferUtil.clearToFill(buf); long len = read(buf); if (len > 0) { readBytes += len; LOG.debug("Read {} bytes",len); BufferUtil.flipToFlush(buf,0); parser.parse(buf); } try { TimeUnit.MILLISECONDS.sleep(20); } catch (InterruptedException gnore) { /* ignore */ } } } catch (IOException e) { LOG.debug("Exception during echo loop",e); } finally { LOG.debug("Read {} bytes",readBytes); bufferPool.release(buf); } } public void setSoTimeout(int ms) throws SocketException { socket.setSoTimeout(ms); } public void startEcho() { if (echoThread != null) { throw new IllegalStateException("Echo thread already declared!"); } echoThread = new Thread(this,"BlockheadServer/Echo"); echoing.set(true); echoThread.start(); } public void stopEcho() { echoing.set(false); } public List<String> upgrade() throws IOException { List<String> requestLines = readRequestLines(); List<ExtensionConfig> extensionConfigs = parseExtensions(requestLines); String key = parseWebSocketKey(requestLines); LOG.debug("Client Request Extensions: {}",extensionConfigs); LOG.debug("Client Request Key: {}",key); Assert.assertThat("Request: Sec-WebSocket-Key",key,notNullValue()); // collect extensions configured in response header ExtensionStack extensionStack = new ExtensionStack(extensionRegistry); extensionStack.negotiate(extensionConfigs); // Start with default routing extensionStack.setNextIncoming(this); extensionStack.setNextOutgoing(this); // Configure Parser / Generator extensionStack.configure(parser); extensionStack.configure(generator); // Start Stack try { extensionStack.start(); } catch (Exception e) { throw new IOException("Unable to start Extension Stack"); } // Configure Parser parser.setIncomingFramesHandler(extensionStack); // Setup Response StringBuilder resp = new StringBuilder(); resp.append("HTTP/1.1 101 Upgrade\r\n"); resp.append("Connection: upgrade\r\n"); resp.append("Content-Length: 0\r\n"); resp.append("Sec-WebSocket-Accept: "); resp.append(AcceptHash.hashKey(key)).append("\r\n"); if (extensionStack.hasNegotiatedExtensions()) { // Respond to used extensions resp.append("Sec-WebSocket-Extensions: "); boolean delim = false; for (ExtensionConfig ext : extensionStack.getNegotiatedExtensions()) { if (delim) { resp.append(", "); } resp.append(ext.getParameterizedName()); delim = true; } resp.append("\r\n"); } if (extraResponseHeaders.size() > 0) { for (Map.Entry<String, String> xheader : extraResponseHeaders.entrySet()) { resp.append(xheader.getKey()); resp.append(": "); resp.append(xheader.getValue()); resp.append("\r\n"); } } resp.append("\r\n"); // Write Response LOG.debug("Response: {}",resp.toString()); write(resp.toString().getBytes()); return requestLines; } private void write(byte[] bytes) throws IOException { getOutputStream().write(bytes); } public void write(byte[] buf, int offset, int length) throws IOException { getOutputStream().write(buf,offset,length); } /* (non-Javadoc) * @see org.eclipse.jetty.websocket.common.test.IBlockheadServerConnection#write(org.eclipse.jetty.websocket.api.extensions.Frame) */ @Override public void write(Frame frame) throws IOException { LOG.debug("write(Frame->{}) to {}",frame,outgoing); outgoing.outgoingFrame(frame,null,BatchMode.OFF); } public void write(int b) throws IOException { getOutputStream().write(b); } public void write(ByteBuffer buf) throws IOException { byte arr[] = BufferUtil.toArray(buf); if ((arr != null) && (arr.length > 0)) { getOutputStream().write(arr); } } }