package io.nucleo.net; import io.nucleo.net.proto.ControlMessage; import io.nucleo.net.proto.HELOMessage; import io.nucleo.net.proto.IDMessage; import io.nucleo.net.proto.Message; import io.nucleo.net.proto.exceptions.ConnectionException; import io.nucleo.net.proto.exceptions.ProtocolViolationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.EOFException; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; import java.net.UnknownHostException; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Pattern; public class Node { /** * Use this whenever to flush the socket header over the socket! * * @param socket the socket to construct an objectOutputStream from * @return the outputstream from the socket * @throws IOException in case something goes wrong, duh! */ static ObjectOutputStream prepareOOSForSocket(Socket socket) throws IOException { ObjectOutputStream out = new ObjectOutputStream(socket.getOutputStream()); out.flush(); return out; } private static final Logger log = LoggerFactory.getLogger(Node.class); private final ServiceDescriptor descriptor; private final HashMap<String, Connection> connections; @SuppressWarnings("rawtypes") private final TorNode tor; private final AtomicBoolean serverRunning; public Node(TCPServiceDescriptor descriptor) { this(null, descriptor); } public Node(HiddenServiceDescriptor descriptor, TorNode<?, ?> tor) { this(tor, descriptor); } private Node(TorNode<?, ?> tor, ServiceDescriptor descriptor) { this.connections = new HashMap<>(); this.descriptor = descriptor; this.tor = tor; this.serverRunning = new AtomicBoolean(false); } public String getLocalName() { return descriptor.getFullAddress(); } public Connection connect(String peer, Collection<ConnectionListener> listeners) throws NumberFormatException, IOException { if (!serverRunning.get()) { throw new IOException("This node has not been started yet!"); } if (peer.equals(descriptor.getFullAddress())) throw new IOException("If you find yourself talking to yourself too often, you should really seek help!"); synchronized (connections) { if (connections.containsKey(peer)) throw new IOException("Already connected to " + peer); } final Socket sock = connectToService(peer); return new OutgoingConnection(peer, sock, listeners); } private Socket connectToService(String hostname, int port) throws IOException, UnknownHostException, SocketException { final Socket sock; if (tor != null) sock = tor.connectToHiddenService(hostname, port); else sock = new Socket(hostname, port); sock.setSoTimeout(60000); return sock; } private Socket connectToService(String peer) throws IOException, UnknownHostException, SocketException { final String[] split = peer.split(Pattern.quote(":")); return connectToService(split[0], Integer.parseInt(split[1])); } public synchronized Server startListening(ServerConnectListener listener) throws IOException { if (serverRunning.getAndSet(true)) throw new IOException("This node is already listening!"); final Server server = new Server(descriptor.getServerSocket(), listener); server.start(); return server; } public Connection getConnection(String peerAddress) { synchronized (connections) { return connections.get(peerAddress); } } public Set<Connection> getConnections() { synchronized (connections) { return new HashSet<Connection>(connections.values()); } } public class Server extends Thread { private boolean running; private final ServerSocket serverSocket; private final ExecutorService executorService; private final ServerConnectListener serverConnectListener; private Server(ServerSocket serverSocket, ServerConnectListener listener) { super("Server"); this.serverSocket = descriptor.getServerSocket(); this.serverConnectListener = listener; running = true; executorService = Executors.newCachedThreadPool(); } public void shutdown() throws IOException { running = false; synchronized (connections) { final Set<Connection> conns = new HashSet<Connection>(connections.values()); for (Connection con : conns) { con.close(); } } serverSocket.close(); try { executorService.awaitTermination(2, TimeUnit.SECONDS); } catch (InterruptedException e) { e.printStackTrace(); } Node.this.serverRunning.set(false); log.debug("Server successfully shutdown"); } @Override public void run() { try { while (running) { final Socket socket = serverSocket.accept(); log.debug("Accepting Client on port " + socket.getLocalPort()); executorService.submit(new Acceptor(socket)); } } catch (IOException e) { if (running) e.printStackTrace(); } } private boolean verifyIdentity(HELOMessage helo, ObjectInputStream in) throws IOException { log.debug("Verifying HELO msg"); final Socket sock = connectToService(helo.getHostname(), helo.getPort()); log.debug("Connected to advertised client " + helo.getPeer()); ObjectOutputStream out = prepareOOSForSocket(sock); final IDMessage challenge = new IDMessage(descriptor); out.writeObject(challenge); log.debug("Sent IDMessage to"); out.flush(); // wait for other side to close try { while (sock.getInputStream().read() != -1) ; } catch (IOException e) { // no matter } out.close(); sock.close(); log.debug("Closed socket after sending IDMessage"); try { log.debug("Waiting for response of challenge"); IDMessage response = (IDMessage) in.readObject(); log.debug("Got response for challenge"); final boolean verified = challenge.verify(response); log.debug("Response verified correctly!"); return verified; } catch (ClassNotFoundException e) { new ProtocolViolationException(e).printStackTrace(); } return false; } private class Acceptor implements Runnable { private final Socket socket; private Acceptor(Socket socket) { this.socket = socket; } @Override public void run() { { try { socket.setSoTimeout((int) TimeUnit.SECONDS.toMillis(60)); } catch (SocketException e2) { e2.printStackTrace(); try { socket.close(); } catch (IOException e) { } return; } ObjectInputStream objectInputStream = null; ObjectOutputStream out = null; // get incoming data try { out = prepareOOSForSocket(socket); // LookAheadObjectInputStream not needed here as the class it not used in Bitsquare (used to test the library) objectInputStream = new ObjectInputStream(socket.getInputStream()); } catch (EOFException e) { log.debug("Got bogus incoming connection"); } catch (IOException e) { e.printStackTrace(); try { socket.close(); } catch (IOException e1) { } return; } String peer = null; try { log.debug("Waiting for HELO or Identification"); final Message helo = (Message) objectInputStream.readObject(); if (helo instanceof HELOMessage) { peer = ((HELOMessage) helo).getPeer(); log.debug("Got HELO from " + peer); boolean alreadyConnected; synchronized (connections) { alreadyConnected = connections.containsKey(peer); } if (alreadyConnected || !verifyIdentity((HELOMessage) helo, objectInputStream)) { log.debug(alreadyConnected ? ("already connected to " + peer) : "verification failed"); out.writeObject(alreadyConnected ? ControlMessage.ALREADY_CONNECTED : ControlMessage.HANDSHAKE_FAILED); out.writeObject(ControlMessage.DISCONNECT); out.flush(); out.close(); objectInputStream.close(); socket.close(); return; } log.debug("Verification of " + peer + " successful"); } else if (helo instanceof IDMessage) { peer = ((IDMessage) helo).getPeer(); log.debug("got IDMessage from " + peer); final Connection client = connections.get(peer); if (client != null) { log.debug("Got preexisting connection for " + peer); client.sendMsg(((IDMessage) helo).reply()); log.debug("Sent response for challenge"); } else { log.debug("Got IDMessage for unknown connection to " + peer); } out.flush(); out.close(); objectInputStream.close(); socket.close(); log.debug("Closed socket for identification"); return; } else throw new ClassNotFoundException("First Message was neither HELO, nor ID"); } catch (ClassNotFoundException e) { new ProtocolViolationException(e); } catch (IOException e) { try { objectInputStream.close(); out.close(); socket.close(); } catch (IOException e1) { } return; } // Here we go log.debug("Incoming Connection ready!"); try { // TODO: listeners are only added afterwards, so messages can be lost! IncomingConnection incomingConnection = new IncomingConnection(peer, socket, out, objectInputStream); serverConnectListener.onConnect(incomingConnection); } catch (IOException e) { e.printStackTrace(); } } } } } private class IncomingConnection extends Connection { private IncomingConnection(String peer, Socket socket, ObjectOutputStream out, ObjectInputStream in) throws IOException { super(peer, socket, out, in); synchronized (connections) { connections.put(peer, this); } sendMsg(ControlMessage.AVAILABLE); } @Override public void listen() throws ConnectionException { super.listen(); onReady(); } @Override protected void onMessage(Message msg) throws IOException { if ((msg instanceof ControlMessage) && (ControlMessage.HEARTBEAT == msg)) { log.debug("RX+REPLY HEARTBEAT"); try { sendMsg(ControlMessage.HEARTBEAT); } catch (IOException e) { onError(e); } } else super.onMessage(msg); } @Override public void onDisconnect() { synchronized (connections) { connections.remove(getPeer()); } } @Override public boolean isIncoming() { return true; } } private class OutgoingConnection extends Connection { private OutgoingConnection(String peer, Socket socket, Collection<ConnectionListener> listeners) throws IOException { super(peer, socket); synchronized (connections) { connections.put(peer, this); } setConnectionListeners(listeners); try { listen(); } catch (ConnectionException e) { // Never happens } log.debug("Sending HELO"); sendMsg(new HELOMessage(descriptor)); log.debug("Sent HELO"); } @Override public void onDisconnect() { synchronized (connections) { connections.remove(getPeer()); } } @Override public boolean isIncoming() { return false; } } }