/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package gobblin.tunnel; import java.io.IOException; import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import java.util.concurrent.Callable; import static java.nio.channels.SelectionKey.OP_READ; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * This class handles the relaying of data back and forth between the Client-to-Tunnel and Tunnel-to-Proxy * socket connections. This class is not thread safe. */ class ReadWriteHandler implements Callable<HandlerState> { static final Logger LOG = LoggerFactory.getLogger(Tunnel.class); private final SocketChannel proxy; private final SocketChannel client; private final Selector selector; private final ByteBuffer buffer = ByteBuffer.allocate(1000000); private HandlerState state = HandlerState.READING; ReadWriteHandler(SocketChannel proxy, ByteBuffer mixedServerResponseBuffer, SocketChannel client, Selector selector) throws IOException { this.proxy = proxy; this.client = client; this.selector = selector; // drain response that is not part of proxy's 200 OK and is part of data pushed from server, and push to client if (mixedServerResponseBuffer.limit() > mixedServerResponseBuffer.position()) { this.client.configureBlocking(true); OutputStream clientOut = this.client.socket().getOutputStream(); clientOut.write(mixedServerResponseBuffer.array(), mixedServerResponseBuffer.position(), mixedServerResponseBuffer.limit() - mixedServerResponseBuffer.position()); clientOut.flush(); } this.proxy.configureBlocking(false); this.client.configureBlocking(false); this.client.register(this.selector, OP_READ, this); this.proxy.register(this.selector, OP_READ, this); } @Override public HandlerState call() throws Exception { try { switch (this.state) { case READING: read(); break; case WRITING: write(); break; default: throw new IllegalStateException("ReadWriteHandler should never be in state " + this.state); } } catch (CancelledKeyException e) { LOG.warn("Encountered canceled key while " + this.state, e); } catch (IOException ioe) { closeChannels(); throw new IOException(String.format("Could not read/write between %s and %s", this.proxy, this.client), ioe); } catch (Exception e) { LOG.error("Unexpected exception", e); try { closeChannels(); } finally { throw e; } } return this.state; } private void write() throws IOException { SelectionKey proxyKey = this.proxy.keyFor(this.selector); SelectionKey clientKey = this.client.keyFor(this.selector); SocketChannel writeChannel = null; SocketChannel readChannel = null; SelectionKey writeKey = null; if (this.selector.selectedKeys().contains(proxyKey) && proxyKey.isValid() && proxyKey.isWritable()) { writeChannel = this.proxy; readChannel = this.client; writeKey = proxyKey; } else if (this.selector.selectedKeys().contains(clientKey) && clientKey.isValid() && clientKey.isWritable()) { writeChannel = this.client; readChannel = this.proxy; writeKey = clientKey; } if (writeKey != null) { int lastWrite, totalWrite = 0; this.buffer.flip(); int available = this.buffer.remaining(); while ((lastWrite = writeChannel.write(this.buffer)) > 0) { totalWrite += lastWrite; } LOG.debug("{} bytes written to {}", totalWrite, writeChannel == this.proxy ? "proxy" : "client"); if (totalWrite == available) { this.buffer.clear(); if(readChannel.isOpen()) { readChannel.register(this.selector, SelectionKey.OP_READ, this); writeChannel.register(this.selector, SelectionKey.OP_READ, this); } else{ writeChannel.close(); } this.state = HandlerState.READING; } else { this.buffer.compact(); } if (lastWrite == -1) { closeChannels(); } } } private void read() throws IOException { SelectionKey proxyKey = this.proxy.keyFor(this.selector); SelectionKey clientKey = this.client.keyFor(this.selector); SocketChannel readChannel = null; SocketChannel writeChannel = null; SelectionKey readKey = null; if (this.selector.selectedKeys().contains(proxyKey) && proxyKey.isReadable()) { readChannel = this.proxy; writeChannel = this.client; readKey = proxyKey; } else if (this.selector.selectedKeys().contains(clientKey) && clientKey.isReadable()) { readChannel = this.client; writeChannel = this.proxy; readKey = clientKey; } if (readKey != null) { int lastRead, totalRead = 0; while ((lastRead = readChannel.read(this.buffer)) > 0) { totalRead += lastRead; } LOG.debug("{} bytes read from {}", totalRead, readChannel == this.proxy ? "proxy":"client"); if (totalRead > 0) { readKey.cancel(); writeChannel.register(this.selector, SelectionKey.OP_WRITE, this); this.state = HandlerState.WRITING; } if (lastRead == -1) { readChannel.close(); } } } private void closeChannels() { if (this.proxy.isOpen()) { try { this.proxy.close(); } catch (IOException log) { LOG.warn("Failed to close proxy channel {}", this.proxy,log); } } if (this.client.isOpen()) { try { this.client.close(); } catch (IOException log) { LOG.warn("Failed to close client channel {}", this.client,log); } } } }