/* * Copyrigth (C) 2010 Henrik Baastrup. * * Licensed under the GNU Lesser General Public License version 3; * you may not use this file except in compliance with the License. * You should have received a copy of the license together with this * file but can obtain a copy of the License at: * * http://www.gnu.org/licenses/lgpl-3.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package javax.net.stun.services; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.lang.Thread.UncaughtExceptionHandler; import java.net.InetAddress; import java.net.Socket; import java.net.SocketTimeoutException; import java.security.GeneralSecurityException; import java.security.KeyStore; import java.util.ArrayList; import java.util.List; import java.util.logging.Level; import java.util.logging.Logger; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLServerSocket; import javax.net.ssl.SSLServerSocketFactory; import javax.net.stun.MessageAttribute; import javax.net.stun.MessageHeader; import javax.net.stun.Utils; /** * * @author Henrik Baastrup */ public class SharedSecretService implements Runnable,UncaughtExceptionHandler { private boolean running = false; private SSLServerSocket serverSocket = null; private Thread thread = null; private InetAddress address = null; private int port = 3478; private File keyStoreFile = null; private char keyStorePassword[] = null; private char keyPassword[] = null; private ArrayList<UserHolder> users = new ArrayList<UserHolder>(); private boolean debug = false; public SharedSecretService() { } public SharedSecretService(final int port) { if (port!=0) this.port = port; } public SharedSecretService(final InetAddress localAddress, final int port) { this(port); this.address = localAddress; } @Override protected void finalize() throws Throwable { stopThread(); super.finalize(); } public List<UserHolder> getUsers() { synchronized (users) { return new ArrayList<UserHolder>(users); } } public InetAddress getAddress() {return address;} public int getPort() {return port;} /** * Set the keystore to use for TLS. A call to this method will override * the javax.net.ssl.trustStore, javax.net.ssl.trustStoreType and * javax.net.ssl.trustStorePassword System properties, * javax.net.ssl.trustStore property is set to the abeolute path of * the file passed, the javax.net.ssl.keyStoreType is set to JKS and * the javax.net.ssl.trustStorePassword is set to the given password.<br> * Use this method if you use a private keystore conting the certificate for * the TLS sessions. Default keystores are:<br> * {java.home}/lib/security/jssecacerts.<br> * [java.home]/lib/security/cacerts<br> * * @param arg0 Filepath to keystore file * @param arg1 password for keystore * @param arg2 password for key */ public void setKeyStore(File arg0, char arg1[], char arg2[]) { keyStoreFile = arg0; keyStorePassword = new char[arg1.length]; for (int i=0; i<arg1.length; i++) keyStorePassword[i] = arg1[i]; keyPassword = new char[arg2.length]; for (int i=0; i<arg2.length; i++) keyPassword[i] = arg2[i]; } /** * * @return true if the local thread is running and the service is listin to a socket. */ public boolean isRunning() {return running;} public void startThread() { if (running) return; thread = new Thread(this, "Shared Secret Service Thread"); thread.setUncaughtExceptionHandler(this); thread.start(); } public void stopThread() { running = false; if (serverSocket!=null) { synchronized (serverSocket) { serverSocket.notifyAll(); } } thread = null; } public void run() { if (debug) { Logger.getLogger(SharedSecretService.class.getName()).log(Level.INFO, "Service thread start with following parameters:"); Logger.getLogger(SharedSecretService.class.getName()).log(Level.INFO, " Listein on port: "+port); if (keyStoreFile!=null) { Logger.getLogger(SharedSecretService.class.getName()).log(Level.INFO, " Using the key-store file: "+keyStoreFile.getAbsolutePath()); } else { Logger.getLogger(SharedSecretService.class.getName()).log(Level.INFO, " With no key-store file!"); } } running = true; SSLContext sslContext = null; FileInputStream keyStoreIn = null; try { if (address==null) address = Utils.getLocalAddr(); if (keyStoreFile!=null) { KeyStore keyStore = KeyStore.getInstance("JKS"); keyStoreIn = new FileInputStream(keyStoreFile); keyStore.load(keyStoreIn, keyStorePassword); keyStoreIn.close(); keyStoreIn = null; KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); kmf.init(keyStore, keyPassword); sslContext = SSLContext.getInstance("SSLv3"); sslContext.init(kmf.getKeyManagers(), null, null); } SSLServerSocketFactory sslFactory; if (sslContext != null) { sslFactory = sslContext.getServerSocketFactory(); } else { sslFactory = (SSLServerSocketFactory)SSLServerSocketFactory.getDefault(); } serverSocket = (SSLServerSocket)sslFactory.createServerSocket(port, 10, address); serverSocket.setSoTimeout(30000); // String cipherSuites[] = {"TLS_RSA_WITH_AES_128_CBC_SHA","TLS_DHE_RSA_WITH_AES_128_CBC_SHA","TLS_DHE_DSS_WITH_AES_128_CBC_SHA"}; //Only TLS Cipher Suites // serverSocket.setEnabledCipherSuites(cipherSuites); while (running) { try { Socket sock; try { sock = serverSocket.accept(); if (!running) break; read(sock); } catch (SocketTimeoutException ex) { cleanUpUsers(); continue; } } catch (RuntimeException ex) { Logger.getLogger(SharedSecretService.class.getName()).log(Level.SEVERE, null, ex); } } } catch (GeneralSecurityException ex) { Logger.getLogger(SharedSecretService.class.getName()).log(Level.SEVERE, null, ex); } catch (IOException ex) { Logger.getLogger(SharedSecretService.class.getName()).log(Level.SEVERE, null, ex); } finally { if (keyStoreIn!=null) try{keyStoreIn.close();}catch(IOException ignore){} if (serverSocket!=null) try{serverSocket.close();}catch(IOException ignore){} } running = false; } public void setDebug(boolean arg0) {debug = arg0;} public int controllMessageIntegrity(MessageHeader receivedHeader) { MessageAttribute username = receivedHeader.getMessageAttribute(MessageAttribute.MessageAttributeType.USERNAME); MessageAttribute messageIntegrity = receivedHeader.getMessageAttribute(MessageAttribute.MessageAttributeType.MESSAGE_INTEGRITY); byte password[] = getPassword(receivedHeader); if (username==null) { return 432; } else if (messageIntegrity==null) { return 401; } else if (password==null) { return 430; } int errorInt = receivedHeader.integrityCheck(password); return errorInt; } public byte[] getPassword(MessageHeader header) { MessageAttribute username = header.getMessageAttribute(MessageAttribute.MessageAttributeType.USERNAME); if (username==null) return null; UserHolder userHolder = null; List<UserHolder> userList = getUsers(); for (UserHolder uh: userList) { if (uh.username.equals(username.getUsername())) { userHolder = uh; break; } } if (userHolder==null) return null; return userHolder.password; } private void read(Socket sock) { if (debug) { InetAddress clientAddr = sock.getInetAddress(); Logger.getLogger(SharedSecretService.class.getName()).log(Level.INFO, "Recived a connect from: "+clientAddr); } InputStream in = null; OutputStream out = null; try { in = sock.getInputStream(); out = sock.getOutputStream(); byte head[] = new byte[20]; int bytesRead = 0; while (bytesRead<20) { int r = in.read(head, 0, 20); if (r < 0) return; bytesRead += r; } int length = (0x000000FF & ((int)head[2]))<<8; length += (0x000000FF & ((int)head[3])); byte buffer[] = new byte[length]; int read = 0; while (read<length) { int r = in.read(buffer, read, length); read += r; } byte headBuffer[] = new byte[length+20]; for (int i=0; i<20; i++) headBuffer[i] = head[i]; for (int i=0; i<length; i++) headBuffer[i+20] = buffer[i]; MessageHeader recHeader = MessageHeader.create(headBuffer); MessageHeader retHeader; if (recHeader.getType()!=MessageHeader.HeaderType.SHARED_SECRET_REQUEST && recHeader.getType()!=MessageHeader.HeaderType.SHARED_SECRET_VERIFY_REQUEST) { retHeader = new MessageHeader(MessageHeader.HeaderType.SHARED_SECRET_ERROR_RESPONSE); retHeader.setTransactionId(recHeader.getTransactionId()); MessageAttribute errorCode = MessageAttribute.create(MessageAttribute.MessageAttributeType.ERROR_CODE, Utils.createErrorString(400), 400); retHeader.addMessageAttribute(errorCode); out.write(retHeader.toBytes()); return; } UserHolder userHolder = UserHolder.create(); synchronized (users) { users.add(userHolder); } if (recHeader.getType()==MessageHeader.HeaderType.SHARED_SECRET_VERIFY_REQUEST) { //This is a message integrity veryfy request! int errCod = controllMessageIntegrity(recHeader); if (errCod!=0) { retHeader = new MessageHeader(MessageHeader.HeaderType.SHARED_SECRET_ERROR_RESPONSE); MessageAttribute errorCode = MessageAttribute.create(MessageAttribute.MessageAttributeType.ERROR_CODE, Utils.createErrorString(errCod), errCod); retHeader.addMessageAttribute(errorCode); } else { byte passwd[] = getPassword(recHeader); if (passwd==null) { retHeader = new MessageHeader(MessageHeader.HeaderType.SHARED_SECRET_ERROR_RESPONSE); MessageAttribute errorCode = MessageAttribute.create(MessageAttribute.MessageAttributeType.ERROR_CODE, Utils.createErrorString(430), 430); retHeader.addMessageAttribute(errorCode); } else { //If all is OK we response with a password attribute so the requesting server //can construct a Integrity message retHeader = new MessageHeader(MessageHeader.HeaderType.SHARED_SECRET_RESPONSE); MessageAttribute attr = MessageAttribute.create(MessageAttribute.MessageAttributeType.PASSWORD, passwd, 0); retHeader.addMessageAttribute(attr); } } } else { //Default response: A shared secret response with Username and //Password attributes. retHeader = new MessageHeader(MessageHeader.HeaderType.SHARED_SECRET_RESPONSE); MessageAttribute attr = MessageAttribute.create(MessageAttribute.MessageAttributeType.USERNAME, userHolder.username, 0); retHeader.addMessageAttribute(attr); attr = MessageAttribute.create(MessageAttribute.MessageAttributeType.PASSWORD, userHolder.password, 0); retHeader.addMessageAttribute(attr); } out.write(retHeader.toBytes()); } catch (IOException ex) { Logger.getLogger(SharedSecretService.class.getName()).log(Level.SEVERE, null, ex); } finally { if (in!=null) try {in.close();}catch(IOException ignore){} if (out!=null) try{out.close();}catch(IOException ignore){} try{sock.close();}catch(IOException ignore){} } } private void cleanUpUsers() { long now = System.currentTimeMillis(); synchronized (users) { ArrayList<UserHolder> usersToDelete = new ArrayList<UserHolder>(); for (UserHolder uh: users) { if ((now - uh.created)>600000) usersToDelete.add(uh); } for (UserHolder uh: usersToDelete) users.remove(uh); } } public void uncaughtException(Thread t, Throwable e) { System.err.println("Uncaught exception in thread: "+t.getName()+". The thread will die"); Logger.getLogger(SharedSecretService.class.getName()).log(Level.SEVERE, null, e); } }