package org.mariadb.jdbc.failover;
import org.mariadb.jdbc.internal.logging.Logger;
import org.mariadb.jdbc.internal.logging.LoggerFactory;
import java.io.*;
import java.net.BindException;
import java.net.ServerSocket;
import java.net.Socket;
public class TcpProxySocket implements Runnable {
private static Logger logger = LoggerFactory.getLogger(TcpProxy.class);
String host;
int remoteport;
int localport;
boolean stop = false;
Socket client = null;
Socket server = null;
ServerSocket ss;
/**
* Creation of proxy.
*
* @param host database host
* @param remoteport database port
* @throws IOException exception
*/
public TcpProxySocket(String host, int remoteport) throws IOException {
this.host = host;
this.remoteport = remoteport;
ss = new ServerSocket(0);
this.localport = ss.getLocalPort();
}
public int getLocalPort() {
return ss.getLocalPort();
}
public boolean isClosed() {
return ss.isClosed();
}
/**
* Kill proxy.
*/
public void kill() {
stop = true;
try {
if (server != null) {
server.close();
}
} catch (IOException e) {
//eat Exception
}
try {
if (client != null) {
client.close();
}
} catch (IOException e) {
//eat Exception
}
try {
ss.close();
} catch (IOException e) {
//eat Exception
}
}
@Override
public void run() {
logger.trace("host proxy port " + this.localport + " for " + host + " started");
stop = false;
try {
try {
if (ss.isClosed()) {
ss = new ServerSocket(localport);
}
} catch (BindException b) {
//in case for testing crash and reopen too quickly
try {
Thread.sleep(100);
} catch (InterruptedException i) {
//eat Exception
}
if (ss.isClosed()) {
ss = new ServerSocket(localport);
}
}
final byte[] request = new byte[1024];
byte[] reply = new byte[4096];
while (!stop) {
try {
client = ss.accept();
final InputStream fromClient = client.getInputStream();
final OutputStream toClient = client.getOutputStream();
try {
server = new Socket(host, remoteport);
} catch (IOException e) {
PrintWriter out = new PrintWriter(new OutputStreamWriter(toClient));
out.println("Proxy server cannot connect to " + host + ":"
+ remoteport + ":\n" + e);
out.flush();
client.close();
continue;
}
final InputStream fromServer = server.getInputStream();
final OutputStream toServer = server.getOutputStream();
new Thread() {
public void run() {
int bytesRead;
try {
while ((bytesRead = fromClient.read(request)) != -1) {
toServer.write(request, 0, bytesRead);
toServer.flush();
}
} catch (IOException e) {
//eat exception
}
try {
toServer.close();
} catch (IOException e) {
//eat exception
}
}
}.start();
int bytesRead;
try {
while ((bytesRead = fromServer.read(reply)) != -1) {
try {
Thread.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
toClient.write(reply, 0, bytesRead);
toClient.flush();
}
} catch (IOException e) {
//eat exception
}
toClient.close();
} catch (IOException e) {
//System.err.println("ERROR socket : "+e);
} finally {
try {
if (server != null) {
server.close();
}
if (client != null) {
client.close();
}
} catch (IOException e) {
//eat exception
}
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
}