/*
* Copyright (C) 2015 Actor LLC. <https://actor.im>
*/
package im.actor.runtime.generic.network;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import javax.net.ssl.SSLSocketFactory;
import im.actor.runtime.Log;
import im.actor.runtime.bser.DataInput;
import im.actor.runtime.bser.DataOutput;
import im.actor.runtime.mtproto.AsyncConnection;
import im.actor.runtime.mtproto.AsyncConnectionInterface;
import im.actor.runtime.mtproto.ConnectionEndpoint;
import im.actor.runtime.mtproto.ManagedConnection;
// Disabling Bounds checks for speeding up calculations
/*-[
#define J2OBJC_DISABLE_ARRAY_BOUND_CHECKS 1
]-*/
public class AsyncTcpConnection extends AsyncConnection {
private final ExecutorService connectExecutor = Executors.newSingleThreadExecutor();
private final String TAG;
private Socket socket;
private InputStream inputStream;
private OutputStream outputStream;
private WriterThread writerThread;
private ReaderThread readerThread;
private boolean isConnected = false;
private boolean isClosed = false;
public AsyncTcpConnection(int id, ConnectionEndpoint endpoint, AsyncConnectionInterface connection) {
super(endpoint, connection);
this.TAG = "ConnectionTcp#" + id;
}
@Override
public void doConnect() {
connectExecutor.submit((Runnable) () -> {
try {
ConnectionEndpoint endpoint1 = getEndpoint();
// Trying to connect to known ip first
if (endpoint1.getKnownIp() != null) {
try {
Log.d(TAG, "Trying to connect to " + endpoint1.getHost() + " with Known IP " + endpoint1.getKnownIp());
Socket socket1 = new Socket();
// Configure socket
socket1.setKeepAlive(false);
socket1.setTcpNoDelay(true);
// Connecting
socket1.connect(new InetSocketAddress(endpoint1.getKnownIp(), endpoint1.getPort()), ManagedConnection.CONNECTION_TIMEOUT);
// Converting SSL socket
if (endpoint1.getType() == ConnectionEndpoint.TYPE_TCP_TLS) {
SSLSocketFactory socketFactory = (SSLSocketFactory) SSLSocketFactory.getDefault();
socket1 = socketFactory.createSocket(socket1,
endpoint1.getHost(), endpoint1.getPort(), true);
}
// Init streams
socket1.getInputStream();
socket1.getOutputStream();
Log.d(TAG, "Connection successful");
onSocketCreated(socket1);
return;
} catch (Throwable e) {
e.printStackTrace();
}
}
Log.d(TAG, "Trying to connect to " + endpoint1.getHost());
// Trying to connect with DNS resolving
Socket socket1 = new Socket();
// Configure socket
socket1.setKeepAlive(false);
socket1.setTcpNoDelay(true);
// Connecting
socket1.connect(new InetSocketAddress(endpoint1.getHost(), endpoint1.getPort()), ManagedConnection.CONNECTION_TIMEOUT);
// Converting SSL socket
if (endpoint1.getType() == ConnectionEndpoint.TYPE_TCP_TLS) {
SSLSocketFactory socketFactory = (SSLSocketFactory) SSLSocketFactory.getDefault();
socket1 = socketFactory.createSocket(socket1,
endpoint1.getHost(), endpoint1.getPort(), true);
}
// Init streams
socket1.getInputStream();
socket1.getOutputStream();
onSocketCreated(socket1);
} catch (Throwable e) {
e.printStackTrace();
crashConnection();
}
});
}
@Override
public void doSend(byte[] data) {
writerThread.pushPackage(data);
}
@Override
public void doClose() {
crashConnection();
}
private synchronized void onSocketCreated(Socket socket) throws IOException {
if (isClosed) {
Log.w(TAG, "Socket created after external close: disposing");
throw new IOException("Socket created after external close: disposing");
}
this.socket = socket;
this.inputStream = socket.getInputStream();
this.outputStream = socket.getOutputStream();
this.isClosed = false;
this.isConnected = true;
this.readerThread = new ReaderThread();
this.readerThread.start();
this.writerThread = new WriterThread();
this.writerThread.start();
onConnected();
}
private synchronized void onRawReceived(byte[] data) throws IOException {
if (!isConnected) {
Log.d(TAG, "onRawReceived: Not connected");
return;
}
onReceived(data);
}
private synchronized void crashConnection() {
Log.d(TAG, "Crashing Connection");
if (isClosed) {
return;
}
isClosed = true;
isConnected = false;
if (writerThread != null) {
writerThread.interrupt();
}
if (readerThread != null) {
readerThread.interrupt();
}
writerThread = null;
readerThread = null;
if (socket != null) {
try {
socket.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (outputStream != null) {
try {
outputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
socket = null;
inputStream = null;
outputStream = null;
onClosed();
}
private class WriterThread extends Thread {
private final ConcurrentLinkedQueue<byte[]> packages = new ConcurrentLinkedQueue<>();
public WriterThread() {
setName(TAG + "#Writer");
}
/**
* Send package to connection
*
* @param p package
*/
public void pushPackage(final byte[] p) {
packages.add(p);
synchronized (packages) {
packages.notifyAll();
}
}
@Override
public void run() {
try {
while (isConnected) {
// Pooling of package from queue
byte[] p;
synchronized (packages) {
p = packages.poll();
if (p == null) {
try {
packages.wait();
} catch (final InterruptedException e) {
return;
}
p = packages.poll();
}
}
if (p == null) {
continue;
}
outputStream.write(p);
outputStream.flush();
}
} catch (IOException | NullPointerException e) {
e.printStackTrace();
crashConnection();
}
}
}
private class ReaderThread extends Thread {
private ReaderThread() {
setName(TAG + "#Reader");
}
@Override
public void run() {
try {
while (isConnected) {
// Reading package headers
byte[] header = readBytes(9);
DataInput dataInput = new DataInput(header);
int receivedPackageIndex = dataInput.readInt();
int headerValue = dataInput.readByte();
int size = dataInput.readInt();
if (size > 1024 * 1024) {
throw new IOException("Incorrect size");
}
// Reading package body
byte[] body = readBytes(size + 4);
DataOutput dataOutput = new DataOutput();
dataOutput.writeBytes(header);
dataOutput.writeBytes(body);
onRawReceived(dataOutput.toByteArray());
}
} catch (IOException e) {
e.printStackTrace();
crashConnection();
}
}
private byte[] readBytes(int count) throws IOException {
byte[] res = new byte[count];
int offset = 0;
while (offset < res.length) {
int readed = inputStream.read(res, offset, res.length - offset);
if (readed > 0) {
offset += readed;
} else if (readed < 0) {
throw new IOException();
} else {
Thread.yield();
}
}
return res;
}
}
}