package com.github.davidmoten.rx.internal.operators;
import java.io.IOException;
import java.io.InputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import com.github.davidmoten.rx.Actions;
import com.github.davidmoten.rx.Bytes;
import com.github.davidmoten.rx.Checked;
import com.github.davidmoten.rx.Checked.F0;
import com.github.davidmoten.rx.Functions;
import rx.Observable;
import rx.Observer;
import rx.functions.Action0;
import rx.functions.Action1;
import rx.functions.Action2;
import rx.functions.Func0;
import rx.functions.Func1;
import rx.observables.SyncOnSubscribe;
public final class ObservableServerSocket {
private ObservableServerSocket() {
// prevent instantiation
}
public static Observable<Observable<byte[]>> create(
final Func0<? extends ServerSocket> serverSocketFactory, final int timeoutMs,
final int bufferSize, Action0 preAcceptAction, int acceptTimeoutMs,
Func1<? super Socket, Boolean> acceptSocket) {
Func1<ServerSocket, Observable<Observable<byte[]>>> observableFactory = createObservableFactory(
timeoutMs, bufferSize, preAcceptAction, acceptSocket);
return Observable.<Observable<byte[]>, ServerSocket> using( //
createServerSocketFactory(serverSocketFactory, acceptTimeoutMs), //
observableFactory, //
new Action1<ServerSocket>() {
@Override
public void call(ServerSocket ss) {
try {
ss.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}, true);
}
private static Func0<ServerSocket> createServerSocketFactory(
final Func0<? extends ServerSocket> serverSocketFactory, final int acceptTimeoutMs) {
return Checked.f0(new F0<ServerSocket>() {
@Override
public ServerSocket call() throws Exception {
return createServerSocket(serverSocketFactory, acceptTimeoutMs);
}
});
}
private static ServerSocket createServerSocket(
Func0<? extends ServerSocket> serverSocketCreator, long timeoutMs) throws IOException {
ServerSocket s = serverSocketCreator.call();
s.setSoTimeout((int) timeoutMs);
return s;
}
private static Func1<ServerSocket, Observable<Observable<byte[]>>> createObservableFactory(
final int timeoutMs, final int bufferSize, final Action0 preAcceptAction,
final Func1<? super Socket, Boolean> acceptSocket) {
return new Func1<ServerSocket, Observable<Observable<byte[]>>>() {
@Override
public Observable<Observable<byte[]>> call(ServerSocket serverSocket) {
return createServerSocketObservable(serverSocket, timeoutMs, bufferSize,
preAcceptAction, acceptSocket);
}
};
}
private static Observable<Observable<byte[]>> createServerSocketObservable(
ServerSocket serverSocket, final long timeoutMs, final int bufferSize,
final Action0 preAcceptAction, final Func1<? super Socket, Boolean> acceptSocket) {
return Observable.create( //
SyncOnSubscribe.<ServerSocket, Observable<byte[]>> createSingleState( //
Functions.constant0(serverSocket), //
new Action2<ServerSocket, Observer<? super Observable<byte[]>>>() {
@Override
public void call(ServerSocket ss,
Observer<? super Observable<byte[]>> observer) {
acceptConnection(timeoutMs, bufferSize, ss, observer,
preAcceptAction, acceptSocket);
}
}));
}
private static void acceptConnection(long timeoutMs, int bufferSize, ServerSocket ss,
Observer<? super Observable<byte[]>> observer, Action0 preAcceptAction,
Func1<? super Socket, Boolean> acceptSocket) {
Socket socket;
while (true) {
try {
preAcceptAction.call();
socket = ss.accept();
if (!acceptSocket.call(socket)) {
closeQuietly(socket);
} else {
observer.onNext(createSocketObservable(socket, timeoutMs, bufferSize));
break;
}
} catch (SocketTimeoutException e) {
// timed out so will loop around again
} catch (IOException e) {
// unknown problem
observer.onError(e);
break;
}
}
}
private static void closeQuietly(Socket socket) {
try {
socket.close();
} catch (IOException e) {
// ignore exception
}
}
private static Observable<byte[]> createSocketObservable(final Socket socket, long timeoutMs,
final int bufferSize) {
setTimeout(socket, timeoutMs);
return Observable.using( //
Checked.f0(new F0<InputStream>() {
@Override
public InputStream call() throws Exception {
return socket.getInputStream();
}
}), //
new Func1<InputStream, Observable<byte[]>>() {
@Override
public Observable<byte[]> call(InputStream is) {
return Bytes.from(is, bufferSize);
}
}, //
Actions.close(), //
true);
}
private static void setTimeout(Socket socket, long timeoutMs) {
try {
socket.setSoTimeout((int) timeoutMs);
} catch (SocketException e) {
throw new RuntimeException(e);
}
}
}