package edu.washington.cs.publickey.ssl.server; import java.io.BufferedWriter; import java.io.DataInputStream; import java.io.File; import java.io.IOException; import java.io.OutputStreamWriter; import java.security.KeyManagementException; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.UnrecoverableKeyException; import java.security.cert.Certificate; import java.security.cert.CertificateException; import java.util.Arrays; import java.util.Date; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.zip.GZIPOutputStream; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLServerSocket; import javax.net.ssl.SSLSocket; import edu.washington.cs.publickey.PublicKeyFriend; import edu.washington.cs.publickey.Tools; import edu.washington.cs.publickey.storage.PersistentStorage; public class PublicKeySSLServer { static final String KEY_SSL_SERVER_KEYSTORE = "ssl_server_keystore"; final static String KEY_SSL_PORT = "ssl_server_port"; private static final int NUM_THREADS = 500; private static final int MAX_DB_QUEUE = 20; /* * to protect against dos, only allow 2 connection attempts/second from the * same ip and max 50 concurrent connections */ private static final long MIN_MS_BETWEEN_CONNECT_ATTEMPTS_PER_IP = 500; private static final Integer MAX_CONNECTION_PER_IP = 10; private final PersistentStorage storage; private final int serverPort; private volatile boolean quit = false; private final ExecutorService threadPool; private SSLServerSocket serverSocket; private HashMap<String, Integer> activeConnections = new HashMap<String, Integer>(); private HashMap<String, Long> lastConnectAttempt = new HashMap<String, Long>(); private volatile int queueLength = 0; public PublicKeySSLServer(Properties props, PersistentStorage storage, char[] keystorePassword) throws IOException, KeyManagementException, NoSuchAlgorithmException, KeyStoreException, CertificateException, UnrecoverableKeyException, InterruptedException { this.serverPort = Integer.parseInt((String) props.get(KEY_SSL_PORT)); this.storage = storage; this.threadPool = Executors.newFixedThreadPool(NUM_THREADS); File keyStoreFile = new File(props.getProperty(KEY_SSL_SERVER_KEYSTORE)); SSLKeyManager sslManager = new SSLKeyManager(keyStoreFile, keystorePassword); serverSocket = sslManager.createServerSocket(serverPort); serverSocket.setNeedClientAuth(true); Thread t = new Thread(new Runnable() { public void run() { System.out.println("SSL server: listening on port " + serverPort); while (!quit) { try { SSLSocket csocket = (SSLSocket) serverSocket.accept(); String remoteIp = csocket.getInetAddress().getHostAddress(); if (isConnectionAllowed(remoteIp)) { queueLength++; // System.out.println("connection from: " + remoteIp // + " queue=" + queueLength); initiatingConnection(remoteIp); PublicKeySSLServerProtocol publicKeySSLServerProtocol = new PublicKeySSLServerProtocol(csocket); threadPool.execute(publicKeySSLServerProtocol); } } catch (IOException e) { if (e instanceof java.net.SocketException && e.getMessage().equals("Socket closed")) { System.out.println("SSL Server: closing socket"); } else { e.printStackTrace(); } } } } }); t.setName("SSL Server accept thread"); t.start(); } private void initiatingConnection(String remoteIp) { synchronized (lastConnectAttempt) { lastConnectAttempt.put(remoteIp, System.currentTimeMillis()); Integer active = activeConnections.get(remoteIp); if (active == null) { activeConnections.put(remoteIp, 1); } else { activeConnections.put(remoteIp, active + 1); } } } private void closingConnection(String remoteIp) { synchronized (lastConnectAttempt) { Integer active = activeConnections.get(remoteIp); if (active <= 1) { activeConnections.remove(remoteIp); } else { activeConnections.put(remoteIp, active - 1); } } } private boolean isConnectionAllowed(String remoteIp) { synchronized (lastConnectAttempt) { Long lastAttempt = lastConnectAttempt.get(remoteIp); if (lastAttempt != null) { long timeSince = System.currentTimeMillis() - lastAttempt; if (timeSince < MIN_MS_BETWEEN_CONNECT_ATTEMPTS_PER_IP) { System.err.println(new Date() + ": connection from '" + remoteIp + "' denied, " + " to high connect frequency (" + timeSince + "ms<" + MIN_MS_BETWEEN_CONNECT_ATTEMPTS_PER_IP + "ms)"); return false; } } Integer numActiveConnections = activeConnections.get(remoteIp); if (numActiveConnections != null && numActiveConnections > MAX_CONNECTION_PER_IP) { System.err.println(new Date() + ": connection from '" + remoteIp + "' denied, " + " to many connections (" + numActiveConnections + ">" + MAX_CONNECTION_PER_IP + ")"); return false; } return true; } } class PublicKeySSLServerProtocol implements Runnable { private final SSLSocket socket; private final String remoteIp; long timeInclNet; public PublicKeySSLServerProtocol(SSLSocket socket) { this.socket = socket; this.remoteIp = socket.getInetAddress().getHostAddress(); timeInclNet = System.currentTimeMillis(); } public void run() { try { int dbQueueLength = storage.getDbQueueLength(); if (dbQueueLength > MAX_DB_QUEUE) { long t = System.currentTimeMillis(); DataInputStream in = new DataInputStream(socket.getInputStream()); BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new GZIPOutputStream(socket.getOutputStream()))); // read to not confuse the client readIgnoreList(in); out.write(PublicKeyFriend.serialize(new PublicKeyFriend[0])); out.close(); in.close(); // log("dropping incoming connection, db_queue: " + // dbQueueLength + " conn_queue: " + queueLength + " took: " // + (System.currentTimeMillis() - t) + "ms"); } else { Certificate[] remoteCerts; remoteCerts = socket.getSession().getPeerCertificates(); if (remoteCerts.length > 0) { byte[] remoteKey = remoteCerts[0].getPublicKey().getEncoded(); // log("accepting incoming connection, db_queue: " + // dbQueueLength + " conn_queue: " + queueLength); DataInputStream in = new DataInputStream(socket.getInputStream()); BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new GZIPOutputStream(socket.getOutputStream()))); HashSet<Integer> keysToIgnore = readIgnoreList(in); long timeExclNet = System.currentTimeMillis(); // ok, now we know what to ignore // lets try to find new friends PublicKeyFriend f = new PublicKeyFriend(); f.setPublicKey(remoteKey); f.setPublicKeySha1(Tools.getSha1(remoteKey)); long dbStartTime = System.currentTimeMillis(); List<PublicKeyFriend> allFriends = storage.getFriendsUsingPublicKey(f); long preLastSeenUpdate = System.currentTimeMillis(); storage.updateUserLastSeen(f); long dbTime = System.currentTimeMillis() - dbStartTime; long lastSeenOverhead = System.currentTimeMillis() - preLastSeenUpdate; Map<Integer, PublicKeyFriend> newFriends = new HashMap<Integer, PublicKeyFriend>(); // add them all for (PublicKeyFriend friend : allFriends) { // ignore friends the user already knows of int friendHash = Arrays.hashCode(friend.getPublicKeySha1()); if (!keysToIgnore.contains(friendHash)) { // check if we already added this one if (!newFriends.containsKey(friendHash)) { newFriends.put(friendHash, friend); } } } List<PublicKeyFriend> friendsArray = new LinkedList<PublicKeyFriend>(); friendsArray.addAll(newFriends.values()); /* * for debugging, just return an empty list */ String serialized = PublicKeyFriend.serialize(new PublicKeyFriend[0]); // String serialized = // PublicKeyFriend.serialize(friendsArray.toArray(new // PublicKeyFriend[newFriends.size()])); timeExclNet = System.currentTimeMillis() - timeExclNet; out.write(serialized); out.close(); in.close(); timeInclNet = System.currentTimeMillis() - timeInclNet; log("done, returned: " + friendsArray.size() + " ignored: " + keysToIgnore.size() + " time: (queries: " + timeExclNet + " lastSeen: " + lastSeenOverhead + " total: " + timeInclNet + " db=" + dbTime + ")"); } } } catch (java.io.EOFException e) { System.err.println(remoteIp + ": " + "EOF to early"); } catch (SSLPeerUnverifiedException e) { System.err.println(remoteIp + ": no cert, closing conn"); } catch (java.net.SocketException e) { if ("Connection timed out".equals(e.getMessage())) { System.err.println(remoteIp + ": other side timed out"); } else if ("Connection reset".equals(e.getMessage())) { System.err.println(remoteIp + ": other side closed socket"); } else if ("Broken pipe".equals(e.getMessage())) { System.err.println(remoteIp + ": broken pipe"); } else { e.printStackTrace(); } } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } finally { if (socket != null) { try { socket.close(); } catch (IOException e1) { } } closingConnection(remoteIp); queueLength--; } } private HashSet<Integer> readIgnoreList(DataInputStream in) throws IOException { int numToIgnore = in.readInt(); if (numToIgnore > 100000) { System.err.println("warning: user specified more that 100000 friends (!!!???), closing conn"); socket.close(); throw new IOException("user specified invalid data"); } HashSet<Integer> keysToIgnore = new HashSet<Integer>(); byte[] pubKeySha = new byte[20]; for (int i = 0; i < numToIgnore; i++) { // read 20 bytes in.readFully(pubKeySha); int hash = Arrays.hashCode(pubKeySha); keysToIgnore.add(hash); } return keysToIgnore; } private void log(String msg) { System.out.println(remoteIp + ": " + msg); } } public void shutdown() { quit = true; try { serverSocket.close(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } threadPool.shutdown(); storage.shutdown(); } }