/*
* Copyright 2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.gradle.internal.remote.internal.inet;
import com.google.common.base.Objects;
import org.gradle.internal.UncheckedException;
import org.gradle.internal.concurrent.CompositeStoppable;
import org.gradle.internal.remote.internal.RecoverableMessageIOException;
import org.gradle.internal.serialize.FlushableEncoder;
import org.gradle.internal.serialize.ObjectReader;
import org.gradle.internal.serialize.ObjectWriter;
import org.gradle.internal.serialize.StatefulSerializer;
import org.gradle.internal.remote.internal.MessageIOException;
import org.gradle.internal.remote.internal.MessageSerializer;
import org.gradle.internal.remote.internal.RemoteConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedSelectorException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
public class SocketConnection<T> implements RemoteConnection<T> {
private static final Logger LOGGER = LoggerFactory.getLogger(SocketConnection.class);
private final SocketChannel socket;
private final SocketInetAddress localAddress;
private final SocketInetAddress remoteAddress;
private final ObjectWriter<T> objectWriter;
private final ObjectReader<T> objectReader;
private final InputStream instr;
private final OutputStream outstr;
private final FlushableEncoder encoder;
public SocketConnection(SocketChannel socket, MessageSerializer streamSerializer, StatefulSerializer<T> messageSerializer) {
this.socket = socket;
try {
// NOTE: we use non-blocking IO as there is no reliable way when using blocking IO to shutdown reads while
// keeping writes active. For example, Socket.shutdownInput() does not work on Windows.
socket.configureBlocking(false);
outstr = new SocketOutputStream(socket);
instr = new SocketInputStream(socket);
} catch (IOException e) {
throw UncheckedException.throwAsUncheckedException(e);
}
InetSocketAddress localSocketAddress = (InetSocketAddress) socket.socket().getLocalSocketAddress();
localAddress = new SocketInetAddress(localSocketAddress.getAddress(), localSocketAddress.getPort());
InetSocketAddress remoteSocketAddress = (InetSocketAddress) socket.socket().getRemoteSocketAddress();
remoteAddress = new SocketInetAddress(remoteSocketAddress.getAddress(), remoteSocketAddress.getPort());
objectReader = messageSerializer.newReader(streamSerializer.newDecoder(instr));
encoder = streamSerializer.newEncoder(outstr);
objectWriter = messageSerializer.newWriter(encoder);
}
@Override
public String toString() {
return "socket connection from " + localAddress + " to " + remoteAddress;
}
public T receive() throws MessageIOException {
try {
return objectReader.read();
} catch (EOFException e) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Discarding EOFException: {}", e.toString());
}
return null;
} catch (ObjectStreamException e) {
throw new RecoverableMessageIOException(String.format("Could not read message from '%s'.", remoteAddress), e);
} catch (ClassNotFoundException e) {
throw new RecoverableMessageIOException(String.format("Could not read message from '%s'.", remoteAddress), e);
} catch (IOException e) {
throw new RecoverableMessageIOException(String.format("Could not read message from '%s'.", remoteAddress), e);
} catch (Exception e) {
throw new MessageIOException(String.format("Could not read message from '%s'.", remoteAddress), e);
}
}
private static boolean isEndOfStream(Exception e) {
if (e instanceof EOFException) {
return true;
}
if (e instanceof IOException) {
if (Objects.equal(e.getMessage(), "An existing connection was forcibly closed by the remote host")) {
return true;
}
if (Objects.equal(e.getMessage(), "An established connection was aborted by the software in your host machine")) {
return true;
}
if (Objects.equal(e.getMessage(), "Connection reset by peer")) {
return true;
}
}
return false;
}
public void dispatch(T message) throws MessageIOException {
try {
objectWriter.write(message);
} catch (ObjectStreamException e) {
throw new RecoverableMessageIOException(String.format("Could not write message %s to '%s'.", message, remoteAddress), e);
} catch (ClassNotFoundException e) {
throw new RecoverableMessageIOException(String.format("Could not write message %s to '%s'.", message, remoteAddress), e);
} catch (IOException e) {
throw new RecoverableMessageIOException(String.format("Could not write message %s to '%s'.", message, remoteAddress), e);
} catch (Exception e) {
throw new MessageIOException(String.format("Could not write message %s to '%s'.", message, remoteAddress), e);
}
}
@Override
public void flush() throws MessageIOException {
try {
encoder.flush();
outstr.flush();
} catch (Exception e) {
throw new MessageIOException(String.format("Could not write '%s'.", remoteAddress), e);
}
}
public void stop() {
CompositeStoppable.stoppable(new Closeable() {
@Override
public void close() throws IOException {
flush();
}
}, instr, outstr, socket).stop();
}
private static class SocketInputStream extends InputStream {
private final Selector selector;
private final ByteBuffer buffer;
private final SocketChannel socket;
private final byte[] readBuffer = new byte[1];
public SocketInputStream(SocketChannel socket) throws IOException {
this.socket = socket;
selector = Selector.open();
socket.register(selector, SelectionKey.OP_READ);
buffer = ByteBuffer.allocateDirect(4096);
buffer.limit(0);
}
@Override
public int read() throws IOException {
int nread = read(readBuffer, 0, 1);
if (nread <= 0) {
return nread;
}
return readBuffer[0];
}
@Override
public int read(byte[] dest, int offset, int max) throws IOException {
if (max == 0) {
return 0;
}
if (buffer.remaining() == 0) {
try {
selector.select();
} catch (ClosedSelectorException e) {
return -1;
}
if (!selector.isOpen()) {
return -1;
}
buffer.clear();
int nread;
try {
nread = socket.read(buffer);
} catch (IOException e) {
if (isEndOfStream(e)) {
buffer.position(0);
buffer.limit(0);
return -1;
}
throw e;
}
buffer.flip();
if (nread < 0) {
return -1;
}
}
int count = Math.min(buffer.remaining(), max);
buffer.get(dest, offset, count);
return count;
}
@Override
public void close() throws IOException {
selector.close();
}
}
private static class SocketOutputStream extends OutputStream {
private static final int RETRIES_WHEN_BUFFER_FULL = 2;
private Selector selector;
private final SocketChannel socket;
private final ByteBuffer buffer;
private final byte[] writeBuffer = new byte[1];
public SocketOutputStream(SocketChannel socket) throws IOException {
this.socket = socket;
buffer = ByteBuffer.allocateDirect(32 * 1024);
}
@Override
public void write(int b) throws IOException {
writeBuffer[0] = (byte) b;
write(writeBuffer);
}
@Override
public void write(byte[] src, int offset, int max) throws IOException {
int remaining = max;
int currentPos = offset;
while (remaining > 0) {
int count = Math.min(remaining, buffer.remaining());
if (count > 0) {
buffer.put(src, currentPos, count);
remaining -= count;
currentPos += count;
}
while (buffer.remaining() == 0) {
writeBufferToChannel();
}
}
}
@Override
public void flush() throws IOException {
while (buffer.position() > 0) {
writeBufferToChannel();
}
}
private void writeBufferToChannel() throws IOException {
buffer.flip();
int count = writeWithNonBlockingRetry();
if (count == 0) {
// buffer was still full after non-blocking retries, now block
waitForWriteBufferToDrain();
}
buffer.compact();
}
private int writeWithNonBlockingRetry() throws IOException {
int count = 0;
int retryCount = 0;
while (count == 0 && retryCount++ < RETRIES_WHEN_BUFFER_FULL) {
count = socket.write(buffer);
if (count < 0) {
throw new EOFException();
} else if (count == 0) {
// buffer was full, just call Thread.yield
Thread.yield();
}
}
return count;
}
private void waitForWriteBufferToDrain() throws IOException {
if (selector == null) {
selector = Selector.open();
}
SelectionKey key = socket.register(selector, SelectionKey.OP_WRITE);
// block until ready for write operations
selector.select();
// cancel OP_WRITE selection
key.cancel();
// complete cancelling key
selector.selectNow();
}
@Override
public void close() throws IOException {
if (selector != null) {
selector.close();
selector = null;
}
}
}
}